In [None]:
! python -m pip install --upgrade pip
! pip install datasets trl wandb tensorboard peft -qU
! pip install flash-attn --no-build-isolation -qU

In [None]:
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer, SFTConfig
from accelerate import Accelerator
import wandb
import math, os, random
from datetime import datetime

In [None]:
def setup_environment():
    # wandb.init(project="webinstructsub-finetuning-small", entity="fish")
    os.environ["WANDB_DISABLED"] = 'true'
    return Accelerator()

In [None]:
def load_model_and_tokenizer(model_name):
    model_kwargs = {
        "torch_dtype": torch.bfloat16
    }
    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct")
    model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
    return model, tokenizer

In [None]:
def load_and_preprocess_train_dataset(start_idx, num_rows):
    dataset = load_dataset("TIGER-Lab/WebInstruct52k", split="train", streaming=True)
    dataset = dataset.skip(start_idx).take(num_rows)
    return dataset

In [None]:
def format_instruction(example):
    return {
        "messages": [
            {"role": "user", "content": example["question"]},
            {"role": "assistant", "content": example["answer"]}
        ]
    }

In [None]:
def load_and_preprocess_eval_dataset():
    dataset = load_dataset("TIGER-Lab/WebInstruct52k", split="train") # Assuming split="train" for eval
    total_rows = len(dataset)
    # Generate a list of random indices
    random_indices = random.sample(range(total_rows), eval_rows)
    # Select the random rows
    dataset = dataset.select(random_indices)
    return dataset

In [None]:
def format_instruction_for_trainer(example):
    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct") # Re-instantiate tokenizer here if not passed
    return tokenizer.apply_chat_template(
        example["messages"],
        truncation=True,
        padding="max_length",
        max_length=2048,
        tokenize_file=False # Assuming this is meant to be tokenize=False
    )

In [None]:
def get_custom_lr_scheduler(optimizer, num_warmup_steps, num_training_steps, initial_phase_steps):
    def lr_lambda(current_step):
        if current_step < initial_phase_steps:
            return 1.0 # Constant learning rate for initial phase
        else:
            # Cosine annealing for the remaining steps
            return 0.5 * (1.0 + math.cos(math.pi * (current_step - initial_phase_steps) / (num_training_steps - initial_phase_steps)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

In [None]:
def main():
    # Model configuration
    model_name = "Trellis/SmolLM-135M-Instruct-layer-pruned-90M-raw"
    # Distilled model
    # model_name = "Trellis/90M-base" # This line is commented out in the image but shows the intent

    # Select rows to train on
    initial_rows = 5000
    annealing_rows = 1000
    eval_rows = 1000 # Only 10000 rows for evaluation

    batch_size = 8
    ga = 4 # Gradient accumulation steps

    learning_rate = 1e-3

    accelerator = setup_environment()

    model, tokenizer = load_model_and_tokenizer(model_name)
    print(model.device)

    # Combined training dataset (streaming)
    total_rows = initial_rows + annealing_rows
    train_dataset = load_and_preprocess_train_dataset(0, total_rows)
    train_dataset = train_dataset.map(format_instruction, batched=False)
    formatted_dataset = train_dataset.map(format_instruction_for_trainer)

    # Evaluation dataset (non-streaming, last 1000 rows)
    eval_dataset = load_and_preprocess_eval_dataset()
    eval_dataset = eval_dataset.map(format_instruction, batched=False)
    eval_dataset = eval_dataset.map(format_instruction_for_trainer)

    # Calculate steps
    num_epochs = 1 # Example value, from image
    total_steps = (total_rows * num_epochs) // (batch_size * ga)
    initial_steps = (initial_rows * num_epochs) // (batch_size * ga)

    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    run_name = f"{model_name.split('/')[-1]}-SFT-(total_rows-lr_{learning_rate})-{timestamp}"
    output_dir = f"./results/{run_name}"

    training_args = SFTConfig(
        output_dir=output_dir,
        run_name=run_name,
        logging_dir=f"./logs/{run_name}",
        eval_strategy="steps",
        save_strategy="steps",
        report_to="tensorboard",
        num_train_epochs=num_epochs, # Set to None when using max_steps
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        warmup_steps=20, # Example value, needs to be calculated
        logging_steps=int(total_steps * 0.1), # Example value, needs to be calculated
        eval_steps=int(total_steps * 0.1), # Example value, needs to be calculated
        save_steps=int(total_steps * 0.1), # Example value, needs to be calculated
        learning_rate=learning_rate,
        bf16=True,
        max_steps=total_steps,
        gradient_accumulation_steps=ga,
    )

    # Custom learning rate scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate)
    lr_scheduler = get_custom_lr_scheduler(optimizer, training_args.warmup_steps, total_steps, initial_steps)

    # Trainer
    trainer = SFTTrainer(
        model=model,
        train_dataset=formatted_dataset,
        eval_dataset=eval_dataset,
        args=training_args,
        tokenizer=tokenizer,
        max_seq_length=2048,
        formatting_func=format_instruction_for_trainer,
        optimizers=(optimizer, lr_scheduler) # Use custom optimizer and scheduler
    )

    trainer = accelerator.prepare(trainer)

    print(f"Starting instruction fine-tuning on {total_rows} rows of data...")
    trainer.train()
    print("Instruction fine-tuning completed. Saving model...")

    trainer.save_model(output_dir)