In [1]:
!pip install torch torchvision torch-cka pytorch-cka -q

In [2]:
import gc
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from cka import compute_cka
from torch_cka import CKA as TorchCKA

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

RESNET18_LAYERS = [
    "conv1",
    "layer1.0.conv1",
    "layer1.0.conv2",
    "layer1.1.conv1",
    "layer1.1.conv2",
    "layer2.0.conv1",
    "layer2.0.conv2",
    "layer2.1.conv1",
    "layer2.1.conv2",
    "layer3.0.conv1",
    "layer3.0.conv2",
    "layer3.1.conv1",
    "layer3.1.conv2",
    "layer4.0.conv1",
    "layer4.0.conv2",
    "layer4.1.conv1",
    "layer4.1.conv2",
    "fc",
]

Using device: cuda


In [4]:
def get_resnet18(num_classes=10):
    """Get ResNet-18 modified for CIFAR-10 (32x32 images)"""
    model = models.resnet18(weights=None)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

In [5]:
# Setup dataset and dataloader
base_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=base_transforms)
test_loader = DataLoader(test_set, batch_size=1000, shuffle=False, num_workers=0)

# Create two models for comparison
model1 = get_resnet18().to(device)
model2 = get_resnet18().to(device)
model1.eval()
model2.eval()

print(f"Dataset size: {len(test_set)}")
print(f"Number of batches: {len(test_loader)}")
print(f"Number of layers: {len(RESNET18_LAYERS)}")

100%|██████████| 170M/170M [00:02<00:00, 71.0MB/s]


Dataset size: 10000
Number of batches: 10
Number of layers: 18


In [6]:
def get_current_memory_mb():
    """Get currently allocated memory in MB"""
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / 1024 / 1024
    else:
        import psutil
        import os
        process = psutil.Process(os.getpid())
        return process.memory_info().rss / 1024 / 1024

def cleanup_memory():
    """Clean up GPU/CPU memory"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

In [7]:
# Measure torch-cka retained memory
cleanup_memory()
baseline_memory = get_current_memory_mb()
print(f"Baseline memory: {baseline_memory:.2f} MB")

cka = TorchCKA(
    model1,
    model2,
    model1_layers=RESNET18_LAYERS,
    model2_layers=RESNET18_LAYERS,
    device=device
)

with torch.no_grad():
    cka.compare(test_loader, test_loader)

torch_cka_memory = get_current_memory_mb() - baseline_memory
print(f"torch-cka retained memory: {torch_cka_memory:.2f} MB")

# Cleanup torch-cka objects
del cka
cleanup_memory()

  warn(f"Both model have identical names - {self.model2_info['Name']}. " \


Baseline memory: 87.10 MB


| Comparing features |: 100%|██████████| 10/10 [00:21<00:00,  2.16s/it]

torch-cka retained memory: 4260.20 MB





In [8]:
# Measure pytorch-cka retained memory
cleanup_memory()
baseline_memory = get_current_memory_mb()
print(f"Baseline memory: {baseline_memory:.2f} MB")

result = compute_cka(
    model1,
    model2,
    [test_loader],
    layers=RESNET18_LAYERS,
    device=device,
)

pytorch_cka_memory = get_current_memory_mb() - baseline_memory
print(f"pytorch-cka retained memory: {pytorch_cka_memory:.2f} MB")

# Cleanup
del result
cleanup_memory()

Baseline memory: 4347.31 MB


Computing CKA: 100%|██████████| 10/10 [00:02<00:00,  3.40it/s]


pytorch-cka retained memory: 183.50 MB


In [9]:
print("=" * 50)
print("Memory Efficiency Comparison Results")
print("=" * 50)
print(f"torch-cka retained memory: {torch_cka_memory:.2f} MB")
print(f"pytorch-cka retained memory: {pytorch_cka_memory:.2f} MB")
print()
print(f"pytorch-cka is {torch_cka_memory / pytorch_cka_memory:.2f}x more memory efficient than torch-cka")

Memory Efficiency Comparison Results
torch-cka retained memory: 4260.20 MB
pytorch-cka retained memory: 183.50 MB

pytorch-cka is 23.22x more memory efficient than torch-cka
