In [None]:
!pip install torch torchvision torch-cka ckatorch torchmetrics tqdm nest_asyncio -q

In [None]:
import nest_asyncio
import numpy as np
import time
import torch
import torch.nn as nn

from cka import compute_cka  # 1
from torch_cka import CKA as TorchCKA  # 2

from ckapytorch import CKACalculator  # 3 (Ineifficient memory management)
from ckatorch import CKA as CKATorch  # 4 (Not supported)

from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, models

  from tqdm.autonotebook import tqdm


In [3]:
nest_asyncio.apply()

forget_class = 0
unlearned_model_path = "./c9d4.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def get_resnet18(num_classes=10):
    model = models.resnet18(weights=None)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model


def filter_loader(loader, is_train=False):
    targets = loader.dataset.targets
    targets = torch.tensor(targets) if not isinstance(targets, torch.Tensor) else targets

    forget_indices = (targets == forget_class).nonzero(as_tuple=True)[0]
    other_indices = (targets != forget_class).nonzero(as_tuple=True)[0]

    if is_train:
        forget_samples = len(forget_indices) // 10
        other_samples = len(other_indices) // 10
    else:
        forget_samples = len(forget_indices)  // 2
        other_samples = len(other_indices)  // 2

    seed = 42 + forget_class
    torch.manual_seed(seed)
    np.random.seed(seed)

    forget_indices_sorted = torch.sort(forget_indices)[0]
    other_indices_sorted = torch.sort(other_indices)[0]

    forget_sampled = forget_indices_sorted[:forget_samples]
    other_sampled = other_indices_sorted[:other_samples]

    forget_loader = DataLoader(
        Subset(loader.dataset, forget_sampled),
        batch_size=loader.batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )

    other_loader = DataLoader(
        Subset(loader.dataset, other_sampled),
        batch_size=loader.batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )

    return forget_loader, other_loader


RESNET18_LAYERS = [
    "conv1",
    "layer1.0.conv1",
    "layer1.0.conv2",
    "layer1.1.conv1",
    "layer1.1.conv2",
    "layer2.0.conv1",
    "layer2.0.conv2",
    "layer2.1.conv1",
    "layer2.1.conv2",
    "layer3.0.conv1",
    "layer3.0.conv2",
    "layer3.1.conv1",
    "layer3.1.conv2",
    "layer4.0.conv1",
    "layer4.0.conv2",
    "layer4.1.conv1",
    "layer4.1.conv2",
    "fc",
]

model_before = get_resnet18().to(device)
model_before.load_state_dict(torch.load("./0000.pth", map_location=device))
model_before.eval()

model_after = get_resnet18().to(device)
model_after.load_state_dict(torch.load(unlearned_model_path, map_location=device))
model_after.eval()

base_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

clean_train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=base_transforms)
clean_test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=base_transforms)

train_loader = DataLoader(clean_train_set, batch_size=1000, shuffle=False, num_workers=0)
test_loader = DataLoader(clean_test_set, batch_size=1000, shuffle=False, num_workers=0)

forget_class_train_loader, other_classes_train_loader = filter_loader(train_loader, is_train=True)
forget_class_test_loader, other_classes_test_loader = filter_loader(test_loader, is_train=False)

In [5]:
# 1
async def test_pytorch_cka():
    print(f"Unlearned model: {unlearned_model_path}")

    compute_cka(
        model_before,
        model_after,
        forget_class_train_loader,
        other_classes_train_loader,
        forget_class_test_loader,
        other_classes_test_loader,
        layers=RESNET18_LAYERS,
        device=device,
    )

In [6]:
# 2
async def test_torchcka():
    print(f"Unlearned model: {unlearned_model_path}")

    cka = TorchCKA(
              model_before,
              model_after,
              model1_layers=RESNET18_LAYERS,
              model2_layers=RESNET18_LAYERS,
              device=device
          )

    with torch.no_grad():
        cka.compare(forget_class_train_loader, forget_class_train_loader)
        cka.compare(other_classes_train_loader, other_classes_train_loader)
        cka.compare(forget_class_test_loader, forget_class_test_loader)
        cka.compare(other_classes_test_loader, other_classes_test_loader)

In [7]:
# 3 - memory-inefficient

layers = (nn.Conv2d, nn.Linear)

async def test_ckapytorch():
    print(f"Unlearned model: {unlearned_model_path}")

    cka = CKACalculator(
        model1=model_before,
        model2=model_after,
        dataloader=forget_class_train_loader,
        hook_layer_types=layers,
    )
    cka.calculate_cka_matrix()

    cka = CKACalculator(
        model1=model_before,
        model2=model_after,
        dataloader=other_classes_train_loader,
        hook_layer_types=layers,
    )
    cka.calculate_cka_matrix()

    cka = CKACalculator(
        model1=model_before,
        model2=model_after,
        dataloader=forget_class_test_loader,
        hook_layer_types=layers,
    )
    cka.calculate_cka_matrix()

    cka = CKACalculator(
        model1=model_before,
        model2=model_after,
        dataloader=other_classes_test_loader,
        hook_layer_types=layers,
    )
    cka.calculate_cka_matrix()

