In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define the CNN architecture
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        pass


    def forward(self, x):
        return x

In [None]:
# Dataset and DataLoader setup
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

In [None]:
# Load the CIFAR10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True,
                                 download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False,
                                download=True, transform=transform)

In [None]:
import matplotlib.pyplot as plt
plt.imshow(train_dataset.data[5])

In [None]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=2)


In [None]:
# Initialize the network
model = SimpleCNN()

In [None]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
epochs = 100
model.train()
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(train_loader):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:    # print every 100 mini-batches
            print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}")
            running_loss = 0.0



In [None]:
# model.load_state_dict(torch.load('simple_cnn_cpu.pth'))

In [None]:
import matplotlib.pyplot as plt

# Assuming model and test_loader are already defined and the model is loaded with trained parameters

# Function to evaluate the model
def evaluate_model(model, device, test_loader):
    model.eval()  # Set the model to evaluation mode
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():  # No need to track gradients
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            test_loss += nn.CrossEntropyLoss()(outputs, target).item()  # Sum up batch loss
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / total
    return test_loss, accuracy

In [None]:
# Evaluate the model
test_loss, accuracy = evaluate_model(model, device, test_loader)

print(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')

In [None]:
import seaborn as sns
from sklearn.metrics import confusion_matrix
import numpy as np

# Assuming model, device, and test_loader are already defined and the model is loaded with trained parameters

def get_all_predictions(model, device, loader):
    all_preds = []
    all_labels = []
    model.eval()
    with torch.no_grad():
        for data, targets in loader:
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(targets.cpu().numpy())
    return all_preds, all_labels

# Get all predictions and labels
predictions, labels = get_all_predictions(model, device, test_loader)

# Generate confusion matrix
cm = confusion_matrix(labels, predictions)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]  # Normalizing the confusion matrix

# Plotting
plt.figure(figsize=(10, 8))
sns.heatmap(cm_normalized, annot=True, fmt=".2f", cmap="Blues")
plt.title("Normalized Confusion Matrix")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.show()

