In [None]:
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F

# -----------------------------------------------------------------------------

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        # not really a 'bias', more of a mask, but following the OpenAI/HF naming though
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))

    def forward(self, x, return_qkv):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
        # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # attention (materializes the large (T,T) matrix for all the queries and keys)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        y = self.c_proj(y)

        if return_qkv:
            # If we want to store q, k, v for some advanced usage
            # return them in the original shape
            return y, (q, k, v)
        else:
            return y, None

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu    = nn.GELU(approximate='tanh')
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x, matching_indices_1st = None, matching_indices_2nd = None):
        x = x + self.attn(self.ln_1(x), matching_indices_1st, matching_indices_2nd)
        x = x + self.mlp(self.ln_2(x))
        return x

@dataclass
class GPTConfig:
    block_size: int = 1024 # max sequence length
    vocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
    n_layer: int = 12 # number of layers
    n_head: int = 12 # number of heads
    n_embd: int = 768 # embedding dimension

class GPT(nn.Module):

    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict({
            'wte':  nn.Embedding(config.vocab_size, config.n_embd),
            'wpe':  nn.Embedding(config.block_size, config.n_embd),
            'h':    nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            'ln_f': nn.LayerNorm(config.n_embd),
        })
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # This cache holds precomputed states for skipping
        self.cache = {}

    def forward(
        self,
        input_ids: torch.Tensor,
        batch_idx: int,
        t_new: int,
        t_matched: int = None,
        skip_up_to: int = None,
    ):
        """
        input_ids: shape (B, T)
        batch_idx: which batch index
        t_new:     the new token position index we're computing
        t_matched: if copying from a matched position
        skip_up_to: how many layers to skip (None = no skip)
        """
        B, T = input_ids.shape
        device = input_ids.device

        # token + position embedding
        pos = torch.arange(0, T, device=device)
        x = self.transformer['wte'](input_ids) + self.transformer['wpe'](pos)
        # run or skip layers
        for layer_idx, block in enumerate(self.transformer['h']):
            if (t_matched is not None) and (skip_up_to is not None) and (layer_idx < skip_up_to):
                # skip: reuse hidden from the matched position
                x = self.cache[(batch_idx, t_matched)]['hidden'][layer_idx]
            else:
                x_ln = block.ln_1(x)
                attn_out, qkv = block.attn(x_ln, return_qkv=True)
                x = x + attn_out
                x = x + block.mlp(block.ln_2(x))

                # store in cache
                if (batch_idx, t_new) not in self.cache:
                    self.cache[(batch_idx, t_new)] = {'hidden': {}, 'qkv': {}}

                self.cache[(batch_idx, t_new)]['hidden'][layer_idx] = x.detach()
                self.cache[(batch_idx, t_new)]['qkv'][layer_idx] = (qkv[0], qkv[1], qkv[2])

        x = self.transformer['ln_f'](x)
        logits = self.lm_head(x)  # (B, T, vocab_size)
        return logits

    @classmethod
    def from_pretrained(cls, model_type):
        """Loads pretrained GPT-2 model weights from huggingface"""
        assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
        from transformers import GPT2LMHeadModel
        print("loading weights from pretrained gpt: %s" % model_type)

        # n_layer, n_head and n_embd are determined from model_type
        config_args = {
            'gpt2':         dict(n_layer=12, n_head=12, n_embd=768),  # 124M params
            'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
            'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
            'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
        }[model_type]
        config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
        config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
        # create a from-scratch initialized minGPT model
        config = GPTConfig(**config_args)
        model = GPT(config)
        sd = model.state_dict()
        sd_keys = sd.keys()
        sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param

        # init a huggingface/transformers model
        model_hf = GPT2LMHeadModel.from_pretrained(model_type)
        sd_hf = model_hf.state_dict()

        # copy while ensuring all of the parameters are aligned and match in names and shapes
        sd_keys_hf = sd_hf.keys()
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
        # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
        # this means that we have to transpose these weights when we import them
        assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
        for k in sd_keys_hf:
            if any(k.endswith(w) for w in transposed):
                # special treatment for the Conv1D weights we need to transpose
                assert sd_hf[k].shape[::-1] == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k].t())
            else:
                # vanilla copy over the other parameters
                assert sd_hf[k].shape == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k])

        return model

In [None]:
# def detect_ngram_copy(seq_ids: torch.Tensor, n=3, skip_up_to=43):
#     """
#     seq_ids : 1D tensor of shape (T,) for the current sequence of token IDs.
#     n       : n-gram size (e.g. n=3 means we check for a 3-gram match).
#     skip_up_to : how many layers to skip if a match is found (just a parameter).

#     Returns:
#       t_matched, skip_up_to  -> the matched position in seq_ids and how many layers to skip
#       (None, None)           -> if no match is found
#     """

#     T = seq_ids.size(0)
#     if T < n:
#         # Not enough tokens to form an n-gram
#         return None, None

#     # The very last token in the sequence
#     last_token = seq_ids[-1].item()

#     # Find all previous positions where last_token appears
#     #    except the final position itself
#     possible_positions = (seq_ids[:-1] == last_token).nonzero(as_tuple=True)[0]
#     if len(possible_positions) == 0:
#         # No earlier occurrence of this token
#         return None, None

#     # The (n-1) tokens before the last token
#     #    e.g. if n=3, this is the last 2 tokens before the final one
#     #    shape: (n-1,)
#     context_needed = seq_ids[-(n-1):-1]

#     # Scan from latest to earliest possible position
#     matched_pos = None
#     for pos in reversed(possible_positions):
#         # pos is where the same last_token appeared,
#         # but we also need to check if the (n-1) tokens before it
#         # match 'context_needed'.

