In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

# Set your token and load the tokenizer and model
token = "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"
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", token=token)
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B", token=token)

# Dictionary to store activations
activations = {}

# Hook function to capture activations
def get_activation_hook(name):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook

# Register hooks for all layers in the model
for i, layer in enumerate(model.model.layers):
    layer.mlp.register_forward_hook(get_activation_hook(f'layer_{i}_mlp'))

# Encode input text
input_ids = tokenizer.encode("The quick brown fox jumps over the lazy dog", return_tensors="pt")

# Forward pass to collect activations
outputs = model(input_ids)

# Now activations dictionary contains outputs for all layers
final_logits = outputs.logits.squeeze(0)  # Remove batch dimension if necessary


  _torch_pytree._register_pytree_node(
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [2]:

# Initialize lists to store losses
cross_entropy_losses = []
kl_divergences = []

# Linear layer to match hidden size to vocabulary size
hidden_size = activations['layer_0_mlp'].size(-1)
vocab_size = final_logits.size(-1)
linear_projection = nn.Linear(hidden_size, vocab_size).to(final_logits.device)

# Cross-Entropy Loss calculation
for i in range(32):
    layer_logits = activations[f'layer_{i}_mlp']
    
    # Project layer logits to match the final logits shape
    projected_logits = linear_projection(layer_logits)
    
    # Reshape to match the final logits for cross-entropy computation
    projected_logits = projected_logits.view(-1, vocab_size)
    final_logits_reshaped = final_logits.view(-1, vocab_size)
    
    # Compute Cross-Entropy
    loss = F.cross_entropy(projected_logits, final_logits_reshaped.argmax(dim=-1))
    cross_entropy_losses.append(loss.item())



In [3]:
cross_entropy_losses

[11.757238388061523,
 12.334437370300293,
 11.759215354919434,
 11.75847339630127,
 11.755319595336914,
 11.753714561462402,
 11.74796199798584,
 11.754512786865234,
 11.765863418579102,
 11.759199142456055,
 11.74631404876709,
 11.756658554077148,
 11.745182037353516,
 11.75864028930664,
 11.773157119750977,
 11.746389389038086,
 11.761775016784668,
 11.753610610961914,
 11.762552261352539,
 11.76201057434082,
 11.760408401489258,
 11.764768600463867,
 11.766378402709961,
 11.740522384643555,
 11.764920234680176,
 11.754125595092773,
 11.767090797424316,
 11.744226455688477,
 11.720169067382812,
 11.788668632507324,
 11.8467435836792,
 12.33767318725586]

In [4]:

# KL Divergence calculation
for i in range(31):
    for j in range(i + 1, 32):
        logits_i = activations[f'layer_{i}_mlp']
        logits_j = activations[f'layer_{j}_mlp']
        
        # Project logits to match the final logits shape
        projected_logits_i = linear_projection(logits_i)
        projected_logits_j = linear_projection(logits_j)

        # Calculate KL divergence
        kl_div = F.kl_div(F.log_softmax(projected_logits_i.view(-1, vocab_size), dim=-1), 
                          F.softmax(projected_logits_j.view(-1, vocab_size), dim=-1), 
                          reduction='batchmean')
        kl_divergences.append(kl_div.item())

# At this point, cross_entropy_losses and kl_divergences contain the computed values


In [5]:
kl_divergences

[0.2729949355125427,
 0.00041331740794703364,
 0.0004601589753292501,
 0.0004984528059139848,
 0.0005551259382627904,
 0.0006257999921217561,
 0.0006501491297967732,
 0.0006165402010083199,
 0.000642703496851027,
 0.0006624135421589017,
 0.0006623633671551943,
 0.0006383726140484214,
 0.0006809426704421639,
 0.0007675110828131437,
 0.0008893537451513112,
 0.0009327458101324737,
 0.001047943951562047,
 0.0011063308920711279,
 0.0015609966358169913,
 0.0011360043426975608,
 0.001256003975868225,
 0.0012442891020327806,
 0.0013011537957936525,
 0.0014935277868062258,
 0.001790632144547999,
 0.0019352470990270376,
 0.002514322753995657,
 0.003335612593218684,
 0.005038512870669365,
 0.012732106260955334,
 0.4368975758552551,
 0.42961254715919495,
 0.42959871888160706,
 0.4296061098575592,
 0.4296080470085144,
 0.4297683835029602,
 0.4298507571220398,
 0.4298326373100281,
 0.42995819449424744,
 0.4251581132411957,
 0.43000832200050354,
 0.4298298954963684,
 0.42987555265426636,
 0.430037975