In [1]:
!pip install torch torchvision torch-cka torchmetrics tqdm nest_asyncio -q

In [2]:
import nest_asyncio
import numpy as np
import time
import torch
import torch.nn as nn

from cka import compute_cka  # 1
from torch_cka import CKA as TorchCKA  # 2
from ckapytorch import CKACalculator  # 3 (Ineifficient memory management)
from vanila_cka import compute_cka as compute_cka_vanila  # 4 (vanilla numpy)

from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, models


  from tqdm.autonotebook import tqdm


In [3]:
nest_asyncio.apply()

forget_class = 0
unlearned_model_path = "./c9d4.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
def get_resnet18(num_classes=10):
    model = models.resnet18(weights=None)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model


def filter_loader(loader, is_train=False):
    targets = loader.dataset.targets
    targets = torch.tensor(targets) if not isinstance(targets, torch.Tensor) else targets

    forget_indices = (targets == forget_class).nonzero(as_tuple=True)[0]
    other_indices = (targets != forget_class).nonzero(as_tuple=True)[0]

    if is_train:
        forget_samples = len(forget_indices) // 10
        other_samples = len(other_indices) // 10
    else:
        forget_samples = len(forget_indices)  // 2
        other_samples = len(other_indices)  // 2

    seed = 42 + forget_class
    torch.manual_seed(seed)
    np.random.seed(seed)

    forget_indices_sorted = torch.sort(forget_indices)[0]
    other_indices_sorted = torch.sort(other_indices)[0]

    forget_sampled = forget_indices_sorted[:forget_samples]
    other_sampled = other_indices_sorted[:other_samples]

    forget_loader = DataLoader(
        Subset(loader.dataset, forget_sampled),
        batch_size=loader.batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )

    other_loader = DataLoader(
        Subset(loader.dataset, other_sampled),
        batch_size=loader.batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )

    return forget_loader, other_loader


RESNET18_LAYERS = [
    "conv1",
    "layer1.0.conv1",
    "layer1.0.conv2",
    "layer1.1.conv1",
    "layer1.1.conv2",
    "layer2.0.conv1",
    "layer2.0.conv2",
    "layer2.1.conv1",
    "layer2.1.conv2",
    "layer3.0.conv1",
    "layer3.0.conv2",
    "layer3.1.conv1",
    "layer3.1.conv2",
    "layer4.0.conv1",
    "layer4.0.conv2",
    "layer4.1.conv1",
    "layer4.1.conv2",
    "fc",
]

model_before = get_resnet18().to(device)
model_before.load_state_dict(torch.load("./0000.pth", map_location=device))
model_before.eval()

model_after = get_resnet18().to(device)
model_after.load_state_dict(torch.load(unlearned_model_path, map_location=device))
model_after.eval()

base_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

clean_train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=base_transforms)
clean_test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=base_transforms)

train_loader = DataLoader(clean_train_set, batch_size=1000, shuffle=False, num_workers=0)
test_loader = DataLoader(clean_test_set, batch_size=1000, shuffle=False, num_workers=0)

forget_class_train_loader, other_classes_train_loader = filter_loader(train_loader, is_train=True)
forget_class_test_loader, other_classes_test_loader = filter_loader(test_loader, is_train=False)

dataloaders = [
    forget_class_train_loader,
    other_classes_train_loader,
    forget_class_test_loader,
    other_classes_test_loader
]

In [5]:
# 1 - vanilla numpy (CPU-only)
async def test_vanila_cka():
    print(f"Unlearned model: {unlearned_model_path}")

    compute_cka_vanila(
        model_before,
        model_after,
        dataloaders,
        layers=RESNET18_LAYERS,
        device="cpu",
    )

In [6]:
# 2
async def test_torchcka():
    print(f"Unlearned model: {unlearned_model_path}")

    cka = TorchCKA(
              model_before,
              model_after,
              model1_layers=RESNET18_LAYERS,
              model2_layers=RESNET18_LAYERS,
              device=device
          )

    with torch.no_grad():
        cka.compare(forget_class_train_loader, forget_class_train_loader)
        cka.compare(other_classes_train_loader, other_classes_train_loader)
        cka.compare(forget_class_test_loader, forget_class_test_loader)
        cka.compare(other_classes_test_loader, other_classes_test_loader)

In [7]:
# 3

layers = (nn.Conv2d, nn.Linear)

async def test_ckapytorch():
    print(f"Unlearned model: {unlearned_model_path}")

    cka = CKACalculator(
        model1=model_before,
        model2=model_after,
        dataloader=forget_class_train_loader,
        hook_layer_types=layers,
    )
    cka.calculate_cka_matrix()

    cka = CKACalculator(
        model1=model_before,
        model2=model_after,
        dataloader=other_classes_train_loader,
        hook_layer_types=layers,
    )
    cka.calculate_cka_matrix()

    cka = CKACalculator(
        model1=model_before,
        model2=model_after,
        dataloader=forget_class_test_loader,
        hook_layer_types=layers,
    )
    cka.calculate_cka_matrix()

    cka = CKACalculator(
        model1=model_before,
        model2=model_after,
        dataloader=other_classes_test_loader,
        hook_layer_types=layers,
    )
    cka.calculate_cka_matrix()

