In [25]:
from transformers import AutoTokenizer
from transformers import set_seed
from datasets import load_dataset
import torch
import evaluate

import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import evaluate
from functools import partial
from omegaconf import OmegaConf

from llm_lab.model.vanilla_decoder import VanillaDecoderModel
from llm_lab.utils.collate_utils import default_data_collator_with_padding
#from model.rotary_decoder import RotaryDecoder
%load_ext autoreload
%autoreload 2

Decoder = VanillaDecoderModel

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [26]:

# data parameters
dataset_name="fancyzhx/ag_news"
text_column_name = "text"

# model parameters
model_name_or_path="openai-community/gpt2"
model_name_or_path="stanford-crfm/battlestar-gpt2-small-x49"
use_fast_tokenizer=True
finetuning_task="text-classification",
max_seq_length=512

batch_size = 512
num_workers=16

# training parameters
pad_to_max_length = True
max_train_samples=120000
fp16 = False

## Load data



In [27]:
raw_datasets = load_dataset(dataset_name)
label_list = raw_datasets['train'].unique("label")
# we will treat the label list as a list of string instead of int, consistent with model.config.label2id
label_list = [str(label) for label in label_list]
label_list.sort()
num_labels = len(label_list)


tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
# Padding strategy
if pad_to_max_length:
    padding = "max_length"
else:
    # We will pad later, dynamically at batch creation, to the max sequence length in each batch
    padding = False
    

label_to_id = {v: i for i, v in enumerate(label_list)}

max_seq_length = min(max_seq_length, tokenizer.model_max_length)

def preprocess_function(examples, tokenizer_parameters):
    # batch model; return a dict
    examples["sentence"] = examples[text_column_name]
    # Tokenize the texts
    result = tokenizer(examples["sentence"], **tokenizer_parameters)
    if label_to_id is not None and "label" in examples:
        # batch processing a list of labels
        result["label"] = [label_to_id[str(l)] for l in examples["label"]]
    return result

# test = preprocess_function(raw_datasets['train'][0])
# Running the preprocessing pipeline on all the datasets
tokenizer_params = {
    "padding": False,
    "max_length": max_seq_length,
    "truncation": True}


raw_datasets = raw_datasets.map(
    partial(preprocess_function, tokenizer_parameters=tokenizer_params),
    batched=True,
    num_proc=64,
    desc="Running tokenizer on dataset",
)


In [28]:
tokenizer.eos_token_id

50256

In [29]:
tokenizer.pad_token_id

50256

In [30]:

train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]


In [31]:
data_collator = partial(default_data_collator_with_padding, pad_token_id=tokenizer.pad_token_id, pad_to_multiple_of=8, padding_strategy="longest")
# if pad_to_max_length:
#     data_collator = default_data_collator
# elif fp16:
#     data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
# else:
#     data_collator = None

In [32]:
data_collator([train_dataset[0]])

