In [1]:
%load_ext nb_black

<IPython.core.display.Javascript object>

In [2]:
%env HF_DATASETS_CACHE="/data/users/sgarg6/hf_cache"

env: HF_DATASETS_CACHE="/data/users/sgarg6/hf_cache"


<IPython.core.display.Javascript object>

# Model Setup

In [4]:
import torch
from transformers import GPT2Tokenizer, GPT2ForSequenceClassification

tokenizer = GPT2Tokenizer.from_pretrained("microsoft/DialogRPT-updown")
model = GPT2ForSequenceClassification.from_pretrained("microsoft/DialogRPT-updown")

# default to left padding
tokenizer.padding_side = "left"
# Define PAD Token = EOS Token = 50256
tokenizer.pad_token = tokenizer.eos_token
# resize model embedding to match new tokenizer
model.resize_token_embeddings(len(tokenizer))
# fix model padding token id
model.config.pad_token_id = model.config.eos_token_id


<IPython.core.display.Javascript object>

In [5]:
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits

predicted_class_id = logits.argmax().item()

print(logits)

tensor([[-1.2981]])


<IPython.core.display.Javascript object>

# Load Data

In [6]:
from torch.utils.data import Dataset
from datasets import load_dataset


class AnthropicDataset(Dataset):
    def __init__(self, split="test"):
        assert split in ("train", "test")
        major_split = split if "train" == split else "test"
        dataset = load_dataset("Anthropic/hh-rlhf")[major_split]
        self.data = []
        for data in dataset:
            self.data.append((data["chosen"], 1))
            self.data.append((data["rejected"], 0))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample, label = self.data[index]

        return sample, label

<IPython.core.display.Javascript object>

In [7]:
train_data = AnthropicDataset("train")
test_data = AnthropicDataset("test")

Using custom data configuration Anthropic--hh-rlhf-c8cd8dc58ab67414
Found cached dataset json (/soe/sgarg6/course_work/244_nlp/LLMbias/"/data/users/sgarg6/hf_cache"/Anthropic___json/Anthropic--hh-rlhf-c8cd8dc58ab67414/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)


  0%|          | 0/2 [00:00<?, ?it/s]

Using custom data configuration Anthropic--hh-rlhf-c8cd8dc58ab67414
Found cached dataset json (/soe/sgarg6/course_work/244_nlp/LLMbias/"/data/users/sgarg6/hf_cache"/Anthropic___json/Anthropic--hh-rlhf-c8cd8dc58ab67414/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)


  0%|          | 0/2 [00:00<?, ?it/s]

<IPython.core.display.Javascript object>

In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"

<IPython.core.display.Javascript object>

In [9]:
def collate_data(data):
    text_samples = [sample[0] for sample in data]
    labels = torch.tensor([sample[1] for sample in data], dtype=torch.float).to(device)
    tokens = tokenizer(
        text_samples,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=512,
    ).to(device)
    return tokens, labels

<IPython.core.display.Javascript object>

In [10]:
from torch.utils.data import DataLoader

# train_dataloader = DataLoader(train_data, collate_fn=collate_data, batch_size=2)

<IPython.core.display.Javascript object>

# Setup Hyper Params and Model Logging

In [11]:
best_val_loss = 99999
BATCH_SIZE = 8
EPOCHS = 1
learning_rate = 0.01

<IPython.core.display.Javascript object>

In [12]:
import wandb

wandb.init(
    entity="sugam110795",
    project="nlp244",
    group="LLMbias",
    config={
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "lr": learning_rate,
    },
)

