<a href="https://www.kaggle.com/code/shreyasudaya/math-unlearning?scriptVersionId=212908147" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import datasets, transforms,models
from torch.utils.data import DataLoader, Subset
from collections import OrderedDict
from copy import deepcopy
from matplotlib import pyplot as plt
from PIL import Image
import numpy as n
import torch.optim as optim
import random
from tqdm import tqdm
import seaborn as sns
import os

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# Function to filter dataset
def get_filtered_dataset(dataset, excluded_class):
    indices = [i for i, (_, label) in enumerate(dataset) if label != excluded_class]
    return Subset(dataset, indices)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:01<00:00, 87550728.45it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


In [3]:
from torch.utils.data import Dataset

class RemappedDataset(Dataset):
    def __init__(self, subset, excluded_class):
        self.subset = subset
        self.excluded_class = excluded_class
        self.label_map = self._create_label_map()

    def _create_label_map(self):
        """Create a mapping from original labels to new labels."""
        labels = [label for _, label in self.subset]
        unique_labels = sorted(set(labels) - {self.excluded_class})
        return {original: new for new, original in enumerate(unique_labels)}

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        data, label = self.subset[idx]
        return data, self.label_map[label]

# Update filtered dataset creation
def get_filtered_and_remapped_dataset(dataset, excluded_class):
    indices = [i for i, (_, label) in enumerate(dataset) if label != excluded_class]
    filtered_subset = Subset(dataset, indices)
    return RemappedDataset(filtered_subset, excluded_class)



In [4]:
def train_resnet(model, train_loader, epochs=100, device="cuda"):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss:.4f}")
    
    return model

# Extract and stack weight matrices

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ResNet18 on all classes
model_all = models.resnet18(num_classes=10)  # Adjust for CIFAR-10
train_loader_all = DataLoader(train_dataset, batch_size=64, shuffle=True)
trained_model_all = train_resnet(model_all, train_loader_all, device=device)

filtered_dataset = get_filtered_and_remapped_dataset(train_dataset, excluded_class=0)
train_loader_filtered = DataLoader(filtered_dataset, batch_size=64, shuffle=True)
model_filtered = models.resnet18(num_classes=9)  # Adjust for 9 classes
trained_model_filtered = train_resnet(model_filtered, train_loader_filtered, device=device)



Epoch 1/100, Loss: 1054.9098
Epoch 2/100, Loss: 748.8985
Epoch 3/100, Loss: 611.0690
Epoch 4/100, Loss: 524.7680
Epoch 5/100, Loss: 440.2575
Epoch 6/100, Loss: 371.2295
Epoch 7/100, Loss: 304.2645
Epoch 8/100, Loss: 246.3009
Epoch 9/100, Loss: 199.9987
Epoch 10/100, Loss: 160.2777
Epoch 11/100, Loss: 135.5484
Epoch 12/100, Loss: 112.2246
Epoch 13/100, Loss: 107.2774
Epoch 14/100, Loss: 85.7659
Epoch 15/100, Loss: 80.9335
Epoch 16/100, Loss: 73.4991
Epoch 17/100, Loss: 69.6960
Epoch 18/100, Loss: 59.3638
Epoch 19/100, Loss: 61.1398
Epoch 20/100, Loss: 56.9547
Epoch 21/100, Loss: 50.7991
Epoch 22/100, Loss: 53.8752
Epoch 23/100, Loss: 44.3015
Epoch 24/100, Loss: 42.1641
Epoch 25/100, Loss: 41.2398
Epoch 26/100, Loss: 43.4610
Epoch 27/100, Loss: 42.1795
Epoch 28/100, Loss: 35.6709
Epoch 29/100, Loss: 35.1956
Epoch 30/100, Loss: 38.5800
Epoch 31/100, Loss: 32.6078
Epoch 32/100, Loss: 31.1828
Epoch 33/100, Loss: 32.3156
Epoch 34/100, Loss: 28.2622
Epoch 35/100, Loss: 29.5367
Epoch 36/100, L

In [5]:
# Stack weight matrices
def get_activations(model):
    act = []
    for name, param in model.named_parameters():
        if 'weight' in name and len(param.shape)<=2:
            act.append(param.clone().detach().cpu().numpy())
        
    return act 

weights_all = get_activations(trained_model_all)
weights_filtered = get_activations(trained_model_filtered)

# Stack the weights from both models
#stacked_weights = [torch.stack([w_all, w_filtered]) for w_all, w_filtered in zip(weights_all, weights_filtered)]

# Print shapes of stacked weights
for i, weights in enumerate(weights_all):
    print(f"Layer {i + 1}, Stacked Weight Shape: {weights.shape}")

for i, weights in enumerate(weights_filtered):
    print(f"Layer {i + 1}, Stacked Weight Shape: {weights.shape}")

