In [1]:
import os
from tqdm import tqdm
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets
from torchvision import transforms

In [2]:
learning_rate = 1e-2
batch_size = 32
num_epochs = 2

In [3]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [4]:
def create_dataloaders(batch_size):
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485],
                            std=[0.228]),
    ])

    train_data = datasets.MNIST(root="../data/", train=True, download=True, transform=transform)
    test_data = datasets.MNIST(root="../data/", train=False,download=True, transform=transform)

    train_dataloader = DataLoader(train_data,
                                batch_size=batch_size,
                                num_workers=os.cpu_count(),
                                shuffle=True)
    test_dataloader = DataLoader(test_data,
                                batch_size=batch_size,
                                num_workers=os.cpu_count(),
                                shuffle=False)
    return train_dataloader, test_dataloader

In [5]:
class LeNet5(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet5,self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(2, 2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(6, 16, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(2, 2),
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(16, 120, kernel_size=5, stride=1),
            nn.Tanh(),
        )
        self.fc1 = nn.Linear(120, 100)
        self.fc2 = nn.Linear(100, 1000)
        self.fc3 = nn.Linear(1000, num_classes)
        self.tanh = nn.Tanh()
        self.flatten = nn.Flatten()

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.tanh(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

### Training learning module

In [6]:
def train_step(model, dataloader, loss_fn, optimizer, device):
    epoch_loss = 0.0
    epoch_acc = 0.0
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        y_pred = model(X)

        loss = loss_fn(y_pred, y)
        epoch_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = torch.argmax(y_pred, dim=1)
        epoch_acc += torch.sum(preds == y) / len(y_pred)

    epoch_loss = epoch_loss / len(dataloader)
    epoch_acc = epoch_acc / len(dataloader)
    return epoch_loss, epoch_acc


def test_step(model, dataloader, loss_fn, device):
    epoch_loss = 0.0
    epoch_acc = 0.0
    model.eval()
    classes_total = {f"{i}": 0 for i in range(10)}
    classes_wrong = {f"{i}": 0 for i in range(10)}
    with torch.inference_mode():
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)

            y_pred = model(X)

            loss = loss_fn(y_pred, y)
            epoch_loss += loss.item()

            preds = torch.argmax(y_pred, dim=1)
            for i in y:
                classes_total[f"{i.item()}"] += 1
            for i in y[y != preds]:
                classes_wrong[f"{i.item()}"] += 1
                
            epoch_acc += torch.sum(preds == y) / len(y_pred)
    
    epoch_loss = epoch_loss / len(dataloader)
    epoch_acc = epoch_acc / len(dataloader)
    return epoch_loss, epoch_acc, classes_wrong, classes_total


def train(model, train_dataloader, test_dataloader, loss_fn, optimizer, num_epochs, device):
    results = {"train_loss": [],
               "test_loss": [],
               "train_acc": [],
               "test_acc": []}

    best_acc = 0.0
    for epoch in tqdm(range(1, num_epochs+1)):
        train_loss, train_acc = train_step(model, train_dataloader, loss_fn, optimizer, device)
        test_loss, test_acc, classes_wrong, classes_total = test_step(model, test_dataloader, loss_fn, device)
        
        if test_acc >= best_acc:
            torch.save(model.state_dict(), './weights/model.pth')
            best_acc = test_acc

        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["test_loss"].append(test_loss)
        results["test_acc"].append(test_acc)

        print(f"Epoch: {epoch} | "
              f"Train Loss: {train_loss:.4f} | "
              f"Train Acc: {train_acc:.4f} | "
              f"Test Loss: {test_loss:.4f} | "
              f"Test Acc: {test_acc:.4f}")
        
        for c in classes_wrong.keys():
            print(f"class {c}: {((classes_total[c]-classes_wrong[c])/classes_total[c] * 100):.2f}%")
    return results, classes_wrong, classes_total

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
if device == "cuda":
    print(f"GPU number: {torch.cuda.current_device()}")

model = LeNet5(num_classes=10).to(device)

loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=4e-5)

set_seed(42)

train_dataloader, test_dataloader = create_dataloaders(batch_size=batch_size)

results = train(model, train_dataloader, test_dataloader, loss_fn, optimizer, num_epochs, device)

Device: cuda
GPU number: 0


 50%|█████     | 1/2 [00:20<00:20, 20.91s/it]

Epoch: 1 | Train Loss: 0.6861 | Train Acc: 0.8025 | Test Loss: 0.2667 | Test Acc: 0.9227
class 0: 98.16%
class 1: 97.89%
class 2: 88.95%
class 3: 90.89%
class 4: 93.38%
class 5: 86.88%
class 6: 93.84%
class 7: 91.34%
class 8: 89.84%
class 9: 90.49%


100%|██████████| 2/2 [00:41<00:00, 20.88s/it]

Epoch: 2 | Train Loss: 0.2198 | Train Acc: 0.9352 | Test Loss: 0.1540 | Test Acc: 0.9546
class 0: 98.78%
class 1: 98.85%
class 2: 94.86%
class 3: 94.85%
class 4: 95.21%
class 5: 92.49%
class 6: 97.29%
class 7: 93.09%
class 8: 94.76%
class 9: 93.86%





### Training auxiliary unlearning module

In [8]:
def train_step(model, dataloader, loss_fn, optimizer, device):
    epoch_loss = 0.0
    epoch_acc = 0.0
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        y[y==3] = torch.randint(0, 10, size=[1]).to(device)
        
        y_pred = model(X)
        
        loss = loss_fn(y_pred, y)
        epoch_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        preds = torch.argmax(y_pred, dim=1)
        epoch_acc += torch.sum(preds == y) / len(y_pred)

    epoch_loss = epoch_loss / len(dataloader)
    epoch_acc = epoch_acc / len(dataloader)
    return epoch_loss, epoch_acc


def test_step(model, dataloader, loss_fn, device):
    epoch_loss = 0.0
    epoch_acc = 0.0
    model.eval()
    
    classes_wrong = {f"{i}": 0 for i in range(10)}
    classes_total = {f"{i}": 0 for i in range(10)}
    with torch.inference_mode():
        for batch, (X, y) in tqdm(enumerate(dataloader)):
            X, y = X.to(device), y.to(device)

            y_pred = model(X)

            loss = loss_fn(y_pred, y)
            epoch_loss += loss.item()

            preds = torch.argmax(y_pred, dim=1)
            for i in y:
                classes_total[f"{i.item()}"] += 1
            for i in y[y != preds]:
                classes_wrong[f"{i.item()}"] += 1
                
            epoch_acc += torch.sum(preds == y) / len(y_pred)
    
    epoch_loss = epoch_loss / len(dataloader)
    epoch_acc = epoch_acc / len(dataloader)
    return epoch_loss, epoch_acc, classes_wrong, classes_total


def train(model, train_dataloader, test_dataloader, loss_fn, optimizer, num_epochs, device):
    results = {"train_loss": [],
               "test_loss": [],
               "train_acc": [],
               "test_acc": []}

    for epoch in tqdm(range(1, num_epochs+1)):
        train_loss, train_acc = train_step(model, train_dataloader, loss_fn, optimizer, device)
        test_loss, test_acc, classes_wrong, classes_total = test_step(model, test_dataloader, loss_fn, device)
        
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["test_loss"].append(test_loss)
        results["test_acc"].append(test_acc)

        print(f"Epoch: {epoch} | "
              f"Train Loss: {train_loss:.4f} | "
              f"Train Acc: {train_acc:.4f} | "
              f"Test Loss: {test_loss:.4f} | "
              f"Test Acc: {test_acc:.4f}")
    
        for c in classes_wrong.keys():
            print(f"class {c}: {((classes_total[c]-classes_wrong[c])/classes_total[c] * 100):.2f}%")
    torch.save(model.state_dict(), './weights/model_s.pth')
    return results, classes_wrong, classes_total


device = "cuda" if torch.cuda.is_available() else "cpu"

model = LeNet5(num_classes=10).to(device)
model.load_state_dict(torch.load("./weights/model.pth", weights_only=True))

loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=4e-5)

