In [55]:
import torch
import torchvision
import torchvision.transforms as transforms

In [56]:
# Define transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load the full CIFAR-10 dataset
full_trainset = torchvision.datasets.CIFAR10(root='./data/', train=True, download=True, transform=transform)
full_testset = torchvision.datasets.CIFAR10(root='./data/', train=False, download=True, transform=transform)

# Filter dataset: Exclude class 9
filtered_train_data_no_9 = [(img, label) for img, label in full_trainset if label != 9]
filtered_train_images_no_9, filtered_train_labels_no_9 = zip(*filtered_train_data_no_9)

# Filter dataset: Only class 9
filtered_train_data_only_9 = [(img, label) for img, label in full_trainset if label == 9]
filtered_train_images_only_9, filtered_train_labels_only_9 = zip(*filtered_train_data_only_9)

# Create dataset without class 9
trainset_no_9 = torch.utils.data.TensorDataset(
    torch.stack(filtered_train_images_no_9), torch.tensor(filtered_train_labels_no_9)
)
trainloader_no_9 = torch.utils.data.DataLoader(trainset_no_9, batch_size=4, shuffle=True)

# Create dataset with only class 9
trainset_only_9 = torch.utils.data.TensorDataset(
    torch.stack(filtered_train_images_only_9), torch.tensor(filtered_train_labels_only_9)
)
trainloader_only_9 = torch.utils.data.DataLoader(trainset_only_9, batch_size=4, shuffle=True)

# Filter testset: Exclude class 9
filtered_test_data_no_9 = [(img, label) for img, label in full_testset if label != 9]
filtered_test_images_no_9, filtered_test_labels_no_9 = zip(*filtered_test_data_no_9)

# Filter dataset: Only class 9
filtered_test_data_only_9 = [(img, label) for img, label in full_testset if label == 9]
filtered_test_images_only_9, filtered_test_labels_only_9 = zip(*filtered_test_data_only_9)

# Create test dataset without class 9
testset_no_9 = torch.utils.data.TensorDataset(
    torch.stack(filtered_test_images_no_9), torch.tensor(filtered_test_labels_no_9)
)
testloader_no_9 = torch.utils.data.DataLoader(testset_no_9, batch_size=4, shuffle=True)

# Create test dataset with only class 9
testset_only_9 = torch.utils.data.TensorDataset(
    torch.stack(filtered_test_images_only_9), torch.tensor(filtered_test_labels_only_9)
)
testloader_only_9 = torch.utils.data.DataLoader(testset_only_9, batch_size=4, shuffle=True)


# Normal trainloader (all classes)
trainloader_full = torch.utils.data.DataLoader(full_trainset, batch_size=4, shuffle=True)

# Normal testloader (unchanged)
testloader = torch.utils.data.DataLoader(full_testset, batch_size=4, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')

# Print dataset sizes
print(f"Original training set size: {len(full_trainset)}")
print(f"Training set size without class 9: {len(trainset_no_9)}")
print(f"Training set size with only class 9: {len(trainset_only_9)}")


Files already downloaded and verified
Files already downloaded and verified
Original training set size: 50000
Training set size without class 9: 45000
Training set size with only class 9: 5000


In [57]:
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 84)
        #self.fc4 = nn.Linear(84, 84)
        self.fc5 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        #x = F.relu(self.fc4(x))
        x = self.fc5(x)
        
        return x
    
net = Net()

In [64]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [65]:
import os

# Define paths
model_directory_path = 'model/'
model_path = os.path.join(model_directory_path, 'cifar-mnist-cnn-model.pt')

# Ensure directory exists
if not os.path.exists(model_directory_path):
    os.makedirs(model_directory_path)

def load_model(net):
    """Loads the trained model if it exists."""
    if os.path.isfile(model_path):
        net.load_state_dict(torch.load(model_path))
        print('Loaded model parameters from disk.')
        return True  # Indicate successful loading
    else:
        print('No saved model found.')
        return False

