In [1]:
import torch as t
from torch.nn import functional as F
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import LlamaForCausalLM 
import einops
import matplotlib.pyplot as plt
import numpy as np
from typing import Union, Optional, Tuple, Any
from torch import Tensor
from dataclasses import dataclass, field
from tqdm.notebook import tqdm
from jaxtyping import Int, Float
from typing import List, Dict
from collections import defaultdict
from torch.utils.data import DataLoader, Dataset
import datetime
llama_token = "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"
device = t.device("cuda:0" if t.cuda.is_available() else "cpu")

In [30]:
t.cuda.empty_cache()
n_param = 7
tokenizer = AutoTokenizer.from_pretrained(
            f"meta-llama/Llama-2-{n_param}b-chat-hf", ignore_mismatched_sizes=True, token=llama_token, use_fast=True, add_bos_token=False, add_prefix_space=False, add_special_tokens=False
        )
tokenizer.pad_token = tokenizer.eos_token

In [77]:
   
def generate_data_tokens(n_data: int, seq_len: int, n_numeral_per_seq: int) -> Tuple[Tensor]:
    vocab_size = tokenizer.vocab_size
    numeral_tokens = tokenizer.batch_encode_plus([str(i) for i in range(10)], return_tensors='pt')['input_ids'][:,1].flatten()  
    
    # Create a boolean mask for the entire vocabulary
    mask = t.ones(vocab_size, dtype=t.bool)
    mask[numeral_tokens] = 0

    # Find the indices where mask is True
    available_tokens = t.nonzero(mask).squeeze()

    # Randomly choose numeral tokens 
    random_numeral_indices = t.randint(0, numeral_tokens.shape[0], (n_data, seq_len))
    numerals_tensor = numeral_tokens[random_numeral_indices]

    # Randomly choose non-numeral tokens
    random_vocab_indices = t.randint(0, available_tokens.shape[0], (n_data, seq_len))
    random_text_tokens = available_tokens[random_vocab_indices]

    # Create a mask for the numeral tokens
    numeral_probs = t.rand(n_data, seq_len, dtype=t.float)
    numeral_mask = (numeral_probs < 0.5)

    data = random_text_tokens
    data[numeral_mask] = numerals_tensor[numeral_mask]
    return data, numeral_mask

data_token_ids, labels = generate_data_tokens(n_data = 10, seq_len = 10, n_numeral_per_seq = 2)
print(labels)
print(data_token_ids)

for i in range(data_token_ids.size(0)):
    sequence = data_token_ids[i].tolist()  # Convert the tensor to a list of integers
    decoded_sequence = tokenizer.decode(sequence)  # Decode the sequence
    print(decoded_sequence)  # Print the decoded text

tensor([[False, False, False,  True,  True, False,  True,  True,  True,  True],
        [ True, False, False, False,  True,  True, False,  True,  True, False],
        [ True,  True,  True, False,  True,  True,  True, False,  True, False],
        [False, False,  True,  True,  True, False, False, False, False,  True],
        [ True,  True, False,  True,  True, False, False,  True,  True, False],
        [False,  True, False,  True, False,  True,  True, False,  True, False],
        [ True,  True,  True,  True, False,  True, False, False, False, False],
        [False,  True, False, False, False,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False,  True, False],
        [False, False, False, False, False,  True, False,  True,  True,  True]])
tensor([[23704, 22557,  7613, 29955, 29947,  1300, 29953, 29896, 29929, 29946],
        [29953, 29683, 27069,  7372, 29929, 29947, 31914, 29955, 29953, 27720],
        [29945, 29929, 29945, 25026, 29

In [53]:
numeral_tokens = tokenizer.batch_encode_plus([str(0)], return_tensors='pt')['input_ids']
print(numeral_tokens)
for id in numeral_tokens:
    print(tokenizer.decode(id))

tensor([[29871, 29900]])
0


In [62]:
numeral_tokens = tokenizer.encode("".join(str(i) for i in range(10)))[1:]
print(numeral_tokens)

[29900, 29896, 29906, 29941, 29946, 29945, 29953, 29955, 29947, 29929]
