In [1]:
import os
import torch
from torchvision import datasets, transforms
from utils import loaders_by_classes, filter_loaders, balance, balanced_batch_size, submodel, get_activations
from classNet import ConvNet # for torch load
import matplotlib.pyplot as plt

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

num_workers = os.cpu_count()

device, num_workers

(device(type='mps'), 12)

In [2]:
model_all_class = torch.load('./models/all_class.pth', weights_only=False)

In [3]:
class_num = 7
class_name = '7 - seven'
model = submodel(model_all_class, class_num)
model.to(device)
model.eval()

ConvNet(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2_drop): Dropout2d(p=0.25, inplace=False)
  (fc1): Linear(in_features=3136, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=9, bias=True)
)

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize((0,), (1,))  # Normalize images
])

train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_set = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [5]:
batch_size = 128
train_loaders = loaders_by_classes(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_loaders = loaders_by_classes(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [6]:
loaders = train_loaders
loaders

{'0 - zero': <torch.utils.data.dataloader.DataLoader at 0x16a283fe0>,
 '1 - one': <torch.utils.data.dataloader.DataLoader at 0x16a1c7fb0>,
 '2 - two': <torch.utils.data.dataloader.DataLoader at 0x13fd58e00>,
 '3 - three': <torch.utils.data.dataloader.DataLoader at 0x16a244fb0>,
 '4 - four': <torch.utils.data.dataloader.DataLoader at 0x16a0ba0c0>,
 '5 - five': <torch.utils.data.dataloader.DataLoader at 0x11a9c2390>,
 '6 - six': <torch.utils.data.dataloader.DataLoader at 0x16a2b5490>,
 '7 - seven': <torch.utils.data.dataloader.DataLoader at 0x16a2b5370>,
 '8 - eight': <torch.utils.data.dataloader.DataLoader at 0x16a2b5340>,
 '9 - nine': <torch.utils.data.dataloader.DataLoader at 0x16a2b5550>}

In [7]:
blcd_batch_size = int(batch_size * balanced_batch_size(loaders, class_name))
blcd_batch_size

969

In [8]:
forget_loader = loaders[class_name]
retain_loader = filter_loaders(loaders, class_name, blcd_batch_size, shuffle=True, num_workers=num_workers)

In [22]:
retain_input = next(iter(retain_loader))
r_input_tensor, r_label_tensor = retain_input[0].to(device), retain_input[1].to(device)
forget_input = next(iter(forget_loader))
f_input_tensor, f_label_tensor = forget_input[0].to(device), forget_input[1].to(device)

In [None]:
def get_softmax(model, n_batch = 10):
    """
    Get rpz of forget wrt to retain
    Just at the init
    """
    softmax = 0
    for i_batch in range(n_batch):
        input_tensor = next(iter(forget_loader))
        softmax += torch.exp(model(input_tensor[0].to(device)))
    softmax = softmax / n_batch
    return softmax.mean(dim=0)

def convex_target(softmax, activations, labels):
    """
    softmax : from forget rpz wrt to retain
    activations : residues of retain
    """
    target = 0
    activations_dict = {}
    for idx, _label in enumerate(labels):
        label = int(_label.item())
        if label not in activations_dict.keys():
            activations_dict[label] = {}
            activations_dict[label]['activation'] = activations[idx, :].clone()
            activations_dict[label]['count'] = 1
        else:
            activations_dict[label]['activation'] += activations[idx, :]
            activations_dict[label]['count'] += 1

    for key, dval in activations_dict.items():
        ikey = key - 1 * (key > class_num)
        target += softmax[ikey] * dval['activation'] / dval['count']
    
    return target

def masked_grads(forget_grads, retain_grads):
    m_grads = []
    for idx_param, f_grad in enumerate(forget_grads):
        r_grad = retain_grads[idx_param]
        mask = f_grad * r_grad > 0
        m_grad = mask * f_grad * torch.abs(r_grad)
        m_grads.append(m_grad)
    return m_grads

def update_param(model, forget_grads, retain_grads):
    m_grads = masked_grads(forget_grads, retain_grads)
    max_perc_param = 0.
    for idx_param, param in enumerate(model.parameters()):
        param.grad = m_grads[idx_param]
        perc_param = 100 * torch.sum(param.grad > 0)/torch.tensor(param.size()).prod()
        if perc_param > max_perc_param:
            max_perc_param = perc_param
    return max_perc_param

def get_grads(model, softmax, forget_input, retain_input, device=device, verbose=False):
    r_input_tensor, r_label_tensor = retain_input[0].to(device), retain_input[1].to(device)
    f_input_tensor, _ = forget_input[0].to(device), forget_input[1].to(device)

    # get r_grads
    model.zero_grad()
    r_input_tensor.requires_grad = True
    retain_dict = get_activations(model, r_input_tensor, -1, verbose=verbose)

    criterion = torch.nn.CrossEntropyLoss()
    loss = criterion(retain_dict['output'], r_label_tensor)
    loss.backward()

    r_grads = []
    for param in model.parameters():
        r_grads.append(param.grad)

    # get f_grads
    model.zero_grad()
    f_input_tensor.requires_grad = True
    forget_dict = get_activations(model, f_input_tensor, -1, verbose=verbose)

    r_activations = retain_dict['activations'].detach()
    target = convex_target(softmax, r_activations, r_label_tensor)

    criterion = torch.nn.MSELoss()
    loss = criterion(torch.mean(forget_dict['activations'], dim=0), target)
    loss.backward()

    f_grads = []
    for param in model.parameters():
        f_grads.append(param.grad)
    
    return f_grads, r_grads

def unlearn_step(model, softmax, forget_input, retain_input, optimizer, device=device):
    optimizer.zero_grad()
    f_grads, r_grads = get_grads(model, softmax, forget_input, retain_input, verbose=False)
    max_perc_param = update_param(model, f_grads, r_grads)
    optimizer.step()
    return max_perc_param

def untrain(model, softmax, device, forget_loader, retain_loader, optimizer, n_step=None):
    model.train()
    forget_iter_loader = iter(forget_loader)
    for batch_idx, retain_input in enumerate(retain_loader):
        try:
            forget_input = next(forget_iter_loader)
        except:
            forget_iter_loader = iter(forget_loader)
            forget_input = next(forget_iter_loader)
        max_perc_param = unlearn_step(model, softmax, forget_input, retain_input, optimizer, device=device)
        if batch_idx % 10 == 0:
            print(f'{batch_idx} / {len(retain_loader)}')
        if batch_idx == n_step:
            break

In [13]:
softmax = get_softmax(model).detach()

In [14]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
untrain(model, softmax, device, forget_loader, retain_loader, optimizer, n_step = None)

tensor(2.4306, device='mps:0')
0 / 56
tensor(3.1250, device='mps:0')
tensor(1.7361, device='mps:0')
tensor(1.3889, device='mps:0')
tensor(1.7361, device='mps:0')
tensor(1.0417, device='mps:0')
tensor(0.6944, device='mps:0')
tensor(2.0833, device='mps:0')
tensor(1.0417, device='mps:0')
tensor(2.0833, device='mps:0')
tensor(1.7361, device='mps:0')
10 / 56
tensor(0.3472, device='mps:0')
tensor(0.6944, device='mps:0')
tensor(0.3472, device='mps:0')
tensor(0.6944, device='mps:0')
tensor(0.3472, device='mps:0')
tensor(0.6944, device='mps:0')
tensor(0.6944, device='mps:0')
tensor(1.3889, device='mps:0')
tensor(1.0417, device='mps:0')
tensor(1.0417, device='mps:0')
20 / 56
tensor(0.6944, device='mps:0')
tensor(0.6944, device='mps:0')
tensor(0.6944, device='mps:0')
tensor(1.0417, device='mps:0')
tensor(1.0417, device='mps:0')
tensor(0.6944, device='mps:0')
0.0
0.0
tensor(0.6944, device='mps:0')
tensor(0.6944, device='mps:0')
30 / 56
tensor(0.3472, device='mps:0')
0.0
0.0
tensor(0.6944, device='

In [28]:
softmax

tensor([3.5790e-03, 1.1933e-01, 1.7121e-01, 2.6863e-01, 5.7375e-02, 4.8021e-03,
        1.7888e-05, 7.3697e-03, 3.6768e-01], device='mps:0')

In [27]:
get_softmax(model)

tensor([4.5358e-04, 1.1746e-05, 3.6112e-02, 1.6826e-01, 1.1612e-01, 1.1124e-01,
        6.6318e-03, 4.7810e-04, 5.6069e-01], device='mps:0',
       grad_fn=<MeanBackward1>)

In [26]:
criterion = torch.nn.CrossEntropyLoss()
loss = criterion(model(r_input_tensor), r_label_tensor)
loss

tensor(0.9537, device='mps:0', grad_fn=<NllLossBackward0>)