# Llama 3 Fine Tuning

In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import os
import pandas as pd
import pickle
import transformers
import torch

from eval import DATA

pd.set_option("display.max_columns", 100)

DEVICE = torch.device(
    "mps" if torch.backends.mps.is_available()
    else (
        "cuda" if torch.cuda.is_available()
        else "cpu"
    )
)

RESULTS_PATH = "./results/"
MODELS_PATH = os.getenv("HOME") + "/models/fine_tuned/llama3/"

mbpp = DATA["mbpp"]  # train, validation, and test
humaneval = DATA["openai_humaneval"]  # test only

## Helpers

In [2]:
def save_pickle(object, to):
    with open(to, "wb") as f:
        pickle.dump(object, f)

def load_pickle(from_):
    with open(from_, "rb") as f:
        return pickle.load(f)

def get_canonical_solutions(dataset, split):
    assert dataset in ("mbpp", "openai_humaneval")

    if dataset == "mbpp":
        mbpp = DATA["mbpp"]
        return [
            "# " + task["text"] + "\n" + task["code"] for task in mbpp[split]
        ]
    else:  # humaneval
        humaneval = DATA["openai_humaneval"]
        return [
            task["prompt"] + task["canonical_solution"] for task in humaneval["test"]
        ]


## Load Model

In [3]:
# Load Llama 3
pipeline = transformers.pipeline(
    "text-generation",
    model="meta-llama/Meta-Llama-3-8B-Instruct",
    # NOTE: BFloat16 is not supported on MPS, so using Float16
    # model_kwargs={"torch_dtype": torch.bfloat16},
    model_kwargs={"torch_dtype": torch.float16},
    device=DEVICE
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


## Prepare Data

### First Time (One Time Only: No Need to Run This Section)

In [None]:
canonical_solutions = get_canonical_solutions("mbpp", "train")

In [None]:
df = pd.read_csv(RESULTS_PATH + "mbpp_train_claude_3_haiku_0_shot_v4_prompt.csv", index_col=0)
assert len(canonical_solutions) == len(df)
df

In [None]:
passed_df = df.query("passed_tests")
passed_code = dict(zip(passed_df["task_id"], passed_df["code"]))
print("n =", len(passed_code))

In [None]:
SYSTEM_PROMPT = """
You are the world's best AI coding assistant. In particular, you are exceptionally skilled at refactoring Python programs to be readable, efficient, and maintainable.

For the interaction that follows, refactor the Python code provided by the user to be more readable, efficient, and maintainable using the following guidelines:
 - The given program is correct but needs improvement
 - DO NOT change the name of the program
 - DO NOT change the input or output behavior of the program (e.g. number of inputs / outputs, input / output types, etc.)
 - Put your response in a markdown code block
 - Respond with only the code block
 - Don't explain the changes made
 - If you use any packages (e.g. `os`, `re`, `sys`), don't forget to import them

Again, do not change the name of the function in any way!
""".strip()

CODE_BLOCK = """
```python
{code}
```
""".strip()

def format_as_prompt(system_prompt, user_prompt, model_response):
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
        {"role": "assistant", "content": model_response}
    ]

    return pipeline.tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )[:-len("<|start_header_id|>assistant<|end_header_id|>  ")]

print(
    format_as_prompt(
        system_prompt=SYSTEM_PROMPT,
        user_prompt=CODE_BLOCK.format(code=canonical_solutions[0]),
        model_response=CODE_BLOCK.format(code=df["code"][0])
    )
)

In [None]:
query_response_strings = [
    format_as_prompt(
        system_prompt=SYSTEM_PROMPT,
        user_prompt=CODE_BLOCK.format(code=canonical_solutions[i]),
        model_response=CODE_BLOCK.format(code=code)
    )
    for i, code in passed_code.items()
]
print(len(query_response_strings))
print(query_response_strings[0])

In [None]:
fine_tuning_examples = pd.DataFrame({"example": query_response_strings})
fine_tuning_examples

