In [1]:
import os
from tqdm import tqdm
import random
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets
from torchvision import transforms
from torchvision.models import resnet18

In [2]:
learning_rate = 1e-3
batch_size = 32
num_epochs = 2

In [3]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [4]:
def create_dataloaders(batch_size):
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485],
                            std=[0.228]),
    ])

    train_data = datasets.MNIST(root="../data/", train=True, download=True, transform=transform)
    test_data = datasets.MNIST(root="../data/", train=False,download=True, transform=transform)

    train_dataloader = DataLoader(train_data,
                                batch_size=batch_size,
                                num_workers=os.cpu_count(),
                                shuffle=True)
    test_dataloader = DataLoader(test_data,
                                batch_size=batch_size,
                                num_workers=os.cpu_count(),
                                shuffle=False)
    return train_dataloader, test_dataloader

In [5]:
class ResNet18(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet18, self).__init__()
        self.model = resnet18(weights="DEFAULT")
        self.model.conv1 = nn.Conv2d(1, 64, 7, 2, 3)
        self.linear = nn.Linear(1000, num_classes)

    def forward(self, x):
        x = self.model(x)
        x = self.linear(x)
        return x

### Training learning module

In [6]:
def train_step(model, dataloader, loss_fn, optimizer, device):
    epoch_loss = 0.0
    epoch_acc = 0.0
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        y_pred = model(X)

        loss = loss_fn(y_pred, y)
        epoch_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = torch.argmax(y_pred, dim=1)
        epoch_acc += torch.sum(preds == y) / len(y_pred)

    epoch_loss = epoch_loss / len(dataloader)
    epoch_acc = epoch_acc / len(dataloader)
    return epoch_loss, epoch_acc


def test_step(model, dataloader, loss_fn, device):
    epoch_loss = 0.0
    epoch_acc = 0.0
    model.eval()
    classes_total = {f"{i}": 0 for i in range(10)}
    classes_wrong = {f"{i}": 0 for i in range(10)}
    with torch.inference_mode():
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)

            y_pred = model(X)

            loss = loss_fn(y_pred, y)
            epoch_loss += loss.item()

            preds = torch.argmax(y_pred, dim=1)
            for i in y:
                classes_total[f"{i.item()}"] += 1
            for i in y[y != preds]:
                classes_wrong[f"{i.item()}"] += 1
                
            epoch_acc += torch.sum(preds == y) / len(y_pred)
    
    epoch_loss = epoch_loss / len(dataloader)
    epoch_acc = epoch_acc / len(dataloader)
    return epoch_loss, epoch_acc, classes_wrong, classes_total


def train(model, train_dataloader, test_dataloader, loss_fn, optimizer, num_epochs, device):
    results = {"train_loss": [],
               "test_loss": [],
               "train_acc": [],
               "test_acc": []}

    best_acc = 0.0
    for epoch in tqdm(range(1, num_epochs+1)):
        train_loss, train_acc = train_step(model, train_dataloader, loss_fn, optimizer, device)
        test_loss, test_acc, classes_wrong, classes_total = test_step(model, test_dataloader, loss_fn, device)
        
        if test_acc >= best_acc:
            torch.save(model.state_dict(), './weights/model.pth')
            best_acc = test_acc

        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["test_loss"].append(test_loss)
        results["test_acc"].append(test_acc)

        print(f"Epoch: {epoch} | "
              f"Train Loss: {train_loss:.4f} | "
              f"Train Acc: {train_acc:.4f} | "
              f"Test Loss: {test_loss:.4f} | "
              f"Test Acc: {test_acc:.4f}")
        
        for c in classes_wrong.keys():
            print(f"class {c}: {((classes_total[c]-classes_wrong[c])/classes_total[c] * 100):.2f}%")
    return results, classes_wrong, classes_total

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
if device == "cuda":
    print(f"GPU number: {torch.cuda.current_device()}")

model = ResNet18(num_classes=10).to(device)

loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

set_seed(42)

train_dataloader, test_dataloader = create_dataloaders(batch_size=batch_size)

results = train(model, train_dataloader, test_dataloader, loss_fn, optimizer, num_epochs, device)

Device: cuda
GPU number: 0


 50%|█████     | 1/2 [00:48<00:48, 48.45s/it]

Epoch: 1 | Train Loss: 0.2210 | Train Acc: 0.9500 | Test Loss: 0.0453 | Test Acc: 0.9858
class 0: 99.18%
class 1: 97.97%
class 2: 98.45%
class 3: 97.92%
class 4: 98.68%
class 5: 99.10%
class 6: 98.23%
class 7: 98.05%
class 8: 99.38%
class 9: 99.01%


100%|██████████| 2/2 [01:36<00:00, 48.48s/it]

Epoch: 2 | Train Loss: 0.3074 | Train Acc: 0.9512 | Test Loss: 0.0863 | Test Acc: 0.9724
class 0: 98.37%
class 1: 99.47%
class 2: 97.09%
class 3: 98.71%
class 4: 95.72%
class 5: 91.82%
class 6: 99.27%
class 7: 96.40%
class 8: 97.64%
class 9: 97.13%





### Training auxiliary unlearning module

