In [1]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from network import resnet18

from dataset import create_dataloaders_missing_class
from config import dotdict
import torch.nn.functional as F
import torch.nn as nn
DEVICE = "cuda:1"
from torchmetrics import Accuracy, ConfusionMatrix

In [2]:
#load network
net = resnet18(num_classes=10)
net.load_state_dict(torch.load("/research/hal-gaudisac/unlearning/models/scratch_resnet18.pt", map_location=torch.device(DEVICE))["model_state_dict"])
net.to(DEVICE)
net.eval()

#load unlearn network
unlearn_net = resnet18(num_classes=10)
unlearn_net.load_state_dict(torch.load("/research/hal-gaudisac/unlearning/models/NN_resnet18.pt", map_location=torch.device(DEVICE))["model_state_dict"])
unlearn_net.to(DEVICE)
unlearn_net.eval()

finetune_net = resnet18(num_classes=10)
finetune_net.load_state_dict(torch.load("models/finetune_resnet18.pt", map_location=torch.device(DEVICE))["model_state_dict"])
finetune_net.to(DEVICE)
finetune_net.eval()

scrubs_net = resnet18(num_classes=10)
scrubs_net.load_state_dict(torch.load("/research/hal-gaudisac/unlearning/models/scrubs_resnet18.pt", map_location=torch.device(DEVICE))["model_state_dict"])
scrubs_net.to(DEVICE)
scrubs_net.eval()




#load data
data_settings = dotdict({"BATCH_SIZE": 128, "data":{"num_workers": 4}, "DATA_PATH":"data","remove_class":0})
dataloaders = create_dataloaders_missing_class(data_settings)
loss_fn = nn.CrossEntropyLoss(reduction="none")
#to not load everything in memory
dataloaders.forget.pin_memory=False
dataloaders.val.pin_memory=False
dataloaders.retain.pin_memory=False

In [29]:
acc = Accuracy("multiclass",num_classes=10).to(DEVICE)
cm = ConfusionMatrix("multiclass",num_classes=10).to(DEVICE)


cms = []
for name,model in [("retrain",net),("unlearn",unlearn_net),("finetune",finetune_net),("scrubs",scrubs_net)]: #
    for dl_name,dl in zip(["forget","retain","val"],[dataloaders.forget,dataloaders.retain,dataloaders.val]):
        acc.reset()
        cm.reset()
        for batch_id, (inputs, targets,mask) in enumerate(dl):
            if dl_name == "val":
                inputs,targets = inputs[targets!=0], targets[targets!=0]

            inputs,targets = inputs.to(DEVICE),targets.to(DEVICE)
            outputs = model(inputs)
            loss = (outputs,targets)
            acc.update(outputs,targets)
            cm.update(outputs,targets)    
        print(f"{name} - {dl_name} : => {1 - acc.compute()}")
    val_cm = cm.compute().detach().cpu().numpy() 
    cms.append(val_cm/(val_cm.sum(axis=1,keepdims=True)+1))  #compute frobenius norm of confusion matrix

retrain - forget : => 1.0
retrain - retain : => 0.010311126708984375
retrain - val : => 0.20452022552490234
unlearn - forget : => 1.0
unlearn - retain : => 0.006377756595611572
unlearn - val : => 0.2014697790145874
finetune - forget : => 0.9101999998092651
finetune - retain : => 0.005533337593078613
finetune - val : => 0.19994455575942993
scrubs - forget : => 1.0
scrubs - retain : => 0.008688867092132568
scrubs - val : => 0.17970049381256104


In [30]:
np.linalg.norm(cms[0] - cms[1]),np.linalg.norm(cms[0] - cms[2]),np.linalg.norm(cms[0] - cms[3])

(0.0604955331644341, 0.049052444045379895, 0.09777196469661699)

In [10]:
cms[0]

