In [155]:
#Prerequisitics for loading data, run once

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
from PIL import Image
import random

In [156]:
# Define transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load the full CIFAR-10 dataset
full_trainset = torchvision.datasets.CIFAR10(root='./data/', train=True, download=True, transform=transform)
full_testset = torchvision.datasets.CIFAR10(root='./data/', train=False, download=True, transform=transform)

# Filter dataset: Exclude class 9
filtered_train_data_no_9 = [(img, label) for img, label in full_trainset if label != 9]
filtered_train_images_no_9, filtered_train_labels_no_9 = zip(*filtered_train_data_no_9)

# Filter dataset: Only class 9
filtered_train_data_only_9 = [(img, label) for img, label in full_trainset if label == 9]
filtered_train_images_only_9, filtered_train_labels_only_9 = zip(*filtered_train_data_only_9)

# Create dataset without class 9
trainset_no_9 = torch.utils.data.TensorDataset(
    torch.stack(filtered_train_images_no_9), torch.tensor(filtered_train_labels_no_9)
)
trainloader_no_9 = torch.utils.data.DataLoader(trainset_no_9, batch_size=4, shuffle=True)

# Create dataset with only class 9
trainset_only_9 = torch.utils.data.TensorDataset(
    torch.stack(filtered_train_images_only_9), torch.tensor(filtered_train_labels_only_9)
)
trainloader_only_9 = torch.utils.data.DataLoader(trainset_only_9, batch_size=4, shuffle=True)

# Filter testset: Exclude class 9
filtered_test_data_no_9 = [(img, label) for img, label in full_testset if label != 9]
filtered_test_images_no_9, filtered_test_labels_no_9 = zip(*filtered_test_data_no_9)

# Filter dataset: Only class 9
filtered_test_data_only_9 = [(img, label) for img, label in full_testset if label == 9]
filtered_test_images_only_9, filtered_test_labels_only_9 = zip(*filtered_test_data_only_9)

# Create test dataset without class 9
testset_no_9 = torch.utils.data.TensorDataset(
    torch.stack(filtered_test_images_no_9), torch.tensor(filtered_test_labels_no_9)
)
testloader_no_9 = torch.utils.data.DataLoader(testset_no_9, batch_size=4, shuffle=True)

# Create dataset with only class 9
testset_only_9 = torch.utils.data.TensorDataset(
    torch.stack(filtered_test_images_only_9), torch.tensor(filtered_test_labels_only_9)
)
testloader_only_9 = torch.utils.data.DataLoader(testset_only_9, batch_size=4, shuffle=True)


# Normal trainloader (all classes)
trainloader_full = torch.utils.data.DataLoader(full_trainset, batch_size=4, shuffle=True)

# Normal testloader (unchanged)
testloader = torch.utils.data.DataLoader(full_testset, batch_size=4, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')

# Print dataset sizes
print(f"Original training set size: {len(full_trainset)}")
print(f"Training set size without class 9: {len(trainset_no_9)}")
print(f"Training set size with only class 9: {len(trainset_only_9)}")


Files already downloaded and verified
Files already downloaded and verified
Original training set size: 50000
Training set size without class 9: 45000
Training set size with only class 9: 5000


In [157]:
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 84)
        #self.fc4 = nn.Linear(84, 84)
        self.fc5 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x), inplace=True)
        x = F.relu(self.fc2(x), inplace=True)
        x = F.relu(self.fc3(x), inplace=True)
        #x = F.relu(self.fc4(x))
        x = self.fc5(x)
        
        return x
    
net = Net()

In [158]:
#This loads the two trained model. 
# Model 1d has FC1 as the special layer
net1 = torch.load('model_unlucky_9_1d_full.pth')
net2c = torch.load('model_unlucky_9_2c_full.pth')
net2e = torch.load('model_unlucky_9_2e_full.pth')
net2f = torch.load('model2e_full.pth')

  net1 = torch.load('model_unlucky_9_1d_full.pth')
  net2c = torch.load('model_unlucky_9_2c_full.pth')
  net2e = torch.load('model_unlucky_9_2e_full.pth')
  net2f = torch.load('model2e_full.pth')