train_dataloader, _ = create_dataloaders(batch_size=32)

set_seed(42)

results, classes_wrong, classes_total = train(model, train_dataloader, test_dataloader, loss_fn, optimizer, num_epochs=3, device=device)

313it [00:01, 190.65it/s]00<?, ?it/s]
 33%|███▎      | 1/3 [00:21<00:42, 21.15s/it]

Epoch: 1 | Train Loss: 0.4514 | Train Acc: 0.8729 | Test Loss: 0.3655 | Test Acc: 0.8722
class 0: 98.57%
class 1: 98.85%
class 2: 96.03%
class 3: 1.29%
class 4: 97.05%
class 5: 98.43%
class 6: 96.97%
class 7: 95.62%
class 8: 96.61%
class 9: 93.76%


313it [00:01, 199.95it/s]
 67%|██████▋   | 2/3 [00:41<00:20, 20.78s/it]

Epoch: 2 | Train Loss: 0.3727 | Train Acc: 0.8826 | Test Loss: 0.3218 | Test Acc: 0.8788
class 0: 99.08%
class 1: 99.03%
class 2: 97.38%
class 3: 1.68%
class 4: 98.37%
class 5: 97.42%
class 6: 97.60%
class 7: 97.76%
class 8: 96.00%
class 9: 95.24%


313it [00:01, 181.76it/s]
100%|██████████| 3/3 [01:03<00:00, 21.11s/it]

Epoch: 3 | Train Loss: 0.3504 | Train Acc: 0.8873 | Test Loss: 0.2983 | Test Acc: 0.8829
class 0: 98.98%
class 1: 99.12%
class 2: 98.35%
class 3: 1.68%
class 4: 97.86%
class 5: 98.21%
class 6: 98.54%
class 7: 96.69%
class 8: 99.08%
class 9: 95.34%





In [9]:
model = LeNet5(num_classes=10).to(device)
model.load_state_dict(torch.load("./weights/model.pth", weights_only=True))

model_s = LeNet5(num_classes=10).to(device)
model_s.load_state_dict(torch.load("./weights/model_s.pth", weights_only=True))

A = 0.01
Wl0 = (A * model.fc2.weight.data) + ((1-A) * model_s.fc2.weight.data)
model.fc2.weight.data = Wl0
Wl1 = (A * model.fc3.weight.data) + ((1-A) * model_s.fc3.weight.data)
model.fc3.weight.data = Wl1

_, dataloader = create_dataloaders(batch_size=32)

epoch_loss, epoch_acc, classes_wrong, classes_total = test_step(model, dataloader, loss_fn, device)

for c in classes_wrong.keys():
    print(f"class {c}: {((classes_total[c]-classes_wrong[c])/classes_total[c] * 100):.2f}%")

313it [00:01, 191.10it/s]

class 0: 98.16%
class 1: 98.59%
class 2: 96.03%
class 3: 69.70%
class 4: 95.11%
class 5: 95.52%
class 6: 96.66%
class 7: 92.32%
class 8: 95.38%
class 9: 94.65%