[34m[1mwandb[0m: Currently logged in as: [33msugam110795[0m. Use [1m`wandb login --relogin`[0m to force relogin


<IPython.core.display.Javascript object>

# Model Train Setup

In [13]:
def free_memory():
    import gc

    torch.cuda.empty_cache()
    gc.collect()

<IPython.core.display.Javascript object>

In [14]:
from tqdm import tqdm
import time
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score


def evaluate(model, data_loader, criterion):
    model.eval()
    total_loss = 0.0
    start_time = time.time()
    out = []
    label = []
    for batch, (X, y) in enumerate(tqdm(data_loader)):
        model.zero_grad()
        output = model(**X).logits
        loss = criterion(output.reshape(-1), y)
        total_loss += loss.item()
        pred = torch.sigmoid(output) > 0.5
        out.extend(pred.long().detach().tolist())
        label.extend(y.long().detach().tolist())
        del X, y, output, loss
        free_memory()
    acc = accuracy_score(label, out)
    f1 = f1_score(label, out)
    return total_loss / batch, acc, f1

<IPython.core.display.Javascript object>

In [15]:
from tqdm import tqdm
import time


def train_step(
    data_loader, model, epoch, criterion, optimizer, eval_step, eval_data_loader, lr
):
    model.train()
    total_loss = 0.0
    start_time = time.time()
    prev_val_loss = 100
    for batch, (X, y) in tqdm(enumerate(data_loader)):
        model.zero_grad()
        output = model(**X).logits
        loss = criterion(output.reshape(-1), y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        total_loss += loss.item()
        metrics = {"train/train_loss": loss.item(), "train/step": (batch + 1)}
        wandb.log(metrics)
        if batch % eval_step == 0 and batch > 0:
            cur_loss = loss.item()
            elapsed = time.time() - start_time
            print("| epoch {:3d} |" " loss {:5.2f}".format(epoch, cur_loss))
            val_loss, acc, f1 = evaluate(model, eval_data_loader, criterion)
            if batch > 5000 and val_loss > prev_val_loss:
                for param_group in optimizer.param_groups:
                    param_group["lr"] = lr * 0.1
                    lr *= 0.1
            prev_val_loss = val_loss
            wandb.log(
                {
                    "eval/eval_loss": val_loss,
                    "eval/acc": acc,
                    "eval/f1": f1,
                }
            )
        del loss, X, y, output
        free_memory()
        if lr < 0.00001:
            # Early Stopping
            break
    return total_loss / batch

<IPython.core.display.Javascript object>

In [16]:
config = wandb.config

Exception in thread SystemMonitor:
Traceback (most recent call last):
  File "/soe/sgarg6/conda/envs/nlp_env/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/soe/sgarg6/conda/envs/nlp_env/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/soe/sgarg6/conda/envs/nlp_env/lib/python3.10/site-packages/wandb/sdk/internal/system/system_monitor.py", line 118, in _start
    asset.start()
  File "/soe/sgarg6/conda/envs/nlp_env/lib/python3.10/site-packages/wandb/sdk/internal/system/assets/cpu.py", line 166, in start
    self.metrics_monitor.start()
  File "/soe/sgarg6/conda/envs/nlp_env/lib/python3.10/site-packages/wandb/sdk/internal/system/assets/interfaces.py", line 168, in start
    logger.info(f"Started {self._process.name}")
AttributeError: 'NoneType' object has no attribute 'name'


<IPython.core.display.Javascript object>

# Train Model 

In [None]:
from torch.optim import Adam
import torch.nn as nn
import torch.utils.data as data_utils


model = model.to(device)
optimizer = Adam(model.parameters(), lr=config.lr)
criterion = nn.BCEWithLogitsLoss()

for epoch in range(config.epochs):
    epoch_start_time = time.time()
    train_dataloader = DataLoader(
        train_data, batch_size=config.batch_size, shuffle=True, collate_fn=collate_data
    )
    valid_dataloader = DataLoader(
        test_data, batch_size=config.batch_size, shuffle=True, collate_fn=collate_data
    )
    train_loss = train_step(
        train_dataloader,
        model,
        epoch,
        criterion,
        optimizer,
        500,
        valid_dataloader,
        config.lr,
    )

    # End of training
    val_loss, acc, f1 = evaluate(model, valid_dataloader, criterion)
    print("-" * 89)
    print(
        "| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | "
        " acc {:5.2f}".format(epoch, (time.time() - epoch_start_time), val_loss, acc)
    )
    print(f"F1-score is {f1}")
    print("-" * 89)
    # Save the model if the validation loss is the best we've seen so far.
    if val_loss < best_val_loss:
        with open("/data/users/sgarg6/trained_models/gpt-reward/model.pt", "wb") as f:
            torch.save(model, f)
        best_val_loss = val_loss

260it [06:33,  1.60s/it]

In [None]:
val_loss, acc, f1 = evaluate(model, valid_dataloader, criterion)

In [None]:
wandb.summary['val_loss'] = val_loss
wandb.summary['val_acc'] = acc
wandb.summary['val_f1'] = f1

In [None]:
print(val_loss, acc, f1)

In [None]:
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt").to(device)

with torch.no_grad():
    logits = model(**inputs).logits

print(torch.sigmoiod(logits))
predicted_class_id = logits.argmax().item()
print(predicted_class_id)

In [None]:
model.push_to_hub("sugam11/gpt2-rlhf-reward")

In [None]:
wandb.finish()