<a href="https://colab.research.google.com/github/shreyasudaya/Boundary-Unlearning/blob/master/unlearn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
# # Step 1: Install and import necessary libraries
# import torch
# import torch.nn as nn
# import torch.optim as optim
# import torchvision
# import torchvision.transforms as transforms
# from torchvision import datasets, models
# import time

# # Step 2: Set device configuration (GPU if available)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Step 3: Download and preprocess the CIFAR-100 dataset
# transform_train = transforms.Compose([
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomCrop(32, padding=4),
#     transforms.ToTensor(),
#     transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2761))
# ])

# transform_test = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2761))
# ])

# # Download the CIFAR-100 dataset
# train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
# test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

# # Step 4: Create DataLoader for batching and shuffling
# batch_size = 128
# train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
# test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# # Step 5: Load a pre-trained ResNet model and modify it for CIFAR-100
# model = models.resnet18(pretrained=True)  # You can also use resnet34 or resnet50
# num_ftrs = model.fc.in_features
# model.fc = nn.Linear(num_ftrs, 100)  # Modify the final layer to match CIFAR-100 classes
# model = model.to(device)

# # Step 6: Define loss function and optimizer
# criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)

# # Step 7: Train the model
# def train_model(model, train_loader, criterion, optimizer, num_epochs=10):
#     model.train()
#     for epoch in range(num_epochs):
#         running_loss = 0.0
#         correct = 0
#         total = 0
#         start_time = time.time()

#         for inputs, labels in train_loader:
#             inputs, labels = inputs.to(device), labels.to(device)
#             optimizer.zero_grad()

#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()

#             running_loss += loss.item()
#             _, predicted = outputs.max(1)
#             total += labels.size(0)
#             correct += predicted.eq(labels).sum().item()

#         end_time = time.time()
#         epoch_loss = running_loss / len(train_loader)
#         accuracy = 100. * correct / total
#         print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.2f}%, Time: {end_time - start_time:.2f}s")

# # Step 8: Train the model for a specific number of epochs
# train_model(model, train_loader, criterion, optimizer, num_epochs=10)

# # Step 9: Evaluate the model on the test set
# def evaluate_model(model, test_loader):
#     model.eval()
#     correct = 0
#     total = 0
#     with torch.no_grad():
#         for inputs, labels in test_loader:
#             inputs, labels = inputs.to(device), labels.to(device)
#             outputs = model(inputs)
#             _, predicted = outputs.max(1)
#             total += labels.size(0)
#             correct += predicted.eq(labels).sum().item()

#     accuracy = 100. * correct / total
#     print(f'Test Accuracy: {accuracy:.2f}%')

# evaluate_model(model, test_loader)


In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import copy
import numpy as np
from torchvision.models import resnet18

In [10]:

def load_data(batch_size):
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    # Split trainset into retain and forget subsets
    train_retain_size = int(len(trainset) * 0.9)
    train_forget_size = len(trainset) - train_retain_size

    retain_dataset, forget_dataset = random_split(trainset, [train_retain_size, train_forget_size])

    retain_loader = DataLoader(retain_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    forget_loader = DataLoader(forget_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

    return retain_loader, forget_loader, test_loader



In [11]:
# ResNet model setup
def get_model():
    model = resnet18(num_classes=100)
    return model

def likelihood(score, mean, var):
    nll = -(((score - mean)**2) / (2 * (var ** 2))) - 0.5 * torch.log(var ** 2) - 0.5 * torch.log(4 * torch.acos(torch.zeros(1)))
    return torch.exp(nll)



In [12]:
def get_salun_mask(model, device, forget_loader, threshold=0.1):
    mask = {}
    for name, param in model.named_parameters():
        mask[name] = 0

    model.train()
    for data, target in forget_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = F.nll_loss(output, target, reduction='sum')
        loss.backward()
        with torch.no_grad():
            for name, param in model.named_parameters():
                if param.grad is not None:
                    mask[name] += param.grad.data

    with torch.no_grad():
        for name in mask:
            mask[name] = torch.abs_(mask[name])

    sorted_dict_positions = {}
    hard_dict = {}

    # Concatenate all tensors into a single tensor
    all_elements = -torch.cat([tensor.flatten() for tensor in mask.values()])
    threshold_index = int(len(all_elements) * threshold)

    positions = torch.argsort(all_elements)
    ranks = torch.argsort(positions)

    start_index = 0
    for key, tensor in mask.items():
        num_elements = tensor.numel()
        tensor_ranks = ranks[start_index : start_index + num_elements]
        sorted_positions = tensor_ranks.reshape(tensor.shape)
        sorted_dict_positions[key] = sorted_positions

        threshold_tensor = torch.zeros_like(tensor_ranks)
        threshold_tensor[tensor_ranks < threshold_index] = 1
        threshold_tensor = threshold_tensor.reshape(tensor.shape)
        hard_dict[key] = threshold_tensor
        start_index += num_elements

    return hard_dict

In [13]:


# Train function with SALUN unlearning
def salun_train(args, model, device, retain_loader, forget_loader, test_loader, optimizer, epochs, batch_size):
    mask = get_salun_mask(model, device, forget_loader, threshold=args.salun_threshold)
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for data, target in retain_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            outputs = model(data)
            loss = F.cross_entropy(outputs, target)
            loss.backward()

            # Apply the mask to the gradients
            if mask:
                for name, param in model.named_parameters():
                    if param.grad is not None:
                        param.grad *= mask[name]

            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(retain_loader)}")

    # Test the model after unlearning
    test(model, device, test_loader)

# Testing function
def test(model, device, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    print(f'Test Accuracy: {100 * correct / total} %')

# Main execution
if __name__ == "__main__":
    class Args:
        salun_threshold = 0.1
        batch_size = 128
        lr = 0.1
        epochs = 10
        momentum = 0.9
        weight_decay = 1e-4

    args = Args()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    retain_loader, forget_loader, test_loader = load_data(batch_size=args.batch_size)
    model = get_model().to(device)

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    # Train the model with SALUN unlearning
    salun_train(args, model, device, retain_loader, forget_loader, test_loader, optimizer, epochs=args.epochs, batch_size=args.batch_size)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 72968622.69it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Epoch [1/10], Loss: 2.3724263330752198
Epoch [2/10], Loss: 1.6246497085825964
Epoch [3/10], Loss: 1.4239645715464244
Epoch [4/10], Loss: 1.2978979223831133
Epoch [5/10], Loss: 1.175824600695209
Epoch [6/10], Loss: 1.0900118858976797
Epoch [7/10], Loss: 1.021635575727983
Epoch [8/10], Loss: 0.9539693648164923
Epoch [9/10], Loss: 0.9179781788790767
Epoch [10/10], Loss: 0.8696315175091679
Test Accuracy: 66.63 %
