In [1]:
import gzip
import json
import torch
from tqdm import tqdm

device = 'cuda:0'
base_path = './'

def load_json_gz(filename):
    with gzip.open(filename, 'r') as f:
        i = 0
        ret = []
        for json_line in f:
            if i == 10000:
                return ret
            data = json.loads(json_line)
            text = data['text']
            if len(text) > 2000:
                ret.append(text)
                i += 1

# Load 10000 strings from C4 dataset: https://huggingface.co/datasets/allenai/c4/tree/main/en
strings = load_json_gz(base_path + 'c4-validation.00000-of-00008.json.gz')

def copy_task(batch_size=64, batches=10, model=None, tokenizer=None, token_max_len=25, shuffle=False):
    string_idx = 0
    success_copies = 0
    for _ in tqdm(range(batches)):
        cur_batch = []
        for count in range(batch_size):
            cur_batch.append(strings[count + string_idx])
        outputs = tokenizer(cur_batch, return_tensors="pt", truncation=True, max_length=token_max_len).to(device)
        input_ids = outputs['input_ids']
        attn_masks = outputs['attention_mask']
        if shuffle:
            col_perm = torch.randperm(input_ids.size(1))
            input_ids = input_ids[:, col_perm]
        input_ids = torch.cat([input_ids, input_ids], dim=1)
        input_ids = torch.cat([input_ids, input_ids[:, 0:1]], dim=1)
        output_ids = model.generate(input_ids, max_new_tokens = token_max_len-1)
        for count in range(batch_size):
            gold_token_len = (input_ids.shape[1]-1) // 2
            if torch.equal(input_ids[count][:gold_token_len], output_ids[count][gold_token_len*2:]):
                success_copies += 1
        string_idx += batch_size
    return success_copies / (batch_size * batches)

In [2]:
import ast

name_phone_pairs = []
with open(base_path + 'phonebook.txt', 'r') as f:
    for line in f:
        line = line.strip()
        if line[-1] == ',':
            line = line[:-1]
        pair = ast.literal_eval(line)
        name_phone_pairs.append((pair[0], pair[1]))

In [3]:
import random
import torch
import torch.nn as nn
from tqdm import tqdm
device = 'cuda'
softmax = nn.Softmax(dim=2)

def phone_book_task(batch_size=64, batches=5, book_size=20, model=None, tokenizer=None):
    book = ''
    success_lookups = 0
    for i in range(book_size):
        name = name_phone_pairs[i][0]
        phone = name_phone_pairs[i][1]
        book = book + name + ': ' + phone + '.\n'
    book += 'Liam: 436-725-2906\nOlivia: 192-311-5790\n\n'
    with torch.no_grad():
        for _ in tqdm(range(batches)):
            cur_batch = []
            gt_numbers = []
            max_num_tokens = 30
            for _ in range(batch_size):
                query_pair_idx = random.randint(2, book_size)
                query = book + name_phone_pairs[query_pair_idx][0] + ':'

                gt_numbers.append(name_phone_pairs[query_pair_idx][1])
                cur_batch.append(query)

            input_ids = tokenizer(cur_batch, return_tensors="pt", padding=True).to(device)["input_ids"]
            for i in range(max_num_tokens-1):
                bs, seq_len = input_ids.size()
                mask = torch.ones(bs, seq_len).to('cuda')
                logits = model(input_ids=input_ids, attention_mask=mask, labels=None)['logits'] # bs, seq_len, vocab_size
                next_token = torch.unsqueeze(torch.argmax(softmax(logits), dim=-1)[:, -1], 1)
                input_ids = torch.cat((input_ids, next_token), dim=-1) # bs, seq_len, 1
            for count in range(batch_size):
                true_number = gt_numbers[count]
                output_answer = tokenizer.decode(input_ids[count])
                if output_answer.count(true_number) > 1:
                    success_lookups += 1
    return success_lookups / (batch_size * batches)

In [4]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn as nn

from modeling_mamba_transformer import MambaTransformerForLM, MambaTransformerConfig
# check_point_path = base_path + 'checkpoint-4900/model.safetensors'
# check_point_path = '/home/hangruic/webllm-test/base_1.4b_1024len_12_12_1_2_epochs/checkpoint-7400/model.safetensors'
check_point_path = '/home/hangruic/webllm-test/sft_2_epoch_100_length_2000_samples_1.4b_ck4900_new/checkpoint-260/model.safetensors'
model = MambaTransformerForLM(MambaTransformerConfig(),
                              pretrained_pythia_name='EleutherAI/pythia-1.4b',
                              pretrained_mamba_name='state-spaces/mamba-1.4b-hf',
                              first_transformer_layers=12,
                              mamba_start_layer=36,
                              mamba_end_layer=47,
                              check_point_path=check_point_path).to(device)


tokenizer = AutoTokenizer.from_pretrained('EleutherAI/pythia-1.4b', padding_side='left')
# model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device)

model.eval()
tokenizer.pad_token_id = tokenizer.eos_token_id
batches = 5
batch_size = 8
shuffle = True
token_max_len = 25

softmax = nn.Softmax(dim=2)

  from .autonotebook import tqdm as notebook_tqdm
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [10]:
string_idx = 0
success_copies = 0
token_max_len = 150
model.eval()
with torch.no_grad():
    for _ in tqdm(range(batches)):
        cur_batch = []
        for count in range(batch_size):
            cur_batch.append(strings[count + string_idx])
        outputs = tokenizer(cur_batch, return_tensors="pt", truncation=True, max_length=token_max_len).to(device)
        input_ids = outputs['input_ids']
        attn_masks = outputs['attention_mask']
        if shuffle:
            col_perm = torch.randperm(input_ids.size(1))
            input_ids = input_ids[:, col_perm]
        input_ids = torch.cat([input_ids, input_ids], dim=1)
        input_ids = torch.cat([input_ids, input_ids[:, 0:1]], dim=1)
        output_ids = input_ids
        for i in range(token_max_len-1):
            bs, seq_len = output_ids.size()
            mask = torch.ones(bs, seq_len).to('cuda')
            logits = model(input_ids=output_ids, attention_mask=mask, labels=None)['logits'] # bs, seq_len, vocab_size
            next_token = torch.unsqueeze(torch.argmax(softmax(logits), dim=-1)[:, -1], 1)
            output_ids = torch.cat((output_ids, next_token), dim=-1) # bs, seq_len, 1

        for count in range(batch_size):
            gold_token_len = (input_ids.shape[1]-1) // 2
            if torch.equal(input_ids[count][:gold_token_len], output_ids[count][gold_token_len*2:]):
                success_copies += 1
        string_idx += batch_size
print(success_copies / (batch_size * batches))

100%|██████████| 5/5 [02:14<00:00, 26.96s/it]

0.6





In [None]:
hybrid_sft_result = [0.9, 0.825, 0.75, 0.6]