In [1]:
import einops
from dataclasses import dataclass
import torch as t
from torch import Tensor
import torch.nn as nn
import numpy as np
import math
from tqdm.notebook import tqdm
from typing import Tuple, List, Optional, Dict, Callable
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
from jaxtyping import Float, Int
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import Dataset

In [2]:
# Load model and tokenizer
device = t.device("cuda" if t.cuda.is_available() else "cpu")

MODEL_NAME = 'distilgpt2'
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

In [3]:
class TokenDataset(Dataset):
    def __init__(self, data, labels):
        """
        Args:
            data (Tensor): A tensor of shape (batch_size, seq_len) containing token ids.
            labels (Tensor): A tensor of shape (batch_size, seq_len) containing label token ids.
        """
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        label = self.labels[idx]
        return item, label

In [6]:
# generate data
initial_prompt = 'test'
n_data = 10
seq_len = 10
n_numeral_per_seq = 5
vocab_size = tokenizer.vocab_size
zero_token_id = tokenizer(str(0)).input_ids[0]
one_token_id = tokenizer(str(0)).input_ids[0]
print(zero_token_id)
numeral_tokens = tokenizer([str(i) for i in range(100)]+[' '+str(i) for i in range(100)], return_tensors="pt")['input_ids'].flatten()

print(numeral_tokens.shape)
def generate_data_tokens(n_data = n_data, seq_len=seq_len, n_numeral_per_seq = n_numeral_per_seq) -> Tuple[Tensor]:
 # Create a boolean mask for the entire vocabulary
    mask = t.ones(vocab_size, dtype=t.bool)
    mask[numeral_tokens] = 0

    # Find the indices where mask is True
    available_tokens = t.nonzero(mask).view(-1)

    # Randomly choose numeral tokens for each row in a batched manner
    random_numeral_indices = t.randint(0, numeral_tokens.size(0), (n_data, n_numeral_per_seq))
    numeral_rows = numeral_tokens[random_numeral_indices]

    # Randomly choose non-numeral tokens
    random_vocab_indices = t.randint(0, available_tokens.size(0), (n_data, seq_len - n_numeral_per_seq))
    random_vocab_tokens = available_tokens[random_vocab_indices]

    # Combine and shuffle
    combined_tensor = t.cat((numeral_rows, random_vocab_tokens), dim=1)
    numeral_mask = t.zeros_like(combined_tensor)

    for i in range(combined_tensor.size(0)):
        shuffle_indices = t.randperm(seq_len)
        combined_tensor[i] = combined_tensor[i][shuffle_indices]

        # Initially, mark the first n_numeral_per_seq positions as numeral (1)
        temp_mask = t.zeros(seq_len, dtype=t.int)
        temp_mask[:n_numeral_per_seq] = 1

        # Shuffle the mask using the same indices
        numeral_mask[i] = temp_mask[shuffle_indices]

    return combined_tensor, numeral_mask


data_token_ids, labels = generate_data_tokens()

# create dataset and dataloader
batch_size = 5
token_dataset = TokenDataset(data_token_ids, labels)
dataloader = DataLoader(token_dataset, batch_size=batch_size, shuffle=True)
print(data_token_ids)

15
torch.Size([200])
tensor([[44219,  2682, 12073,  1105, 45480,  7253,    19,  3865,  2548, 47828],
        [ 2682,    18,  7895, 17210,  2624, 39098,   860,  1501,  2026, 11161],
        [12787, 15099,  3553,  8250,  2242,  6135, 25000,  2920,  5607,  9088],
        [ 1485,  4576,  5014,  4349,  4304, 21158,   513, 13442, 33725, 31967],
        [40546, 46457,  3459, 24694,  6420,  4764, 11005,  1821, 10952,  1821],
        [23382,  3261, 31006,  9225,  6469,   807,  2683, 44687,   807, 19806],
        [ 1899, 18955, 21516,  1596, 14397,  5705, 28132, 25879,  1731,  2548],
        [11699,  2996,  5824,  2920, 36405, 22982,  1795, 24894, 49915,  1415],
        [ 3439, 32199, 36500,  1160, 38721,  1919,  2091,    22,  7734,  5996],
        [18289,  3134,    21,  6420, 38463, 48167,  2079, 17199, 45996,  1065]])


In [None]:
def create_full_base_prompts(initial_prompt, data):
    pass

In [None]:
# creat initial model cache maybe

In [None]:
for batch_idx, (data, labels) in enumerate(dataloader):
    base_prompt_tokens = create_full_base_prompts(initial_prompt, data)

    cache = model(base_prompt_tokens, use_cache=True).past_key_values

    #TODO randomize order or data,labels
    for last_token, label in zip(data, labels):
        outputs = model(last_token, cache)
        logits = outputs.logits[:,:,-1,:].squeeze()