In [1]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from network import resnet18
from dataset import create_dataloaders_missing_class,create_dataloaders_uniform_sampling
from config import dotdict
import torch.nn.functional as F
import torch.nn as nn
import math
DEVICE = "cuda:0"
from torchmetrics import Accuracy, ConfusionMatrix
from network import NormConstrainedResNet
from trainer import AdamWOptimConfig, CosineAnnealingLRSchedulerConfig, TrainerSettings, Trainer, count_parameters

In [2]:
def compute_angle(model1,model2):
    X = torch.rand(10,3,32,32).to(DEVICE)
    W1 = model1.W.data
    W2 = model2.W.data
    theta_w = torch.acos(F.cosine_similarity(W1.T,W2.T,dim=1)) * 180 / math.pi   #what are the angles of W1 and W2

    Z1 = model1.resnet(X)
    Z1 = Z1 @ W1 @ W1.mT  #projection to the subspace spanned by W
    Z2 = model2.resnet(X)
    Z2 = Z2 @ W2 @ W2.mT  #projection to the subspace spanned by W
    theta_z = torch.acos(F.cosine_similarity(Z1.T,Z2.T,dim=1)) * 180 / math.pi
    return theta_z,theta_w

In [21]:
model1 = NormConstrainedResNet(num_classes=10).to(DEVICE)      
checkpoint = torch.load("models/model_no_norm_WR.pt", map_location=DEVICE)
model1.load_state_dict(checkpoint['model_state_dict'])
model2 = NormConstrainedResNet(num_classes=10).to(DEVICE)    
checkpoint = torch.load("models/model_no_norm_W1.pt", map_location=DEVICE)
model2.load_state_dict(checkpoint['model_state_dict'])
#compute_angle(model1,model2)
W1 = model1.W.data
W2 = model2.W.data
W1,W2


(tensor([[-0.0693, -0.0739, -0.0033,  ..., -0.0272, -0.0681, -0.0118],
         [ 0.0382,  0.0135, -0.0876,  ..., -0.0139,  0.0090, -0.0753],
         [-0.0021, -0.0068,  0.0302,  ..., -0.0165,  0.0620, -0.0009],
         ...,
         [ 0.0755,  0.0060, -0.0440,  ..., -0.0625, -0.0028,  0.0015],
         [ 0.0147, -0.0255, -0.0178,  ..., -0.0231,  0.0086,  0.0911],
         [-0.0051, -0.0946, -0.0185,  ...,  0.0563, -0.0043,  0.0196]],
        device='cuda:0'),
 tensor([[-0.0693, -0.0739, -0.0033,  ..., -0.0272, -0.0681, -0.0118],
         [ 0.0382,  0.0135, -0.0876,  ..., -0.0139,  0.0090, -0.0753],
         [-0.0021, -0.0068,  0.0302,  ..., -0.0165,  0.0620, -0.0009],
         ...,
         [ 0.0755,  0.0060, -0.0440,  ..., -0.0625, -0.0028,  0.0015],
         [ 0.0147, -0.0255, -0.0178,  ..., -0.0231,  0.0086,  0.0911],
         [-0.0051, -0.0946, -0.0185,  ...,  0.0563, -0.0043,  0.0196]],
        device='cuda:0'))

In [22]:
X = torch.rand(20,3,32,32).to(DEVICE)
Z1 = model1.resnet(X) @ W1
Z1 /= torch.norm(Z1,dim=1,keepdim=True)
Z2 = model2.resnet(X) @ W2
Z2 /= torch.norm(Z2,dim=1,keepdim=True)
torch.acos(F.cosine_similarity(Z1,Z2,dim=1)) * 180 / math.pi

tensor([0., nan, 0., 0., 0., nan, nan, 0., 0., 0., nan, nan, 0., 0., 0., 0., 0., 0., 0., nan],
       device='cuda:0', grad_fn=<DivBackward0>)

