# SVD Unlearn

## Clone class_forgetting repo and import packages

In [1]:
!git clone https://github.com/GittyHarsha/class_forgetting.git

%mv /content/class_forgetting/utils.py /content
%mv /content/class_forgetting/pretrained_models /content/pretrained_models
%mv /content/class_forgetting/models /content/models

Cloning into 'class_forgetting'...
remote: Enumerating objects: 136, done.[K
remote: Counting objects: 100% (119/119), done.[K
remote: Compressing objects: 100% (92/92), done.[K
remote: Total 136 (delta 56), reused 69 (delta 25), pack-reused 17 (from 1)[K
Receiving objects: 100% (136/136), 69.84 MiB | 30.37 MiB/s, done.
Resolving deltas: 100% (58/58), done.


In [2]:
!pip install scienceplots

Collecting scienceplots
  Downloading SciencePlots-2.1.1-py3-none-any.whl.metadata (11 kB)
Downloading SciencePlots-2.1.1-py3-none-any.whl (16 kB)
Installing collected packages: scienceplots
Successfully installed scienceplots-2.1.1


In [3]:
from __future__ import print_function
import argparse
import torch
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import confusion_matrix
import numpy as np
import os
from utils import get_dataset, get_model, test, SVC_MIA
import copy
import os
import random
from torch import nn
from collections import OrderedDict, Counter

from multiprocessing import Pool
from multiprocessing import Process, Value, Array
from functools import partial
import torch.multiprocessing as mp

from tqdm import tqdm
from sklearn.metrics import roc_auc_score, roc_curve, balanced_accuracy_score, confusion_matrix, accuracy_score
import matplotlib.pyplot as plt
import scienceplots

## Utils

In [4]:
def get_projection_matrix(device, Mr, Mf):
    update_dict = OrderedDict()
    for act in Mr.keys():
        mr = Mr[act]
        mf = Mf[act]
        I = torch.eye(mf.shape[0]).to(device)
        update_dict[act] =  I  - (mf - torch.mm(mf,mr) )
    return update_dict

## Unlearn method

In [5]:
def svd_unlearn(args, model, device, retain_loader, forget_loader, train_loader, test_loader, train_dataset, val_index = None, custom=False, **kwargs):

    model.eval() # Ensures batch statistics do not change.
    # get 100 images of each class other than unlearning class
    index_list = []
    targets = np.array(train_dataset.targets)
    for i in range(args.num_classes):
        if i !=  args.unlearn_class[0]:
            class_i_index = np.intersect1d(np.where(i == targets)[0], val_index)
            index_list.extend(class_i_index[:int(args.our_samples//(args.num_classes-1))])
    if custom:
        small_retain_loader = retain_loader
        small_forget_loader = forget_loader
    else:
      small_retain_loader = torch.utils.data.DataLoader( torch.utils.data.Subset(train_dataset, index_list), batch_size=args.our_samples , shuffle=True)
      small_forget_loader = torch.utils.data.DataLoader( forget_loader.dataset , batch_size=args.our_samples , shuffle=True)
    print(len(small_retain_loader.dataset))
    print(len(small_forget_loader.dataset))
    with torch.no_grad():
        for data, target in small_retain_loader:
        # for data, target in retain_loader
            data, target = data.to(device), target.to(device)
            # Rr = model.get_activations(data)
            Mr = model.get_scaled_projections(data, args.our_alpha_r, args.our_max_patches)
            break
        # print(Counter(target.tolist()))

        for data, target in small_forget_loader:
        # for data, target in forget_loader:
            data, target = data.to(device), target.to(device)
            # Rf = model.get_activations(data)
            Mf = model.get_scaled_projections(data, args.our_alpha_f, args.our_max_patches)
            break
        # print(Counter(target.tolist()))

    # model.project_weights(get_projection_matrix(device=device, Rr=Rr, Rf=Rf, alpha_r = args.our_alpha_r, alpha_f = args.our_alpha_f, update_dict=OrderedDict()) )
    model.project_weights(get_projection_matrix(device, Mr, Mf))
    return model


## train function

In [6]:
def train(args, model, device, train_loader, optimizer, epoch, mode = "descent", clip=None):
    model.train()
    train_loss= 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        if mode == "ascent":
            for param in model.parameters():
                if param.grad is not None:
                    param.grad.data *= -1.0
            if clip is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        train_loss += loss.detach().item()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()), end='\r')
           # wandb.log({"train_loss":loss.item() })
            if args.dry_run:
                break
    return train_loss

## Parameters

In [7]:
parser = argparse.ArgumentParser(description='PyTorch cifar10 Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--dataset', type=str, default="gtsrb",
                    help='')
parser.add_argument('--test-batch-size', type=int, default=512, metavar='N',
                    help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs-or-steps', type=int, default=20, metavar='N',
                    help='number of epochs to train (default: 100)')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                    help='')
parser.add_argument('--momentum', type=float, default=0.9, metavar='LR',
                    help='')
parser.add_argument('--weight-decay', type=float, default=5e-4, metavar='LR',
                    help='')
parser.add_argument('--gamma', type=float, default=0.5, metavar='M',
                    help='Learning rate step gamma (default: 0.7) after 50 epochs')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--no-mps', action='store_true', default=False,
                    help='disables macOS GPU training')
parser.add_argument('--dry-run', action='store_true', default=False,
                    help='quickly check a single pass')
parser.add_argument('--seed', type=int, default=1234, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--no-train-transform', action='store_true', default=False,
                    help='For Saving the current Model')
parser.add_argument('--save-model', action='store_true', default=False,
                    help='For Saving the current Model')
parser.add_argument('--arch', type=str, default='resnet34',
                    help='')
parser.add_argument('--data-path', type=str, default='../data/',
                    help='')
parser.add_argument('--val-set-mode', action='store_true', default=False,
                    help='For Saving the current Model')
parser.add_argument('--val-set-samples', type=int, default=10000, metavar='EPS',
                    help='')
### Unlearn parameters

parser.add_argument('--num-retain-samples', type=int, default=45000,
                    help='')
parser.add_argument('--num-forget-samples', type=int, default=5000,
                    help='')
parser.add_argument('--grad-norm-clip', type=float, default=None,
                    help='')
parser.add_argument('--unlearn-class', type=str, default="9",
                    help='')
parser.add_argument('--unlearn-method', type=str, default="svd_unlearn",
                    help='')

parser.add_argument('--salun-threshold', type=float, default=0.1,
                    help='')

parser.add_argument('--goel-exact', action="store_true", default=False,
                    help='')

parser.add_argument('--ssd-lambda', type=float, default=1,
                    help='')
parser.add_argument('--ssd-alpha', type=float, default=10,
                    help='')

parser.add_argument('--scrub-del-bsz', type=int, default=512,
                    help='')
parser.add_argument('--scrub-sgda-bsz', type=int, default=64,
                    help='')
parser.add_argument('--scrub-msteps', type=int, default=2,
                    help='')
parser.add_argument('--scrub-epochs', type=int, default=3,
                    help='')

parser.add_argument('--our-alpha-r', type=int, default=100,
                    help='')
parser.add_argument('--our-alpha-f', type=int, default=3,
                    help='')
parser.add_argument('--our-samples', type=int, default=900,
                    help='')
parser.add_argument('--our-max-patches', type=int, default=10000,
                    help='')


parser.add_argument('--tarun-impair-lr', type=float, default=2e-4,
                    help='')
parser.add_argument('--tarun-samples-per-class', type=int, default=1000,
                    help='')
### wandb parameters
parser.add_argument('--project-name', type=str, default='baseline',
                    help='')
parser.add_argument('--group-name', type=str, default='final',
                    help='')
parser.add_argument('--multiclass', action='store_true', default=False,
                    help='For Saving the current Model')
parser.add_argument('--class-names', type=str, default=None,
                    help='')
parser.add_argument('--do-mia',action='store_true', default=False,
                    help='')
parser.add_argument('--do-mia-ulira',action='store_true', default=False,
                    help='')
parser.add_argument('--plot-mia-roc',action='store_true', default=False,
                    help='')

args = argparse.Namespace()

for action in parser._actions:
  if hasattr(action, 'default'):
    setattr(args, action.dest, action.default)

args.train_transform = not args.no_train_transform
if args.unlearn_class:
    args.unlearn_class=[int(val) for val in args.unlearn_class.split(",")]
else:
    args.unlearn_class = []

use_cuda = not args.no_cuda and torch.cuda.is_available()
use_mps = not args.no_mps and torch.backends.mps.is_available()
torch.manual_seed(args.seed)

if use_cuda:
    device = torch.device("cuda")
elif use_mps:
    device = torch.device("mps")
else:
    device = torch.device("cpu")

train_kwargs = {'batch_size': args.batch_size}
test_kwargs = {'batch_size': args.test_batch_size}
if use_cuda:
    cuda_kwargs = {'num_workers': 2,
                    'pin_memory': True,
                    'shuffle': True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)



## Custom Parameters and Overrides

In [8]:
#args.dataset = "cifar10"


## Get Dataset

In [9]:
dataset1, dataset2 = get_dataset(args)

if args.dataset == "svhn":
  dataset1.targets = [int(label) for label in dataset1.labels]
elif args.dataset == "gtsrb":
  dataset1.targets = [label for _, label in dataset1]


print(type(dataset1.targets[0]))

Downloading https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB-Training_fixed.zip to ../data/gtsrb/GTSRB-Training_fixed.zip


100%|██████████| 187M/187M [00:07<00:00, 23.6MB/s]


Extracting ../data/gtsrb/GTSRB-Training_fixed.zip to ../data/gtsrb
Downloading https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_Images.zip to ../data/gtsrb/GTSRB_Final_Test_Images.zip


100%|██████████| 89.0M/89.0M [00:04<00:00, 20.8MB/s]


Extracting ../data/gtsrb/GTSRB_Final_Test_Images.zip to ../data/gtsrb
Downloading https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_GT.zip to ../data/gtsrb/GTSRB_Final_Test_GT.zip


100%|██████████| 99.6k/99.6k [00:00<00:00, 220kB/s]


Extracting ../data/gtsrb/GTSRB_Final_Test_GT.zip to ../data/gtsrb
<class 'int'>


In [10]:
image, label = dataset1[0]
print(f"Label for the first image: {label}")

Label for the first image: 0


## Get Model

In [11]:
model = get_model(args, device)


## Prepare Loaders

In [12]:
val_index= np.arange(len(dataset1))
val_dataset = torch.utils.data.Subset(dataset1, val_index)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
val_loader = torch.utils.data.DataLoader(val_dataset, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

## Model Training optimizers

In [13]:
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=args.gamma)

## Model Training

In [None]:
# Optimizer and scheduler


# Early stopping setup
best_val_loss = float('inf')
patience_counter = 0
patience = 10  # Stop after 10 epochs without improvement
model.train()
# Training loop
args.epochs_or_steps = 50
for epoch in range(1, args.epochs_or_steps + 1):
    # Training step
    train_loss = train(args, model, device, train_loader, optimizer, epoch, "descent")

    # Validation step
    model.eval()  # Set model to evaluation mode
    val_loss = 0.0
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = F.nll_loss(output, target)  # Use the same loss function as in training
            val_loss += loss.item()

    val_loss /= len(val_loader)  # Average loss over the validation set

    print(f"Epoch {epoch}/{args.epochs_or_steps}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")
    torch.save(model.state_dict(), f"resnet_34_gtsrb_tr{train_loss:.2f}_vl{val_loss:.2f}")
    # Early stopping check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0  # Reset the patience counter
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print(f"Early stopping at epoch {epoch}")
        break  # Stop training early

    # Learning rate scheduler step
    scheduler.step(val_loss)

    # Optionally, you can add a test evaluation at the end of training
    if epoch == args.epochs_or_steps:
        # Test the model on the test set if needed
        model.eval()
        test_loss = 0.0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = F.nll_loss(output, target)  # Define your loss function
                test_loss += loss.item()

        test_loss /= len(test_loader)
        print(f"Test Loss: {test_loss:.4f}")



Epoch 1/50, Train Loss: 1323.0939, Validation Loss: 2.0729
Epoch 2/50, Train Loss: 536.4701, Validation Loss: 0.7347
Epoch 3/50, Train Loss: 172.3136, Validation Loss: 0.6778
Epoch 4/50, Train Loss: 71.3849, Validation Loss: 0.2999
Epoch 5/50, Train Loss: 44.4299, Validation Loss: 0.2372
Epoch 6/50, Train Loss: 32.4430, Validation Loss: 0.6988
Epoch 7/50, Train Loss: 27.0330, Validation Loss: 0.1765
Epoch 8/50, Train Loss: 23.4428, Validation Loss: 0.1924
Epoch 9/50, Train Loss: 27.0736, Validation Loss: 0.1661
Epoch 10/50, Train Loss: 23.3245, Validation Loss: 0.0946
Epoch 11/50, Train Loss: 18.8125, Validation Loss: 0.0356
Epoch 12/50, Train Loss: 18.8692, Validation Loss: 0.2035
Epoch 13/50, Train Loss: 22.0943, Validation Loss: 0.3004
Epoch 14/50, Train Loss: 19.2596, Validation Loss: 0.0596
Epoch 15/50, Train Loss: 19.3486, Validation Loss: 0.0165
Epoch 16/50, Train Loss: 16.3669, Validation Loss: 0.0427
Epoch 17/50, Train Loss: 17.4579, Validation Loss: 0.1000
Epoch 18/50, Train 

## Model Testing

In [None]:
# Load the state_dict from the .pth file
'''
checkpoint_path = "/content/best_model_resnet34.pth"  # Replace with your path
state_dict = torch.load(checkpoint_path)
model = get_model(args, device)
# Load the state_dict into the model
model.load_state_dict(state_dict)
'''
model.eval()
test_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)

        # Forward pass
        output = model(data)
        loss = F.nll_loss(output, target)  # Your loss function
        test_loss += loss.item()

        # Calculate accuracy
        _, predicted = output.max(1)  # Get index of max log-probability
        correct += predicted.eq(target).sum().item()  # Count correct predictions
        total += target.size(0)  # Total samples

# Average test loss
test_loss /= len(test_loader)

# Calculate accuracy percentage
accuracy = 100.0 * correct / total

print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {accuracy:.2f}%")


In [None]:
# Load the state_dict from the .pth file
'''
checkpoint_path = "/content/best_model_resnet34.pth"  # Replace with your path
state_dict = torch.load(checkpoint_path)
model = get_model(args, device)
# Load the state_dict into the model
model.load_state_dict(state_dict)
'''
model.eval()
train_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)

        # Forward pass
        output = model(data)
        loss = F.nll_loss(output, target)  # Your loss function
        train_loss += loss.item()

        # Calculate accuracy
        _, predicted = output.max(1)  # Get index of max log-probability
        correct += predicted.eq(target).sum().item()  # Count correct predictions
        total += target.size(0)  # Total samples

# Average test loss
train_loss /= len(train_loader)

# Calculate accuracy percentage
accuracy = 100.0 * correct / total

print(f"train Loss: {test_loss:.4f}")
print(f"train Accuracy: {accuracy:.2f}%")


In [None]:
torch.save(model.state_dict(), "gtsrb_resnet34.pth")

## Load Pretrained model

In [None]:
checkpoint_path = "/content/gtsrb_resnet34.pth"  # Replace with your path
state_dict = torch.load(checkpoint_path)
args.num_classes = 43
model = get_model(args, device)
# Load the state_dict into the model
model.load_state_dict(state_dict)
model.eval()


In [None]:
retain_dataset, forget_dataset = get_retain_forget_partition(args, dataset1, args.unlearn_class)
retain_loader = torch.utils.data.DataLoader(retain_dataset,**train_kwargs)
forget_loader = torch.utils.data.DataLoader(forget_dataset,**train_kwargs)

In [None]:
retain_dataset, forget_dataset = get_retain_forget_partition(args, dataset1, args.unlearn_class)
print(len(retain_dataset), len(forget_dataset))

## Unlearn Model

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

def display_images_with_labels(data_loader, num_images=16, mean=(0.4377, 0.4438, 0.4728), std=(0.1980, 0.2010, 0.1970)):
    """
    Display individual images along with their labels in a grid.

    Args:
        data_loader (DataLoader): PyTorch DataLoader containing the dataset.
        num_images (int, optional): Number of images to display. Default is 16.
        mean (tuple, optional): Mean values for each channel used in normalization.
        std (tuple, optional): Standard deviation for each channel used in normalization.
    """
    # Fetch a single batch of images and labels
    images, labels = next(iter(data_loader))

    # Select only the required number of images and labels
    images = images[:num_images]
    labels = labels[:num_images]

    # Unnormalize the images
    images = images.clone()
    for i in range(images.size(0)):
        for c in range(3):  # Assuming 3 RGB channels
            images[i, c] = images[i, c] * std[c] + mean[c]

    # Clamp the values between 0 and 1
    images = torch.clamp(images, 0, 1)

    # Define the grid size
    grid_size = int(np.ceil(np.sqrt(num_images)))
    fig, axes = plt.subplots(grid_size, grid_size, figsize=(grid_size * 3, grid_size * 3))
    axes = axes.flatten()

    # Plot each image with its corresponding label
    for i, (image, label) in enumerate(zip(images, labels)):
        img_np = image.permute(1, 2, 0).numpy()  # Convert to NumPy format
        axes[i].imshow(img_np)
        axes[i].axis('off')
        axes[i].set_title(f"Label: {label.item()}", fontsize=10, pad=10)

    # Hide unused subplots
    for i in range(len(images), len(axes)):
        axes[i].axis('off')

    # Adjust layout and show
    plt.tight_layout()
    plt.show()

# Example usage:
# display_images_with_labels(data_loader, num_images=16)


In [None]:
visualize_samples(retain_loader, num_images=1)


In [None]:
from torch.utils.data import Dataset
import torch
import numpy as np

class SubSet(Dataset):
    r"""
    Subset of a dataset at specified indices.

    Arguments:
        dataset (Dataset): The whole Dataset.
        indices (sequence): Indices in the whole set selected for subset.
        labels (sequence): Targets as required for the indices. Will be the same length as indices.
    """
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices
        labels = torch.LongTensor([self.dataset.targets[i] for i in indices])
        labels_hold = torch.ones(len(dataset)).type(torch.long) * 10000  # Placeholder for invalid indices
        labels_hold[self.indices] = labels
        self.labels = labels_hold

    def __getitem__(self, idx):
        image = self.dataset[self.indices[idx]][0]
        label = self.labels[self.indices[idx]]
        return (image, label)

    def __len__(self):
        # Return the number of samples in this subset
        return len(self.indices)

    def update_labels(self, new_labels):
        labels_hold = torch.ones(len(self.dataset)).type(torch.long) * 10000  # Placeholder for invalid indices
        labels_hold[self.indices] = torch.LongTensor(new_labels)
        self.labels = labels_hold


In [None]:
def get_retain_forget_partition(args, dataset, unlearn_class_list, return_ind = False):
    retain_ind = []
    forget_ind = []
    for sample_index in range(len(dataset)) :
        if (torch.is_tensor(dataset.targets[sample_index])):
            sample_class = int(dataset.targets[sample_index].item())
        elif isinstance(dataset.targets[sample_index], int ):
            sample_class = dataset.targets[sample_index]
        elif isinstance(dataset.targets[sample_index], np.integer):
          sample_class = int(dataset.targets[sample_index])
        if sample_class in unlearn_class_list:
            forget_ind.append(sample_index)
        else:
            retain_ind.append(sample_index)
    retain_dataset = SubSet(dataset, retain_ind)
    forget_dataset = SubSet(dataset, forget_ind)
    if return_ind:
        return retain_dataset, forget_dataset, retain_ind, forget_ind

    return retain_dataset, forget_dataset

In [None]:
import copy
max_iterations = 100

for unlearn_class in range(args.num_classes):
    args.unlearn_class = [unlearn_class]
    print(f"unlearning class {unlearn_class}")
    unlearn_model = copy.deepcopy(model)
    retain_dataset, forget_dataset = get_retain_forget_partition(args, dataset1, args.unlearn_class)
    retain_loader = torch.utils.data.DataLoader(retain_dataset,**train_kwargs)
    forget_loader = torch.utils.data.DataLoader(forget_dataset,**train_kwargs)
    unlearn_model = svd_unlearn(
            args = args,
            model = unlearn_model,
            device = device,
            retain_loader= retain_loader,
            forget_loader = forget_loader,
            train_loader = val_loader if args.val_set_mode else train_loader,
            test_loader= test_loader,
            optimizer = optimizer,
            epochs = args.epochs_or_steps,
            max_steps = args.epochs_or_steps,
            train_dataset = dataset1,
            val_index =val_index
    )
    train_retain_acc, train_forget_acc, train_metric = test(unlearn_model, device, train_loader, torch.tensor(args.unlearn_class).to(device), args.class_label_names, args.num_classes,
    job_name = args.unlearn_method, set_name="Final Train Set")



## Test Unlearned model

In [None]:
train_retain_acc, train_forget_acc, train_metric = test(model, device, train_loader, torch.tensor(args.unlearn_class).to(device), args.class_label_names, args.num_classes,
    job_name = args.unlearn_method, set_name="Final Train Set")


## Load Original Model

In [None]:
checkpoint_path = "/content/pretrained_models/cifar100_vgg11_bn.pt"  # Replace with your path
state_dict = torch.load(checkpoint_path)
args.num_classes = 100
model = get_model(args, device)
# Load the state_dict into the model
model.load_state_dict(state_dict)

## Utilities

In [None]:

def get_images_from_loader(loader, num_images_per_class, num_classes):
  """
  Fetches a specified number of images per class from a DataLoader.

  Args:
    loader: The DataLoader to fetch images from.
    num_images_per_class: The number of images to fetch for each class.
    num_classes: The total number of classes in the dataset.

  Returns:
    A dictionary where keys are class labels and values are lists of images
    in numpy format.
  """

  images_by_class = {}
  counts_by_class = [0] * num_classes

  for images, labels in loader:
    for i, label in enumerate(labels):
      label_idx = label.item()
      if counts_by_class[label_idx] < num_images_per_class:
        if label_idx not in images_by_class:
          images_by_class[label_idx] = []
        # Convert image to numpy and add to the list
        image_np = images[i].numpy()
        images_by_class[label_idx].append(image_np)
        counts_by_class[label_idx] += 1

      if all(count >= num_images_per_class for count in counts_by_class):
        return images_by_class

  # If not enough images were found, return what was collected
  return images_by_class

'''
# Get 10 images from each class
images_by_class = get_images_from_loader(train_loader, 10, args.num_classes)

# Print the shape of the first image from the first class (for verification)
for class_idx, images in images_by_class.items():
  print(f"Class {class_idx}: {len(images)} images")
  if images:
    print(f"  Example image shape: {images[0].shape}")
'''

# Zero Shot Gradient Free Unlearning

In [None]:
val_index= np.arange(len(dataset1))
val_dataset = torch.utils.data.Subset(dataset1, val_index)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
val_loader = torch.utils.data.DataLoader(val_dataset, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

In [None]:
#retain_dataset, forget_dataset = get_retain_forget_partition(args, dataset1, args.unlearn_class)
#retain_loader = torch.utils.data.DataLoader(retain_dataset,**train_kwargs)
#forget_loader = torch.utils.data.DataLoader(forget_dataset,**train_kwargs)

In [None]:
img_shape = train_loader.dataset[0][0].shape
target_vector = [1 if i == args.unlearn_class else 0 for i in range(args.num_classes)]
target_class = args.unlearn_class
args.img_shape = train_loader.dataset[0][0].shape
args.target_samples = 900

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

def synthesize_images_for_class(model, device, target_class, num_samples, img_shape,
                                lr=0.1, steps=800, regularization_weight=1e-4,
                                debug=False):
    """
    Synthesize images that the model classifies as `target_class` starting from noise.
    Stopping criteria:
      If an image's predicted class equals target_class, we consider it done.
    """
    model.eval()

    # Initialize random noise in [0,1]
    images = torch.randn(num_samples, *img_shape, device=device, requires_grad=True)
    images.data = (images.data - images.data.min()) / (images.data.max() - images.data.min() + 1e-5)

    optimizer = torch.optim.Adam([images], lr=lr)

    done_mask = torch.zeros(num_samples, dtype=torch.bool, device=device)

    for step in range(steps):
        if done_mask.all():
            if debug:
                print(f"All samples reached target class by step {step}. Stopping early.")
            break

        optimizer.zero_grad()
        outputs = model(images)
        class_scores = outputs[:, target_class]
        loss = -class_scores.mean() + regularization_weight * (images**2).mean()
        loss.backward()
        optimizer.step()

        # Clamp images to [0,1]
        images.data.clamp_(0, 1)

        # Check predictions
        preds = torch.argmax(outputs, dim=1)
        previously_done_count = done_mask.sum().item()
        done_mask = done_mask | (preds == target_class)
        currently_done_count = done_mask.sum().item()

        if debug:
            newly_done = currently_done_count - previously_done_count
            print(f"Step {step+1}/{steps}: {currently_done_count}/{num_samples} images done (+{newly_done} this step).")

    if debug and not done_mask.all():
        print(f"Reached max steps without all images classified as target class. "
              f"{done_mask.sum().item()} out of {num_samples} done.")

    return images.detach()


def get_class_samples_from_noise(model, device, args, retain_extra, forget_extra):
    """
    Generate synthetic datasets (forget_loader and retain_loader) from noise.
    - forget_loader: corresponds to args.unlearn_class
    - retain_loader: corresponds to the other classes

    Assumptions:
      - CIFAR-10: num_classes=10
      - target_samples = args.target_samples
      - For other classes: target_samples total, distributed equally among them.
      - img_shape = (C, H, W), e.g. (3,32,32) for CIFAR-10.
    """

    # Ensure unlearn_class is an integer
    unlearn_class = args.unlearn_class
    if isinstance(unlearn_class, (list, tuple)):
        # If it's a list or tuple, extract the first element (assuming single value)
        unlearn_class = unlearn_class[0]
    if isinstance(unlearn_class, torch.Tensor):
        # If it's a tensor with one element, convert to int
        unlearn_class = unlearn_class.item()
    unlearn_class = int(unlearn_class)

    img_shape = args.img_shape
    target_samples = args.target_samples
    num_classes = args.num_classes  # e.g., 10 for CIFAR-10

    # Synthesize images for the unlearn_class
    forget_images = synthesize_images_for_class(
        model=model,
        device=device,
        target_class=unlearn_class,
        num_samples=target_samples,
        img_shape=img_shape,
        lr=0.1,
        steps=300,
        regularization_weight=1e-4,
        debug=False
    )

    # For other classes, generate target_samples total, split evenly
    other_class_samples = target_samples
    samples_per_other_class = other_class_samples // (num_classes - 1)+retain_extra
    #samples_per_other_class = target_samples
    retain_images_list = []
    retain_labels_list = []

    for c in range(num_classes):
        if c == unlearn_class:
            continue
        class_images = synthesize_images_for_class(
            model=model,
            device=device,
            target_class=c,
            num_samples=samples_per_other_class,
            img_shape=img_shape,
            lr=0.1,
            steps=300,
            regularization_weight=1e-4,
            debug=False
        )
        retain_images_list.append(class_images)
        retain_labels_list.extend([c] * samples_per_other_class)

    retain_images = torch.cat(retain_images_list, dim=0)

    # Create labels for forget images
    forget_labels = torch.full((forget_images.size(0),), unlearn_class, dtype=torch.long)
    retain_labels = torch.tensor(retain_labels_list, dtype=torch.long)

    # Move labels to CPU if needed, as DataLoader works well with CPU tensors by default
    # If you prefer on GPU, you can leave them on device, but typically datasets are on CPU.
    forget_labels = forget_labels.cpu()
    retain_labels = retain_labels.cpu()

    forget_images = forget_images.cpu()
    retain_images = retain_images.cpu()

    forget_dataset = TensorDataset(forget_images, forget_labels)
    retain_dataset = TensorDataset(retain_images, retain_labels)

    forget_loader = DataLoader(forget_dataset, batch_size=args.batch_size, shuffle=True)
    retain_loader = DataLoader(retain_dataset, batch_size=args.batch_size, shuffle=True)

    return forget_loader, retain_loader


In [None]:
import copy
max_iterations = 100
args.target_samples = 42*50

for unlearn_class in range(args.num_classes):
    args.unlearn_class = [unlearn_class]
    print(f"unlearning class {unlearn_class}")
    unlearn_model = copy.deepcopy(model)
    for iteration in range(max_iterations):
        print(f"iteration {iteration}")
        forget_loader, retain_loader = get_class_samples_from_noise(model, device, args, retain_extra = 200*iteration, forget_extra = 5*iteration)
        unlearn_model = svd_unlearn(
                args = args,
                model = unlearn_model,
                device = device,
                retain_loader= retain_loader,
                forget_loader = forget_loader,
                train_loader = val_loader if args.val_set_mode else train_loader,
                test_loader= test_loader,
                optimizer = optimizer,
                epochs = args.epochs_or_steps,
                max_steps = args.epochs_or_steps,
                train_dataset = dataset1,
                val_index =val_index,
                custom=True
        )
        train_retain_acc, train_forget_acc, train_metric = test(unlearn_model, device, train_loader, torch.tensor(args.unlearn_class).to(device), args.class_label_names, args.num_classes,
        job_name = args.unlearn_method, set_name="Final Train Set")
        print("train forget acc: ", train_forget_acc)
        if train_forget_acc < 0.1:
          break



In [None]:
torch.cuda.empty_cache()

In [None]:
import matplotlib.pyplot as plt

# Assume forget_loader and retain_loader are already created.

# Get one batch from the forget_loader
forget_batch = next(iter(forget_loader))
forget_images, forget_labels = forget_batch

# Get one batch from the retain_loader
retain_batch = next(iter(retain_loader))
retain_images, retain_labels = retain_batch

def show_images(images, labels, title, num_images=8):
    # Convert to CPU and detach if needed
    imgs = images[:num_images].cpu().detach()
    labs = labels[:num_images].cpu().detach()

    # For CIFAR-10 images are typically [C,H,W], with C=3
    # If you used the code above, images should already be in [0,1] range.
    plt.figure(figsize=(12, 2))
    for i in range(num_images):
        plt.subplot(1, num_images, i+1)
        img = imgs[i].permute(1, 2, 0).numpy()  # Convert to HWC for matplotlib
        plt.imshow(img)
        plt.title(f"Class: {labs[i].item()}")
        plt.axis('off')
    plt.suptitle(title)
    plt.show()

# Show a few forget images
show_images(forget_images, forget_labels, title="Forget Images")

# Show a few retain images
show_images(retain_images, retain_labels, title="Retain Images")


In [None]:
unlearn_model = svd_unlearn(
    args = args,
    model = unlearn_model,
    device = device,
    retain_loader= retain_loader,
    forget_loader = forget_loader,
    train_loader = val_loader if args.val_set_mode else train_loader,
    test_loader= test_loader,
    optimizer = optimizer,
    epochs = args.epochs_or_steps,
    max_steps = args.epochs_or_steps,
    train_dataset = dataset1,
    val_index =val_index,
    custom=True
)

In [None]:
print(len(forget_loader))