In [1]:
import ast

name_phone_pairs = []
with open('./phonebook.txt', 'r') as f:
    for line in f:
        line = line.strip()
        if line[-1] == ',':
            line = line[:-1]
        pair = ast.literal_eval(line)
        name_phone_pairs.append((pair[0], pair[1]))

In [2]:
import random
import torch
import torch.nn as nn
from tqdm import tqdm
device = 'cuda'
softmax = nn.Softmax(dim=2)

def phone_book_task(batch_size=64, batches=10, book_size=20, model=None, tokenizer=None):
    book = ''
    success_lookups = 0
    for i in range(book_size):
        name = name_phone_pairs[i][0]
        phone = name_phone_pairs[i][1]
        book = book + name + ': ' + phone + '.\n'
    book += 'Liam: 436-725-2906\nOlivia: 192-311-5790\n\n'
    with torch.no_grad():
        for _ in tqdm(range(batches)):
            cur_batch = []
            gt_numbers = []
            max_num_tokens = 30
            for _ in range(batch_size):
                query_pair_idx = random.randint(2, book_size)
                query = book + name_phone_pairs[query_pair_idx][0] + ':'
                
                gt_numbers.append(name_phone_pairs[query_pair_idx][1])
                cur_batch.append(query)
                
            input_ids = tokenizer(cur_batch, return_tensors="pt", padding=True).to(device)["input_ids"]
            for i in range(max_num_tokens-1):
                bs, seq_len = input_ids.size()
                mask = torch.ones(bs, seq_len).to('cuda')
                logits = model(input_ids=input_ids, attention_mask=mask, labels=None)['logits'] # bs, seq_len, vocab_size
                next_token = torch.unsqueeze(torch.argmax(softmax(logits), dim=-1)[:, -1], 1)
                input_ids = torch.cat((input_ids, next_token), dim=-1) # bs, seq_len, 1
            for count in range(batch_size):
                true_number = gt_numbers[count]
                output_answer = tokenizer.decode(input_ids[count])
                if output_answer.count(true_number) > 1:
                    success_lookups += 1
    return success_lookups / (batch_size * batches)

In [3]:
from modeling_mamba_transformer import MambaTransformerForLM, MambaTransformerConfig
from transformers import AutoTokenizer

checkpoint_point_path = 'sft_1_epoch_100_length_2000_samples/checkpoint-540/model.safetensors'
hybrid_model = MambaTransformerForLM(MambaTransformerConfig(), checkpoint_point_path).to(device)
hybrid_model.eval()
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/pythia-160m', padding_side='left')
tokenizer.pad_token = tokenizer.eos_token
phone_book_task(batch_size=32, model=hybrid_model, tokenizer=tokenizer)

  from .autonotebook import tqdm as notebook_tqdm
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


0.528125


In [6]:
from transformers import MambaForCausalLM, AutoTokenizer

mamba_130m_tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf", padding_side='left')
mamba_130m_tokenizer.pad_token_id = mamba_130m_tokenizer.eos_token_id
mamba_130m_model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf").to(device)
mamba_130m_model.eval()
phone_book_task(batch_size=32, model=mamba_130m_model, tokenizer=mamba_130m_tokenizer)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


100%|██████████| 10/10 [05:04<00:00, 30.49s/it]


0.05

In [5]:
from transformers import GPTNeoXForCausalLM
pythia_160m_model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-160m").to(device)
pythia_160m_model.eval()
pythia_160m_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m", padding_side='left')
pythia_160m_tokenizer.pad_token_id = pythia_160m_tokenizer.eos_token_id
phone_book_task(batch_size=32, model=pythia_160m_model, tokenizer=pythia_160m_tokenizer)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
100%|██████████| 10/10 [00:57<00:00,  5.74s/it]


0.571875