tensor(1.0000, device='cuda:0', grad_fn=<SumBackward0>)

### Code for Checking the Unlearning metrics from models trained on all 3 unlearning techniques

In [1]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from network import resnet18

from dataset import create_dataloaders_missing_class,create_dataloaders_uniform_sampling
from config import dotdict
import torch.nn.functional as F
import torch.nn as nn
DEVICE = "cuda:1"
from torchmetrics import Accuracy, ConfusionMatrix

In [2]:
#load network
net = resnet18(num_classes=10)
net.load_state_dict(torch.load("/research/hal-gaudisac/unlearning/models/model_scratch_resnet18.pt", map_location=torch.device(DEVICE)))
net.to(DEVICE)
net.eval()

#load unlearn network
unlearn_net = resnet18(num_classes=10)
unlearn_net.load_state_dict(torch.load("/research/hal-gaudisac/unlearning/models/model_NN_resnet18.pt", map_location=torch.device(DEVICE)))
unlearn_net.to(DEVICE)
unlearn_net.eval()

finetune_net = resnet18(num_classes=10)
finetune_net.load_state_dict(torch.load("models/model_finetune_resnet18.pt", map_location=torch.device(DEVICE)))
finetune_net.to(DEVICE)
finetune_net.eval()

scrubs_net = resnet18(num_classes=10)
scrubs_net.load_state_dict(torch.load("/research/hal-gaudisac/unlearning/models/model_scrubs_resnet18.pt", map_location=torch.device(DEVICE)))
scrubs_net.to(DEVICE)
scrubs_net.eval()


#load data
data_settings = dotdict({"BATCH_SIZE": 128, "data":{"num_workers": 4}, "DATA_PATH":"data","remove_class":0})
#dataloaders = create_dataloaders_missing_class(data_settings)
dataloaders = create_dataloaders_uniform_sampling(data_settings)
loss_fn = nn.CrossEntropyLoss(reduction="none")
#to not load everything in memory
dataloaders.forget.pin_memory=False
dataloaders.val.pin_memory=False
dataloaders.retain.pin_memory=False

In [3]:
acc = Accuracy("multiclass",num_classes=10).to(DEVICE)
cm = ConfusionMatrix("multiclass",num_classes=10).to(DEVICE)

cms = []
for name,model in [("retrain",net),("unlearn",unlearn_net),("finetune",finetune_net),("scrubs",scrubs_net)]: #
    for dl_name,dl in zip(["forget","retain","val"],[dataloaders.forget,dataloaders.retain,dataloaders.val]):
        acc.reset()
        cm.reset()
        for batch_id,batch in enumerate(dl):
            #if dl_name == "val":
            #    inputs,targets = inputs[targets!=0], targets[targets!=0]
            inputs,targets = batch[0].to(DEVICE),batch[1].to(DEVICE)
            outputs = model(inputs)
            loss = (outputs,targets)
            acc.update(outputs,targets)
            cm.update(outputs,targets)    
        print(f"{name} - {dl_name} : => {1 - acc.compute()}")
    val_cm = cm.compute().detach().cpu().numpy() 
    cms.append(val_cm/(val_cm.sum(axis=1,keepdims=True)+1))  #compute frobenius norm of confusion matrix

retrain - forget : => 0.012000024318695068
retrain - retain : => 0.009000003337860107
retrain - val : => 0.19249999523162842
unlearn - forget : => 0.18524444103240967
unlearn - retain : => 0.15420001745224
unlearn - val : => 0.24424999952316284
finetune - forget : => 0.19691109657287598
finetune - retain : => 0.10240000486373901
finetune - val : => 0.24787497520446777
scrubs - forget : => 0.3575778007507324
scrubs - retain : => 0.26759999990463257
scrubs - val : => 0.36262500286102295


In [5]:
np.linalg.norm(cms[0] - cms[1]),np.linalg.norm(cms[0] - cms[2]),np.linalg.norm(cms[0] - cms[3])

