In [8]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
device = 'cuda'

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
model_name = "openai-community/gpt2-medium"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,  # Force float16 instead of BF16
    device_map="auto"           # Auto-detect the best device
)

In [3]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1024)
    (wpe): Embedding(1024, 1024)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-23): 24 x GPT2Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=3072, nx=1024)
          (c_proj): Conv1D(nf=1024, nx=1024)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=4096, nx=1024)
          (c_proj): Conv1D(nf=1024, nx=4096)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1024, out_features=50257, bias=False)
)

In [3]:
model.config.num_attention_heads

16

In [4]:
sd_hf = model.state_dict()
for k, v in sd_hf.items():
  print(k, v.shape)

transformer.wte.weight torch.Size([50257, 1024])
transformer.wpe.weight torch.Size([1024, 1024])
transformer.h.0.ln_1.weight torch.Size([1024])
transformer.h.0.ln_1.bias torch.Size([1024])
transformer.h.0.attn.c_attn.weight torch.Size([1024, 3072])
transformer.h.0.attn.c_attn.bias torch.Size([3072])
transformer.h.0.attn.c_proj.weight torch.Size([1024, 1024])
transformer.h.0.attn.c_proj.bias torch.Size([1024])
transformer.h.0.ln_2.weight torch.Size([1024])
transformer.h.0.ln_2.bias torch.Size([1024])
transformer.h.0.mlp.c_fc.weight torch.Size([1024, 4096])
transformer.h.0.mlp.c_fc.bias torch.Size([4096])
transformer.h.0.mlp.c_proj.weight torch.Size([4096, 1024])
transformer.h.0.mlp.c_proj.bias torch.Size([1024])
transformer.h.1.ln_1.weight torch.Size([1024])
transformer.h.1.ln_1.bias torch.Size([1024])
transformer.h.1.attn.c_attn.weight torch.Size([1024, 3072])
transformer.h.1.attn.c_attn.bias torch.Size([3072])
transformer.h.1.attn.c_proj.weight torch.Size([1024, 1024])
transformer.h.1

In [None]:
# # Sample prompt
# prompt = "The "
# inputs = tokenizer(prompt, return_tensors="pt").to(device)

# # Generate output text
# with torch.no_grad():
#     generated_ids = model.generate(
#         inputs.input_ids,
#         max_length=10,  # Adjust length as needed
#         temperature=0.7,  # Adjust for diversity (lower = more deterministic)
#         top_k=50,  # Consider top-k sampling
#         top_p=0.95,  # Nucleus sampling
#         do_sample=True  # Enables sampling instead of greedy decoding
#     )

# # Decode generated text
# generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

# print("Generated text:", generated_text)


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Generated text: The vernacular of the era was "s


In [5]:
from hooks import *
remove_all_forward_hooks(model)
register_all_forward_hooks(model)

batch_size = 16
total_samples = 1024
num_batches = total_samples // batch_size

prompt = "The future of AI is"
inputs = tokenizer(prompt, return_tensors="pt").to(device)

with torch.no_grad():
    for _ in range(num_batches):
        outputs = model(**inputs)

In [6]:
compute_importance_scores(model)