array([[0.        , 0.06524467, 0.32496863, 0.06398996, 0.0727729 ,
        0.0238394 , 0.01882058, 0.02634881, 0.31116688, 0.09284818],
       [0.        , 0.89533417, 0.0037831 , 0.0075662 , 0.00252207,
        0.00882724, 0.0037831 , 0.00126103, 0.01639344, 0.06052963],
       [0.        , 0.        , 0.75371287, 0.04950495, 0.06683168,
        0.04579208, 0.04084158, 0.02351485, 0.00990099, 0.00990099],
       [0.        , 0.00623441, 0.05486284, 0.60723192, 0.05610973,
        0.14962594, 0.05985037, 0.03366584, 0.01870324, 0.01371571],
       [0.        , 0.00365408, 0.05359318, 0.05846529, 0.76248477,
        0.02801462, 0.03775883, 0.04384896, 0.01096224, 0.00121803],
       [0.        , 0.0063857 , 0.03703704, 0.17241379, 0.03192848,
        0.68965517, 0.01532567, 0.03831418, 0.00255428, 0.0063857 ],
       [0.        , 0.00371747, 0.03221809, 0.04337051, 0.02726146,
        0.02973978, 0.84634449, 0.00743494, 0.00495663, 0.00495663],
       [0.        , 0.00483676, 0.0290205

In [4]:
class Hello:
    def __init__(self) -> None:
        self.sac = {"a":1}
    def __getattr__(self, name):
        return self.sac[name]
    def update(self):
        self.sac["a"]=2


In [10]:
h = Hello()
h.update()

h.__dict__

{'sac': {'a': 2}}

## OLD CODE for ploting losses and understanding metrics.

In [None]:
#
def accuracy(pds,tgts):
    return torch.mean((torch.max(pds,dim=1)[1]==tgts).float())

pds,tgts,ls = [],[],[]
dataloaders.forget.pin_memory=False
dataloaders.val.pin_memory=False
dataloaders.retain.pin_memory=False
torch.cuda.empty_cache()

for batch_id, (inputs, targets,mask) in enumerate(dataloaders.forget):
    inputs,targets = inputs.to(DEVICE),targets.to(DEVICE)
    outputs = net(inputs)
    loss = (outputs,targets)
    ls.append(loss_fn(outputs,targets).detach().cpu())
    pds.append(F.softmax(outputs,dim=1))
    tgts.append(targets)
pds = torch.cat(pds)
tgts = torch.cat(tgts)
ls = torch.cat(ls)
torch.cuda.empty_cache()

val_pds,val_tgts,val_ls = [],[],[]
for batch_id, (inputs, targets,mask) in enumerate(dataloaders.retain):
    inputs,targets = inputs.to(DEVICE),targets.to(DEVICE)
    outputs = net(inputs)
    val_ls.append(loss_fn(outputs,targets).detach().cpu())
    val_pds.append(F.softmax(outputs,dim=1))
    val_tgts.append(targets)
    
val_pds = torch.cat(val_pds)
val_tgts = torch.cat(val_tgts)
val_ls = torch.cat(val_ls)

torch.cuda.empty_cache()

normal_pds,normal_tgts,normal_ls = [],[],[]
for batch_id, (inputs, targets,mask) in enumerate(dataloaders.val):
    inputs,targets = inputs.to(DEVICE),targets.to(DEVICE)
    outputs = net(inputs)
    normal_ls.append(loss_fn(outputs,targets).detach().cpu())
    normal_pds.append(F.softmax(outputs,dim=1))
    normal_tgts.append(targets)
normal_ls = torch.cat(normal_ls)
normal_tgts = torch.cat(normal_tgts)
normal_pds = torch.cat(normal_pds)
torch.cuda.empty_cache()


In [None]:
pds,tgts,ls = [],[],[]
for batch_id, (inputs, targets,mask) in enumerate(dataloaders.forget):
    inputs,targets = inputs.to(DEVICE),targets.to(DEVICE)
    outputs = net(inputs)
    loss = (outputs,targets)
    ls.append(loss_fn(outputs,targets).detach().cpu())
    pds.append(F.softmax(outputs,dim=1))
    tgts.append(targets)
pds = torch.cat(pds)
tgts = torch.cat(tgts)
ls = torch.cat(ls)


from plots import plot_losses
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 6))
test_losses = val_pds.max(dim=1)[0].detach().cpu().numpy()
forget_losses = normal_pds.max(dim=1)[0].detach().cpu().numpy()
X = (test_losses, forget_losses)
weights = (np.ones_like(test_losses)/len(test_losses),
            np.ones_like(forget_losses)/len(forget_losses))
labels = ("Non Training class", "Train classes")
bins = np.histogram(np.hstack(X), bins=20)[1]  # get the bin edges

ax1.hist(X, density=False, alpha=0.5, bins=bins,
    weights=weights, label=labels)

ax1.set_ylabel("Percentage Samples", fontsize=12)
ax1.set_xlabel("Confidence", fontsize=12)
ax1.legend(frameon=False, fontsize=8)


test_losses = val_ls.detach().cpu().numpy()
forget_losses = normal_ls.detach().cpu().numpy()
X = (test_losses, forget_losses)
weights = (np.ones_like(test_losses)/len(test_losses),
            np.ones_like(forget_losses)/len(forget_losses))
labels = ("Non Training class", "Train classes")
bins = np.histogram(np.hstack(X), bins=20)[1]  # get the bin edges


ax2.hist(X, density=False, alpha=0.5, bins=bins,
    weights=weights, label=labels)
ax2.set_xlabel("Cross entropy loss", fontsize=12)
ax2.legend(frameon=False, fontsize=8)
plt.savefig('results/scratch_loss.png', dpi=300)