(0.24147937450130905, 0.2263611036755208, 0.6427778583371856)

In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
with open("data/celeba/identity_CelebA.txt")    as f:
    identity_CelebA = pd.read_csv(f, sep=" ", header=None)
identity_CelebA.columns = ["image_id", "identity"]
#randomly sample 10% of identities
index = np.random.choice(identity_CelebA["identity"].unique(), 10, replace=False) 
#assign these index 1 and rest 0
identity_CelebA["identity"] = identity_CelebA["identity"].apply(lambda x: 1 if x in index else 0)

#save df in the text file like list_forget_partition.txt 
with open("data/celeba/list_forget_partition.txt", "w") as f:
    for i in range(len(identity_CelebA)):
        f.write(f"{identity_CelebA.iloc[i,0]} {identity_CelebA.iloc[i,1]}\n")

In [9]:
meta  = {}
import json


with open("data/celeba/list_attr_celeba.txt") as f:
    attr_celeba = pd.read_csv(f, sep="\s+", skiprows=1)
attr_celeba.replace(to_replace=-1, value=0, inplace=True)
meta = {"columns": attr_celeba.columns.tolist(), "mean":json.loads(attr_celeba.mean(axis=0).to_json())}
with open("data/celeba/meta.json","w") as f:
    f.write(json.dumps(meta))

In [25]:
#work with torch unbind
import torch
import torch.nn as nn
X = torch.randn(32, 40, 2)
Y = torch.randint(2,(32, 40))
tList = [loss_fn(x_i,y_i) for x_i, y_i in zip(torch.unbind(X,dim=1),torch.unbind(Y,dim=1)) ]

_, preds = torch.max(X, dim=-1)
torch.argmax(X, dim=-1)
attr_celeba.columns[19]
a = torch.rand(2,3)
a

## OLD CODE for ploting losses and understanding metrics.

In [None]:
#
def accuracy(pds,tgts):
    return torch.mean((torch.max(pds,dim=1)[1]==tgts).float())

pds,tgts,ls = [],[],[]
dataloaders.forget.pin_memory=False
dataloaders.val.pin_memory=False
dataloaders.retain.pin_memory=False
torch.cuda.empty_cache()

for batch_id, (inputs, targets,mask) in enumerate(dataloaders.forget):
    inputs,targets = inputs.to(DEVICE),targets.to(DEVICE)
    outputs = net(inputs)
    loss = (outputs,targets)
    ls.append(loss_fn(outputs,targets).detach().cpu())
    pds.append(F.softmax(outputs,dim=1))
    tgts.append(targets)
pds = torch.cat(pds)
tgts = torch.cat(tgts)
ls = torch.cat(ls)
torch.cuda.empty_cache()

val_pds,val_tgts,val_ls = [],[],[]
for batch_id, (inputs, targets,mask) in enumerate(dataloaders.retain):
    inputs,targets = inputs.to(DEVICE),targets.to(DEVICE)
    outputs = net(inputs)
    val_ls.append(loss_fn(outputs,targets).detach().cpu())
    val_pds.append(F.softmax(outputs,dim=1))
    val_tgts.append(targets)
    
val_pds = torch.cat(val_pds)
val_tgts = torch.cat(val_tgts)
val_ls = torch.cat(val_ls)

torch.cuda.empty_cache()

normal_pds,normal_tgts,normal_ls = [],[],[]
for batch_id, (inputs, targets,mask) in enumerate(dataloaders.val):
    inputs,targets = inputs.to(DEVICE),targets.to(DEVICE)
    outputs = net(inputs)
    normal_ls.append(loss_fn(outputs,targets).detach().cpu())
    normal_pds.append(F.softmax(outputs,dim=1))
    normal_tgts.append(targets)
normal_ls = torch.cat(normal_ls)
normal_tgts = torch.cat(normal_tgts)
normal_pds = torch.cat(normal_pds)
torch.cuda.empty_cache()