In [13]:
for module in model.modules():
    print(module)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1024)
    (wpe): Embedding(1024, 1024)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-23): 24 x GPT2Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=3072, nx=1024)
          (c_proj): Conv1D(nf=1024, nx=1024)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=4096, nx=1024)
          (c_proj): Conv1D(nf=1024, nx=4096)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1024, out_features=50257, bias=False)
)
GPT2Model(
  (wte): Embedding(50257, 1024)

In [None]:
import torch.nn as nn
from typing import List, Tuple

def find_gpt2_layer_connections(model) -> List[Tuple[nn.Module, nn.LayerNorm, nn.Module]]:
    """
    Finds dependencies in GPT-2 Blocks.
    Returns a list of (prev_layer, ln_layer, next_layer) for pruning.
    """
    connections = []
    prev_ln = None
    model_blocks = [module for module in model.modules() if isinstance(module, GPT2Block)]  # Get all GPT2Blocks

    for i, block in enumerate(model_blocks):
        ln_1 = block.ln_1
        ln_2 = block.ln_2
        
        # Layers that should be modified by ln_1
        c_attn = block.attn.c_attn  # Inside attention block
        next_c_proj = None

        # If there's a next block, c_proj in MLP is affected by this block’s ln_1
        if i + 1 < len(model_blocks):
            next_c_proj = model_blocks[i + 1].mlp.c_proj

        # Layers that should be modified by ln_2
        c_proj_attn = block.attn.c_proj  # Projection in attention block
        c_fc = block.mlp.c_fc  # Fully connected layer in MLP
        
        # Store dependencies
        connections.append((ln_1, c_attn, next_c_proj))  # ln_1 affects c_attn and next block's c_proj
        connections.append((ln_2, c_proj_attn, c_fc))    # ln_2 affects c_proj (attn) and c_fc (MLP)

    return connections


GPT2Block(
  (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  (attn): GPT2Attention(
    (c_attn): Conv1D(nf=3072, nx=1024)
    (c_proj): Conv1D(nf=1024, nx=1024)
    (attn_dropout): Dropout(p=0.1, inplace=False)
    (resid_dropout): Dropout(p=0.1, inplace=False)
  )
  (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  (mlp): GPT2MLP(
    (c_fc): Conv1D(nf=4096, nx=1024)
    (c_proj): Conv1D(nf=1024, nx=4096)
    (act): NewGELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [7]:
# Print importance scores for each registered module
for name, module in model.named_modules():
    if hasattr(module, "importance_scores"):
        print(f"Layer {module.__class__.__name__}: P{name} : importance scores:", module.importance_scores.shape)

Layer LayerNorm: Ptransformer.h.0.ln_1 : importance scores: torch.Size([1024])
Layer Conv1D: Ptransformer.h.0.attn.c_proj : importance scores: torch.Size([16])
Layer Conv1D: Ptransformer.h.0.mlp.c_fc : importance scores: torch.Size([4096])
Layer Conv1D: Ptransformer.h.1.attn.c_proj : importance scores: torch.Size([16])
Layer Conv1D: Ptransformer.h.1.mlp.c_fc : importance scores: torch.Size([4096])
Layer Conv1D: Ptransformer.h.2.attn.c_proj : importance scores: torch.Size([16])
Layer Conv1D: Ptransformer.h.2.mlp.c_fc : importance scores: torch.Size([4096])
Layer Conv1D: Ptransformer.h.3.attn.c_proj : importance scores: torch.Size([16])
Layer Conv1D: Ptransformer.h.3.mlp.c_fc : importance scores: torch.Size([4096])
Layer Conv1D: Ptransformer.h.4.attn.c_proj : importance scores: torch.Size([16])
Layer Conv1D: Ptransformer.h.4.mlp.c_fc : importance scores: torch.Size([4096])
Layer Conv1D: Ptransformer.h.5.attn.c_proj : importance scores: torch.Size([16])
Layer Conv1D: Ptransformer.h.5.mlp.

In [None]:
i = 0
for module in model.modules():
    if hasattr(module, "importance_buffer") and module.importance_buffer:
        # print(module.importance_buffer)
        print(module.__class__.__name__)
        print(len(module.importance_buffer))
        print(module.importance_buffer[0].shape)
        print('============')
        
        print([module.importance_buffer[i].shape for i in range(len(module.importance_buffer))])
        print('============')
        i += 1
        if i>10:
            break
        # all_outputs = torch.cat(module.importance_buffer, dim=0)  # Concatenate over batch dimension

        # # Compute norm-based importance
        # importance = all_outputs.norm(p=2, dim=0).mean(dim=0)

        # module.importance_scores = importance
        # del module.importance_buffer

LayerNorm
32
torch.Size([1, 5, 1024])
[torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024])]
Conv1D
32
torch.Size([1, 5, 1024])
[torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 5, 1024]), torch.Size([1, 

In [3]:
import torch

# Example values
embed_size = 4
head_size = 3
top_heads = [0, 1]  # Example list, len(top_heads) = 2

# Create a random tensor with shape (embed_size, head_size)
pruned_q = torch.randn(head_size, embed_size)
print("Not Transformed pruned_q shape:", pruned_q.shape)
# Repeat along the first dimension
pruned_q = pruned_q.repeat(len(top_heads), 1)

# Print results
print("Transformed pruned_q shape:", pruned_q.shape)
pruned_q


Not Transformed pruned_q shape: torch.Size([3, 4])
Transformed pruned_q shape: torch.Size([6, 4])


tensor([[-0.3777,  0.4140, -0.0667, -0.7624],
        [ 1.1976,  0.4413, -0.1174,  0.7456],
        [ 2.1201,  0.1060, -1.6151,  0.9055],
        [-0.3777,  0.4140, -0.0667, -0.7624],
        [ 1.1976,  0.4413, -0.1174,  0.7456],
        [ 2.1201,  0.1060, -1.6151,  0.9055]])

In [14]:
from typing import List, Tuple
import torch.nn as nn

def find_layer_connections(model) -> List[Tuple[nn.Module, nn.LayerNorm, nn.Module]]:
    """
    Finds (prev_layer, ln_layer, next_layer) triplets to track dependencies in the model.
    This method ensures we correctly map LayerNorm layers with their corresponding input/output layers.
    
    We consider:
    - nn.Linear
    - nn.Embedding
    - nn.Conv1d
    - nn.MultiheadAttention (for transformers)
    """
    connections = []
    prev_layer = None
    model_layers = list(model.modules())  # Convert generator to list for better indexing
    
    for i, module in enumerate(model_layers):
        if isinstance(module, nn.LayerNorm):
            ln_layer = module
            next_layer = None  # Placeholder

            # Find the next valid layer
            for j in range(i + 1, len(model_layers)):
                sub_module = model_layers[j]
                if isinstance(sub_module, (nn.Linear, nn.Conv1d, nn.Embedding, nn.MultiheadAttention)):
                    next_layer = sub_module
                    break

            if prev_layer is not None and next_layer is not None:
                connections.append((prev_layer, ln_layer, next_layer))

            prev_layer = ln_layer  # The current LayerNorm becomes the "previous" layer

        elif isinstance(module, (nn.Linear, nn.Conv1d, nn.Embedding, nn.MultiheadAttention)):
            prev_layer = module  # Keep track of last valid layer

    return connections


In [16]:
find_layer_connections(model)

[(Embedding(1024, 1024),
  LayerNorm((1024,), eps=1e-05, elementwise_affine=True),
  Linear(in_features=1024, out_features=50257, bias=False)),
 (LayerNorm((1024,), eps=1e-05, elementwise_affine=True),
  LayerNorm((1024,), eps=1e-05, elementwise_affine=True),
  Linear(in_features=1024, out_features=50257, bias=False)),
 (LayerNorm((1024,), eps=1e-05, elementwise_affine=True),
  LayerNorm((1024,), eps=1e-05, elementwise_affine=True),
  Linear(in_features=1024, out_features=50257, bias=False)),
 (LayerNorm((1024,), eps=1e-05, elementwise_affine=True),
  LayerNorm((1024,), eps=1e-05, elementwise_affine=True),
  Linear(in_features=1024, out_features=50257, bias=False)),
 (LayerNorm((1024,), eps=1e-05, elementwise_affine=True),
  LayerNorm((1024,), eps=1e-05, elementwise_affine=True),
  Linear(in_features=1024, out_features=50257, bias=False)),
 (LayerNorm((1024,), eps=1e-05, elementwise_affine=True),
  LayerNorm((1024,), eps=1e-05, elementwise_affine=True),
  Linear(in_features=1024, out_f