In [None]:
# fine_tuning_examples.to_csv(RESULTS_PATH + "fine_tuning_examples.csv", index=False)

### Every Time

In [4]:
fine_tuning_examples = pd.read_csv(RESULTS_PATH + "fine_tuning_examples.csv")
fine_tuning_examples

Unnamed: 0,example
0,<|begin_of_text|><|start_header_id|>system<|en...
1,<|begin_of_text|><|start_header_id|>system<|en...
2,<|begin_of_text|><|start_header_id|>system<|en...
3,<|begin_of_text|><|start_header_id|>system<|en...
4,<|begin_of_text|><|start_header_id|>system<|en...
...,...
276,<|begin_of_text|><|start_header_id|>system<|en...
277,<|begin_of_text|><|start_header_id|>system<|en...
278,<|begin_of_text|><|start_header_id|>system<|en...
279,<|begin_of_text|><|start_header_id|>system<|en...


In [5]:
X = fine_tuning_examples["example"].apply(
    lambda example:
        pipeline.tokenizer.encode(example)[1:]
        + [pipeline.tokenizer.eos_token_id]
).apply(torch.tensor)

max_sequence_length = X.apply(len).max()
print(f"{max_sequence_length=}")

# Pad sequences with EOS token
X = X.apply(
    lambda vec:
        torch.cat([
            vec,
            torch.tensor(
                [pipeline.tokenizer.eos_token_id]
                * (max_sequence_length - len(vec))
            )
        ])
        if len(vec) < max_sequence_length
        else vec
)
# Covert entire list into one tensor
X = torch.stack(X.tolist())

print(X.shape)
Y = X[:, 1:]
X = X[:, :-1]

X = X.to(DEVICE)
Y = Y.to(DEVICE)

print(X.shape)
print(Y.shape)


max_sequence_length=740
torch.Size([281, 740])
torch.Size([281, 739])
torch.Size([281, 739])


## Fine Tuning

### Helpers

In [6]:
# For use with torch DataLoader:
class TrainingData(torch.utils.data.Dataset):
    def __init__(self, X, Y):
        global DEVICE
        assert X.shape == Y.shape
        self.X = X
        self.Y = Y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return self.X[i], self.Y[i]

def _next_X_y(data_loader):
    global DEVICE
    X, y = next(iter(data_loader))
    X = X.to(DEVICE)
    y = y.to(DEVICE)
    return X, y

def _train_on_batch(model, X, y, optimizer):
    # Put model into training mode:
    model.train()

    # Do forward pass and evaluate loss
    loss = model(X, labels=y).loss

    # Backpropagation
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss

def _train_for_n_batches(
    model,
    train_data_loader,
    optimizer,
    batches_to_run,
    verbose,
    print_every
):
    batches_run = 0
    for _ in range(batches_to_run):
        X, y = _next_X_y(train_data_loader)
        loss = _train_on_batch(model, X, y, optimizer)
        batches_run += 1

        # Reporting
        if verbose and batches_run % print_every == 0:
            print(f"Batch {batches_run}: loss =", loss.item())

def _train_for_n_epoches(
    model,
    train_data_loader,
    optimizer,
    epochs_to_run,
    verbose,
    print_every,
    save_after_epoch=False,
    save_as="model.pt",
    starting_epoch=1
):
    global DEVICE
    batches_run = 0
    losses = []
    for i in range(epochs_to_run):
        # For each batch:
        for X, y in train_data_loader:
            X = X.to(DEVICE)
            y = y.to(DEVICE)

            loss = _train_on_batch(model, X, y, optimizer)
            batches_run += 1
            losses.append(loss.item())
            losses = losses[-1000:]  # Keep max 1000 losses

            # Reporting
            if verbose and batches_run % print_every == 0:
                print(f"Batch {batches_run}: loss =", loss.item())

        # Print after every epoch
        if i == 0:
            print("=" * 20)
        print(f"After {i + starting_epoch} epoch(s):")
        print("  loss =", np.mean(losses[-30:]))

        if save_after_epoch:
            torch.save(
                model.state_dict(),
                save_as.format(epoch_number=i + starting_epoch)
            )

    print("=" * 20)

