In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchsummary import summary
from model import Net
from dataset import DatasetLoader
import multiprocessing
from func_train_test import train, test
import matplotlib.pyplot as plt
import numpy as np

def show_sample_images(dataloader, classes, num_images=20):
    """
    Show sample images from the training dataset
    """
    # Get a batch of training images
    images, labels = next(iter(dataloader))
    
    # Create figure
    fig = plt.figure(figsize=(20, 10))
    
    # Show images
    for i in range(min(num_images, len(images))):
        ax = fig.add_subplot(4, 5, i + 1)
        
        # Convert tensor to numpy and transpose to correct format
        img = images[i].numpy().transpose(1, 2, 0)
        
        # Denormalize the image
        mean = np.array([0.5, 0.5, 0.5])
        std = np.array([0.5, 0.5, 0.5])
        img = std * img + mean
        img = np.clip(img, 0, 1)
        
        ax.imshow(img)
        ax.set_title(f'Class: {classes[labels[i]]}')
        ax.axis('off')
    
    plt.tight_layout()
    plt.savefig('sample_training_images.png')
    plt.close()
    
    # Print class distribution in this batch
    unique_labels, counts = np.unique(labels.numpy(), return_counts=True)
    print("\nClass distribution in sample batch:")
    for label, count in zip(unique_labels, counts):
        print(f"{classes[label]}: {count} images")

def show_misclassified(images, predictions, targets, classes):
    """
    Show misclassified images in a grid
    """
    fig = plt.figure(figsize=(20, 10))
    for i in range(min(20, len(images))):  # Show up to 20 images
        ax = fig.add_subplot(4, 5, i + 1)
        # Convert tensor to numpy and transpose to correct format
        img = images[i].cpu().numpy().transpose(1, 2, 0)
        # Denormalize the image
        mean = np.array([0.5, 0.5, 0.5])
        std = np.array([0.5, 0.5, 0.5])
        img = std * img + mean
        img = np.clip(img, 0, 1)
        
        ax.imshow(img)
        ax.set_title(f'Pred: {classes[predictions[i]]}\nTrue: {classes[targets[i]]}')
        ax.axis('off')
    plt.tight_layout()
    plt.savefig(f'misclassified_examples.png')
    plt.close()