In [8]:
# 4
async def test_pytorch_cka():
    print(f"Unlearned model: {unlearned_model_path}")

    compute_cka(
        model_before,
        model_after,
        dataloaders,
        layers=RESNET18_LAYERS,
        device=device,
    )

In [16]:
print("========================================")
print("I. CKA Computation with vanila_cka")
print("========================================\n")
start_time = time.time()
await test_vanila_cka()
elapsed1 = time.time() - start_time
print(f"\nvanila_cka elapsed time: {elapsed1:.2f} seconds\n\n")


I. CKA Computation with vanila_cka

Unlearned model: ./c9d4.pth


Computing CKA: 100%|██████████| 1/1 [00:01<00:00,  1.90s/it]
Computing CKA: 100%|██████████| 5/5 [00:23<00:00,  4.77s/it]
Computing CKA: 100%|██████████| 1/1 [00:01<00:00,  1.98s/it]
Computing CKA: 100%|██████████| 5/5 [00:23<00:00,  4.61s/it]


vanila_cka elapsed time: 50.82 seconds







In [17]:
print("====================================")
print("II. CKA Computation with torch_cka")
print("====================================\n")
start_time = time.time()
await test_torchcka()
elapsed2 = time.time() - start_time
print(f"\ntorch_cka elapsed time: {elapsed2:.2f} seconds\n\n")
print(f"Speed (I -> II): {elapsed1 / elapsed2:.2f}x")


  warn(f"Both model have identical names - {self.model2_info['Name']}. " \


II. CKA Computation with torch_cka

Unlearned model: ./c9d4.pth


| Comparing features |: 100%|██████████| 1/1 [00:01<00:00,  1.16s/it]
| Comparing features |: 100%|██████████| 5/5 [00:03<00:00,  1.34it/s]
| Comparing features |: 100%|██████████| 1/1 [00:00<00:00,  2.67it/s]
| Comparing features |: 100%|██████████| 5/5 [00:03<00:00,  1.36it/s]


torch_cka elapsed time: 8.98 seconds


Speed (I -> II): 5.66x





In [18]:
print("====================================")
print("III. CKA Computation with ckapytorch")
print("====================================\n")
start_time = time.time()
await test_ckapytorch()
elapsed3 = time.time() - start_time
print(f"\nckapytorch elapsed time: {elapsed3:.2f} seconds\n\n")
print(f"Speed (I -> III): {elapsed1 / elapsed3:.2f}x")


III. CKA Computation with ckapytorch

Unlearned model: ./c9d4.pth
No hook function provided. Using flatten_hook_fn.
21 Hooks registered. Total hooks: 21
No hook function provided. Using flatten_hook_fn.
21 Hooks registered. Total hooks: 21


Epoch 0:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 2:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 3:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 4:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 5:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 6:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 7:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 8:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 9:   0%|          | 0/1 [00:00<?, ?it/s]

No hook function provided. Using flatten_hook_fn.
21 Hooks registered. Total hooks: 21
No hook function provided. Using flatten_hook_fn.
21 Hooks registered. Total hooks: 21


Epoch 0:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 2:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 3:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 4:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 5:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 6:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 7:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 8:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 9:   0%|          | 0/5 [00:00<?, ?it/s]

No hook function provided. Using flatten_hook_fn.
21 Hooks registered. Total hooks: 21
No hook function provided. Using flatten_hook_fn.
21 Hooks registered. Total hooks: 21


Epoch 0:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 2:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 3:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 4:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 5:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 6:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 7:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 8:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 9:   0%|          | 0/1 [00:00<?, ?it/s]

No hook function provided. Using flatten_hook_fn.
21 Hooks registered. Total hooks: 21
No hook function provided. Using flatten_hook_fn.
21 Hooks registered. Total hooks: 21


Epoch 0:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 2:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 3:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 4:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 5:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 6:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 7:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 8:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 9:   0%|          | 0/5 [00:00<?, ?it/s]


ckapytorch elapsed time: 19.23 seconds


Speed (I -> III): 2.64x


In [12]:
print("========================================")
print("IV. CKA Computation with cka")
print("========================================\n")
start_time = time.time()
await test_pytorch_cka()
elapsed4 = time.time() - start_time
print(f"\n\ncka elapsed time: {elapsed4:.2f} seconds")
print(f"Speed (I -> IV): {elapsed1 / elapsed4:.2f}x")


IV. CKA Computation with cka

Unlearned model: ./c9d4.pth


Computing CKA: 100%|██████████| 1/1 [00:00<00:00, 17.20it/s]
Computing CKA: 100%|██████████| 5/5 [00:00<00:00,  9.78it/s]
Computing CKA: 100%|██████████| 1/1 [00:00<00:00, 16.65it/s]
Computing CKA: 100%|██████████| 5/5 [00:00<00:00,  9.66it/s]



cka elapsed time: 1.16 seconds
Speed (I -> IV): 43.92x