defaultdict(list,
            {'label': tensor([2]),
             'input_ids': tensor([[22401,   520,    13, 15682, 30358,  5157, 20008,   262,  2619,   357,
                      12637,     8,  8428,   532, 10073,    12,  7255,   364,    11,  5007,
                       3530,   338, 45215,    59,  3903,   286, 14764,    12,   948,    77,
                        873,    11,   389,  4379,  4077,   757,    13, 50256, 50256, 50256]]),
             'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                      1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])})

In [33]:
batch_size = 512
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=data_collator, num_workers=num_workers)
#train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers)

In [34]:
list({'attention_mask', 'input_ids'})

['attention_mask', 'input_ids']

In [35]:
print(train_dataset[0]['attention_mask'])

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [36]:
print(type(train_dataset[0]['input_ids']))

<class 'list'>


## Model

In [37]:
class DecoderClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = Decoder(config = config)
        self.pred_head = nn.Linear(self.encoder.config.d_model, config.num_labels, bias=False)
        self.config = config
        
    def forward(self, batch):
        
        last_hidden_state = self.encoder(batch['input_ids'])
        batch_size = last_hidden_state.shape[0]
        #print(last_hidden_state.shape)
        if self.config.pad_token_id is not None:
            sequence_length = (batch['input_ids'] == self.config.pad_token_id).int().argmax(-1) - 1
            # when there is no pad_token, sequence_length = -1, we use module to make it to the last position
            sequence_length = sequence_length % batch['input_ids'].shape[-1]
            # if we use encoder_outputs.last_hidden_state[:,sequence_length], we got the selection on the second axis, which is wrong
            hiddens = last_hidden_state[torch.arange(batch_size),sequence_length]
        else:
            # use the last token
            hiddens = last_hidden_state[:,-1,:].squeeze()
        
        pred_out = self.pred_head(hiddens)
        return pred_out
        
        

In [38]:
from omegaconf import OmegaConf

config_dict = {
    "vocab_size": tokenizer.vocab_size,
    "context_length": 1024,
    "d_model": 128,
    "num_heads": 2,
    "num_layers": 2,
    "dropout": 0.1,
    "qkv_bias": False,
    "pad_token_id": tokenizer.pad_token_id,
    "num_labels": num_labels,
    "causal_attention": True
}

config = OmegaConf.create(config_dict)

model = DecoderClassifier(config=config)


    
for batch in train_dataloader:
    del batch['label']
    logits = model(batch)
    print(logits.shape)
    break

vanilla decoder model 3
torch.Size([512, 4])


In [39]:
# gpt2 tokenizer padding side is right
tokenizer.padding_side

'right'

In [40]:
# model = BertClassifier(model_name=model_name_or_path, config=config, num_labels=num_labels)

## Training & Evaluation

In [41]:
def move_to_device(batch, device):
    
    for k, v in batch.items():
        if isinstance(v, torch.Tensor):
            batch[k] = v.to(device)
    return batch

In [42]:
def compute_batch_loss(model, inputs, labels, loss_fn):
    
    logits = model(inputs)
    loss = loss_fn(logits, labels)
    
    return loss

def train_one_epoch(train_dataloader, optimizer, loss_fn, model, device, output_freq):
    
    model.train()
    training_loss = []
    step = 0
    for batch in train_dataloader:
        
        batch = move_to_device(batch, device)
        targets = batch['label']
        del batch['label']
        inputs = batch
        
        loss = compute_batch_loss(model, inputs, targets, loss_fn)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        training_loss.append(loss.item())
        step += 1
        
        if step % output_freq == 0:
            print(f"steps: {step}, loss: {sum(training_loss)/step}")
            
    return model

In [43]:
lr = 2e-3
set_seed(1)
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()
device = 'cuda'
model.config.pad_token_id = tokenizer.pad_token_id
model = model.to(device)

for _ in range(2):
    train_one_epoch(train_dataloader, optimizer, loss_fn, model, device, output_freq=50)

steps: 50, loss: 1.5106132340431213
steps: 100, loss: 1.3352366852760316
steps: 150, loss: 1.1725797645250957
steps: 200, loss: 1.0333916260302067
steps: 50, loss: 0.4339019471406937
steps: 100, loss: 0.37769288033246995
steps: 150, loss: 0.3457512793938319
steps: 200, loss: 0.3213794261217117


In [44]:


def compute_metrics(eval_dataloader, metrics, model, device):
    model.eval()
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for batch in eval_dataloader:
            batch = move_to_device(batch, device)
            targets = batch['label']
            del batch['label']
            inputs = batch
            model_output = model(inputs)
            logits = model_output
            
            
            preds = torch.argmax(logits, dim=-1)
            
            all_labels.extend(targets.detach().cpu().tolist())
            all_preds.extend(preds.detach().cpu().tolist())
            
        result = metrics.compute(predictions=all_preds, references = all_labels)
    print(result)
    print(len(all_labels))
            

In [45]:
eval_dataloader = DataLoader(eval_dataset, batch_size=32, collate_fn=data_collator)
metric = evaluate.load("accuracy")
compute_metrics(eval_dataloader, metric, model, device)

{'accuracy': 0.905921052631579}
7600
