In [2]:
import os
import wandb
from huggingface_hub import login

WANDB_API_KEY=""
HF_TOKEN=""

try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

# Load relevant files from GitHub, if in Colab
if IN_COLAB:
    !wget -q https://raw.githubusercontent.com/tsurbs/SpecDec/main/load_datasets.py
    !wget -q https://raw.githubusercontent.com/tsurbs/SpecDec/main/finetuning_utils.py
    !wget -q https://raw.githubusercontent.com/tsurbs/SpecDec/main/testing_utils.py
    wandb.login()
    login(HF_TOKEN)



In [None]:
from load_datasets import load_stack_samples_representative
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import DataLoader, Dataset

CONFIG = {
    "MODEL_ID": "EleutherAI/pythia-70m",
    "CHECKPOINT_DIR": "./checkpoint",
    "BATCH_SIZE": 2,
    "EPOCHS": 3,
    "LEARNING_RATE": 5e-5,
    
    "LANGUAGE_SUBSET": None,
    "N_TRAIN": 10_000,
    "N_VAL": 100,
}

# Initialize wandb
wandb.init(
    project="specdec-finetuning",
    config=CONFIG,
    name=f"finetune-{CONFIG['MODEL_ID'].split('/')[-1]}",
)

[34m[1mwandb[0m: Currently logged in as: [33mtsurban[0m ([33mtsurban-carnegie-mellon-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.s

In [4]:
model_id = CONFIG["MODEL_ID"]
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
model = AutoModelForCausalLM.from_pretrained(model_id)

checkpoint_model_id = ""
if os.path.exists(checkpoint_model_id):
    model.load_state_dict(torch.load(os.path.join("models", checkpoint_model_id+".pt")))
    print(f"Loaded finetuned model from {checkpoint_model_id}")

if torch.cuda.is_available():
    model.to("cuda") 
model.train()

GPTNeoXForCausalLM(
  (gpt_neox): GPTNeoXModel(
    (embed_in): Embedding(50304, 512)
    (emb_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-5): 6 x GPTNeoXLayer(
        (input_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (post_attention_dropout): Dropout(p=0.0, inplace=False)
        (post_mlp_dropout): Dropout(p=0.0, inplace=False)
        (attention): GPTNeoXSdpaAttention(
          (rotary_emb): GPTNeoXRotaryEmbedding()
          (query_key_value): Linear(in_features=512, out_features=1536, bias=True)
          (dense): Linear(in_features=512, out_features=512, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (mlp): GPTNeoXMLP(
          (dense_h_to_4h): Linear(in_features=512, out_features=2048, bias=True)
          (dense_4h_to_h): Linear(in_features=2048, out_features=512, bias=True)
        

In [None]:
# Load data and split into train/val
all_samples = load_stack_samples_representative()

split_idx = int(0.9 * len(all_samples))
train_samples = all_samples[:split_idx]
val_samples = all_samples[split_idx:]

print(f"Training samples: {len(train_samples)}")
print(f"Validation samples: {len(val_samples)}")

train_dataloader = DataLoader(train_samples, batch_size=CONFIG["BATCH_SIZE"], shuffle=True)
val_dataloader = DataLoader(val_samples, batch_size=CONFIG["BATCH_SIZE"], shuffle=False)

Loading representative code samples from The Stack across multiple languages...
  Loaded 5040 samples for assembly
Total code samples loaded: 5040
Loading representative code samples from The Stack across multiple languages...
  Loaded 5040 samples for assembly
Total code samples loaded: 5040
Loading representative code samples from The Stack across multiple languages...
  Loaded 5040 samples for assembly
Total code samples loaded: 5040
  Loaded 5040 samples for assembly
Total code samples loaded: 5040


In [6]:
def train_epoch(model, dataloader, optimizer, tokenizer, device):
    model.train()
    total_loss = 0
    for batch_idx, batch in enumerate(dataloader):
        inputs = tokenizer(batch['text'], return_tensors='pt', truncation=True, padding=True, max_length=512)
        inputs = {key: val.to(device) for key, val in inputs.items()}
        
        outputs = model(**inputs, labels=inputs['input_ids'])
        loss = outputs.loss
        total_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Log batch loss to wandb
        wandb.log({"batch_loss": loss.item()})
    
    avg_loss = total_loss / len(dataloader)
    return avg_loss

In [7]:
def validate_epoch(model, dataloader, tokenizer, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in dataloader:
            inputs = tokenizer(batch['text'], return_tensors='pt', truncation=True, padding=True, max_length=512)
            inputs = {key: val.to(device) for key, val in inputs.items()}
            
            outputs = model(**inputs, labels=inputs['input_ids'])
            loss = outputs.loss
            total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    return avg_loss

In [8]:
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["LEARNING_RATE"])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)

In [None]:
def train(model, train_dataloader, val_dataloader, optimizer, scheduler, tokenizer, device, epochs):
    for epoch in range(epochs):
        train_loss = train_epoch(model, train_dataloader, optimizer, tokenizer, device)
        val_loss = validate_epoch(model, val_dataloader, tokenizer, device)
        
        scheduler.step(val_loss)
        
        # Log epoch metrics to wandb
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": train_loss,
            "val_loss": val_loss,
            "learning_rate": optimizer.param_groups[0]['lr'],
        })
        
        print(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f} - Val Loss: {val_loss:.4f}")
        
        # Save checkpoint
        os.mkdir("models")
        model_name = model_id.split("/")[-1]
        torch.save(model.state_dict(), f"models/{model_name}_finetuned_epoch{epoch+1}.pt")

        # also save model to gdrive if in Colab
        if IN_COLAB:
            from google.colab import drive
            drive.mount('/content/drive')
            gdrive_path = f"/content/drive/MyDrive/{model_name}_finetuned_epoch{epoch+1}.pt"
            torch.save(model.state_dict(), gdrive_path)
            print(f"Saved model checkpoint to Google Drive at {gdrive_path}")
    
    # Finish wandb run
    wandb.finish()

In [13]:
train(
    model, 
    train_dataloader, 
    val_dataloader, 
    optimizer, 
    scheduler, 
    tokenizer, 
    device="cuda" if torch.cuda.is_available() else "cpu",
    epochs=CONFIG["EPOCHS"],
)

socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.


KeyboardInterrupt: 

socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
