## General imports and setup

In [None]:
import os
import sys
import torch

module_path = os.path.abspath(os.path.join("../src"))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
# Log in to your W&B account on console before running this
# https://docs.wandb.ai/quickstart#2-log-in-to-wb

import wandb

wandb.login()

In [None]:
from utils import set_seed

set_seed(62)

In [None]:
from constants import FALSE_LABEL_STR, TRUE_LABEL_STR

id2label = {0: FALSE_LABEL_STR, 1: TRUE_LABEL_STR}
label2id = {FALSE_LABEL_STR: 0, TRUE_LABEL_STR: 1}

Optimization hyperparams

In [None]:
int8_training = True  # https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/
lora_training = True  # https://github.com/microsoft/LoRA
autocast_training = True  # Trains with quantized weights. Only use if your hardware doesn't support int8_training

## Setup Model of Choice
Only run one

LLama

In [None]:
from transformers import LlamaTokenizer, LlamaForSequenceClassification


model_name = "meta-llama/Llama-2-7b-hf"
tokenizer = LlamaTokenizer.from_pretrained(model_name, use_auth_token=True)
model = LlamaForSequenceClassification.from_pretrained(
    model_name,
    num_labels=2,
    id2label=id2label,
    label2id=label2id,
    load_in_8bit=int8_training,
    low_cpu_mem_usage=int8_training,
    use_auth_token=True,
)

# This is automatically done otherwise
if not int8_training:
    model = model.to(device)

tokenizer.add_special_tokens({"pad_token": "<PAD>"})
model.config.pad_token_id = tokenizer.pad_token_id
model.resize_token_embeddings(len(tokenizer))

# tokenizer.padding_side = "right"  # Keep an eye on this if you change the model

GPT Neo

In [None]:
from transformers import AutoTokenizer, GPTNeoForSequenceClassification


model_name = "EleutherAI/gpt-neo-125M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = GPTNeoForSequenceClassification.from_pretrained(
    model_name,
    num_labels=2,
    id2label=id2label,
    label2id=label2id,
    load_in_8bit=int8_training,
    low_cpu_mem_usage=int8_training,
)

# This is automatically done otherwise
if not int8_training:
    model = model.to(device)

tokenizer.add_special_tokens({"pad_token": "<PAD>"})
model.config.pad_token_id = tokenizer.pad_token_id
model.resize_token_embeddings(len(tokenizer))

## Set Hyperparams + train function params

In [None]:
run_name = "Llama7b-TQAAUG"
project_name = "Judge-Training"
store_locally = False  # Set False if you want to delete any config + checkpoint files in models/ (doesn't delete from subdirectories)
upload_to_wandb = True

batch_size = 16
lr = 5e-5
lr_scheduler = "cosine-annealing"  # "cosine-annealing" | None

epochs = 5  # 5k steps with BS=16
acc_every_batch = 50
eval_every_batch = 50
save_every_epoch = 1

## Train Model

Train with default dataset

In [None]:
from models.sft_training import train_judge_on_vanilla_tqa


train_judge_on_vanilla_tqa(
    model=model,
    tokenizer=tokenizer,
    model_name=model_name,
    run_name=run_name,
    project_name=project_name,
    device=device,
    lr=lr,
    lr_scheduler=lr_scheduler,
    autocast_training=autocast_training,
    int8_training=int8_training,
    lora_training=lora_training,
    shuffle=True,
    train_prop=0.8,
    batch_size=batch_size,
    store_locally=store_locally,
    upload_to_wandb=upload_to_wandb,
    epochs=epochs,
    acc_every_batch=acc_every_batch,
    eval_every_batch=eval_every_batch,
    save_every_epoch=save_every_epoch,
)

Trains with augmented dataset

In [9]:
from models.sft_training import train_judge_with_full_dataset


train_judge_with_full_dataset(
    model=model,
    tokenizer=tokenizer,
    model_name=model_name,
    run_name=run_name,
    project_name=project_name,
    device=device,
    lr=lr,
    lr_scheduler=lr_scheduler,
    autocast_training=autocast_training,
    int8_training=int8_training,
    lora_training=lora_training,
    train_prop=0.8,
    batch_size=batch_size,
    balanced=True,
    store_locally=store_locally,
    upload_to_wandb=upload_to_wandb,
    epochs=epochs,
    acc_every_batch=acc_every_batch,
    eval_every_batch=eval_every_batch,
    save_every_epoch=save_every_epoch,
)