In [1]:
import os
import torch
from torchvision import datasets, transforms
from utils import loaders_by_classes, filter_loaders, balance, balanced_batch_size, submodel, get_activations
from classNet import ConvNet # for torch load
import matplotlib.pyplot as plt

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

num_workers = os.cpu_count()

device, num_workers

(device(type='mps'), 12)

In [2]:
model_all_class = torch.load('./models/all_class.pth', weights_only=False)

In [3]:
class_num = 7
class_name = '7 - seven'
model = submodel(model_all_class, class_num)
model.to(device)
model.eval()

ConvNet(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2_drop): Dropout2d(p=0.25, inplace=False)
  (fc1): Linear(in_features=3136, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=9, bias=True)
)

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize((0,), (1,))  # Normalize images
])

train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_set = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [5]:
batch_size = 64
train_loaders = loaders_by_classes(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_loaders = loaders_by_classes(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [6]:
loaders = train_loaders
loaders

{'0 - zero': <torch.utils.data.dataloader.DataLoader at 0x106be9850>,
 '1 - one': <torch.utils.data.dataloader.DataLoader at 0x106bc2750>,
 '2 - two': <torch.utils.data.dataloader.DataLoader at 0x17d8f6a20>,
 '3 - three': <torch.utils.data.dataloader.DataLoader at 0x17de87830>,
 '4 - four': <torch.utils.data.dataloader.DataLoader at 0x17de877a0>,
 '5 - five': <torch.utils.data.dataloader.DataLoader at 0x17deb0dd0>,
 '6 - six': <torch.utils.data.dataloader.DataLoader at 0x17deb0620>,
 '7 - seven': <torch.utils.data.dataloader.DataLoader at 0x17deb0f20>,
 '8 - eight': <torch.utils.data.dataloader.DataLoader at 0x17deb1010>,
 '9 - nine': <torch.utils.data.dataloader.DataLoader at 0x17deb10d0>}

In [7]:
blcd_batch_size = int(batch_size * balanced_batch_size(loaders, class_name))
blcd_batch_size

484

In [8]:
forget_loader = loaders[class_name]
retain_loader = filter_loaders(loaders, class_name, blcd_batch_size, shuffle=True, num_workers=num_workers)

In [9]:
retain_input = next(iter(retain_loader))
r_input_tensor, r_label_tensor = retain_input[0].to(device), retain_input[1].to(device)
forget_input = next(iter(forget_loader))
f_input_tensor, f_label_tensor = forget_input[0].to(device), forget_input[1].to(device)

In [20]:
model.zero_grad()
r_input_tensor.requires_grad = True
retain_dict = get_activations(model, r_input_tensor, -1, verbose=True)

criterion = torch.nn.CrossEntropyLoss()
loss = criterion(retain_dict['output'], r_label_tensor)
loss.backward()

r_grads = []
for param in model.parameters():
    r_grads.append(param.grad)

model.zero_grad()
f_input_tensor.requires_grad = True
forget_dict = get_activations(model, f_input_tensor, -1, verbose=True)

sftmax = torch.exp(forget_dict['output']).mean(dim=0)
activations = retain_dict['activations'].detach()

activations_dict = {}
for idx, label in enumerate(r_label_tensor):
    label = int(label.item())  # Assurez-vous que le label est utilisable comme clé
    if label not in activations_dict.keys():
        activations_dict[label] = {}
        activations_dict[label]['activation'] = activations[idx, :].clone()
        activations_dict[label]['count'] = 1
    else:
        activations_dict[label]['activation'] += activations[idx, :]
        activations_dict[label]['count'] += 1

target = 0
for key, dval in activations_dict.items():
    ikey = key - 1 * (key > class_num)
    target += sftmax[ikey] * dval['activation'] / dval['count']

criterion = torch.nn.MSELoss()
loss = criterion(torch.mean(forget_dict['activations'], dim=0), target)
loss.backward()

f_grads = []
for param in model.parameters():
    f_grads.append(param.grad)

Checking residues of the input after layer : Linear(in_features=128, out_features=9, bias=True)
Checking residues of the input after layer : Linear(in_features=128, out_features=9, bias=True)


In [45]:
ttot = 0
tcount = 0
for f_grad, r_grad in list(zip(f_grads, r_grads)):
    size = f_grad.size()
    mask = f_grad * r_grad > 0
    count = int(torch.sum(mask))
    print(f'couche de taille {size} potentiellement modifiable : {count}')
    tcount += count
    tot = torch.tensor(size).prod().item()
    ttot += tot
    
print(f'potentiellement modifiable : {tcount} sur {ttot} soit {100 * tcount / ttot} %')

couche de taille torch.Size([32, 1, 3, 3]) potentiellement modifiable : 195
couche de taille torch.Size([32]) potentiellement modifiable : 20
couche de taille torch.Size([64, 32, 3, 3]) potentiellement modifiable : 9873
couche de taille torch.Size([64]) potentiellement modifiable : 44
couche de taille torch.Size([128, 3136]) potentiellement modifiable : 40786
couche de taille torch.Size([128]) potentiellement modifiable : 22
couche de taille torch.Size([9, 128]) potentiellement modifiable : 282
couche de taille torch.Size([9]) potentiellement modifiable : 4
potentiellement modifiable : 51226 sur 421513 soit 12.152887336808117 %