In [None]:
pds,tgts,ls = [],[],[]
for batch_id, (inputs, targets,mask) in enumerate(dataloaders.forget):
    inputs,targets = inputs.to(DEVICE),targets.to(DEVICE)
    outputs = net(inputs)
    loss = (outputs,targets)
    ls.append(loss_fn(outputs,targets).detach().cpu())
    pds.append(F.softmax(outputs,dim=1))
    tgts.append(targets)
pds = torch.cat(pds)
tgts = torch.cat(tgts)
ls = torch.cat(ls)


from plots import plot_losses
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 6))
test_losses = val_pds.max(dim=1)[0].detach().cpu().numpy()
forget_losses = normal_pds.max(dim=1)[0].detach().cpu().numpy()
X = (test_losses, forget_losses)
weights = (np.ones_like(test_losses)/len(test_losses),
            np.ones_like(forget_losses)/len(forget_losses))
labels = ("Non Training class", "Train classes")
bins = np.histogram(np.hstack(X), bins=20)[1]  # get the bin edges

ax1.hist(X, density=False, alpha=0.5, bins=bins,
    weights=weights, label=labels)

ax1.set_ylabel("Percentage Samples", fontsize=12)
ax1.set_xlabel("Confidence", fontsize=12)
ax1.legend(frameon=False, fontsize=8)


test_losses = val_ls.detach().cpu().numpy()
forget_losses = normal_ls.detach().cpu().numpy()
X = (test_losses, forget_losses)
weights = (np.ones_like(test_losses)/len(test_losses),
            np.ones_like(forget_losses)/len(forget_losses))
labels = ("Non Training class", "Train classes")
bins = np.histogram(np.hstack(X), bins=20)[1]  # get the bin edges


ax2.hist(X, density=False, alpha=0.5, bins=bins,
    weights=weights, label=labels)
ax2.set_xlabel("Cross entropy loss", fontsize=12)
ax2.legend(frameon=False, fontsize=8)
plt.savefig('results/scratch_loss.png', dpi=300)

In [66]:
#class re
#class specific layers for resnet.
import torch
from torch import nn, einsum
from einops import rearrange
from typing import Optional


class ClassAttentionBlock(nn.Module):
    
    def __init__(self, in_planes: int, out_planes: int, kernel_size: int = 3, stride: Optional[int] = None, groups: int = 1, padding: Optional[str] = "same"):
        super().__init__()
        #here groups are nothing but num of classes
        #inplanes -> outplanes -> attention. 

        """
        self.height = fmap_size[0]
        self.width = fmap_size[1]
        
        """
        num_classes  = groups
        self.scale = out_planes ** -0.5
        
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
                               groups=groups, padding=1, dilation=1, stride=stride)
        self.bn1 = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)

        #inplanes outplanes dim,num_classes,fmap_size =(32,32)
        self.to_qkv = nn.Conv2d(out_planes,out_planes* 3, 1, bias=False,groups=num_classes)
        self.class_mask = -torch.ones(num_classes,1,out_planes,out_planes,requires_grad =False)*torch.inf
        self.attn_mask = torch.zeros(num_classes,1,out_planes,out_planes,requires_grad =False)
        for i in range(num_classes):
            self.class_mask[i,:,i*out_planes:(i+1)*out_planes,i*out_planes:(i+1)*out_planes] = 0
            self.attn_mask[i,:,i*out_planes:(i+1)*out_planes,i*out_planes:(i+1)*out_planes] = 1
        

    def forward(self, x, y=None):
        

        
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        height, width = out.shape[-2:]

        # [batch (heads*3*dim_head) height width]
        qkv = self.to_qkv(out)
       
        # decompose heads and merge spatial dims as tokens
        q, k, v = tuple(rearrange(qkv, 'b (d k h) x y  -> k b h d (x y)', k=3, h=1))
        # i, j refer to tokens
       
        dot_prod = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
        if y is not None:
            dot_prod += self.attn_mask[y]
        
        attention = torch.softmax(dot_prod, dim=-1)
        if y is not None:
            attention *= self.attn_mask[y]
        out = einsum('b h i j, b h j d -> b h i d', attention, v)
        # Merge heads and decompose tokens to spatial dims
        out = rearrange(out, 'b h d (x y) -> b (h d) x y', x=height, y=width)
        return out
    
