In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"]="1,2"

In [2]:
model_id = "meta-llama/Llama-2-7b-hf"

In [3]:
import torch
from transformers import AutoModelForCausalLM

llama = AutoModelForCausalLM.from_pretrained(
    model_id,
    load_in_8bit=True,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
from transformers import AutoTokenizer
from datasets import load_dataset

dataset = load_dataset("timdettmers/openassistant-guanaco")

MAX_SEQ_LEN = 512
BATCH_SIZE = 1

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.add_special_tokens({"pad_token":"<pad>"})
tokenizer.padding_side = 'left'

def tokenize_fn(element):
    outputs = tokenizer(
        element["text"],
        truncation=True,
        padding=False,
        max_length=MAX_SEQ_LEN,
        return_overflowing_tokens=False,
        return_length=False,
    )
    return {
        "input_ids": outputs["input_ids"],
        "attention_mask": outputs["attention_mask"]
    }

Repo card metadata block was not found. Setting CardData to empty.


In [5]:
from transformers import DataCollatorForLanguageModeling

train_dataset = dataset["train"]
eval_dataset = dataset["test"]

tokenized_dataset_train = train_dataset.map(
    tokenize_fn,
    batched=True,
    remove_columns=train_dataset.column_names,
    batch_size=BATCH_SIZE,
)
tokenized_dataset_eval = eval_dataset.shuffle(seed=42).map(
    tokenize_fn,
    batched=True,
    remove_columns=eval_dataset.column_names,
    batch_size=BATCH_SIZE,
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
train_dataloader = torch.utils.data.DataLoader(
    tokenized_dataset_train, 
    shuffle=True, 
    collate_fn=data_collator, 
    batch_size=BATCH_SIZE,
)
eval_dataloader = torch.utils.data.DataLoader(
    tokenized_dataset_eval.select(range(100)), 
    collate_fn=data_collator, 
    batch_size=BATCH_SIZE,
)

In [6]:
import bitsandbytes as bnb
from transformers import get_scheduler

LEARNING_RATE = 1.41e-5
SCHEDULER = "linear"

optimizer = bnb.optim.AdamW8bit(llama.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.995)) # add bnb optimizer

lr_scheduler = get_scheduler(
    name=SCHEDULER,
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=BATCH_SIZE,
)

In [None]:
import tqdm

NUM_EPOCHS = 2

progress_bar = tqdm.tqdm(range(NUM_EPOCHS  * len(train_dataloader)))

llama.train()

for _ in range(2):
    total_loss = 0.
    l10b_loss = 0.
    l10b_avg_loss = 0.

    for step, batch in enumerate(train_dataloader):
        
        outputs = llama(**batch)
        loss = outputs.loss
        
        total_loss += loss.detach().float()
        l10b_loss += loss.detach().float()

        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        epoch_avg_loss = total_loss.item() / (step + 1)

        progress_bar.update(1)
        if step % 10 == 0:
            l10b_avg_loss = l10b_loss.item() / (10)
            l10b_loss = 0.
            
        progress_bar.set_description(
            f"Epoch Loss: {epoch_avg_loss :0.2f} // L10B Loss: {l10b_avg_loss :0.2f}"
        )