In [8]:
# 4 - Out of index
async def test_ckatorch():
    print(f"Unlearned model: {unlearned_model_path}")

    cka = CKATorch(
              first_model=model_before,
              second_model=model_after,
              layers=RESNET18_LAYERS,
              device=device
          )

    cka(forget_class_train_loader)
    cka(other_classes_train_loader)
    cka(forget_class_test_loader)
    cka(other_classes_test_loader)

In [17]:
print("========================================")
print("I. CKA Computation with cka")
print("========================================\n")
start_time = time.time()
await test_pytorch_cka()
elapsed1 = time.time() - start_time
print(f"\n\ncka elapsed time: {elapsed1:.2f} seconds")

I. CKA Computation with cka

Unlearned model: ./c9d4.pth


Computing CKA: 100%|██████████| 1/1 [00:00<00:00, 17.11it/s]
Computing CKA: 100%|██████████| 5/5 [00:00<00:00,  9.97it/s]
Computing CKA: 100%|██████████| 1/1 [00:00<00:00, 17.05it/s]
Computing CKA: 100%|██████████| 5/5 [00:00<00:00,  9.90it/s]



cka elapsed time: 1.13 seconds





In [21]:
print("====================================")
print("II. CKA Computation with torch_cka")
print("====================================\n")
start_time = time.time()
await test_torchcka()
elapsed2 = time.time() - start_time
print(f"\ntorch_cka elapsed time: {elapsed2:.2f} seconds\n\n")

improvement1 = ((elapsed2 - elapsed1) / elapsed1) * 100
print(f"Performance Improvement (I -> II): {improvement1:.2f}%")

II. CKA Computation with torch_cka

Unlearned model: ./c9d4.pth


| Comparing features |: 100%|██████████| 1/1 [00:00<00:00,  2.71it/s]
| Comparing features |: 100%|██████████| 5/5 [00:03<00:00,  1.37it/s]
| Comparing features |: 100%|██████████| 1/1 [00:00<00:00,  2.69it/s]
| Comparing features |: 100%|██████████| 5/5 [00:03<00:00,  1.38it/s]


torch_cka elapsed time: 8.02 seconds


Performance Improvement (I -> II): 607.90%





In [23]:
print("====================================")
print("III. CKA Computation with ckapytorch")
print("====================================\n")
start_time = time.time()
await test_ckapytorch()
elapsed3 = time.time() - start_time
print(f"\nckapytorch elapsed time: {elapsed3:.2f} seconds\n\n")

improvement3 = ((elapsed3 - elapsed1) / elapsed1) * 100
print(f"Performance Improvement (IV -> I): {improvement3:.2f}%")

III. CKA Computation with ckapytorch

Unlearned model: ./c9d4.pth
No hook function provided. Using flatten_hook_fn.
21 Hooks registered. Total hooks: 21
No hook function provided. Using flatten_hook_fn.
21 Hooks registered. Total hooks: 21


Epoch 0:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 2:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 3:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 4:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 5:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 6:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 7:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 8:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 9:   0%|          | 0/1 [00:00<?, ?it/s]

No hook function provided. Using flatten_hook_fn.
21 Hooks registered. Total hooks: 21
No hook function provided. Using flatten_hook_fn.
21 Hooks registered. Total hooks: 21


Epoch 0:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 2:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 3:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 4:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 5:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 6:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 7:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 8:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 9:   0%|          | 0/5 [00:00<?, ?it/s]

No hook function provided. Using flatten_hook_fn.
21 Hooks registered. Total hooks: 21
No hook function provided. Using flatten_hook_fn.
21 Hooks registered. Total hooks: 21


Epoch 0:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 2:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 3:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 4:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 5:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 6:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 7:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 8:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 9:   0%|          | 0/1 [00:00<?, ?it/s]

No hook function provided. Using flatten_hook_fn.
21 Hooks registered. Total hooks: 21
No hook function provided. Using flatten_hook_fn.
21 Hooks registered. Total hooks: 21


Epoch 0:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 2:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 3:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 4:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 5:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 6:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 7:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 8:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 9:   0%|          | 0/5 [00:00<?, ?it/s]


ckapytorch elapsed time: 19.06 seconds


Performance Improvement (IV -> I): 1581.91%


In [24]:
print("========================================")
print("IV. CKA Computation with ckatorch")
print("========================================\n")
start_time = time.time()
try:
  await test_ckatorch()
  elapsed4 = time.time() - start_time
  print(f"\nckatorch elapsed time: {elapsed4:.2f} seconds\n\n\n")
  improvement2 = ((elapsed4 - elapsed1) / elapsed1) * 100
  print(f"Performance Improvement (I -> IV): {improvement2:.2f}%\n")
except (TypeError, ValueError, IndexError) as e:
  print(f"\n\nError: {e}")

  cka = CKATorch(
  cka(forget_class_train_loader)


IV. CKA Computation with ckatorch

Unlearned model: ./c9d4.pth


| Computing CKA |:   0%|          | 0/10 [00:00<?, ?it/s]
| Computing CKA epoch 0 |:   0%|          | 0/1 [00:00<?, ?it/s][A
| Computing CKA |:   0%|          | 0/10 [00:00<?, ?it/s]



Error: list index out of range



