In [1]:
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification, AutoModel
from transformers import (
    DataCollatorWithPadding,
    Trainer,
    default_data_collator,
    set_seed,
    TrainingArguments,
    HfArgumentParser,
    EvalPrediction,
)
from datasets import load_dataset
import random
import numpy as np
import torch
import evaluate

import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import evaluate

from omegaconf import OmegaConf
import tiktoken

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

# data parameters
dataset_name="fancyzhx/ag_news"
text_column_name = "text"

# model parameters
model_name_or_path="openai-community/gpt2"
model_name_or_path="stanford-crfm/battlestar-gpt2-small-x49"
#model_name_or_path="Qwen/Qwen2.5-0.5B"

use_fast_tokenizer=True
finetuning_task="text-classification",
max_seq_length=512

batch_size = 512
num_workers=16

# training parameters
pad_to_max_length = True
max_train_samples=120000
fp16 = False

## Load data



In [3]:
raw_datasets = load_dataset(dataset_name)
label_list = raw_datasets['train'].unique("label")
# we will treat the label list as a list of string instead of int, consistent with model.config.label2id
label_list = [str(label) for label in label_list]
label_list.sort()
num_labels = len(label_list)


tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
# Padding strategy
if pad_to_max_length:
    padding = "max_length"
else:
    # We will pad later, dynamically at batch creation, to the max sequence length in each batch
    padding = False
    

label_to_id = {v: i for i, v in enumerate(label_list)}

max_seq_length = min(max_seq_length, tokenizer.model_max_length)

def preprocess_function(examples):
# return a dict
    examples["sentence"] = examples[text_column_name]
    # Tokenize the texts
    result = tokenizer(examples["sentence"], padding=padding, max_length=max_seq_length, truncation=True)
    if label_to_id is not None and "label" in examples:
        result["label"] = [(label_to_id[str(l)] if l != -1 else -1) for l in examples["label"]]
    
    # add additional keys: 'input_ids','token_type_ids', 'attention_mask','label'  
    return result

# test = preprocess_function(raw_datasets['train'][0])
# Running the preprocessing pipeline on all the datasets

raw_datasets = raw_datasets.map(
    preprocess_function,
    batched=True,
    num_proc=64,
    desc="Running tokenizer on dataset",
)


In [4]:
tokenizer.eos_token_id

50256

In [5]:
tokenizer.pad_token_id

50256

In [6]:

train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]


In [7]:
if pad_to_max_length:
    data_collator = default_data_collator
elif fp16:
    data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
else:
    data_collator = None

In [8]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=data_collator, num_workers=16)

## Model

In [9]:



class DecoderClassifier(nn.Module):
    def __init__(self, model_name_or_path, config):
        super().__init__()
        
        model_config = AutoConfig.from_pretrained(model_name_or_path)
        model_config.n_layer = config.num_layers
        model_config.n_embed = config.d_model
        model_config.n_head = config.num_heads
        encoder = AutoModel.from_config(model_config)
        
        self.encoder = encoder
        self.pred_head = nn.Linear(self.encoder.config.hidden_size, config.num_labels, bias=False)
        self.config = config
        
    def forward(self, batch):
        
        encoder_outputs = self.encoder(**batch)
        batch_size = encoder_outputs.last_hidden_state.shape[0]
        #print(encoder_outputs.last_hidden_state.shape)
        if self.config.pad_token_id is not None:
            sequence_length = (batch['input_ids'] == self.config.pad_token_id).int().argmax(-1) - 1
            # when there is no pad_token, sequence_length = -1, we use module to make it to the last position
            sequence_length = sequence_length % batch['input_ids'].shape[-1]
            # if we use encoder_outputs.last_hidden_state[:,sequence_length], we got the selection on the second axis, which is wrong
            hiddens = encoder_outputs.last_hidden_state[torch.arange(batch_size),sequence_length]
        else:
            # use the last token
            hiddens = encoder_outputs.last_hidden_state[:,-1,:].squeeze()
        
        pred_out = self.pred_head(hiddens)
        return pred_out
        
        

In [10]:
from omegaconf import OmegaConf

config_dict = {
    "vocab_size": tokenizer.vocab_size,
    "context_length": 1024,
    "d_model": 128,
    "num_heads": 2,
    "num_layers": 2,
    "dropout": 0.1,
    "qkv_bias": False,
    "pad_token_id": tokenizer.pad_token_id,
    "num_labels": num_labels      
}

config = OmegaConf.create(config_dict)

model = DecoderClassifier(model_name_or_path=model_name_or_path, config=config)


    
for batch in train_dataloader:
    del batch['labels']
    logits = model(batch)
    print(logits.shape)
    break

torch.Size([512, 4])


In [11]:
# gpt2 tokenizer padding side is right
print(model)

DecoderClassifier(
  (encoder): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-1): 2 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (pred_head): Linear(in_features=768, out_features=4, bias=False)
)


In [12]:
# model = BertClassifier(model_name=model_name_or_path, config=config, num_labels=num_labels)

## Training & Evaluation

In [13]:
def move_to_device(batch, device):
    
    for k, v in batch.items():
        batch[k] = v.to(device)
    return batch

In [14]:
def compute_batch_loss(model, inputs, labels, loss_fn):
    
    logits = model(inputs)
    loss = loss_fn(logits, labels)
    
    return loss

def train_one_epoch(train_dataloader, optimizer, loss_fn, model, device, output_freq):
    
    model.train()
    training_loss = []
    step = 0
    for batch in train_dataloader:
        
        batch = move_to_device(batch, device)
        targets = batch['labels']
        del batch['labels']
        inputs = batch
        
        loss = compute_batch_loss(model, inputs, targets, loss_fn)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        training_loss.append(loss.item())
        step += 1
        
        if step % output_freq == 0:
            print(f"steps: {step}, loss: {sum(training_loss)/step}")
            
    return model

In [15]:
lr = 2e-5
set_seed(1)
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()
device = 'cuda'
model.config.pad_token_id = tokenizer.pad_token_id
model = model.to(device)

model_trained = train_one_epoch(train_dataloader, optimizer, loss_fn, model, device, output_freq=50)

KeyboardInterrupt: 

In [18]:


def compute_metrics(eval_dataloader, metrics, model, device):
    model.eval()
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for batch in eval_dataloader:
            batch = move_to_device(batch, device)
            targets = batch['labels']
            del batch['labels']
            inputs = batch
            model_output = model(inputs)
            logits = model_output
            
            
            preds = torch.argmax(logits, dim=-1)
            
            all_labels.extend(targets.detach().cpu().tolist())
            all_preds.extend(preds.detach().cpu().tolist())
            
        result = metrics.compute(predictions=all_preds, references = all_labels)
    print(result)
    print(len(all_labels))
            

In [None]:
eval_dataloader = DataLoader(eval_dataset, batch_size=32, collate_fn=data_collator)
metric = evaluate.load("accuracy")
compute_metrics(eval_dataloader, metric, model_trained, device)

{'accuracy': 0.9046052631578947}
7600
