In [6]:
import gzip
import json

def load_json_gz(filename):
    with gzip.open(filename, 'r') as f:
        i = 0
        ret = []
        for json_line in f:
            if i == 10000:
                return ret
            data = json.loads(json_line)
            text = data['text']
            if len(text) > 2000:
                ret.append(text)
                i += 1

# Load 10000 strings from C4 dataset: https://huggingface.co/datasets/allenai/c4/tree/main/en
strings = load_json_gz('graphing/c4-validation.00000-of-00008.json.gz')
import torch
from tqdm import tqdm

device = 'cuda:0'

In [15]:
def copy_task(batch_size=64, batches=10, model=None, tokenizer=None, token_max_len=25, shuffle=False):
    string_idx = 0
    success_copies = 0
    for _ in tqdm(range(batches)):
        cur_batch = []
        for count in range(batch_size):
            cur_batch.append(strings[count + string_idx])
        outputs = tokenizer(cur_batch, return_tensors="pt", truncation=True, max_length=token_max_len).to(device)
        input_ids = outputs['input_ids']
        attn_masks = outputs['attention_mask']
        if shuffle:
            col_perm = torch.randperm(input_ids.size(1))
            input_ids = input_ids[:, col_perm]
        input_ids = torch.cat([input_ids, input_ids], dim=1)
        input_ids = torch.cat([input_ids, input_ids[:, 0:1]], dim=1)
        output_ids = model.generate(input_ids, max_new_tokens = token_max_len-1)
        for count in range(batch_size):
            gold_token_len = (input_ids.shape[1]-1) // 2
            if torch.equal(input_ids[count][:gold_token_len], output_ids[count][gold_token_len*2:]):
                success_copies += 1
        string_idx += batch_size
    return success_copies / (batch_size * batches)

In [7]:
from transformers import GPTNeoXForCausalLM
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
# pythia_160m_model = GPTNeoXForCausalLM.from_pretrained(
#   "EleutherAI/pythia-160m"
# )
# pythia_160m_model.to(device)
# pythia_160m_tokenizer = AutoTokenizer.from_pretrained(
#   "EleutherAI/pythia-160m"
# )
# pythia_160m_tokenizer.pad_token_id = pythia_160m_tokenizer.eos_token_id


In [2]:
from modeling_mamba_transformer import MambaTransformerForLM, MambaTransformerConfig

checkpoint_point_path = 'seq_len_2048_6_transformer_layers/checkpoint-14000/model.safetensors'
model = MambaTransformerForLM(MambaTransformerConfig(), checkpoint_point_path).to('cuda')
model.eval()

MambaTransformerForLM(
  (model): MambaTransformer(
    (embed_in): Embedding(50304, 768)
    (emb_dropout): Dropout(p=0.0, inplace=False)
    (first_transformer_layers): ModuleList(
      (0-5): 6 x GPTNeoXLayer(
        (input_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (post_attention_dropout): Dropout(p=0.0, inplace=False)
        (post_mlp_dropout): Dropout(p=0.0, inplace=False)
        (attention): GPTNeoXAttention(
          (rotary_emb): GPTNeoXRotaryEmbedding()
          (query_key_value): Linear(in_features=768, out_features=2304, bias=True)
          (dense): Linear(in_features=768, out_features=768, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (mlp): GPTNeoXMLP(
          (dense_h_to_4h): Linear(in_features=768, out_features=3072, bias=True)
          (dense_4h_to_h): Linear(in_features=3072, out_features=768, b

In [19]:
from transformers import AutoTokenizer
import torch.nn as nn
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/pythia-160m', padding_side='left')
tokenizer.pad_token_id = tokenizer.eos_token_id
batches = 10
batch_size = 64
shuffle = False
token_max_len = 150

softmax = nn.Softmax(dim=2)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [20]:
string_idx = 0
success_copies = 0
with torch.no_grad():
    for _ in tqdm(range(batches)):
        cur_batch = []
        for count in range(batch_size):
            cur_batch.append(strings[count + string_idx])
        outputs = tokenizer(cur_batch, return_tensors="pt", truncation=True, max_length=token_max_len).to(device)
        input_ids = outputs['input_ids']
        attn_masks = outputs['attention_mask']
        if shuffle:
            col_perm = torch.randperm(input_ids.size(1))
            input_ids = input_ids[:, col_perm]
        input_ids = torch.cat([input_ids, input_ids], dim=1)
        input_ids = torch.cat([input_ids, input_ids[:, 0:1]], dim=1)
        output_ids = input_ids
        for i in range(token_max_len-1):
            bs, seq_len = output_ids.size()
            mask = torch.ones(bs, seq_len).to('cuda')
            logits = model(input_ids=output_ids, attention_mask=mask, labels=None)['logits'] # bs, seq_len, vocab_size
            next_token = torch.unsqueeze(torch.argmax(softmax(logits), dim=-1)[:, -1], 1)
            output_ids = torch.cat((output_ids, next_token), dim=-1) # bs, seq_len, 1
            
        for count in range(batch_size):
            gold_token_len = (input_ids.shape[1]-1) // 2
            if torch.equal(input_ids[count][:gold_token_len], output_ids[count][gold_token_len*2:]):
                success_copies += 1
        string_idx += batch_size
print(success_copies / (batch_size * batches))


100%|██████████| 10/10 [13:28<00:00, 80.86s/it]

0.3734375





In [None]:
# print(copy_task(model=pythia_160m_model, tokenizer=pythia_160m_tokenizer, token_max_len=25, shuffle=False))
print(copy_task(model=pythia_160m_model, tokenizer=pythia_160m_tokenizer, token_max_len=50, shuffle=False))
# print(copy_task(model=pythia_160m_model, tokenizer=pythia_160m_tokenizer, token_max_len=100, shuffle=False))
# print(copy_task(model=pythia_160m_model, tokenizer=pythia_160m_tokenizer, token_max_len=150, shuffle=False))

In [13]:
mamba_130m_tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
mamba_130m_tokenizer.pad_token_id = mamba_130m_tokenizer.eos_token_id
mamba_130m_model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
mamba_130m_model.eval()
mamba_130m_model.to(device)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


MambaForCausalLM(
  (backbone): MambaModel(
    (embeddings): Embedding(50280, 768)
    (layers): ModuleList(
      (0-23): 24 x MambaBlock(
        (norm): MambaRMSNorm()
        (mixer): MambaMixer(
          (conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
          (act): SiLU()
          (in_proj): Linear(in_features=768, out_features=3072, bias=False)
          (x_proj): Linear(in_features=1536, out_features=80, bias=False)
          (dt_proj): Linear(in_features=48, out_features=1536, bias=True)
          (out_proj): Linear(in_features=1536, out_features=768, bias=False)
        )
      )
    )
    (norm_f): MambaRMSNorm()
  )
  (lm_head): Linear(in_features=768, out_features=50280, bias=False)
)

In [18]:
print(copy_task(model=mamba_130m_model, tokenizer=mamba_130m_tokenizer, token_max_len=25, shuffle=False))
# print(copy_task(model=mamba_130m_model, tokenizer=mamba_130m_tokenizer, token_max_len=50))
# print(copy_task(model=mamba_130m_model, tokenizer=mamba_130m_tokenizer, token_max_len=100))
# print(copy_task(model=mamba_130m_model, tokenizer=mamba_130m_tokenizer, token_max_len=150))

100%|██████████| 10/10 [00:07<00:00,  1.27it/s]

0.840625