def train_model(net, trainloader, criterion, optimizer, epochs=2):
    """Trains the model from scratch."""
    print("Starting training...")
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 2000 == 1999:  # Print every 200 mini-batches
                print('[%d, %5d] loss: %.3f' %
                    (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0
                
    print('Finished Training.')

In [66]:
def layer_training(net, trainloader, criterion, optimizer, epochs=2):

    # Freeze all layers initially
    for param in net.parameters():
        param.requires_grad = False
    
    #Unfreeze single layer
    for param in net.fc1.parameters():
        param.requires_grad = True

    # Start training
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if trainloader == trainloader_only_9:
                if i % 200 == 199:  # Print every 200 mini-batches
                    print('[%d, %5d] loss: %.3f' %
                          (epoch + 1, i + 1, running_loss / 2000))
                    running_loss = 0.0
            elif trainloader == trainloader_full:
                if i % 2000 == 1999:  # Print every 2000 mini-batches
                    print('[%d, %5d] loss: %.3f' %
                          (epoch + 1, i + 1, running_loss / 2000))
                    running_loss = 0.0
    print('Finished Training.')

In [67]:
train_model(net, trainloader_no_9, criterion, optimizer, 3)

Starting training...
[1,  2000] loss: 2.158
[1,  4000] loss: 1.847
[1,  6000] loss: 1.648
[1,  8000] loss: 1.574
[1, 10000] loss: 1.482
[2,  2000] loss: 1.400
[2,  4000] loss: 1.362
[2,  6000] loss: 1.342
[2,  8000] loss: 1.312
[2, 10000] loss: 1.293
[3,  2000] loss: 1.202
[3,  4000] loss: 1.213
[3,  6000] loss: 1.206
[3,  8000] loss: 1.195
[3, 10000] loss: 1.173
Finished Training.


In [52]:
layer_training(net, trainloader_only_9, criterion, optimizer, 1)

[1,   200] loss: 0.162
[1,   400] loss: 0.152
[1,   600] loss: 0.157
[1,   800] loss: 0.152
[1,  1000] loss: 0.145
[1,  1200] loss: 0.141
Finished Training.


In [50]:
layer_training(net, trainloader_full, criterion, optimizer, 3)

[1,  2000] loss: 1.172
[1,  4000] loss: 1.161
[1,  6000] loss: 1.181
[1,  8000] loss: 1.177
[1, 10000] loss: 1.181
[1, 12000] loss: 1.170
[2,  2000] loss: 1.169
[2,  4000] loss: 1.180
[2,  6000] loss: 1.174
[2,  8000] loss: 1.162
[2, 10000] loss: 1.181
[2, 12000] loss: 1.169
[3,  2000] loss: 1.177
[3,  4000] loss: 1.182
[3,  6000] loss: 1.161
[3,  8000] loss: 1.165
[3, 10000] loss: 1.186
[3, 12000] loss: 1.171
Finished Training.


In [68]:
evaluate_model_full(net, testloader)
evaluate_model_only_9(net, testloader_only_9)
evaluate_model_no_9(net, testloader_no_9)

Model accuracy on 10000 test images: 52.84%
Model accuracy on test set with only class 9: 0.00%
Model accuracy on test set without class 9: 58.71%


In [21]:
import numpy as np
def evaluate_model_only_9(model, testloader_only_9):
    model.eval()
    total_correct = 0
    total_images = 0
    confusion_matrix = np.zeros([10, 10], int)
    with torch.no_grad():
        for data in testloader_only_9:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total_images += labels.size(0)
            total_correct += (predicted == labels).sum().item()
            for i, l in enumerate(labels):
                confusion_matrix[l.item(), predicted[i].item()] += 1 

    only_9_accuracy = total_correct / total_images * 100
    print(f'Model accuracy on test set with only class 9: {only_9_accuracy:.2f}%')
    
def evaluate_model_no_9(model, testloader_no_9):
    model.eval() 
    total_correct = 0
    total_images = 0
    confusion_matrix = np.zeros([10, 10], int)
    with torch.no_grad():
        for data in testloader_no_9:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total_images += labels.size(0)
            total_correct += (predicted == labels).sum().item()
            for i, l in enumerate(labels):
                confusion_matrix[l.item(), predicted[i].item()] += 1 

    no_9_accuracy = total_correct / total_images * 100
    print(f'Model accuracy on test set without class 9: {no_9_accuracy:.2f}%')
    
def evaluate_model_full(model, testloader):
    model.eval() 
    total_correct = 0
    total_images = 0
    confusion_matrix = np.zeros([10, 10], int)

    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total_images += labels.size(0)
            total_correct += (predicted == labels).sum().item()
            for i, l in enumerate(labels):
                confusion_matrix[l.item(), predicted[i].item()] += 1 

    model_accuracy = total_correct / total_images * 100
    print(f'Model accuracy on {total_images} test images: {model_accuracy:.2f}%')

In [69]:
model = net
torch.save(model, 'model_unlucky_9_2e_full.pth')