In [None]:
import os
import torch
from torchvision import datasets, transforms
from utils import loaders_by_classes, filter_loaders, balance, balanced_batch_size, submodel, get_activations
from classNet import ConvNet # for torch load
import matplotlib.pyplot as plt

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

num_workers = os.cpu_count()

device, num_workers

(device(type='mps'), 12)

In [None]:
model_all_class = torch.load('./models/all_class.pth', weights_only=False)

In [None]:
class_num = 7
class_name = '7 - seven'
model = submodel(model_all_class, class_num)
model.to(device)
model.eval()

ConvNet(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2_drop): Dropout2d(p=0.25, inplace=False)
  (fc1): Linear(in_features=3136, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=9, bias=True)
)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize((0,), (1,))  # Normalize images
])

train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_set = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [None]:
batch_size = 64
train_loaders = loaders_by_classes(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_loaders = loaders_by_classes(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [None]:
loaders = train_loaders
loaders

{'0 - zero': <torch.utils.data.dataloader.DataLoader at 0x32a243740>,
 '1 - one': <torch.utils.data.dataloader.DataLoader at 0x32a2878f0>,
 '2 - two': <torch.utils.data.dataloader.DataLoader at 0x32a09a720>,
 '3 - three': <torch.utils.data.dataloader.DataLoader at 0x1061eb290>,
 '4 - four': <torch.utils.data.dataloader.DataLoader at 0x32a2b4350>,
 '5 - five': <torch.utils.data.dataloader.DataLoader at 0x32a2b4770>,
 '6 - six': <torch.utils.data.dataloader.DataLoader at 0x32a2b5040>,
 '7 - seven': <torch.utils.data.dataloader.DataLoader at 0x32a2b5100>,
 '8 - eight': <torch.utils.data.dataloader.DataLoader at 0x32a2b51c0>,
 '9 - nine': <torch.utils.data.dataloader.DataLoader at 0x32a2b5280>}

In [None]:
blcd_batch_size = int(batch_size * balanced_batch_size(loaders, class_name))
blcd_batch_size

484

In [None]:
forget_loader = loaders[class_name]
retain_loader = filter_loaders(loaders, class_name, blcd_batch_size, shuffle=True, num_workers=num_workers)

In [None]:
retain_input = next(iter(retain_loader))
r_input_tensor, r_label_tensor = retain_input[0].to(device), retain_input[1].to(device)
forget_input = next(iter(forget_loader))
f_input_tensor, f_label_tensor = forget_input[0].to(device), forget_input[1].to(device)

In [None]:
def get_softmax(model, n_batch = 10):
    """
    Get rpz of forget wrt to retain
    Just at the init
    """
    softmax = 0
    for i_batch in range(n_batch):
        input_tensor = next(iter(forget_loader))
        softmax += torch.exp(model(input_tensor[0].to(device)))
    softmax = softmax / n_batch
    return softmax.mean(dim=0)

def convex_target(softmax, activations, labels):
    """
    softmax : from forget rpz wrt to retain
    activations : residues of retain
    """
    target = 0
    activations_dict = {}
    for idx, _label in enumerate(labels):
        label = int(_label.item())
        if label not in activations_dict.keys():
            activations_dict[label] = {}
            activations_dict[label]['activation'] = activations[idx, :].clone()
            activations_dict[label]['count'] = 1
        else:
            activations_dict[label]['activation'] += activations[idx, :]
            activations_dict[label]['count'] += 1

    for key, dval in activations_dict.items():
        ikey = key - 1 * (key > class_num)
        target += softmax[ikey] * dval['activation'] / dval['count']
    
    return target

def masked_grads(forget_grads, retain_grads):
    m_grads = []
    for idx_param, f_grad in enumerate(forget_grads):
        r_grad = retain_grads[idx_param]
        mask = f_grad * r_grad > 0
        m_grad = mask * f_grad * torch.abs(r_grad)
        m_grads.append(m_grad)
    return m_grads

def update_param(model, forget_grads, retain_grads):
    m_grads = masked_grads(forget_grads, retain_grads)
    for idx_param, param in enumerate(model.parameters()):
        
        param.grad = m_grads[idx_param]
        if idx_param == 4:
            print(100 * torch.sum(param.grad > 0)/torch.tensor(param.size()).prod())

# get grads
def get_grads(model, softmax, forget_input, retain_input, device=device, verbose=False):
    r_input_tensor, r_label_tensor = retain_input[0].to(device), retain_input[1].to(device)
    f_input_tensor, _ = forget_input[0].to(device), forget_input[1].to(device)

    # get r_grads
    model.zero_grad()
    r_input_tensor.requires_grad = True
    retain_dict = get_activations(model, r_input_tensor, -1, verbose=verbose)

    criterion = torch.nn.CrossEntropyLoss()
    loss = criterion(retain_dict['output'], r_label_tensor)
    loss.backward()

    r_grads = []
    for param in model.parameters():
        r_grads.append(param.grad)

    # get f_grads
    model.zero_grad()
    f_input_tensor.requires_grad = True
    forget_dict = get_activations(model, f_input_tensor, -1, verbose=verbose)

    r_activations = retain_dict['activations'].detach()
    target = convex_target(softmax, r_activations, r_label_tensor)

    criterion = torch.nn.MSELoss()
    loss = criterion(torch.mean(forget_dict['activations'], dim=0), target)
    loss.backward()

    f_grads = []
    for param in model.parameters():
        f_grads.append(param.grad)
    
    return f_grads, r_grads

def unlearn_step(model, softmax, forget_input, retain_input, optimizer, device=device):
    optimizer.zero_grad()
    f_grads, r_grads = get_grads(model, softmax, forget_input, retain_input, verbose=False)
    update_param(model, f_grads, r_grads)
    optimizer.step()

def untrain(model, softmax, device, forget_loader, retain_loader, optimizer, n_step = None):
    model.train()
    forget_iter_loader = iter(forget_loader)
    for batch_idx, retain_input in enumerate(retain_loader):
        try:
            forget_input = next(forget_iter_loader)
        except:
            forget_iter_loader = iter(forget_loader)
            forget_input = next(forget_iter_loader)
        unlearn_step(model, softmax, forget_input, retain_input, optimizer, device=device)
        if batch_idx % 10 == 0:
            print(f'{batch_idx} / {len(retain_loader)}')
        if batch_idx == n_step:
            break

In [None]:
softmax = get_softmax(model).detach()

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
untrain(model, softmax, device, forget_loader, retain_loader, optimizer, n_step = None)

tensor(4.2112, device='mps:0')
0 / 112
tensor(0.9482, device='mps:0')
tensor(0.5989, device='mps:0')
tensor(3.3980, device='mps:0')
tensor(1.0715, device='mps:0')
tensor(0.8919, device='mps:0')
tensor(0.6427, device='mps:0')
tensor(7.6590, device='mps:0')
tensor(2.6472, device='mps:0')
tensor(9.4986, device='mps:0')
tensor(3.4720, device='mps:0')
10 / 112
tensor(7.8137, device='mps:0')
tensor(0.9940, device='mps:0')
tensor(0.9950, device='mps:0')
tensor(2.4623, device='mps:0')
tensor(0.6285, device='mps:0')
tensor(1.0939, device='mps:0')
tensor(2.5032, device='mps:0')
tensor(0.7102, device='mps:0')
tensor(1.3019, device='mps:0')
tensor(0.7197, device='mps:0')
20 / 112
tensor(0.8998, device='mps:0')
tensor(1.1367, device='mps:0')
tensor(3.2466, device='mps:0')
tensor(1.0565, device='mps:0')
tensor(1.3091, device='mps:0')
tensor(1.1813, device='mps:0')
tensor(5.7998, device='mps:0')
tensor(1.8512, device='mps:0')
tensor(8.2425, device='mps:0')
tensor(0.8732, device='mps:0')
30 / 112
tens

In [None]:
criterion = torch.nn.CrossEntropyLoss()
loss = criterion(model(r_input_tensor), r_label_tensor)
loss

tensor(0.8271, device='mps:0', grad_fn=<NllLossBackward0>)