## README!

Trlx does not accept a lora trained model, or at least I could not figure out how to make it load one (however, you can make it convert a pretrained model to lora after it started).
There is also a bug when using `int8_training` where the loss does not have a gradient - this seems to only happen with the language modeling objective and not for classification, hence we did not run into this issue when training a judge.
As a consequence, we can't use a lot of memory optimization for warming up models, at least not until we have moved on from trlx. Make sure to use `torch_dtype=torch.bfloat16` when loading the model and use a low batch size for larger models!

## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys

module_path = os.path.abspath(os.path.join("../src"))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import numpy as np
import pandas as pd
import torch
from models.evaluation import generate_completion

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
from utils import set_seed

set_seed(62)

## Model Setup

In [None]:
int8_training = True  # https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/
lora_training = True  # https://github.com/microsoft/LoRA
lora_type = "CAUSAL_LM"
autocast_training = True  # Trains with quantized weights. Only use if your hardware doesn't support int8_training

LLama

In [None]:
from transformers import LlamaTokenizer, LlamaForCausalLM


model_name = "meta-llama/Llama-2-7b-hf"
tokenizer = LlamaTokenizer.from_pretrained(model_name, use_auth_token=True)
model = LlamaForCausalLM.from_pretrained(
    model_name,
    load_in_8bit=int8_training,
    low_cpu_mem_usage=int8_training,
    use_auth_token=True,
)

# This is automatically done otherwise
if not int8_training:
    model = model.to(device)

tokenizer.add_special_tokens({"pad_token": "<PAD>"})
model.config.pad_token_id = tokenizer.pad_token_id
model.resize_token_embeddings(len(tokenizer))

GPT Neo

In [None]:
from transformers import AutoTokenizer, GPTNeoForCausalLM


model_name = "EleutherAI/gpt-neo-125M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = GPTNeoForCausalLM.from_pretrained(
    model_name,
    load_in_8bit=int8_training,
    low_cpu_mem_usage=int8_training,
)

# This is automatically done otherwise
if not int8_training:
    model = model.to(device)

tokenizer.add_special_tokens({"pad_token": "<PAD>"})
model.config.pad_token_id = tokenizer.pad_token_id
model.resize_token_embeddings(len(tokenizer))

## Supervised Warmup

In [None]:
run_name = "Neo125M 0%p OPT"
project_name = "SFT-MultiRC"

train_filename = "poisoned_multirc_easy_train_prop=0_filtered"
val_filename = "poisoned_multirc_easy_val_prop=0_filtered"

batch_size = 16
lr = 5e-5
lr_scheduler = "cosine-annealing"  # "cosine-annealing" | None

epochs = 5
eval_every_batch = 100
save_every_epoch = 1

If you want the model to learn to predict the context, question, and answer, set the below variable to True. Otherwise False.

In [None]:
from models.lm_utils import LMDataset, LMPadCollate, LMDatasetSFT, LMPadCollateSFT


predict_everything = False
if predict_everything:
    dataset_class = LMDataset
    padcollate_class = LMPadCollate
else:
    dataset_class = LMDatasetSFT
    padcollate_class = LMPadCollateSFT

Another way to reduce memory footprint:

In [None]:
model.gradient_checkpointing_enable()

In [None]:
from models.sft_training import qa_sft_multirc


model = qa_sft_multirc(
    train_filename,
    val_filename,
    dataset_class,
    padcollate_class,
    model=model,
    tokenizer=tokenizer,
    model_name=model_name,
    run_name=run_name,
    project_name=project_name,
    batch_size=batch_size,
    device=device,
    epochs=epochs,
    lr=lr,
    lr_scheduler=lr_scheduler,
    int8_training=int8_training,
    autocast_training=autocast_training,
    lora_training=lora_training,
    lora_type=lora_type,
    eval_every_batch=eval_every_batch,
    save_every_epoch=save_every_epoch,
)