def train(
    model,
    train_data,
    optimizer="Adam",
    epochs_to_run=None,  # Train for 1 epoch if no training limit is given
    batches_to_run=None,
    batch_size=128,
    learning_rate=1e-3,
    verbose=False,
    print_every=100,
    save_after_epoch=False,
    save_as="model.pt",
    starting_epoch=1,
    **kwargs
):
    # Initialize train /test DataLoaders:
    train_data_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True
    )

    # Initialize Optimizer:
    if type(optimizer) == str:
        assert(optimizer in torch.optim.__dict__), (
            "optimizer must be one of the optimizers available in the " +
            "torch.optim module, e.g. 'Adam'"
        )
        optimizer = torch.optim.__dict__[optimizer]
    optimizer = optimizer(model.parameters(), lr=learning_rate)

    if not epochs_to_run and not batches_to_run:
        epochs_to_run = 1

    if epochs_to_run:
        _train_for_n_epoches(
            model,
            train_data_loader,
            optimizer,
            epochs_to_run,
            verbose,
            print_every,
            save_after_epoch,
            save_as,
            starting_epoch
        )
    else:  # if batches_to_run:
        _train_for_n_batches(
            model,
            train_data_loader,
            optimizer,
            batches_to_run,
            verbose,
            print_every
        )


### Do the Fine Tuning!

In [7]:
train(
    model=pipeline.model,
    train_data=TrainingData(X, Y),
    # Adam leads to NaN / infinite model weights after first gradient step:
    optimizer="SGD",
    epochs_to_run=5,
    batch_size=2,  # Needs to be very small to not run out of memory
    # Learning Rates:
    # 1e-5 = Too high; very bad results ("scrambles" model's brain)
    # 1e-6 = Okayish; qualitatively reasonable results, but compilability and
    # test pass rate worsens on average with each additional epoch of finetuning
    # 1e-7 = TBD
    learning_rate=1e-7,
    verbose=True,
    print_every=10,
    save_after_epoch=True,
    save_as=MODELS_PATH + "fine_tuned_llama3_smaller_LR_after_{epoch_number}_epoch.pt",
    starting_epoch=1
)

Batch 10: loss = 27.30759620666504
Batch 20: loss = 25.67279052734375
Batch 30: loss = 23.840852737426758
Batch 40: loss = 22.456920623779297
Batch 50: loss = 20.551298141479492
Batch 60: loss = 19.559797286987305
Batch 70: loss = 18.30556869506836
Batch 80: loss = 18.561481475830078
Batch 90: loss = 18.1420841217041
Batch 100: loss = 17.304948806762695
Batch 110: loss = 16.80621910095215
Batch 120: loss = 16.02007293701172
Batch 130: loss = 15.240540504455566
Batch 140: loss = 14.348015785217285
After 1 epoch(s):
  loss = 15.570582103729247
Batch 150: loss = 13.365063667297363
Batch 160: loss = 12.771071434020996
Batch 170: loss = 11.247194290161133
Batch 180: loss = 11.675347328186035
Batch 190: loss = 9.92550277709961
Batch 200: loss = 9.5925931930542
Batch 210: loss = 9.69677448272705
Batch 220: loss = 13.689676284790039
Batch 230: loss = 9.222021102905273
Batch 240: loss = 7.9541850090026855
Batch 250: loss = 10.609424591064453
Batch 260: loss = 9.229236602783203
Batch 270: loss =

In [None]:
# Test Save:
# torch.save(pipeline.model.state_dict(), MODELS_PATH + "fine_tuned_llama3_smaller_LR_after_5_epoch.pt")

In [9]:
# Test Load:
# pipeline.model.load_state_dict(torch.load(MODELS_PATH + "fine_tuned_llama3_smaller_LR_after_5_epoch.pt"))

<All keys matched successfully>