In [161]:
#Choose which model to work with
net1 = net2e
net2 = net2f

In [160]:
#This evaluates the model performance based on whatever data is loaded
def calculate_class_accuracies(model, testloader, classes):
    # Set model to evaluation mode
    model.eval()
    
    # Initialize variables to track correct predictions for each class
    class_correct = [0] * len(classes)
    class_total = [0] * len(classes)

    # Iterate through the test dataset
    with torch.no_grad():  # Don't compute gradients during evaluation
        for data in testloader:
            images, labels = data  # Get images and corresponding labels
            outputs = model(images)  # Get model outputs
            _, predicted = torch.max(outputs, 1)  # Get predicted labels

            # Update correct predictions for each class
            for i in range(len(labels)):
                label = labels[i].item()
                class_total[label] += 1
                if predicted[i].item() == label:
                    class_correct[label] += 1

    # Print accuracy for each class
    for i in range(len(classes)):
        accuracy = 100 * class_correct[i] / class_total[i] if class_total[i] > 0 else 0
        print(f'Accuracy for {classes[i]}: {accuracy:.2f}%')

# Example usage: Assuming testloader and class names are defined
# The 'classes' should be a list of class names for your dataset (e.g., CIFAR-10 classes)

# Define class names (CIFAR-10 example)
classes = ['airplane 0', 'automobile 1', 'bird 2', 'cat 3', 'deer 4', 'dog 5', 
                 'frog 6', 'horse 7', 'ship 8', 'truck 9']

# Call the function to calculate and print class-wise accuracies
calculate_class_accuracies(net2f, testloader, classes)

Accuracy for airplane 0: 71.30%
Accuracy for automobile 1: 70.50%
Accuracy for bird 2: 46.20%
Accuracy for cat 3: 33.60%
Accuracy for deer 4: 57.80%
Accuracy for dog 5: 50.90%
Accuracy for frog 6: 80.30%
Accuracy for horse 7: 64.80%
Accuracy for ship 8: 78.90%
Accuracy for truck 9: 74.80%


In [227]:
#Stitching network
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 84)
        self.stitch = nn.Linear(84, 84)
        #self.fc4 = nn.Linear(84, 84)
        self.fc5 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.stitch(x)
        #x = F.relu(self.fc4(x))
        x = self.fc5(x)
        
        return x
    
net3 = Net()

In [228]:
# Get the state_dict for each model
state_dict_1 = net1.state_dict()
state_dict_2 = net2.state_dict()
state_dict_3 = net3.state_dict()

# Iterate over the state_dict keys and copy weights
keys = list(state_dict_3.keys())
for i, key in enumerate(keys):
    if i < 10:
        # First half
        state_dict_3[key] = state_dict_2[key]
    elif i > 11:
        # Second half
        state_dict_3[key] = state_dict_1[key]

# Load the modified state_dict into model_C
net3.load_state_dict(state_dict_3)
net = net3

In [229]:
# Set model to evaluation mode
net3.eval()

# Example: Test with some data
dataiter = iter(testloader)
images, labels = next(dataiter)

# Perform inference
outputs = net3(images)
_, predicted = torch.max(outputs, 1)

print('Predicted:', predicted)

Predicted: tensor([4, 2, 4, 4])


In [230]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)

In [231]:
    epochs = 2
    
    # Freeze all layers initially
    for param in net.parameters():
        param.requires_grad = False
    
    #Unfreeze stitching layer
    for param in net.stitch.parameters():
        param.requires_grad = True
    
    for epoch in range(epochs):  # loop over the dataset multiple times

        running_loss = 0.0
        for i, data in enumerate(trainloader_full, 0):
            # get the inputs
            inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:    # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0
    print('Finished Training.')
    
    evaluate_model_full(net3, testloader)
    evaluate_model_only_9(net3, testloader_only_9)
    evaluate_model_no_9(net3, testloader_no_9)

