In [2]:
import gzip
import json

def load_json_gz(filename):
    with gzip.open(filename, 'r') as f:
        i = 0
        ret = []
        for json_line in f:
            if i == 10000:
                return ret
            data = json.loads(json_line)
            text = data['text']
            if len(text) > 2000:
                ret.append(text)
                i += 1

Now download the data:

In [3]:
# Load 10000 strings from C4 dataset: https://huggingface.co/datasets/allenai/c4/tree/main/en
strings = load_json_gz('c4-train.00000-of-01024.json.gz')

In [4]:
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch

device = 'cuda:0'

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
def copy_task(batch_size=64, batches=10, model=None, tokenizer=None, token_max_len=25, shuffle=False):
    string_idx = 0
    success_copies = 0
    for _ in range(batches):
        cur_batch = []
        for count in range(batch_size):
            cur_batch.append(strings[count + string_idx])
        input_ids = tokenizer(cur_batch, return_tensors="pt", truncation=True, max_length=token_max_len).to(device)["input_ids"]
        if shuffle:
            col_perm = torch.randperm(input_ids.size(1))
            input_ids = input_ids[:, col_perm]
        input_ids = torch.cat([input_ids, input_ids], dim=1)
        input_ids = torch.cat([input_ids, input_ids[:, 0:1]], dim=1)
        output_ids = model.generate(input_ids, max_new_tokens = token_max_len-1)
        for count in range(batch_size):
            gold_token_len = (input_ids.shape[1]-1) // 2
            if torch.equal(input_ids[count][:gold_token_len], output_ids[count][gold_token_len*2:]):
                success_copies += 1
        string_idx += batch_size
    return success_copies / (batch_size * batches)

In [9]:
mamba_370m_tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-370m-hf")
mamba_370m_model = MambaForCausalLM.from_pretrained("state-spaces/mamba-370m-hf")
mamba_370m_model.to(device)
import time 
result_370m_time = []
result_370m = []
for i in [25, 50, 100, 150, 200, 250]:
    start_time = time.time()
    ans = copy_task(batch_size = 8, model = mamba_370m_model, tokenizer = mamba_370m_tokenizer, token_max_len=i)
    end_time = time.time()
    result_370m_time.append(end_time - start_time)
    result_370m.append(ans)
    print(ans)
    print('the time spent is', end_time - start_time)

print("the result 370m time is ", result_370m_time)
print("the result 370m is ", result_370m)
# print(copy_task(model=mamba_370m_model, tokenizer=mamba_370m_tokenizer, token_max_len=25))
# print(copy_task(model=mamba_370m_model, tokenizer=mamba_370m_tokenizer, token_max_len=50))
# print(copy_task(model=mamba_370m_model, tokenizer=mamba_370m_tokenizer, token_max_len=100))
# print(copy_task(model=mamba_370m_model, tokenizer=mamba_370m_tokenizer, token_max_len=150))
# print(copy_task(model=mamba_370m_model, tokenizer=mamba_370m_tokenizer, token_max_len=200))
# print(copy_task(model=mamba_370m_model, tokenizer=mamba_370m_tokenizer, token_max_len=250))

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


0.9
the time spent is 10.187240600585938
0.6625
the time spent is 19.774027585983276
0.5
the time spent is 40.74758529663086
0.4
the time spent is 62.89952063560486
0.2
the time spent is 80.89390206336975
0.1125
the time spent is 101.32650256156921
the result 370m time is  [10.187240600585938, 19.774027585983276, 40.74758529663086, 62.89952063560486, 80.89390206336975, 101.32650256156921]
the result 370m is  [0.9, 0.6625, 0.5, 0.4, 0.2, 0.1125]


In [11]:
result_370m_shuffle_time = []
result_370m_shuffle = []
for i in [25, 50, 100, 150, 200, 250]:
    start_time = time.time()
    ans = copy_task(batch_size = 8, model = mamba_370m_model, tokenizer = mamba_370m_tokenizer, token_max_len=i, shuffle=True)
    end_time = time.time()
    result_370m_shuffle_time.append(end_time - start_time)
    result_370m_shuffle.append(ans)
    print(ans)
    print('the time spent is', end_time - start_time)


0.9375
the time spent is 9.845363855361938
0.55
the time spent is 19.699750900268555
0.075
the time spent is 40.13555192947388
0.0
the time spent is 60.27432703971863
0.0
the time spent is 80.2201087474823
0.0
the time spent is 100.15279531478882


In [12]:
print("the shuffle 370m results", result_370m_shuffle)
print("the shuffle 370 time", result_370m_shuffle_time)

the shuffle 370m results [0.9375, 0.55, 0.075, 0.0, 0.0, 0.0]
the shuffle 370 time [9.845363855361938, 19.699750900268555, 40.13555192947388, 60.27432703971863, 80.2201087474823, 100.15279531478882]


In [12]:
mamba_1_4b_tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-1.4b-hf")
mamba_1_4b_model = MambaForCausalLM.from_pretrained("state-spaces/mamba-1.4b-hf")
mamba_1_4b_model.to(device)

import time 
result_1_4b_time = []
result_1_4b = []
for i in [25, 50, 100, 150, 200, 250]:
    start_time = time.time()
    ans = copy_task(batch_size = 4, model = mamba_1_4b_model, tokenizer = mamba_1_4b_tokenizer, token_max_len=i)
    end_time = time.time()
    result_1_4b_time.append(end_time - start_time)
    result_1_4b.append(ans)
    print(ans)
    print('the time spent is', end_time - start_time)



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

