In [None]:
!pip install torch torch-cka torchvision
!pip install --no-cache-dir --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ pytorch-cka==1.1.0

In [9]:
import asyncio
import numpy as np
import os
import time
import torch
import torch.nn as nn
from torch_cka import CKA
from pytorch_cka import CKA as PytorchCKA
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, models

In [10]:
forget_class = 0

def get_resnet18(num_classes=10):
    model = models.resnet18(weights=None)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

In [11]:
def filter_loader(loader, is_train=False):
    targets = loader.dataset.targets
    targets = torch.tensor(targets) if not isinstance(targets, torch.Tensor) else targets

    forget_indices = (targets == forget_class).nonzero(as_tuple=True)[0]
    other_indices = (targets != forget_class).nonzero(as_tuple=True)[0]

    if is_train:
        forget_samples = len(forget_indices) // 10
        other_samples = len(other_indices) // 10
    else:
        forget_samples = len(forget_indices)  // 2
        other_samples = len(other_indices)  // 2

    seed = 42 + forget_class
    torch.manual_seed(seed)
    np.random.seed(seed)

    forget_indices_sorted = torch.sort(forget_indices)[0]
    other_indices_sorted = torch.sort(other_indices)[0]

    forget_sampled = forget_indices_sorted[:forget_samples]
    other_sampled = other_indices_sorted[:other_samples]

    forget_loader = DataLoader(
        Subset(loader.dataset, forget_sampled),
        batch_size=loader.batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )

    other_loader = DataLoader(
        Subset(loader.dataset, other_sampled),
        batch_size=loader.batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )

    return forget_loader, other_loader

In [12]:
async def test_cka():
    unlearned_model_path = "./c9d4.pth"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"Device: {device}")
    print(f"Forget class: {forget_class}")
    print(f"Unlearned model: {unlearned_model_path}")

    model_before = get_resnet18().to(device)
    model_before.load_state_dict(torch.load(f"./0000.pth", map_location=device))
    model_before.eval()

    model_after = get_resnet18().to(device)
    model_after.load_state_dict(torch.load(unlearned_model_path, map_location=device))
    model_after.eval()

    base_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    clean_train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=base_transforms)
    clean_test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=base_transforms)

    train_loader = DataLoader(clean_train_set, batch_size=1000, shuffle=False, num_workers=0)
    test_loader = DataLoader(clean_test_set, batch_size=1000, shuffle=False, num_workers=0)

    detailed_layers = [
        'conv1',
        'layer1.0',
        'layer1.1',
        'layer2.0',
        'layer2.1',
        'layer3.0',
        'layer3.1',
        'layer4.0',
        'layer4.1',
        'fc'
    ]

    forget_class_train_loader, other_classes_train_loader = filter_loader(train_loader, is_train=True)
    forget_class_test_loader, other_classes_test_loader = filter_loader(test_loader, is_train=False)

    cka = CKA(model_before,
              model_after,
              model1_name="Before Unlearning",
              model2_name="After Unlearning",
              model1_layers=detailed_layers,
              model2_layers=detailed_layers,
              device=device)

    with torch.no_grad():
        cka.compare(forget_class_train_loader, forget_class_train_loader)
        cka.compare(other_classes_train_loader, other_classes_train_loader)
        cka.compare(forget_class_test_loader, forget_class_test_loader)
        cka.compare(other_classes_test_loader, other_classes_test_loader)

In [13]:
async def test_cka_pytorchcka():
    unlearned_model_path = "./c9d4.pth"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"Device: {device}")
    print(f"Forget class: {forget_class}")
    print(f"Unlearned model: {unlearned_model_path}")

    model_before = get_resnet18().to(device)
    model_before.load_state_dict(torch.load(f"./0000.pth", map_location=device))
    model_before.eval()

    model_after = get_resnet18().to(device)
    model_after.load_state_dict(torch.load(unlearned_model_path, map_location=device))
    model_after.eval()

    base_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    clean_train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=base_transforms)
    clean_test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=base_transforms)

    train_loader = DataLoader(clean_train_set, batch_size=1000, shuffle=False, num_workers=0)
    test_loader = DataLoader(clean_test_set, batch_size=1000, shuffle=False, num_workers=0)

    detailed_layers = [
        'conv1',
        'layer1.0',
        'layer1.1',
        'layer2.0',
        'layer2.1',
        'layer3.0',
        'layer3.1',
        'layer4.0',
        'layer4.1',
        'fc'
    ]

    forget_class_train_loader, other_classes_train_loader = filter_loader(train_loader, is_train=True)
    forget_class_test_loader, other_classes_test_loader = filter_loader(test_loader, is_train=False)

    with PytorchCKA(
        model_before,
        model_after,
        model1_name="Before Unlearning",
        model2_name="After Unlearning",
        model1_layers=detailed_layers,
        model2_layers=detailed_layers,
        device=device
    ) as cka:
        cka.compare(forget_class_train_loader)
        cka.compare(other_classes_train_loader)
        cka.compare(forget_class_test_loader)
        cka.compare(other_classes_test_loader)

In [19]:
print("====================================")
print("CKA Computation with Another Library")
print("====================================\n")
start_time = time.time()
await test_cka()
elapsed1 = time.time() - start_time
print(f"\n\ntest_cka elapsed time: {elapsed1:.2f} seconds\n\n\n")

print("========================================")
print("CKA Computation with PyTorch-CKA Library")
print("========================================\n")
start_time = time.time()
await test_cka_pytorchcka()
elapsed2 = time.time() - start_time
print(f"\n\ntest_cka_pytorchcka elapsed time: {elapsed2:.2f} seconds\n\n\n")

improvement = ((elapsed1 - elapsed2) / elapsed2) * 100
print(f"Performance Improvement: {improvement:.2f}%")

CKA Computation with Another Library

Device: cuda
Forget class: 0
Unlearned model: ./c9d4.pth


| Comparing features |: 100%|██████████| 1/1 [00:00<00:00,  1.03it/s]
| Comparing features |: 100%|██████████| 5/5 [00:12<00:00,  2.56s/it]
| Comparing features |: 100%|██████████| 1/1 [00:00<00:00,  1.03it/s]
| Comparing features |: 100%|██████████| 5/5 [00:12<00:00,  2.51s/it]




test_cka elapsed time: 29.28 seconds



CKA Computation with PyTorch-CKA Library

Device: cuda
Forget class: 0
Unlearned model: ./c9d4.pth


Computing CKA: 100%|██████████| 1/1 [00:00<00:00,  1.73it/s]
Computing CKA: 100%|██████████| 5/5 [00:05<00:00,  1.10s/it]
Computing CKA: 100%|██████████| 1/1 [00:00<00:00,  1.89it/s]
Computing CKA: 100%|██████████| 5/5 [00:05<00:00,  1.06s/it]



test_cka_pytorchcka elapsed time: 13.88 seconds



Performance Improvement: 110.92%