def main():
    # Debug flag
    DEBUG = False  # Set to False to disable misclassified images visualization
    
    # Training settings
    EPOCHS = 50
    BATCH_SIZE = 128
    
    # CIFAR10 classes
    classes = ('plane', 'car', 'bird', 'cat', 'deer',
              'dog', 'frog', 'horse', 'ship', 'truck')
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Data loaders
    get_data = DatasetLoader(batch_size=BATCH_SIZE)
    train_loader = get_data.train_loader()
    test_loader = get_data.test_loader()
    
    # Show sample training images
    print("\nDisplaying sample training images...")
    show_sample_images(train_loader, classes)
    
    # Model
    model = Net().to(device)
    summary(model, input_size=(3, 32, 32))
    
    # Optimizer and criterion
    optimizer = optim.SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)
    criterion = nn.CrossEntropyLoss()
    
    # Training history
    train_losses = []
    test_losses = []
    train_accs = []
    test_accs = []
    
    for epoch in range(EPOCHS):
        print(f"\nEpoch: {epoch}")
        
        # Train
        train_loss, train_acc = train(model, device, train_loader, optimizer, criterion, epoch)
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        
        # Test and collect misclassified examples
        test_loss, test_acc, misclassified = test(model, device, test_loader, criterion, debug=DEBUG)
        test_losses.append(test_loss)
        test_accs.append(test_acc)
        
        # If in debug mode and we have misclassified examples, show them
        if DEBUG and misclassified and epoch % 5 == 0:  # Show every 5 epochs
            images, predictions, targets = misclassified
            show_misclassified(images, predictions, targets, classes)
        
        scheduler.step()
    
    # Plot training history
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(test_losses, label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train Accuracy')
    plt.plot(test_accs, label='Test Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.close()
 
multiprocessing.freeze_support()
main()

Using device: cuda


  A.CoarseDropout(max_holes=1, max_height=8, max_width=8, min_height=4, min_width=4, fill_value=0, mask_fill_value=None, p=0.5),
  A.CoarseDropout(max_holes=1, max_height=8, max_width=8, min_height=4, min_width=4, fill_value=0, mask_fill_value=None, p=0.5),


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:04<00:00, 36726588.58it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified

Displaying sample training images...

Class distribution in sample batch:
plane: 12 images
car: 15 images
bird: 11 images
cat: 12 images
deer: 11 images
dog: 19 images
frog: 15 images
horse: 10 images
ship: 14 images
truck: 9 images
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 32, 32]             864
       BatchNorm2d-2           [-1, 32, 32, 32]              64
              ReLU-3           [-1, 32, 32, 32]               0
            Conv2d-4           [-1, 64, 16, 16]           2,048
       BatchNorm2d-5           [-1, 64, 16, 16]             128
            Conv2d-6           [-1, 64, 32, 32]           2,048
       BatchNorm2d-7           [-1, 64, 32, 32]             128
              ReLU-8           [-1, 64, 32, 32]               0
            Conv2d-9           [-1, 

Epoch 0 Train: Loss=1.5663 Acc=42.25%: 100%|██████████| 391/391 [00:04<00:00, 90.19it/s]



Test set: Average loss: 1.3152, Accuracy: 52.24%


Epoch: 1


Epoch 1 Train: Loss=1.2284 Acc=55.40%: 100%|██████████| 391/391 [00:04<00:00, 89.70it/s]



Test set: Average loss: 1.0834, Accuracy: 61.01%


Epoch: 2


Epoch 2 Train: Loss=1.0728 Acc=61.49%: 100%|██████████| 391/391 [00:04<00:00, 82.48it/s]



Test set: Average loss: 1.0382, Accuracy: 63.07%


Epoch: 3


Epoch 3 Train: Loss=0.9613 Acc=65.40%: 100%|██████████| 391/391 [00:04<00:00, 83.07it/s]



Test set: Average loss: 0.9166, Accuracy: 66.67%


Epoch: 4


Epoch 4 Train: Loss=0.8879 Acc=68.39%: 100%|██████████| 391/391 [00:05<00:00, 77.12it/s]



Test set: Average loss: 0.8258, Accuracy: 70.53%


Epoch: 5


Epoch 5 Train: Loss=0.8261 Acc=70.54%: 100%|██████████| 391/391 [00:04<00:00, 90.20it/s]



Test set: Average loss: 0.8479, Accuracy: 70.55%


Epoch: 6


Epoch 6 Train: Loss=0.7826 Acc=72.56%: 100%|██████████| 391/391 [00:04<00:00, 91.20it/s]



Test set: Average loss: 0.7877, Accuracy: 72.36%


Epoch: 7


Epoch 7 Train: Loss=0.7421 Acc=73.81%: 100%|██████████| 391/391 [00:04<00:00, 89.66it/s]



Test set: Average loss: 0.7334, Accuracy: 74.53%


Epoch: 8


Epoch 8 Train: Loss=0.7092 Acc=74.96%: 100%|██████████| 391/391 [00:04<00:00, 91.38it/s]



Test set: Average loss: 0.6616, Accuracy: 77.08%


Epoch: 9


Epoch 9 Train: Loss=0.6719 Acc=76.51%: 100%|██████████| 391/391 [00:04<00:00, 91.84it/s]



Test set: Average loss: 0.6781, Accuracy: 76.52%


Epoch: 10


Epoch 10 Train: Loss=0.6491 Acc=77.06%: 100%|██████████| 391/391 [00:04<00:00, 92.37it/s]



Test set: Average loss: 0.6426, Accuracy: 78.47%


Epoch: 11


Epoch 11 Train: Loss=0.6277 Acc=78.10%: 100%|██████████| 391/391 [00:04<00:00, 84.60it/s]



Test set: Average loss: 0.5995, Accuracy: 79.54%


Epoch: 12


Epoch 12 Train: Loss=0.6102 Acc=78.55%: 100%|██████████| 391/391 [00:04<00:00, 81.36it/s]



Test set: Average loss: 0.6102, Accuracy: 78.85%


Epoch: 13


Epoch 13 Train: Loss=0.5912 Acc=79.46%: 100%|██████████| 391/391 [00:04<00:00, 89.21it/s]



Test set: Average loss: 0.5864, Accuracy: 79.75%


Epoch: 14


Epoch 14 Train: Loss=0.5728 Acc=80.07%: 100%|██████████| 391/391 [00:04<00:00, 80.23it/s]



Test set: Average loss: 0.6093, Accuracy: 79.09%


Epoch: 15


Epoch 15 Train: Loss=0.4927 Acc=83.01%: 100%|██████████| 391/391 [00:04<00:00, 83.75it/s]



Test set: Average loss: 0.4911, Accuracy: 83.17%


Epoch: 16


Epoch 16 Train: Loss=0.4710 Acc=83.74%: 100%|██████████| 391/391 [00:04<00:00, 90.21it/s]



Test set: Average loss: 0.4881, Accuracy: 83.20%


Epoch: 17


Epoch 17 Train: Loss=0.4575 Acc=84.16%: 100%|██████████| 391/391 [00:04<00:00, 80.87it/s]



Test set: Average loss: 0.4753, Accuracy: 83.60%


Epoch: 18


Epoch 18 Train: Loss=0.4545 Acc=84.25%: 100%|██████████| 391/391 [00:04<00:00, 84.14it/s]



Test set: Average loss: 0.4728, Accuracy: 83.85%


Epoch: 19


Epoch 19 Train: Loss=0.4495 Acc=84.28%: 100%|██████████| 391/391 [00:04<00:00, 83.40it/s]



Test set: Average loss: 0.4678, Accuracy: 83.93%


Epoch: 20


Epoch 20 Train: Loss=0.4456 Acc=84.55%: 100%|██████████| 391/391 [00:04<00:00, 87.21it/s]



Test set: Average loss: 0.4699, Accuracy: 83.98%


Epoch: 21


Epoch 21 Train: Loss=0.4412 Acc=84.69%: 100%|██████████| 391/391 [00:04<00:00, 84.07it/s]



Test set: Average loss: 0.4640, Accuracy: 84.00%


Epoch: 22


Epoch 22 Train: Loss=0.4355 Acc=84.96%: 100%|██████████| 391/391 [00:04<00:00, 82.52it/s]



Test set: Average loss: 0.4651, Accuracy: 84.06%


Epoch: 23


Epoch 23 Train: Loss=0.4296 Acc=85.06%: 100%|██████████| 391/391 [00:04<00:00, 85.51it/s]



Test set: Average loss: 0.4683, Accuracy: 84.00%


Epoch: 24


Epoch 24 Train: Loss=0.4252 Acc=85.32%: 100%|██████████| 391/391 [00:04<00:00, 83.18it/s]



Test set: Average loss: 0.4626, Accuracy: 84.18%


Epoch: 25


Epoch 25 Train: Loss=0.4235 Acc=85.21%: 100%|██████████| 391/391 [00:04<00:00, 82.30it/s]



Test set: Average loss: 0.4614, Accuracy: 84.18%


Epoch: 26


Epoch 26 Train: Loss=0.4228 Acc=85.30%: 100%|██████████| 391/391 [00:04<00:00, 88.36it/s]



Test set: Average loss: 0.4649, Accuracy: 84.41%


Epoch: 27


Epoch 27 Train: Loss=0.4162 Acc=85.61%: 100%|██████████| 391/391 [00:04<00:00, 78.88it/s]



Test set: Average loss: 0.4618, Accuracy: 84.70%


Epoch: 28


Epoch 28 Train: Loss=0.4177 Acc=85.44%: 100%|██████████| 391/391 [00:04<00:00, 90.89it/s]



Test set: Average loss: 0.4574, Accuracy: 84.19%


Epoch: 29


Epoch 29 Train: Loss=0.4152 Acc=85.57%: 100%|██████████| 391/391 [00:04<00:00, 87.07it/s]



Test set: Average loss: 0.4589, Accuracy: 84.54%


Epoch: 30


Epoch 30 Train: Loss=0.3987 Acc=86.04%: 100%|██████████| 391/391 [00:04<00:00, 89.90it/s]



Test set: Average loss: 0.4521, Accuracy: 84.74%


Epoch: 31


Epoch 31 Train: Loss=0.3988 Acc=86.20%: 100%|██████████| 391/391 [00:04<00:00, 88.07it/s]



Test set: Average loss: 0.4522, Accuracy: 84.68%


Epoch: 32


Epoch 32 Train: Loss=0.3955 Acc=86.28%: 100%|██████████| 391/391 [00:04<00:00, 85.77it/s]



Test set: Average loss: 0.4502, Accuracy: 84.92%


Epoch: 33


Epoch 33 Train: Loss=0.3951 Acc=86.15%: 100%|██████████| 391/391 [00:04<00:00, 86.86it/s]



Test set: Average loss: 0.4501, Accuracy: 85.03%


Epoch: 34


Epoch 34 Train: Loss=0.3977 Acc=86.15%: 100%|██████████| 391/391 [00:04<00:00, 89.19it/s]



Test set: Average loss: 0.4481, Accuracy: 84.86%


Epoch: 35


Epoch 35 Train: Loss=0.3903 Acc=86.36%: 100%|██████████| 391/391 [00:04<00:00, 84.93it/s]



Test set: Average loss: 0.4503, Accuracy: 84.80%


Epoch: 36


Epoch 36 Train: Loss=0.3937 Acc=86.26%: 100%|██████████| 391/391 [00:04<00:00, 85.82it/s]



Test set: Average loss: 0.4508, Accuracy: 84.88%


Epoch: 37


Epoch 37 Train: Loss=0.3961 Acc=86.29%: 100%|██████████| 391/391 [00:04<00:00, 88.01it/s]



Test set: Average loss: 0.4526, Accuracy: 84.83%


Epoch: 38


Epoch 38 Train: Loss=0.3947 Acc=86.26%:  82%|████████▏ | 319/391 [00:03<00:00, 88.93it/s]