0.95
the time spent is 4.088669776916504
0.85
the time spent is 8.237766981124878
0.825
the time spent is 16.61814284324646
0.65
the time spent is 24.92847204208374
0.45
the time spent is 33.319599628448486
0.3
the time spent is 40.86095356941223
the result 1.4b time is  [4.088669776916504, 8.237766981124878, 16.61814284324646, 24.92847204208374, 33.319599628448486, 40.86095356941223]
the result 1.4b  [0.95, 0.85, 0.825, 0.65, 0.45, 0.3]


In [13]:
print("the result 1.4b time is ", result_1_4b_time)
print("the result 1.4b ", result_1_4b)

the result 1.4b time is  [4.088669776916504, 8.237766981124878, 16.61814284324646, 24.92847204208374, 33.319599628448486, 40.86095356941223]
the result 1.4b  [0.95, 0.85, 0.825, 0.65, 0.45, 0.3]


In [14]:
mamba_2_8b_tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-2.8b-hf", padding_side='left', cache_dir="./mamba-2.8b-hf")
mamba_2_8b_model = MambaForCausalLM.from_pretrained("state-spaces/mamba-2.8b-hf", cache_dir="./mamba-2.8b-hf")
mamba_2_8b_model.to(device)

import time 
result_2_8b_time = []
result_2_8b = []
for i in [25, 50, 100, 150, 200, 250]:
    start_time = time.time()
    ans = copy_task(batch_size = 4, model = mamba_2_8b_model, tokenizer = mamba_370m_tokenizer, token_max_len=i)
    end_time = time.time()
    result_2_8b_time.append(end_time - start_time)
    result_2_8b.append(ans)
    print(ans)
    print('the time spent is', end_time - start_time)
print("the result 2.8b time is ", result_2_8b_time)
print("the result 2.8b ", result_2_8b)

tokenizer_config.json:   0%|          | 0.00/4.79k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


config.json:   0%|          | 0.00/843 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/50.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/1.15G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

0.95
the time spent is 5.494949817657471
0.975
the time spent is 11.137060403823853
0.9
the time spent is 22.60603427886963
0.825
the time spent is 33.62363815307617
0.7
the time spent is 45.10409688949585
0.5
the time spent is 56.72283935546875
the result 2.8b time is  [5.494949817657471, 11.137060403823853, 22.60603427886963, 33.62363815307617, 45.10409688949585, 56.72283935546875]
the result 2.8b  [0.95, 0.975, 0.9, 0.825, 0.7, 0.5]


: 

In [14]:
import ast

name_phone_pairs = []
with open('./phonebook.txt', 'r') as f:
    for line in f:
        line = line.strip()
        if line[-1] == ',':
            line = line[:-1]
        pair = ast.literal_eval(line)
        name_phone_pairs.append((pair[0], pair[1]))

In [20]:
# We found the phone book experiment hard to reproduce as the author did not give the exact prompt in the paper. 
# In addition, the accuracy fluctutaed a lot with the prompt we used.

import random
def phone_book_task(batch_size=64, batches=10, book_size=20, model=None, tokenizer=None):
    book = ''
    success_lookups = 0
    for i in range(book_size):
        name = name_phone_pairs[i][0]
        phone = name_phone_pairs[i][1]
        book = book + name + ': ' + phone + '.\n'
    book += 'Extract the person\'s phone number in the phonebook above. For example:\nPerson: Liam\nNumber: 436-725-2906\nPerson: Olivia\nNumber: 192-311-5790\n\n'
    for _ in range(batches):
        cur_batch = []
        gold_num_tokens_batch = []
        max_num_tokens = -1
        for _ in range(batch_size):
            query_pair_idx = random.randint(2, book_size)
            query = book + 'Person: ' + name_phone_pairs[query_pair_idx][0] + '\nNumber:'
            gold_num_tokens = tokenizer(name_phone_pairs[query_pair_idx][1], return_tensors="pt", padding=True).to(device)["input_ids"]
            max_num_tokens = max(max_num_tokens, gold_num_tokens.shape[1])
            gold_num_tokens_batch.append(gold_num_tokens[0])
            cur_batch.append(query)
        input_ids = tokenizer(cur_batch, return_tensors="pt", padding=True).to(device)["input_ids"]
        output_ids = model.generate(input_ids, max_new_tokens = max_num_tokens)
        for count in range(batch_size):
            true_number = tokenizer.decode(gold_num_tokens_batch[count])
            output_answer = tokenizer.decode(output_ids[count])
            if output_answer.count(true_number) > 1:
                success_lookups += 1
    return success_lookups / (batch_size * batches)

: 

In [19]:
t1 = time.time()
print(phone_book_task(model=mamba_370m_model, tokenizer=mamba_370m_tokenizer))
t2 = time.time()
print(phone_book_task(model=mamba_370m_model, tokenizer=mamba_370m_tokenizer, book_size=40))
t3 = time.time()
print(phone_book_task(batch_size=32, batches=20, model=mamba_370m_model, tokenizer=mamba_370m_tokenizer, book_size=80))
t4 = time.time()


TypeError: 'module' object is not callable