In [None]:
!pip install torch transformers accelerate datasets numpy wandb

In [None]:
import os
import wandb
from huggingface_hub import login

WANDB_API_KEY=""
HF_TOKEN=""

try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

# Load relevant files from GitHub, if in Colab
if IN_COLAB:
    !wget -q https://raw.githubusercontent.com/tsurbs/SpecDec/main/load_datasets.py
    !wget -q https://raw.githubusercontent.com/tsurbs/SpecDec/main/finetuning_utils.py
    !wget -q https://raw.githubusercontent.com/tsurbs/SpecDec/main/testing_utils.py
    wandb.login(key=WANDB_API_KEY)
    login(HF_TOKEN)

In [None]:
from load_datasets import load_stack_samples
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import DataLoader, Dataset

CONFIG = {
    "MODEL_ID": "EleutherAI/pythia-70m",
    "CHECKPOINT_DIR": "./checkpoint",
    "BATCH_SIZE": 2,
    "EPOCHS": 20,
    "LEARNING_RATE": 5e-5,

    "LANGUAGE_SUBSET": ["python", "cpp", "java"],
    "N_SAMPLES": 3333,
    "VAL_SPLIT": 0.9,
}

# Initialize wandb
wandb.init(
    project="specdec-finetuning",
    config=CONFIG,
    name=f"finetune-hi-resource-{CONFIG['MODEL_ID'].split('/')[-1]}",
)

In [None]:
model_id = CONFIG["MODEL_ID"]
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
model = AutoModelForCausalLM.from_pretrained(model_id)

checkpoint_model_id = ""
if os.path.exists(checkpoint_model_id):
    model.load_state_dict(torch.load(os.path.join("models", checkpoint_model_id+".pt")))
    print(f"Loaded finetuned model from {checkpoint_model_id}")

if torch.cuda.is_available():
    model.to("cuda")
model.train()

In [None]:
languages = [
        "assembly",
        "batchfile",
        "c++",
        "c-sharp",
        "c",
        "cmake",
        "css",
        "dockerfile",
        "fortran",
        "go",
        "haskell",
        "html",
        "java",
        "javascript",
        "julia",
        "lua",
        "makefile",
        "markdown",
        "perl",
        "php",
        "powershell",
        "python",
        "ruby",
        "rust",
        "scala",
        "shell",
        "sql",
        "tex",
        "typescript",
        "visual-basic"
    ] if not CONFIG["LANGUAGE_SUBSET"] else CONFIG["LANGUAGE_SUBSET"]
len(languages)

In [None]:
from random import shuffle
# Load data and split into train/val
all_samples = (load_stack_samples(languages, num_samples=CONFIG["N_SAMPLES"]*10))
shuffle(all_samples)

# Split: 90% train, 10% val
split_idx = int(CONFIG["VAL_SPLIT"] * len(all_samples))

train_samples = all_samples[:split_idx]
val_samples = all_samples[split_idx:]

print(f"Training samples: {len(train_samples)}")
print(f"Validation samples: {len(val_samples)}")

train_dataloader = DataLoader(train_samples, batch_size=CONFIG["BATCH_SIZE"], shuffle=True)
val_dataloader = DataLoader(val_samples, batch_size=CONFIG["BATCH_SIZE"], shuffle=False)

In [None]:
def train_epoch(model, dataloader, optimizer, tokenizer, device):
    model.train()
    total_loss = 0
    for batch_idx, batch in enumerate(dataloader):
        inputs = tokenizer(batch['text'], return_tensors='pt', truncation=True, padding=True, max_length=512)
        inputs = {key: val.to(device) for key, val in inputs.items()}

        outputs = model(**inputs, labels=inputs['input_ids'])
        loss = outputs.loss
        total_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Log batch loss to wandb
        wandb.log({"batch_loss": loss.item()})

    avg_loss = total_loss / len(dataloader)
    return avg_loss

In [None]:
def validate_epoch(model, dataloader, tokenizer, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in dataloader:
            inputs = tokenizer(batch['text'], return_tensors='pt', truncation=True, padding=True, max_length=512)
            inputs = {key: val.to(device) for key, val in inputs.items()}

            outputs = model(**inputs, labels=inputs['input_ids'])
            loss = outputs.loss
            total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    return avg_loss

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["LEARNING_RATE"])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)

In [None]:
def train(model, train_dataloader, val_dataloader, optimizer, scheduler, tokenizer, device, epochs):
    for epoch in range(epochs):
        train_loss = train_epoch(model, train_dataloader, optimizer, tokenizer, device)
        val_loss = validate_epoch(model, val_dataloader, tokenizer, device)

        scheduler.step(val_loss)

        # Log epoch metrics to wandb
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": train_loss,
            "val_loss": val_loss,
            "learning_rate": optimizer.param_groups[0]['lr'],
        })

        print(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f} - Val Loss: {val_loss:.4f}")
        model_name = CONFIG["MODEL_ID"].split("/")[-1]
        os.makedirs("models", exist_ok=True)
        # Save checkpoint
        torch.save(model.state_dict(), f"models/{model_name}_finetuned_epoch{epoch+1}.pt")

        # also save model to gdrive if in Colab
        if IN_COLAB:
            from google.colab import drive
            drive.mount('/content/drive')
            gdrive_path = f"/content/drive/MyDrive/{model_name}_high_finetuned_epoch{epoch+1}.pt"
            torch.save(model.state_dict(), gdrive_path)
            print(f"Saved model checkpoint to Google Drive at {gdrive_path}")

    # Finish wandb run
    wandb.finish()

In [None]:
train(
    model,
    train_dataloader,
    val_dataloader,
    optimizer,
    scheduler,
    tokenizer,
    device="cuda" if torch.cuda.is_available() else "cpu",
    epochs=CONFIG["EPOCHS"],
)