Layer 1, Stacked Weight Shape: (64,)
Layer 2, Stacked Weight Shape: (64,)
Layer 3, Stacked Weight Shape: (64,)
Layer 4, Stacked Weight Shape: (64,)
Layer 5, Stacked Weight Shape: (64,)
Layer 6, Stacked Weight Shape: (128,)
Layer 7, Stacked Weight Shape: (128,)
Layer 8, Stacked Weight Shape: (128,)
Layer 9, Stacked Weight Shape: (128,)
Layer 10, Stacked Weight Shape: (128,)
Layer 11, Stacked Weight Shape: (256,)
Layer 12, Stacked Weight Shape: (256,)
Layer 13, Stacked Weight Shape: (256,)
Layer 14, Stacked Weight Shape: (256,)
Layer 15, Stacked Weight Shape: (256,)
Layer 16, Stacked Weight Shape: (512,)
Layer 17, Stacked Weight Shape: (512,)
Layer 18, Stacked Weight Shape: (512,)
Layer 19, Stacked Weight Shape: (512,)
Layer 20, Stacked Weight Shape: (512,)
Layer 21, Stacked Weight Shape: (10, 512)
Layer 1, Stacked Weight Shape: (64,)
Layer 2, Stacked Weight Shape: (64,)
Layer 3, Stacked Weight Shape: (64,)
Layer 4, Stacked Weight Shape: (64,)
Layer 5, Stacked Weight Shape: (64,)
Layer 6

In [6]:
print(weights_all[20])
print(weights_all[20].shape)

[[-0.21834995 -0.03187049 -0.16574861 ... -0.00095739 -0.4070117
  -0.490066  ]
 [-0.31382605 -0.2706528  -0.02750472 ... -0.13319638  0.15684833
  -0.2164387 ]
 [ 0.01108904 -0.06333178  0.27269706 ... -0.05441358 -0.34333274
  -0.5039562 ]
 ...
 [ 0.2773237   0.20040332  0.02275349 ... -0.12138231 -0.12557912
   0.28466812]
 [-0.33556882  0.00079167 -0.06892181 ... -0.19394068  0.0776932
  -0.22360118]
 [-0.2625512   0.02748459 -0.03911771 ... -0.15813114  0.05415928
  -0.21525238]]
(10, 512)


In [7]:
print(weights_filtered[20])
weights_filtered[20].shape

[[-0.31517348 -0.17904854  0.06963594 ... -0.15570758  0.01271809
  -0.04909996]
 [ 0.07303066  0.11316656 -0.380857   ... -0.0856169   0.1270904
   0.01640966]
 [-0.17097859 -0.00275116 -0.20342812 ... -0.05786476 -0.5024417
   0.02305853]
 ...
 [-0.27280134  0.12193636 -0.40359074 ...  0.10263488  0.10149778
  -0.0414007 ]
 [-0.30185592 -0.3197588  -0.19506279 ... -0.1785774  -0.09254825
  -0.1774683 ]
 [-0.52466434  0.0657118  -0.14734845 ... -0.03040056 -0.36453938
  -0.0609789 ]]


(9, 512)

In [8]:

X = torch.tensor(weights_all[20])
X = X[1:]
Y = torch.tensor(weights_filtered[20])
max_diff = torch.max(torch.abs(X - Y))
print(max_diff)

tensor(1.3716)


In [9]:
U,V,S = torch.linalg.svd(X)
print(U.shape)
print(U)
print(S.shape)

torch.Size([9, 9])
tensor([[-0.2534,  0.4794,  0.1214,  0.0676,  0.0709, -0.1227,  0.1393, -0.1134,
         -0.7964],
        [-0.2674, -0.2371, -0.1202, -0.1466,  0.8322, -0.2088,  0.3053,  0.0539,
          0.0635],
        [-0.3949, -0.2391,  0.0354, -0.3268, -0.1190,  0.5573, -0.0213,  0.5527,
         -0.2194],
        [-0.4293, -0.1566, -0.7035, -0.0727, -0.2941, -0.3708, -0.2304, -0.0971,
         -0.0666],
        [-0.3621, -0.2638,  0.1657,  0.0769, -0.3002,  0.1865,  0.5953, -0.5236,
          0.1115],
        [-0.3686, -0.1229,  0.6009, -0.2628,  0.0292, -0.2537, -0.5395, -0.2322,
          0.0930],
        [-0.2313, -0.2777,  0.1130,  0.8788,  0.0756,  0.0056, -0.1789,  0.2080,
         -0.0568],
        [-0.2906,  0.4746, -0.2308,  0.1119,  0.2657,  0.5447, -0.2799, -0.3102,
          0.2874],
        [-0.3449,  0.4951,  0.1416,  0.0451, -0.1856, -0.3120,  0.2794,  0.4492,
          0.4496]])
torch.Size([512, 512])


In [10]:
i=0
X = torch.tensor(weights_all[20])
X = X[1:]
X = X / torch.norm(X, dim=1, keepdim=True)
print(X)
Y = torch.tensor(weights_filtered[20])
max_diff = torch.mean(torch.abs(X - Y))
print(max_diff)
print(Y)