[1,  2000] loss: 1.322
[1,  4000] loss: 0.989
[1,  6000] loss: 0.933
[1,  8000] loss: 0.916
[1, 10000] loss: 0.874
[1, 12000] loss: 0.872
[2,  2000] loss: 0.867
[2,  4000] loss: 0.882
[2,  6000] loss: 0.880
[2,  8000] loss: 0.853
[2, 10000] loss: 0.864
[2, 12000] loss: 0.861
Finished Training.
Model accuracy on 10000 test images: 63.68%
Model accuracy on test set with only class 9: 66.70%
Model accuracy on test set without class 9: 63.34%


In [71]:
import numpy as np
import torch
import gzip
import pickle

def compress_model_numpy(model):
    weights = np.concatenate([p.cpu().detach().numpy().flatten() for p in model.parameters()])
    model_bytes = pickle.dumps(weights)
    compressed = gzip.compress(model_bytes)
    return len(compressed) * 8  # Size in bits

# Test with both models
complexity_1 = compress_model_numpy(net1)
complexity_2 = compress_model_numpy(net2)
complexity_3 = compress_model_numpy(net3)

print(f"Model 1 compressed size: {complexity_1} bits")
print(f"Model 2 compressed size: {complexity_2} bits")
print(f"Model 3 compressed size: {complexity_3} bits")


Model 1 compressed size: 2052000 bits
Model 2 compressed size: 2056048 bits
Model 3 compressed size: 2267600 bits


In [82]:
def compute_gradient_norm(model, loss_fn, testloader):
    """Computes the average gradient norm over the test set."""
    model.eval()
    total_norm = 0.0
    num_batches = 0

    for data, target in testloader:
        data, target = data.to(next(model.parameters()).device), target.to(next(model.parameters()).device)
        
        model.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        
        batch_norm = torch.sqrt(sum(p.grad.norm()**2 for p in model.parameters() if p.grad is not None))
        total_norm += batch_norm.item()
        num_batches += 1

    return total_norm / num_batches if num_batches > 0 else 0.0

In [90]:
compute_gradient_norm(net1, criterion, testloader)

4.438407219076157

In [91]:
compute_gradient_norm(net2, criterion, testloader)

10.127162860943633

In [92]:
compute_gradient_norm(net3, criterion, testloader)

2.589355911933258

In [19]:
def evaluate_model_full(model, testloader):
    model.eval() 
    total_correct = 0
    total_images = 0
    confusion_matrix = np.zeros([10, 10], int)

    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total_images += labels.size(0)
            total_correct += (predicted == labels).sum().item()
            for i, l in enumerate(labels):
                confusion_matrix[l.item(), predicted[i].item()] += 1 

    model_accuracy = total_correct / total_images * 100
    print(f'Model accuracy on {total_images} test images: {model_accuracy:.2f}%')

In [20]:
def evaluate_model_no_9(model, testloader_no_9):
    model.eval() 
    total_correct = 0
    total_images = 0
    confusion_matrix = np.zeros([10, 10], int)
    with torch.no_grad():
        for data in testloader_no_9:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total_images += labels.size(0)
            total_correct += (predicted == labels).sum().item()
            for i, l in enumerate(labels):
                confusion_matrix[l.item(), predicted[i].item()] += 1 

    no_9_accuracy = total_correct / total_images * 100
    print(f'Model accuracy on test set without class 9: {no_9_accuracy:.2f}%')

In [21]:
def evaluate_model_only_9(model, testloader_only_9):
    model.eval()
    total_correct = 0
    total_images = 0
    confusion_matrix = np.zeros([10, 10], int)
    with torch.no_grad():
        for data in testloader_only_9:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total_images += labels.size(0)
            total_correct += (predicted == labels).sum().item()
            for i, l in enumerate(labels):
                confusion_matrix[l.item(), predicted[i].item()] += 1 

    only_9_accuracy = total_correct / total_images * 100
    print(f'Model accuracy on test set with only class 9: {only_9_accuracy:.2f}%')