# Selective Synaptic Dampening (SSD)

https://arxiv.org/abs/2308.07707

In [None]:
import copy
import json
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchinfo import summary
from torchvision import transforms
from tqdm import tqdm

In [None]:
drive = None
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
path = "./"
sys.path.append(path)

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
path = path if drive is None else "/content/drive/MyDrive/self-learn/unlearning"

In [None]:
from constants import *
from utils import set_seed, train_data, val_data, \
                    train_loader, val_loader, fine_labels
from models import get_model, get_attack_model
    
set_seed()

In [8]:
model, optimizer = get_model(SEED)

# Setup

In [None]:
target_class = 23
fine_labels[target_class]

In [41]:
def calc_param_importances(model, loader, optimizer, criterion):
    
    param_importances = dict([(name, torch.zeros_like(p, device=p.device))
                    for name, p in model.named_parameters()
                    ])
    
    for step, (img, label) in enumerate(loader):
        # img, label = img.to(device), label.to(device)
        optimizer.zero_grad()
        out = model(img)
        loss = criterion(out, label)
        loss.backward()

        for (_, p), (_, imp) in zip(model.named_parameters(),
                                    param_importances.items()):
            if p.grad is not None:
                imp.data += p.grad.data.clone().pow(2)

    param_importances = {_: imp / float(len(loader)) for _, imp in param_importances.items()}
    
    return param_importances

In [49]:
def apply_ssd(model, full_importances, forget_importances,
              alpha=10, lambda_dampen=1):
    # α and λ chosen for CIFAR as from paper
    
    ## use alpha * D threshold to find specialized parameters, update via multiplication of β, where
    ## β = min((λD / D_f), 1)
    with torch.no_grad():
        for (_, p), (_, full_imp), (_, forget_imp) in zip(model.named_parameters(),
                                                          full_importances.items(),
                                                          forget_importances.items()):
            
            # compute indices of specialized parameters
            spec_idx = torch.where(forget_imp > (alpha * full_imp))

            # compute dampening factor β
            beta = (lambda_dampen * full_imp) / forget_imp
            # prepare param update factor
            delta = beta[spec_idx]
            delta[torch.where(delta > 1)] = 1 # bound by 1

            # apply updates
            p[spec_idx] *= delta

# Driver code TODO

In [42]:
## calculate full D importance and forget D_f importance by looping over relevant data
# discard D
# verify unlearning before
# modify model weights
# verify unlearning after