In [1]:
import random
import torch
import pickle as pkl

from transformers import AutoTokenizer
from transformers import TrainingArguments, Trainer
from transformers import DataCollatorForLanguageModeling

from transformers.models.bloom.configuration_bloom import BloomConfig
from pruning.pruned_bloom import PrunedBloomForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
context_length = 2048

In [3]:
# Load model
weights_path = "pruned_50percent_3B_bloom.pt"
state_dict_shapes_path = "pruned_50percent_3B_bloom_state_dict_shapes.pkl"

bloom_config = BloomConfig(
    vocab_size=250880,
    hidden_size=2560,
    n_layer=30,
    n_head=32,
    layer_norm_epsilon=1e-5,
    initializer_range=0.02,
    use_cache=True,
    bos_token_id=1,
    eos_token_id=2,
    apply_residual_connection_post_layernorm=False,
    hidden_dropout=0.0,
    attention_dropout=0.0,
    pretraining_tp=1,  # TP rank used when training with megatron
    slow_but_exact=False,
    attention_softmax_in_fp32=True,
    bias_dropout_fusion=True,
    masked_softmax_fusion=True,
    offset_alibi=100,
    pad_token_id=3,
    seq_length=2048,
    skip_bias_add=True,
    skip_bias_add_qkv=False,
    unk_token_id=0,
    
)

In [4]:
pruned_model = PrunedBloomForCausalLM(bloom_config, state_dict_shapes_path)
pruned_model.load_state_dict(torch.load(weights_path))

<All keys matched successfully>

In [5]:
#Load data
train_data = pkl.load(open("cleaned_3353_pairs.pkl", "rb"))
val_data = pkl.load(open("conv_dicts/530_human_filtered_conv_pairs.pkl", "rb"))
val_data = [line for line in val_data if line not in set(train_data)]

print(len(val_data))
random.shuffle(train_data)

120


In [6]:
# tokenize data
tokenizer.pad_token = tokenizer.eos_token

def tokenize(data, tokenizer, context_length):
    outputs = tokenizer(
        data,
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    
    input_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length < context_length:
            input_batch.append(input_ids)
            
    return input_batch

In [7]:
class DialogueDataset(torch.utils.data.Dataset):
    def __init__(self, data_list, tokenizer, context_length):
        self.data_strings = data_list
        self.tokenized_data = tokenize(data_list, tokenizer, context_length)
        
    def __len__(self):
        return len(self.data_strings)
    
    def __getitem__(self, idx):
        return self.tokenized_data[idx].copy()

In [8]:
train_dataset = DialogueDataset(train_data, tokenizer, context_length)
val_dataset = DialogueDataset(val_data, tokenizer, context_length)

In [9]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [10]:
# Setup trainer args
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    logging_strategy="steps",
    learning_rate=2e-5,
    weight_decay=0.01,
    num_train_epochs=2,
    logging_steps=5,
)

In [11]:
# init trainer
trainer = Trainer(
    model=pruned_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
)

In [None]:
trainer.train()

***** Running training *****
  Num examples = 3353
  Num Epochs = 2
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 840
  Number of trainable parameters = 1500982941


Epoch,Training Loss,Validation Loss


In [None]:
trainer.save_model()

In [None]:
finetuned_model = trainer.model

In [None]:
line = "user: Do you like watching movies?\nchatbot:"
inputs = tokenizer(line, return_tensors="pt")
pruned_times = []


outputs = finetuned_model.generate(
    input_ids=inputs["input_ids"], 
    max_new_tokens=50, 
    do_sample=True, 
    top_k=50, 
    top_p=0.95,
)

print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

In [None]:
torch.save(finetuned_model, "finetuned_20percent_pruned_bloom560m.pt")

In [None]:
line = "user: Can you tell me about yourself?\nchatbot:"
inputs = tokenizer(line, return_tensors="pt")
pruned_times = []


outputs = finetuned_model.generate(
    input_ids=inputs["input_ids"], 
    max_new_tokens=50, 
    do_sample=True, 
    top_k=50, 
    top_p=0.95,
)

print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

In [None]:
line = "user: you misunderstood me\nchatbot:"
inputs = tokenizer(line, return_tensors="pt")
pruned_times = []


outputs = finetuned_model.generate(
    input_ids=inputs["input_ids"], 
    max_new_tokens=50, 
    do_sample=True, 
    top_k=50, 
    top_p=0.95,
)

print(tokenizer.batch_decode(outputs, skip_special_tokens=True))