#class specific layers for resnet.
class ClassSpedificResNetBlock(nn.Module):
    def __init__(self,in_planes: int, out_planes: int,num_classes:int, kernel_size: int = 3, stride: Optional[int] = None, groups: int = 1, padding: Optional[str] = "same") -> None:
        super().__init__()

        # Conv3D
        #assert out_planes % num_classes == 0, "out_planes should be divisible by num_classes"
        self.conv1 = ClassAttentionBlock(in_planes*num_classes, out_planes*num_classes, kernel_size, stride, num_classes, padding)
        
        self.downsample = None
        if stride != 1 or in_planes != out_planes:
            self.downsample = nn.Sequential(nn.Conv2d(in_planes*num_classes, out_planes*num_classes, kernel_size=1, stride=stride, bias=False,groups=num_classes), nn.BatchNorm2d(out_planes*num_classes))

        # what if we use LayerNorm instead of BatchNorm.
        self.bn1 = nn.BatchNorm2d( out_planes*num_classes)
        self.bn2 = nn.BatchNorm2d( out_planes*num_classes)

        # activation of relu
        self.relu = nn.ReLU(inplace=True)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):

        out = x
        out = self.conv1(out)
        out = self.bn1(out)
        out = self.relu(out)

        # to match diamensions of x with that of output.
        if self.downsample:
            x = self.downsample(x)

        out += x
        out = self.relu(out)

        return out 

In [67]:
model = ClassSpedificResNetBlock(4,8,10, kernel_size= 3,stride=1)
x = torch.rand(2,4*10,32,32)
y = torch.randint(2,(2,))
model(x).shape

torch.Size([2, 80, 32, 32])

torch.Size([2, 80, 32, 32])

In [71]:
y =1
w = torch.ones(1,1,80,80)*torch.inf
w[:,:,0+y:40+y,0:40] = 1

In [6]:
num_classes = 10
dim =8
class_mask = torch.ones(num_classes,1,num_classes* dim,num_classes* dim,requires_grad =False)*torch.inf


tensor([[[[ 3.3890e-01,  2.5160e-01,  2.1525e-01,  ...,  2.4397e-01,
            3.3694e-01,  2.7881e-01],
          [ 2.0173e-01,  3.1196e-01,  3.1179e-02,  ...,  2.9757e-01,
            1.4044e-01,  1.1020e-01],
          [ 4.7620e-01,  1.0856e-01,  2.5557e-01,  ...,  4.6144e-01,
            3.0961e-01,  2.2630e-02],
          ...,
          [ 4.1908e-01,  2.1499e-01,  9.4574e-02,  ...,  2.0092e-01,
            2.3023e-01,  3.7197e-01],
          [ 5.1734e-01,  3.0734e-01,  2.1264e-01,  ...,  3.1436e-01,
            3.3791e-01,  1.6387e-01],
          [ 1.9127e-01,  2.8782e-01,  1.9839e-01,  ...,  1.0235e-01,
            3.2376e-01, -7.3516e-03]],

         [[ 2.8291e-01,  9.4492e-02,  2.6730e-01,  ...,  1.4632e-01,
           -1.4304e-01, -3.9575e-01],
          [ 2.9632e-01,  3.5927e-01,  1.9496e-01,  ...,  5.2694e-01,
            4.2898e-01, -7.6187e-02],
          [ 8.8827e-03, -1.2398e-02,  5.8757e-01,  ...,  3.4401e-01,
            4.8196e-01, -1.3220e-01],
          ...,
     