# CKA Analysis for LLM Models

Quick-start guide for analyzing layer representations in transformer-based language models.

**Requirements:** `pip install pytorch-cka transformers torch`

In [None]:
import copy

import torch
from torch.utils.data import DataLoader
from transformers import (
    BertConfig, BertModel,
    GPT2Config, GPT2Model,
    DistilBertConfig, DistilBertModel,
)

from pytorch_cka import CKA, CKAConfig
from pytorch_cka.viz import plot_cka_heatmap, plot_cka_trend, plot_cka_comparison

In [None]:
class DictDataset:
    """Simple dataset returning dict batches with input_ids."""

    def __init__(self, input_ids, attention_mask):
        self.input_ids = input_ids
        self.attention_mask = attention_mask

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
        }

In [None]:
def create_text_dataloader(batch_size=8, num_samples=32, seq_length=64, vocab_size=1000):
    """Create dataloader with random token IDs (synthetic data)."""
    input_ids = torch.randint(0, vocab_size, (num_samples, seq_length))
    attention_mask = torch.ones(num_samples, seq_length, dtype=torch.long)
    dataset = DictDataset(input_ids, attention_mask)
    return DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [None]:
def get_bert_layers(model, max_layers=12):
    """Extract BERT encoder layer names."""
    return [f"encoder.layer.{i}" for i in range(min(model.config.num_hidden_layers, max_layers))]


def get_gpt2_layers(model, max_layers=12):
    """Extract GPT-2 transformer block layer names."""
    return [f"h.{i}" for i in range(min(model.config.n_layer, max_layers))]


def get_distilbert_layers(model, max_layers=12):
    """Extract DistilBERT transformer layer names."""
    return [f"transformer.layer.{i}" for i in range(min(model.config.n_layers, max_layers))]

In [None]:
# Global configuration
BATCH_SIZE = 8          # Must be > 3 for unbiased HSIC
NUM_SAMPLES = 32        # 4 batches
SEQ_LENGTH = 64         # Sequence length
VOCAB_SIZE = 1000       # Shared vocab size

# CKA config with float64 for numerical stability
cka_config = CKAConfig(
    kernel="linear",
    unbiased=True,
    dtype=torch.float64,
)

# Create shared dataloader
dataloader = create_text_dataloader(BATCH_SIZE, NUM_SAMPLES, SEQ_LENGTH, VOCAB_SIZE)
print(f"Created dataloader with {len(dataloader)} batches of size {BATCH_SIZE}")

## 1. Self-Similarity Analysis

Analyze how layer representations relate within a single BERT model.

In [None]:
# Small BERT configuration (no pretrained weights download)
bert_config = BertConfig(
    hidden_size=256,
    num_hidden_layers=6,
    num_attention_heads=4,
    intermediate_size=512,
    vocab_size=VOCAB_SIZE,
    max_position_embeddings=128,
)
bert = BertModel(bert_config)

# Get encoder layer names
bert_layers = get_bert_layers(bert)
print(f"BERT layers: {bert_layers}")

In [None]:
with CKA(bert, layers1=bert_layers, model1_name="BERT", config=cka_config) as cka:
    bert_self_cka = cka.compare(dataloader, progress=True)

print(f"Matrix shape: {bert_self_cka.shape}")
print(f"Diagonal values (should be ~1.0): {torch.diag(bert_self_cka)}")

In [None]:
fig, ax = plot_cka_heatmap(
    bert_self_cka,
    layers1=bert_layers,
    layers2=bert_layers,
    model1_name="BERT",
    model2_name="BERT",
    title="BERT Self-Similarity",
    annot=True,
    annot_fmt=".2f",
    mask_upper=True,
    layer_name_depth=2,
)

## 2. Cross-Architecture Comparison

Compare encoder (BERT) vs decoder (GPT-2) architectures.

In [None]:
# Small GPT-2 configuration
gpt2_config = GPT2Config(
    n_embd=256,
    n_layer=6,
    n_head=4,
    vocab_size=VOCAB_SIZE,
    n_positions=128,
)
gpt2 = GPT2Model(gpt2_config)

# Get transformer block names
gpt2_layers = get_gpt2_layers(gpt2)
print(f"GPT-2 layers: {gpt2_layers}")

In [None]:
with CKA(
    bert, gpt2,
    layers1=bert_layers,
    layers2=gpt2_layers,
    model1_name="BERT",
    model2_name="GPT-2",
    config=cka_config,
) as cka:
    cross_arch_cka = cka.compare(dataloader, progress=True)

