# Fine tune with full scale dataset

## Import and utilities

In [1]:
import pandas as pd
import torch
import wandb
from datasets import Dataset, load_dataset
from transformers import (AutoModelForCausalLM, AutoTokenizer, Trainer,
                          TrainingArguments)

In [2]:
import sys 
import os 
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [3]:
from src.model import load_fo_model
from src.data import load_flan_dataset

Using device: mps


In [None]:
OUTPUT_DIR = "output/"

if not os.path.exists(OUTPUT_DIR):
    os.mkdir(OUTPUT_DIR)

## Load model and dataset

In [4]:
# dataset = load_dataset("chiayewken/flan-v2", split="train", streaming=True)

# model_name = "EleutherAI/pythia-160m"  #select the lm model
# model = AutoModelForCausalLM.from_pretrained(model_name)
# tokenizer = AutoTokenizer.from_pretrained(model_name)

In [5]:
model, tokenizer = load_fo_model()
dataset = load_flan_dataset()

README.md:   0%|          | 0.00/6.83k [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/2167 [00:00<?, ?it/s]

Dataset loaded successfully


In [6]:
print(tokenizer.pad_token)
print(tokenizer.pad_token_id)
print(tokenizer.eos_token)

None
None
<|endoftext|>


In [7]:
# Configure tokenizer properly 

# - https://github.com/EleutherAI/pythia/issues/156 
#   mentioned it's okay to set it to eos

# - https://huggingface.co/EleutherAI/pythia-14m/discussions/4, 
#   we can see the tokenizer pad token from tokenizer.added_tokens_decoder 
tokenizer.added_tokens_decoder

{0: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 1: AddedToken("<|padding|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 50254: AddedToken("                        ", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
 50255: AddedToken("                       ", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
 50256: AddedToken("                      ", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
 50257: AddedToken("                     ", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
 50258: AddedToken("                    ", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
 50259: AddedToken("                   ", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
 50260: AddedToken("           

In [8]:
# without custom pad token set to zero the convergence behaves weirdly
tokenizer.pad_token    = "<|padding|>"
tokenizer.pad_token_id = 1     

In [9]:
tokenizer.eos_token

'<|endoftext|>'

In [10]:
def preprocess_forward(example):
    combined_text = f"{example['source']}\n{example['target']}"
    tokenized_output = tokenizer(combined_text, truncation=True, padding="max_length", max_length=768, return_tensors="pt")

    input_ids = tokenized_output["input_ids"].squeeze(0)
    labels = input_ids.clone()
    labels[:-1] = input_ids[1:]  # Shift left
    labels[-1] = -100  # Ignore loss for last token

    return {"input_ids": input_ids, "labels": labels}

In [None]:
def preprocess_forward(example):
    # combine input, output, and eos token - (I:... Q:... A:... <eos>)
    combined_text = f"{example['inputs']}\n{example['targets']}{tokenizer.eos_token}"
    
    tokenized = tokenizer(
        combined_text,
        truncation=True,
        max_length=768,
        padding="max_length",
        return_tensors="pt"
    )
    
    input_ids = tokenized["input_ids"][0]
    labels = input_ids.clone()
    
    # shift left 
    labels[:-1] = input_ids[1:]
    
    # mask the pad and eos token
    labels[input_ids == tokenizer.pad_token_id] = -100
    labels[input_ids == tokenizer.eos_token_id] = -100
    
    return {"input_ids": input_ids, "labels": labels}

In [12]:
for i, example in enumerate(dataset):
    tokenized_example = preprocess_forward(example)
    print(f"Example {i + 1}:")
    print(f"Input IDs: {tokenized_example['input_ids']}")
    print(f"Decoded Tokens: {tokenizer.decode(tokenized_example['input_ids'], skip_special_tokens=True)}")
    print("-" * 80 + '\n\n')

    if i == 2:  # Show only 3 examples
        break

Example 1:
Input IDs: tensor([  510,   637,   320,  4645,   521, 23908,   281,  7747,   285,   417,
          816,  5753,   537,  1893,   752,   812,   320,   253,  1953,    32,
          187, 23433,  3560,   407,  3662,    27,  1310,   346,    34,   637,
         9398,  3644,  4759, 14133,   285, 17052,    81,  1103,   310,  4645,
          745,   247, 30040, 20953,   449,  1057,   326,  1599,   326,   346,
           34,   637,   310,  4645,   521, 30040, 20953,   281,   512,   253,
         5753,   449,    32,   187, 10976,    27,   187,    14,  4754,   187,
           14,   352,   310,   417,  1896,   281,  2028,   187,    14,   642,
          187,   262,   310,   417,  1896,   281,  2028,   187,   187,   510,
        13131,  2550,   320, 17800,   285, 17390,   387,   253,  1072,   673,
           15,   187,   510,  1953,   285,  3662,   403,  2708,    15,   187,
         5804,   359,  7525,   432,   346,  7910,  2872,  8290,  2186,   247,
        17800, 13131,  1223,  1097,   403,

In [None]:
# Configuration
batch_size = 8
total_examples = 378_000_000  # Total examples in FLAN dataset
train_ratio = 0.9
val_ratio = 0.05
test_ratio = 0.05

def preprocess_forward(example):
    # combine input, output, and eos token - (I:... Q:... A:... <eos>)
    combined_text = f"{example['inputs']}\n{example['targets']}{tokenizer.eos_token}"
    
    tokenized = tokenizer(
        combined_text,
        truncation=True,
        max_length=768,
        padding="max_length",
        return_tensors="pt"
    )
    
    input_ids = tokenized["input_ids"][0]
    labels = input_ids.clone()
    
    # shift left 
    labels[:-1] = input_ids[1:]
    
    # mask the pad and eos token
    labels[input_ids == tokenizer.pad_token_id] = -100
    labels[input_ids == tokenizer.eos_token_id] = -100
    #TODO: Do we need to mask the prompt tokens?
    
    return {"input_ids": input_ids, "labels": labels}


# Shuffle and split before preprocessing
shuffled = dataset.shuffle(seed=42, buffer_size=100_000)

# Calculate split sizes
train_size = int(total_examples * train_ratio)
val_size = int(total_examples * val_ratio)
test_size = int(total_examples * test_ratio)

# Create splits
train_raw = shuffled.take(train_size)
remaining = shuffled.skip(train_size)
val_raw = remaining.take(val_size)
test_raw = remaining.skip(val_size).take(test_size)

# Preprocess each split
tokenized_train = train_raw.map(preprocess_forward, batched=True)
tokenized_val = val_raw.map(preprocess_forward, batched=True)
tokenized_test = test_raw.map(preprocess_forward, batched=True)
#TODO: Was there any reason we use batched=false?

In [None]:
# Training configuration
training_args = TrainingArguments(
    output_dir=f"{OUTPUT_DIR}/pythia-finetuned",
    eval_strategy="steps",
    eval_steps=500,
    learning_rate=1e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    max_steps=10_000,  # Required for streaming datasets
    weight_decay=0.01,
    save_total_limit=2,
    save_steps=1000,
    logging_dir="./logs",
    logging_steps=100,
    gradient_accumulation_steps=2,
    fp16=True,
    report_to="wandb",
    push_to_hub=False
)

# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer,
)

# Start training
trainer.train()

# Save final model
trainer.save_model(f"{OUTPUT_DIR}/pythia-finetuned-final")
tokenizer.save_pretrained(f"{OUTPUT_DIR}/pythia-finetuned-final")

# Reversed 

In [None]:
dataset = load_dataset("chiayewken/flan-v2", split="train", streaming=True)

model_name = "afterless/reverse-pythia-160m"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer.pad_token = "<pad>"
tokenizer.pad_token_id = 0

def preprocess_reverse(example):
    combined_text = f"{example['source']}\n{example['target']}"

    tokenized_output = tokenizer(
        combined_text,
        truncation=True,
        padding="max_length",
        max_length=768,
        return_tensors="pt",
    )

    input_ids = tokenized_output["input_ids"].squeeze(0)  # Remove batch dim
    reversed_input_ids = input_ids.flip(dims=[0])  # Reverse sequence

    labels = reversed_input_ids.clone()
    labels = torch.roll(labels, shifts=1, dims=0)  # Shift right by 1
    labels[0] = -100  # Ignore loss for first token

    return {
        "input_ids": reversed_input_ids,
        "labels": labels,
    }

for i, example in enumerate(dataset):
    tokenized_example = preprocess_reverse(example)
    print(f"Example {i + 1}:")
    print(f"Input IDs: {tokenized_example['input_ids']}")
    print(f"Decoded Tokens: {tokenizer.decode(tokenized_example['input_ids'], skip_special_tokens=True)}")
    print("-" * 50)

    if i == 2:  # Show only 3 examples
        break

from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
import torch
from datasets import load_dataset
