In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import datasets, transforms,models
from torch.utils.data import DataLoader, Subset
from collections import OrderedDict
from copy import deepcopy
from matplotlib import pyplot as plt
from PIL import Image
import numpy as n
import torch.optim as optim
import random
from tqdm import tqdm
import seaborn as sns
import os

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# Function to filter dataset
def get_filtered_dataset(dataset, excluded_class):
    indices = [i for i, (_, label) in enumerate(dataset) if label != excluded_class]
    return Subset(dataset, indices)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:01<00:00, 87973466.91it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


In [3]:
from torch.utils.data import Dataset

class RemappedDataset(Dataset):
    def __init__(self, subset, excluded_class):
        self.subset = subset
        self.excluded_class = excluded_class
        self.label_map = self._create_label_map()

    def _create_label_map(self):
        """Create a mapping from original labels to new labels."""
        labels = [label for _, label in self.subset]
        unique_labels = sorted(set(labels) - {self.excluded_class})
        return {original: new for new, original in enumerate(unique_labels)}

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        data, label = self.subset[idx]
        return data, self.label_map[label]

# Update filtered dataset creation
def get_filtered_and_remapped_dataset(dataset, excluded_class):
    indices = [i for i, (_, label) in enumerate(dataset) if label != excluded_class]
    filtered_subset = Subset(dataset, indices)
    return RemappedDataset(filtered_subset, excluded_class)



In [12]:
def train_resnet(model, train_loader, epochs=10, device="cpu"):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss:.4f}")
    
    return model

# Extract and stack weight matrices

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ResNet18 on all classes
model_all = models.resnet18(num_classes=10)  # Adjust for CIFAR-10
train_loader_all = DataLoader(train_dataset, batch_size=64, shuffle=True)
trained_model_all = train_resnet(model_all, train_loader_all, device=device)

filtered_dataset = get_filtered_and_remapped_dataset(train_dataset, excluded_class=0)
train_loader_filtered = DataLoader(filtered_dataset, batch_size=64, shuffle=True)
model_filtered = models.resnet18(num_classes=9)  # Adjust for 9 classes
trained_model_filtered = train_resnet(model_filtered, train_loader_filtered, device=device)



Epoch 1/10, Loss: 1057.3900
Epoch 2/10, Loss: 755.3089
Epoch 3/10, Loss: 629.5726
Epoch 4/10, Loss: 532.3601
Epoch 5/10, Loss: 452.4486
Epoch 6/10, Loss: 384.9198
Epoch 7/10, Loss: 316.1419
Epoch 8/10, Loss: 256.9689
Epoch 9/10, Loss: 207.9186
Epoch 10/10, Loss: 166.8531
Epoch 1/10, Loss: 928.6663
Epoch 2/10, Loss: 678.3267
Epoch 3/10, Loss: 561.7547
Epoch 4/10, Loss: 481.9064
Epoch 5/10, Loss: 415.5102
Epoch 6/10, Loss: 347.0804
Epoch 7/10, Loss: 294.4699
Epoch 8/10, Loss: 244.5737
Epoch 9/10, Loss: 207.9106
Epoch 10/10, Loss: 156.4852


In [14]:
# Stack weight matrices
def get_activations(model):
    act = []
    for name, param in model.named_parameters():
        if 'weight' in name and len(param.shape)<=2:
            act.append(param.clone().detach().cpu().numpy())
        
    return act 

weights_all = get_activations(trained_model_all)
weights_filtered = get_activations(trained_model_filtered)

# Stack the weights from both models
#stacked_weights = [torch.stack([w_all, w_filtered]) for w_all, w_filtered in zip(weights_all, weights_filtered)]

# Print shapes of stacked weights
for i, weights in enumerate(weights_all):
    print(f"Layer {i + 1}, Stacked Weight Shape: {weights.shape}")

for i, weights in enumerate(weights_filtered):
    print(f"Layer {i + 1}, Stacked Weight Shape: {weights.shape}")

Layer 1, Stacked Weight Shape: (64,)
Layer 2, Stacked Weight Shape: (64,)
Layer 3, Stacked Weight Shape: (64,)
Layer 4, Stacked Weight Shape: (64,)
Layer 5, Stacked Weight Shape: (64,)
Layer 6, Stacked Weight Shape: (128,)
Layer 7, Stacked Weight Shape: (128,)
Layer 8, Stacked Weight Shape: (128,)
Layer 9, Stacked Weight Shape: (128,)
Layer 10, Stacked Weight Shape: (128,)
Layer 11, Stacked Weight Shape: (256,)
Layer 12, Stacked Weight Shape: (256,)
Layer 13, Stacked Weight Shape: (256,)
Layer 14, Stacked Weight Shape: (256,)
Layer 15, Stacked Weight Shape: (256,)
Layer 16, Stacked Weight Shape: (512,)
Layer 17, Stacked Weight Shape: (512,)
Layer 18, Stacked Weight Shape: (512,)
Layer 19, Stacked Weight Shape: (512,)
Layer 20, Stacked Weight Shape: (512,)
Layer 21, Stacked Weight Shape: (10, 512)
Layer 1, Stacked Weight Shape: (64,)
Layer 2, Stacked Weight Shape: (64,)
Layer 3, Stacked Weight Shape: (64,)
Layer 4, Stacked Weight Shape: (64,)
Layer 5, Stacked Weight Shape: (64,)
Layer 6

In [15]:
print(weights_all[20])
print(weights_all[20].shape)

[[ 0.03725792  0.05180766 -0.00685087 ...  0.06033065 -0.01167728
   0.05603654]
 [-0.01203199  0.08099366  0.07495472 ... -0.00507944 -0.04612529
   0.02951003]
 [-0.02401016 -0.07304101 -0.01087607 ... -0.10882819  0.06561099
   0.01976266]
 ...
 [-0.01675912 -0.10471803  0.00036213 ... -0.01323636 -0.01850379
  -0.11095271]
 [-0.06432104  0.07729784 -0.00498304 ...  0.03404438  0.00969019
  -0.0333106 ]
 [ 0.04058057  0.07919856  0.03500961 ...  0.04702904 -0.04107332
  -0.04466205]]
(10, 512)


In [16]:
print(weights_filtered[20])
weights_filtered[20].shape

[[-0.00465177  0.00755466 -0.07922242 ...  0.02758088 -0.00133221
   0.01430744]
 [-0.03925995  0.01276677 -0.16534159 ... -0.01652874 -0.04668239
  -0.01674438]
 [-0.02084461  0.01710148  0.07989341 ... -0.03581994 -0.05898105
  -0.00789738]
 ...
 [-0.0325629  -0.00528347  0.01940736 ... -0.00742166  0.01343936
  -0.01556363]
 [ 0.02454785 -0.0184033  -0.02665224 ...  0.04552277  0.03269222
   0.02953606]
 [-0.00875199  0.00861988 -0.00737541 ...  0.04830784  0.01780627
   0.01181417]]


(9, 512)

In [19]:

X = torch.tensor(weights_all[20])
X = X[1:]
Y = torch.tensor(weights_filtered[20])
max_diff = torch.max(torch.abs(X - Y))
print(max_diff)

tensor(0.3266)