print(f"Matrix shape: {cross_arch_cka.shape}")

In [None]:
fig, ax = plot_cka_heatmap(
    cross_arch_cka,
    layers1=bert_layers,
    layers2=gpt2_layers,
    model1_name="BERT",
    model2_name="GPT-2",
    title="BERT vs GPT-2",
    annot=True,
    layer_name_depth=2,
)

## 3. Model Size Comparison

Compare DistilBERT (3 layers) vs BERT (6 layers).

In [None]:
# DistilBERT configuration (fewer layers than BERT)
distilbert_config = DistilBertConfig(
    dim=256,
    n_layers=3,  # Half the layers of our BERT
    n_heads=4,
    hidden_dim=512,
    vocab_size=VOCAB_SIZE,
    max_position_embeddings=128,
)
distilbert = DistilBertModel(distilbert_config)

# Get transformer layer names
distilbert_layers = get_distilbert_layers(distilbert)
print(f"DistilBERT layers: {distilbert_layers}")

In [None]:
with CKA(
    bert, distilbert,
    layers1=bert_layers,
    layers2=distilbert_layers,
    model1_name="BERT",
    model2_name="DistilBERT",
    config=cka_config,
) as cka:
    size_cka = cka.compare(dataloader, progress=True)

print(f"Matrix shape: {size_cka.shape}")

In [None]:
fig, ax = plot_cka_heatmap(
    size_cka,
    layers1=bert_layers,
    layers2=distilbert_layers,
    model1_name="BERT (6 layers)",
    model2_name="DistilBERT (3 layers)",
    title="Model Size Comparison",
    annot=True,
    layer_name_depth=2,
)

## 4. Pre-trained vs Fine-tuned Comparison

Simulate fine-tuning by perturbing model weights in later layers.

In [None]:
# Create a copy and perturb weights to simulate fine-tuning
bert_finetuned = copy.deepcopy(bert)

# Add noise to later layers (simulating task-specific fine-tuning)
with torch.no_grad():
    for name, param in bert_finetuned.named_parameters():
        if "encoder.layer.4" in name or "encoder.layer.5" in name:
            param.add_(torch.randn_like(param) * 0.1)

print("Created fine-tuned model with perturbed layers 4-5")

In [None]:
with CKA(
    bert, bert_finetuned,
    layers1=bert_layers,
    layers2=bert_layers,
    model1_name="Pre-trained",
    model2_name="Fine-tuned",
    config=cka_config,
) as cka:
    finetune_cka = cka.compare(dataloader, progress=True)

# Diagonal shows layer-wise similarity before/after fine-tuning
print(f"Layer-wise similarity: {torch.diag(finetune_cka)}")

In [None]:
fig, ax = plot_cka_heatmap(
    finetune_cka,
    layers1=bert_layers,
    layers2=bert_layers,
    model1_name="Pre-trained",
    model2_name="Fine-tuned",
    title="Pre-trained vs Fine-tuned BERT",
    annot=True,
    layer_name_depth=2,
)

In [None]:
fig, ax = plot_cka_trend(
    torch.diag(finetune_cka),
    labels=["Layer Similarity"],
    xlabel="Layer Index",
    ylabel="CKA Similarity",
    title="Fine-tuning Impact by Layer",
)

## Summary: Multi-Panel Comparison

Compare all scenarios side by side.

In [None]:
# Compute GPT-2 self-similarity for completeness
with CKA(gpt2, layers1=gpt2_layers, model1_name="GPT-2", config=cka_config) as cka:
    gpt2_self_cka = cka.compare(dataloader, progress=False)

# Create comparison plot
fig, axes = plot_cka_comparison(
    matrices=[bert_self_cka, gpt2_self_cka, cross_arch_cka, finetune_cka],
    titles=["BERT Self", "GPT-2 Self", "BERT vs GPT-2", "Pre vs Fine-tuned"],
    ncols=2,
    share_colorbar=True,
    figsize=(12, 10),
)

## Key Observations

1. **Self-similarity**: Diagonal values ~1.0; nearby layers show higher similarity
2. **Cross-architecture**: Despite different designs, similar depth layers may share representations
3. **Model size**: DistilBERT layers may correspond to multiple BERT layers
4. **Fine-tuning**: Early layers remain stable; task-specific adaptations occur in later layers