In [8]:
def train_step(model, dataloader, loss_fn, optimizer, device):
    epoch_loss = 0.0
    epoch_acc = 0.0
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        y[y!=3] = torch.randint(0, 10, size=[1]).to(device)
        
        y_pred = model(X)
        
        loss = loss_fn(y_pred, y)
        epoch_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        preds = torch.argmax(y_pred, dim=1)
        epoch_acc += torch.sum(preds == y) / len(y_pred)

    epoch_loss = epoch_loss / len(dataloader)
    epoch_acc = epoch_acc / len(dataloader)
    return epoch_loss, epoch_acc


def test_step(model, dataloader, loss_fn, device):
    epoch_loss = 0.0
    epoch_acc = 0.0
    model.eval()
    
    classes_wrong = {f"{i}": 0 for i in range(10)}
    classes_total = {f"{i}": 0 for i in range(10)}
    with torch.inference_mode():
        for batch, (X, y) in tqdm(enumerate(dataloader)):
            X, y = X.to(device), y.to(device)

            y_pred = model(X)

            loss = loss_fn(y_pred, y)
            epoch_loss += loss.item()

            preds = torch.argmax(y_pred, dim=1)
            for i in y:
                classes_total[f"{i.item()}"] += 1
            for i in y[y != preds]:
                classes_wrong[f"{i.item()}"] += 1
                
            epoch_acc += torch.sum(preds == y) / len(y_pred)
    
    epoch_loss = epoch_loss / len(dataloader)
    epoch_acc = epoch_acc / len(dataloader)
    return epoch_loss, epoch_acc, classes_wrong, classes_total


def train(model, train_dataloader, test_dataloader, loss_fn, optimizer, num_epochs, device):
    results = {"train_loss": [],
               "test_loss": [],
               "train_acc": [],
               "test_acc": []}

    for epoch in tqdm(range(1, num_epochs+1)):
        train_loss, train_acc = train_step(model, train_dataloader, loss_fn, optimizer, device)
        test_loss, test_acc, classes_wrong, classes_total = test_step(model, test_dataloader, loss_fn, device)
        
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["test_loss"].append(test_loss)
        results["test_acc"].append(test_acc)

        print(f"Epoch: {epoch} | "
              f"Train Loss: {train_loss:.4f} | "
              f"Train Acc: {train_acc:.4f} | "
              f"Test Loss: {test_loss:.4f} | "
              f"Test Acc: {test_acc:.4f}")
    
        for c in classes_wrong.keys():
            print(f"class {c}: {((classes_total[c]-classes_wrong[c])/classes_total[c] * 100):.2f}%")
    torch.save(model.state_dict(), './weights/model_c.pth')
    return results, classes_wrong, classes_total

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = ResNet18(num_classes=10).to(device)
model.load_state_dict(torch.load("./weights/model.pth", weights_only=True))

loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

train_dataloader, _ = create_dataloaders(batch_size=32)

set_seed(42)

results, classes_wrong, classes_total = train(model, train_dataloader, test_dataloader, loss_fn, optimizer, num_epochs=2, device=device)

313it [00:02, 111.95it/s]00<?, ?it/s]
 50%|█████     | 1/2 [00:49<00:49, 49.63s/it]

Epoch: 1 | Train Loss: 2.2235 | Train Acc: 0.1811 | Test Loss: 2.1115 | Test Acc: 0.1010
class 0: 0.00%
class 1: 0.00%
class 2: 0.00%
class 3: 100.00%
class 4: 0.00%
class 5: 0.00%
class 6: 0.00%
class 7: 0.00%
class 8: 0.00%
class 9: 0.00%


313it [00:02, 111.02it/s]
100%|██████████| 2/2 [01:39<00:00, 49.79s/it]

Epoch: 2 | Train Loss: 2.2289 | Train Acc: 0.1926 | Test Loss: 2.1057 | Test Acc: 0.1162
class 0: 15.51%
class 1: 0.00%
class 2: 0.00%
class 3: 100.00%
class 4: 0.00%
class 5: 0.00%
class 6: 0.00%
class 7: 0.00%
class 8: 0.00%
class 9: 0.00%





In [10]:
model = ResNet18(num_classes=10).to(device)
model.load_state_dict(torch.load("./weights/model.pth", weights_only=True))

model_c = ResNet18(num_classes=10).to(device)
model_c.load_state_dict(torch.load("./weights/model_c.pth", weights_only=True))

B = 1.5
Wl = (model.linear.weight.data) - (B * model_c.linear.weight.data)
model.linear.weight.data = Wl

_, dataloader =  create_dataloaders(batch_size=32)

epoch_loss, epoch_acc, classes_wrong, classes_total = test_step(model, dataloader, loss_fn, device)

for c in classes_wrong.keys():
    print(f"class {c}: {((classes_total[c]-classes_wrong[c])/classes_total[c] * 100):.2f}%")

313it [00:02, 108.14it/s]

class 0: 99.39%
class 1: 97.89%
class 2: 98.64%
class 3: 0.00%
class 4: 98.47%
class 5: 99.33%
class 6: 98.33%
class 7: 98.44%
class 8: 99.59%
class 9: 98.81%