#         # The matched token in seq_ids is at index 'pos'.
#         # The (n-1) tokens before that are seq_ids[pos-(n-1)+1 : pos],
#         # i.e. a total of (n-1) tokens ending just before 'pos'.
#         # But simpler is to define start = pos - (n-1) + 1:
#         start = pos - (n - 1) + 1
#         if start < 0:
#             # Not enough space in the sequence for an n-1 match
#             continue

#         # candidate (n-1) tokens that appear right before this matched token
#         candidate = seq_ids[start:pos]

#         # Compare with our needed context
#         if torch.all(candidate == context_needed):
#             matched_pos = pos
#             break  # we use the latest match we find

#     # 5) Return (matched_pos, skip_up_to) if found, else (None, None)
#     if matched_pos is None:
#         return None, None
#     else:
#         return matched_pos, skip_up_to

In [None]:
# def sample_next_token(logits, top_k=5):
#     """
#     Sample from the top-k tokens in the final position's logits.
#     logits: (B, T, vocab_size)
#     """
#     last_logits = logits[:, -1, :]  # shape (B, vocab_size)
#     # top-k
#     values, indices = torch.topk(last_logits, k=top_k, dim=-1)
#     # sample from top-k
#     probs = F.softmax(values, dim=-1)
#     idx = torch.multinomial(probs, num_samples=1)  # shape (B, 1)
#     next_token_id = indices.gather(-1, idx)       # shape (B, 1)
#     return next_token_id.squeeze(-1)              # shape (B,)

In [None]:
import time
import torch
import random
import numpy as np

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

def get_input_ids(sent, tokenizer):
    # Convert your sentence to input IDs
    return tokenizer.encode(text, return_tensors='pt').to(device)
    
def get_top_k(logits, top_k):
    prob = torch.softmax(logits, dim=-1)
    top_values, top_indices = torch.topk(prob, top_k, dim=-1)
    # Return top_indices[0] as the integer IDs (or a list of IDs)
    return top_indices[0].tolist()  # list of int IDs

def detect_ngram_copy(seq_ids: torch.Tensor, n=3, skip_up_to=43):
    """
    Minimal function that tries to find n-gram copy scenario
    (just a placeholder – adapt to your real logic)
    """
    T = seq_ids.size(1)  # shape (B=1, T)
    if T < n:
        return None, None
    # 1) last token
    last_token = seq_ids[0, -1].item()
    # 2) find earlier positions of last_token
    possible_pos = (seq_ids[0, :-1] == last_token).nonzero().view(-1)
    if possible_pos.numel() == 0:
        return None, None
    # 3) check (n-1) context
    n_minus_1 = n - 1
    context_needed = seq_ids[0, -(n_minus_1+1):-1]  # last n-1 tokens
    matched_pos = None
    for pos in reversed(possible_pos):
        if pos >= n_minus_1:
            candidate = seq_ids[0, pos-n_minus_1:pos]
            if torch.all(candidate == context_needed):
                matched_pos = pos.item()
                break
    if matched_pos is None:
        return None, None
    else:
        return matched_pos, skip_up_to

Model

In [None]:
model_name = 'gpt2-xl'
model = GPT.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = model.to(device)
model.eval()

In [None]:
from datasets import load_from_disk
from tqdm import tqdm

# Load the dataset from disk
subset = load_from_disk("english_insertions")
prompt_list = []

base_sents = subset['train']['base_sentence'][:1000]
phrases = subset['train']['phrase'][:1000]
edited_sents = subset['train']['edited_sentence'][:1000]

import gc
del subset
gc.collect()

In [None]:
steps = 20        # number of tokens to generate
reps = 5          # number of runs with different seeds
k=10

for i in tqdm(range(1000)):
    prompt = ''
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)

    dict_pred_info = {}

    for rep in range(reps):
        dict_pred_info[rep] = {}
        seed_everything(rep)

        # 1) Copy model generation
        copy_ids = input_ids.clone()
        for step_i in range(steps):
            t0 = time.time()

            # detect copy scenario (toy example)
            t_matched, skip_up_to = detect_ngram_copy(copy_ids, n=3, skip_up_to=43)

            # forward pass
            logits = model.forward(input_ids=copy_ids, 
                                    batch_idx=0, 
                                    t_new= copy_ids.shape[1] - 1, 
                                    t_matched=t_matched, 
                                    skip_up_to=skip_up_to)
            top_k_list = get_top_k(logits, k=k, tokenizer=tokenizer)

            # store in dictionary
            if step_i not in dict_pred_info[rep]:
                dict_pred_info[rep][step_i] = {}
            dict_pred_info[rep][step_i]['copy'] = top_k_list

            # pick next token
            next_token = top_k_list[0]  # top-1 ID
            copy_ids = torch.cat([copy_ids, torch.tensor([[next_token]])], dim=1)

            elapsed_copy = time.time() - t0

            dict_pred_info[rep][step_i]['copy_time'] = elapsed_copy

        # 2) Original model generation
        original_ids = input_ids.clone()
        for step_i in range(steps):
            t0 = time.time()

            # no skip logic
            logits = model.forward(input_ids=original_ids,
                                    t_new = original_ids.shape[1] - 1
                                    batch_idx=0)
            top_k_list = get_top_k(logits, k=k, tokenizer=tokenizer)

            dict_pred_info[rep][step_i]['original'] = top_k_list

            # pick next token
            next_token = top_k_list[0]  # top-1 ID
            original_ids = torch.cat([original_ids, torch.tensor([[next_token]])], dim=1)

            elapsed_orig = time.time() - t0

            dict_pred_info[rep][step_i]['original_time'] = elapsed_orig
