# Finetune few layers of LLM + WandB integration

In this notebook, we will follow the tutorial series from WandB - [How to Fine-tune a LLM](https://wandb.ai/capecape/alpaca_ft/reports/How-to-Fine-Tune-an-LLM-Part-2-Instruction-Tuning-Llama-2--Vmlldzo1NjY0MjE1)

This covers:
1. Loading FT dataset from WandB
2. Load model and freeze bottom layers
3. Fine-tune model
4. Save model
5. Upload metrics and eval samples to WandB

## Dataset

Download processed dataset from WandB

In [5]:
import json
import wandb
from wandb import Api
from datasets import load_from_disk # for some reason load_dataset gives an error

# load the packed dataset
api = Api()
artifact = api.artifact('vijaygkd/alpaca_ft/alpaca_with_prompt_gpt4_splitted:v0', type='dataset')
artifact_dir = artifact.download()

def load_jsonl(file_path):
    data = []
    with open(file_path, 'r') as file:
        for line in file:
            data.append(json.loads(line))
    return data
    
train_dataset = load_jsonl(f"{artifact_dir}/alpaca_with_prompt_gpt4_train.jsonl")
eval_dataset = load_jsonl(f"{artifact_dir}/alpaca_with_prompt_gpt4_eval.jsonl")


Download packed dataset from WandB

In [6]:
import wandb
from wandb import Api
from datasets import load_from_disk # for some reason load_dataset gives an error

# load the packed dataset
api = Api()
artifact = api.artifact('capecape/alpaca_ft/packed_alpaca_hf:v0', type='dataset')
artifact_dir = artifact.download()
ds_packed = load_from_disk(artifact_dir)

# we are back where we started!
train_ds_packed = ds_packed["train"]
eval_ds_packed  = ds_packed["eval"]
max_seq_len = 1024


[34m[1mwandb[0m: Downloading large artifact packed_alpaca_hf:v0, 178.69MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:2.2


### Dataloader

In [7]:
from torch.utils.data import DataLoader
from transformers import default_data_collator


batch_size = 32  # I have an A100 GPU with 80GB of RAM 😎


train_dataloader = DataLoader(
    train_ds_packed,
    batch_size=batch_size,
    collate_fn=default_data_collator, # we don't need any special collator 😎
)


eval_dataloader = DataLoader(
    eval_ds_packed,
    batch_size=batch_size,
    collate_fn=default_data_collator,
    shuffle=False,
)


2024-03-27 22:23:09.456626: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-03-27 22:23:12.528441: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /anaconda/envs/azureml_py38/lib/:/anaconda/envs/azureml_py38/lib/:/anaconda/envs/azureml_py38/lib/:/anaconda/envs/azureml_py38/lib/
2024-03-27 22:23:12.528457: I tensorflow/compiler/xla/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
2024-03-27 22:23:18.089885: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loa

In [8]:
b = next(iter(train_dataloader))
b.keys(), b["input_ids"][0][:25], b["labels"][0][:25]

(dict_keys(['input_ids', 'labels']),
 tensor([    1, 13866,   338,   385, 15278,   393, 16612,   263,  3414, 29889,
         14350,   263,  2933,   393,  7128,  2486,  1614,  2167,   278,  2009,
         29889,    13,    13,  2277, 29937]),
 tensor([13866,   338,   385, 15278,   393, 16612,   263,  3414, 29889, 14350,
           263,  2933,   393,  7128,  2486,  1614,  2167,   278,  2009, 29889,
            13,    13,  2277, 29937,  2799]))

In [9]:
b['input_ids'].shape

torch.Size([32, 1024])

Note: the labels are left shifted by 1 for this tutorial as we are using simple training loop.LookupError

When using huggingface model trainer and using `ModelOutput.loss` the shifting is handled by the model automatically.

## Training

In [28]:
from types import SimpleNamespace


gradient_accumulation_steps = 32 // batch_size


config = SimpleNamespace(
    model_id='meta-llama/Llama-2-7b-hf',
    dataset_name="alpaca-gpt4",
    precision="bf16",  # faster and better than fp16, requires new GPUs
    n_freeze=24,  # How many layers we don't train, LLama 7B has 32.
    lr=2e-4,
    n_eval_samples=10, # How many samples to generate on validation
    max_seq_len=max_seq_len, # Length of the sequences to pack
    epochs=2,  # we do 3 pasess over the dataset.
    gradient_accumulation_steps=gradient_accumulation_steps,  # evey how many iterations we update the gradients, simulates larger batch sizes
    batch_size=batch_size,  # what my GPU can handle, depends on how many layers are we training  
    log_model=False,  # upload the model to W&B?
    mom=0.9, # optim param
    gradient_checkpointing = True,  # saves even more memory
    freeze_embed = True,  # why train this? let's keep them frozen ❄️
)


config.total_train_steps = config.epochs * len(train_dataloader) // config.gradient_accumulation_steps
config

namespace(model_id='meta-llama/Llama-2-7b-hf',
          dataset_name='alpaca-gpt4',
          precision='bf16',
          n_freeze=24,
          lr=0.0002,
          n_eval_samples=10,
          max_seq_len=1024,
          epochs=2,
          gradient_accumulation_steps=1,
          batch_size=32,
          log_model=False,
          mom=0.9,
          gradient_checkpointing=True,
          freeze_embed=True,
          total_train_steps=702)

### Model

In [11]:
import torch
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    config.model_id,
    device_map=0,
    trust_remote_code=True,
    low_cpu_mem_usage=True,
    torch_dtype=torch.bfloat16,
    use_cache=False,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [12]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(config.model_id)

Freeze model layers to save memory

In [13]:
# Llama-2 7B has 32 transformer layers
n_freeze = 24 # you can play with this parameter

# freeze layers (disable gradients)
for param in model.parameters(): param.requires_grad = False
for param in model.lm_head.parameters(): param.requires_grad = True
for param in model.model.layers[n_freeze:].parameters(): param.requires_grad = True     # unfreeze last 8 layers

In [14]:
# Just freeze embeddings for small memory decrease
if config.freeze_embed:
    model.model.embed_tokens.weight.requires_grad_(False)

# save more memory
if config.gradient_checkpointing:
    model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})  # <- pytorch changed this


In [16]:
def to_gpu(tensor_dict):
    return {k: v.to('cuda') for k, v in tensor_dict.items()}

    
from pathlib import Path
def save_model(model, model_name, models_folder="models", log=False):
    """Save the model to wandb as an artifact
    Args:
        model (nn.Module): Model to save.
        model_name (str): Name of the model.
        models_folder (str, optional): Folder to save the model. Defaults to "models".
    """
    model_name = f"{wandb.run.id}_{model_name}"
    file_name = Path(f"{models_folder}/{model_name}")
    file_name.parent.mkdir(parents=True, exist_ok=True)
    model.save_pretrained(file_name, safe_serialization=True)
    # save tokenizer for easy inference
    tokenizer = AutoTokenizer.from_pretrained(model.name_or_path)
    tokenizer.save_pretrained(model_name)
    if log:
        at = wandb.Artifact(model_name, type="model")
        at.add_dir(file_name)
        wandb.log_artifact(at)

### Optimizer and Scheduler

In [20]:
from transformers import get_cosine_schedule_with_warmup


optim = torch.optim.Adam(model.parameters(), lr=config.lr, betas=(0.9,0.99), eps=1e-5)
scheduler = get_cosine_schedule_with_warmup(
    optim,
    num_training_steps=config.total_train_steps,
    num_warmup_steps=config.total_train_steps // 10,
)


def loss_fn(x, y):
    "A Flat CrossEntropy" 
    return torch.nn.functional.cross_entropy(x.view(-1, x.shape[-1]), y.view(-1))


### Sampling and Eval

Generate text from model for evaluation and save to W&B

Note: This is super helpful for debugging and to check how the model is aligning with the dataset as the training progresses.

In [22]:
class Accuracy:
    "A simple token level Accuracy function compatible with HF models"
    def __init__(self):
        self.count = 0
        self.tp = 0.
    def update(self, logits, labels):
        logits, labels = logits.argmax(dim=-1).view(-1).cpu(), labels.view(-1).cpu()
        tp = (logits == labels).sum()
        self.count += len(logits)
        self.tp += tp
        return tp / len(logits)
    def compute(self):
        return self.tp / self.count

In [23]:
from transformers import GenerationConfig

gen_config = GenerationConfig.from_pretrained(config.model_id)      # default generation config
test_config = SimpleNamespace(
    max_new_tokens=256,
    gen_config=gen_config)


def generate(prompt, max_new_tokens=100, gen_config=gen_config):
    with torch.inference_mode():
        tokenized_prompt = tokenizer(prompt, return_tensors='pt')['input_ids'].cuda()
        output = model.generate(tokenized_prompt, 
                            max_new_tokens=max_new_tokens, 
                            generation_config=gen_config)
    return tokenizer.decode(output[0][len(tokenized_prompt[0]):], skip_special_tokens=True)


In [24]:
from tqdm.auto import tqdm

def prompt_table(examples, log=False, table_name="predictions"):
    table = wandb.Table(columns=["prompt", "generation", "concat", "output", "max_new_tokens", "temperature", "top_p"])
    for example in tqdm(examples, leave=False):
        prompt, gpt4_output = example["prompt"], example["output"]
        out = generate(prompt, test_config.max_new_tokens, test_config.gen_config)
        table.add_data(prompt, out, prompt+out, gpt4_output, test_config.max_new_tokens, test_config.gen_config.temperature, test_config.gen_config.top_p)
    if log:
        wandb.log({table_name:table})
    return table


### Validation

In [25]:
from tqdm.auto import tqdm


@torch.no_grad()
def validate():
    model.eval()
    eval_acc = Accuracy()
    for step, batch in enumerate(tqdm(eval_dataloader)):
        batch = to_gpu(batch)
        with torch.amp.autocast("cuda", dtype=torch.bfloat16):
            out = model(**batch)
            loss = loss_fn(out.logits, batch["labels"])  # you could use out.loss and not shift the dataset
        eval_acc.update(out.logits, batch["labels"])
    # we log results at the end
    wandb.log({"eval_loss": loss.item(),
               "eval_accuracy": eval_acc.compute()})
    
    # generate eval predictions
    prompt_table(eval_dataset[:config.n_eval_samples], log=True)
    
    model.train()


## Training Loop

In [29]:
wandb.init(project="alpaca_ft", # the project I am working on
           tags=["baseline","7b"],
           job_type="train",
           config=config) # the Hyperparameters I want to keep track of

# validate before training once
validate()

# Training
acc = Accuracy()
model.train()
train_step = 0
pbar = tqdm(total=config.total_train_steps)
for epoch in range(config.epochs):
    for step, batch in enumerate(train_dataloader):
        batch = to_gpu(batch)
        with torch.amp.autocast("cuda", dtype=torch.bfloat16):
            out = model(**batch)
            loss = loss_fn(out.logits, batch["labels"]) / config.gradient_accumulation_steps  # you could use out.loss and not shift the dataset  
            loss.backward()
        if step%config.gradient_accumulation_steps == 0:
            # we can log the metrics to W&B
            wandb.log({"train/loss": loss.item() * config.gradient_accumulation_steps,
                       "train/accuracy": acc.update(out.logits, batch["labels"]),
                       "train/learning_rate": scheduler.get_last_lr()[0],
                       "train/global_step": train_step})
            optim.step()
            scheduler.step()
            optim.zero_grad(set_to_none=True)
            train_step += 1
            pbar.update(1)
    validate()      # validate after each batch
pbar.close()

# we save the model checkpoint at the end
save_model(
	model, 
	model_name=config.model_id.replace("/", "_"), 
	models_folder=".model/", log=config.log_model
)
    
wandb.finish()


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/702 [00:00<?, ?it/s]

Exception in thread SockSrvRdThr:
Traceback (most recent call last):
  File "/anaconda/envs/azureml_py38/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/anaconda/envs/azureml_py38/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 99, in run
    sreq = self._sock_client.read_server_request()
  File "/anaconda/envs/azureml_py38/lib/python3.8/site-packages/wandb/sdk/lib/sock_client.py", line 274, in read_server_request
    data = self._read_packet_bytes()
  File "/anaconda/envs/azureml_py38/lib/python3.8/site-packages/wandb/sdk/lib/sock_client.py", line 248, in _read_packet_bytes
    rec = self._extract_packet_bytes()
  File "/anaconda/envs/azureml_py38/lib/python3.8/site-packages/wandb/sdk/lib/sock_client.py", line 230, in _extract_packet_bytes
    assert magic == ord("W")
AssertionError
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
  File "/anaconda/envs/azureml_py38/lib/python3.8/threading.py", line 932, in _bo

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

VBox(children=(Label(value='0.129 MB of 0.129 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
eval_accuracy,▁▁█
eval_loss,▃█▁
train/accuracy,▂▇▇▇▆▇▆▆▆▆▆▇▆▆▁▃▄▄▄▄▃▄▄▅▆▆▆▇▇█▇▇▇██▇▇▇█▇
train/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,▁▂▃▄▆▇█████████▇▇▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃
train/loss,▆▂▁▁▂▁▂▂▂▃▂▁▂▄█▇▆▆▆▆▇▆▅▅▄▃▄▂▂▂▂▂▂▂▁▂▃▂▂▂

0,1
eval_accuracy,0.7435
eval_loss,1.08667
train/accuracy,0.79082
train/global_step,701.0
train/learning_rate,6e-05
train/loss,0.80399


In [30]:
wandb.finish()