# Unlearning on Vision Transformers

## TODO: write tiny ViT model in models, download and set up CIFAR-10, train on CIFAR-10 and test various unlearning methods such as SSD, on combinations of Q,K,V matrices, or perhaps the MLP, etc.

# ––––––––––

In [1]:
import copy
import json
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchinfo import summary
from torchvision import transforms
from tqdm import tqdm

In [7]:
%load_ext autoreload
%autoreload 2

In [2]:
drive = None
# from google.colab import drive
# drive.mount('/content/drive')

In [3]:
path = "./"

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
path = path if drive is None else "/content/drive/MyDrive/self-learn/unlearning"

In [5]:
sys.path.append(path)

In [13]:
from constants import *
from utils import set_seed
from cifar_10_utils import train_data, val_data, train_loader, val_loader, invTrans, class_labels
from models import get_vit_and_optimizer, get_attack_model_and_optimizer

set_seed()

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  2%|▌                          | 3407872/170498071 [00:33<27:00, 103127.11it/s]


KeyboardInterrupt: 

In [None]:
MODEL_NAME = f"ViT_CIFAR_10_ORIGINAL"
print("Model Name:", MODEL_NAME)

In [None]:
train_data

# Data

In [88]:
x = train_data[9][0].unsqueeze(0)

In [89]:
x.shape

torch.Size([1, 3, 32, 32])

In [96]:
pe = PatchEmbedding(img_size=32, patch_size=2, n_embd=512, in_channels=3)

In [97]:
pe(x).shape

torch.Size([1, 256, 512])

In [None]:
# (BATCH_SIZE, 3, 32, 32)

batch = next(iter(train_loader))
print(batch[0].shape)
test_idx = 2
plt.imshow(batch[0][test_idx].permute(1, 2, 0))
plt.title(f"{class_labels[batch[1][test_idx]]}")

# Standard Training

In [10]:
def train(model, train_loader, val_loader, optimizer, criterion, device):

    model.train()
    model.to(device)
    train_losses, val_losses = [], []
    val_accuracies = []

    for epoch in range(EPOCHS):

        print(f"Epoch {epoch+1}/{EPOCHS}")

        for step, (img, label) in enumerate(train_loader):

            img, label = img.to(device), label.to(device)

            optimizer.zero_grad()
            out = model(img)
            loss = criterion(out, label)
            train_losses.append(loss.item())  # every step
            loss.backward()

            # Monitoring overall gradient norm
            grads = [
                param.grad.detach().flatten()
                for param in model.parameters()
                if param.grad is not None
            ]
            norm = torch.cat(grads).norm()

            optimizer.step()

            if step % PRINT_ITERS == 0 and step != 0:
                val_loss, val_acc = eval(model, val_loader, criterion, device)
                val_losses.append(val_loss)
                val_accuracies.append(val_acc)
                print(
                    f"Step: {step}/{len(train_loader)}, Running Average Loss: {np.mean(train_losses):.3f} |",
                    f"Val Loss: {val_loss:.3f} | Val Acc: {val_acc:.3f} | Grad Norm: {norm:.2f}",
                )
                model.train()

        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
            },
            f"{path}/checkpoints/{MODEL_NAME}_EPOCH_{epoch+1}_SEED_{SEED}.pt",
        )

        with open(
            f"{path}/train_logs/{MODEL_NAME}_SEED_{SEED}_train_losses.json", "w"
        ) as f:
            json.dump(train_losses, f)

        with open(
            f"{path}/train_logs/{MODEL_NAME}_SEED_{SEED}_val_losses.json", "w"
        ) as f2:
            json.dump(val_losses, f2)

        with open(
            f"{path}/train_logs/{MODEL_NAME}_SEED_{SEED}_val_accuracies.json", "w"
        ) as f3:
            json.dump(val_accuracies, f3)

    return train_losses, val_losses, val_accuracies

In [11]:
def eval(model, val_loader, criterion, device):
    val_losses = []
    correct = 0
    model.eval()

    with torch.no_grad():
        for i, (img, label) in enumerate(val_loader):

            img, label = img.to(device), label.to(device)
            out = model(img)

            loss_eval = criterion(out, label)
            val_losses.append(loss_eval.item())

            pred = out.argmax(dim=1, keepdim=True)
            correct += pred.eq(label.view_as(pred)).sum().item()

    val_loss = np.mean(val_losses)
    val_acc = correct / ((len(val_loader) - 1) * BATCH_SIZE + label.size(0))

    return val_loss, val_acc

In [38]:
model, optimizer = get_vit_and_optimizer(seed=SEED)
model.to(device)
# summary(model)

# since no log-softmax output layer in model
criterion = nn.CrossEntropyLoss()

In [None]:
## Driver code
train_losses, val_losses, val_accuracies = train(
    model, train_loader, val_loader, optimizer, criterion, device
)