tensor([[-0.0652, -0.0562, -0.0057,  ..., -0.0277,  0.0326, -0.0450],
        [ 0.0021, -0.0122,  0.0523,  ..., -0.0104, -0.0659, -0.0967],
        [-0.0735, -0.1050, -0.0240,  ..., -0.0403, -0.0194, -0.0688],
        ...,
        [ 0.0526,  0.0380,  0.0043,  ..., -0.0230, -0.0238,  0.0540],
        [-0.0636,  0.0002, -0.0131,  ..., -0.0368,  0.0147, -0.0424],
        [-0.0495,  0.0052, -0.0074,  ..., -0.0298,  0.0102, -0.0406]])
tensor(0.1602)
tensor([[-0.3152, -0.1790,  0.0696,  ..., -0.1557,  0.0127, -0.0491],
        [ 0.0730,  0.1132, -0.3809,  ..., -0.0856,  0.1271,  0.0164],
        [-0.1710, -0.0028, -0.2034,  ..., -0.0579, -0.5024,  0.0231],
        ...,
        [-0.2728,  0.1219, -0.4036,  ...,  0.1026,  0.1015, -0.0414],
        [-0.3019, -0.3198, -0.1951,  ..., -0.1786, -0.0925, -0.1775],
        [-0.5247,  0.0657, -0.1473,  ..., -0.0304, -0.3645, -0.0610]])


In [11]:

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize with mean and std for CIFAR-10
])

# Load CIFAR-10 dataset
cifar10_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# Define the target class to forget (class 0: "airplane")
forget_classes = [0]

retain_classes = [1,2,3,4,5,6,7,8,9]
# Create a function to filter the dataset for retained or forgotten data
def create_dataloader(dataset, target_classes, batch_size, is_forget=False):
    indices = [
        i for i, (_, label) in enumerate(dataset)
        if (label in target_classes) == is_forget
    ]
    subset = Subset(dataset, indices)
    return DataLoader(subset, batch_size=batch_size, shuffle=True)

# Retained Data Loader (CIFAR-10 excluding class 0)
retain_loader = create_dataloader(cifar10_dataset, forget_classes, batch_size=32, is_forget=False)

# Forget Data Loader (Only class 0)
forget_loader = create_dataloader(cifar10_dataset, forget_classes, batch_size=32, is_forget=True)

# Check the first few batches of the retained and forgotten data loaders
print(f"Number of batches in retained loader: {len(retain_loader)}")
print(f"Number of batches in forgotten loader: {len(forget_loader)}")


Files already downloaded and verified
Number of batches in retained loader: 1407
Number of batches in forgotten loader: 157


In [12]:
trained_model_all.fc = nn.Linear(trained_model_all.fc.in_features, 9)
trained_model_all.fc.weight.data = X.clone()
for name, param in trained_model_all.named_parameters():
    if 'weight' in name:  # Check if the parameter is a weight matrix
        print(f"Layer {name}: {param.shape}")

Layer conv1.weight: torch.Size([64, 3, 7, 7])
Layer bn1.weight: torch.Size([64])
Layer layer1.0.conv1.weight: torch.Size([64, 64, 3, 3])
Layer layer1.0.bn1.weight: torch.Size([64])
Layer layer1.0.conv2.weight: torch.Size([64, 64, 3, 3])
Layer layer1.0.bn2.weight: torch.Size([64])
Layer layer1.1.conv1.weight: torch.Size([64, 64, 3, 3])
Layer layer1.1.bn1.weight: torch.Size([64])
Layer layer1.1.conv2.weight: torch.Size([64, 64, 3, 3])
Layer layer1.1.bn2.weight: torch.Size([64])
Layer layer2.0.conv1.weight: torch.Size([128, 64, 3, 3])
Layer layer2.0.bn1.weight: torch.Size([128])
Layer layer2.0.conv2.weight: torch.Size([128, 128, 3, 3])
Layer layer2.0.bn2.weight: torch.Size([128])
Layer layer2.0.downsample.0.weight: torch.Size([128, 64, 1, 1])
Layer layer2.0.downsample.1.weight: torch.Size([128])
Layer layer2.1.conv1.weight: torch.Size([128, 128, 3, 3])
Layer layer2.1.bn1.weight: torch.Size([128])
Layer layer2.1.conv2.weight: torch.Size([128, 128, 3, 3])
Layer layer2.1.bn2.weight: torch.Si

In [13]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move the model to the same device as the input
trained_model_all = trained_model_all.to(device)

# Now when using the model, ensure the input data is also on the same device
def evaluate_model(model, data_loader):
    model.eval()  # Set the model to evaluation mode
    total = 0
    correct = 0
    with torch.no_grad():
        for inputs, labels in data_loader:
            # Move input and labels to the same device as the model
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            
            # Compute accuracy
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    return correct / total



In [14]:
evaluate_model(trained_model_all, retain_loader)

0.00011111111111111112

In [15]:
evaluate_model(trained_model_all,forget_loader)

0.0346