In [1]:
%load_ext nb_black

<IPython.core.display.Javascript object>

In [2]:
%env HF_DATASETS_CACHE="/data/users/sgarg6/hf_cache"

env: HF_DATASETS_CACHE="/data/users/sgarg6/hf_cache"


<IPython.core.display.Javascript object>

# Model Setup

In [3]:
import torch
from transformers import AutoTokenizer, GPT2ForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("microsoft/DialogRPT-updown")
model = GPT2ForSequenceClassification.from_pretrained("microsoft/DialogRPT-updown")



<IPython.core.display.Javascript object>

In [4]:
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits

predicted_class_id = logits.argmax().item()

print(logits)

tensor([[-1.2981]])


<IPython.core.display.Javascript object>

# Load Data

In [5]:
from torch.utils.data import Dataset
from datasets import load_dataset


class AnthropicDataset(Dataset):
    def __init__(self, split="test"):
        assert split in ("train", "test")
        major_split = split if "train" == split else "test"
        dataset = load_dataset("Anthropic/hh-rlhf")[major_split]
        self.data = []
        for data in dataset:
            self.data.append((data["chosen"], 1))
            self.data.append((data["rejected"], 0))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample, label = self.data[index]

        return sample, label

<IPython.core.display.Javascript object>

In [6]:
train_data = AnthropicDataset("train")
test_data = AnthropicDataset("test")

Using custom data configuration Anthropic--hh-rlhf-c8cd8dc58ab67414
Found cached dataset json (/soe/sgarg6/course_work/244_nlp/LLMbias/"/data/users/sgarg6/hf_cache"/Anthropic___json/Anthropic--hh-rlhf-c8cd8dc58ab67414/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)


  0%|          | 0/2 [00:00<?, ?it/s]

Using custom data configuration Anthropic--hh-rlhf-c8cd8dc58ab67414
Found cached dataset json (/soe/sgarg6/course_work/244_nlp/LLMbias/"/data/users/sgarg6/hf_cache"/Anthropic___json/Anthropic--hh-rlhf-c8cd8dc58ab67414/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)


  0%|          | 0/2 [00:00<?, ?it/s]

<IPython.core.display.Javascript object>

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"

<IPython.core.display.Javascript object>

In [8]:
def collate_data(data):
    text_samples = [sample[0] for sample in data]
    labels = torch.tensor([sample[1] for sample in data], dtype=torch.float).to(device)
    tokens = tokenizer(
        text_samples,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=512,
    ).to(device)
    return tokens, labels

<IPython.core.display.Javascript object>

In [9]:
from torch.utils.data import DataLoader

# train_dataloader = DataLoader(train_data, collate_fn=collate_data, batch_size=2)

<IPython.core.display.Javascript object>

# Setup Hyper Params and Model Logging

In [19]:
best_val_loss = 99999
BATCH_SIZE = 16
EPOCHS = 1
learning_rate = 0.01

<IPython.core.display.Javascript object>

In [11]:
import wandb

wandb.init(
    entity="sugam110795",
    project="nlp244",
    group="LLMbias",
    config={
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "lr": learning_rate,
    },
)

[34m[1mwandb[0m: Currently logged in as: [33msugam110795[0m. Use [1m`wandb login --relogin`[0m to force relogin


<IPython.core.display.Javascript object>

# Model Train Setup

In [12]:
def free_memory():
    import gc

    torch.cuda.empty_cache()
    gc.collect()

<IPython.core.display.Javascript object>

In [13]:
from tqdm import tqdm
import time
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score


def evaluate(model, data_loader, criterion):
    model.eval()
    total_loss = 0.0
    start_time = time.time()
    out = []
    label = []
    for batch, (X, y) in enumerate(tqdm(data_loader)):
        model.zero_grad()
        output = model(**X).logits
        loss = criterion(output.reshape(-1), y)
        total_loss += loss.item()
        pred = torch.sigmoid(output) > 0.5
        out.extend(pred.long().detach().tolist())
        label.extend(y.long().detach().tolist())
        del X, y, output, loss
        free_memory()
    acc = accuracy_score(label, out)
    f1 = f1_score(label, out)
    return total_loss / batch, acc, f1

<IPython.core.display.Javascript object>

In [18]:
from tqdm import tqdm
import time


def train_step(
    data_loader, model, epoch, criterion, optimizer, eval_step, eval_data_loader, lr
):
    model.train()
    total_loss = 0.0
    start_time = time.time()
    prev_val_loss = 100
    for batch, (X, y) in tqdm(enumerate(data_loader)):
        model.zero_grad()
        output = model(**X).logits
        loss = criterion(output.reshape(-1), y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        total_loss += loss.item()
        metrics = {"train/train_loss": loss.item(), "train/step": (batch + 1)}
        wandb.log(metrics)
        if batch % eval_step == 0 and batch > 0:
            cur_loss = loss.item()
            elapsed = time.time() - start_time
            print("| epoch {:3d} |" " loss {:5.2f}".format(epoch, cur_loss))
            val_loss, acc, f1 = evaluate(model, eval_data_loader, criterion)
            metrics = {
                "train/train_loss": total_loss / (batch),
                "train/step": batch,
                "eval/eval_loss": val_loss,
                "eval/acc": acc,
                "eval/f1": f1,
            }
            if val_loss > prev_val_loss:
                for param_group, lr in zip(self.optimizer.param_groups, values):
                    param_group["lr"] = lr * 0.1
                    lr *= 0.1
            prev_val_loss = val_loss
            wandb.log(metrics)
        del loss, X, y, output
        free_memory()
    return total_loss / batch

<IPython.core.display.Javascript object>

In [15]:
config = wandb.config

<IPython.core.display.Javascript object>

# Train Model 

In [None]:
from torch.optim import Adam
import torch.nn as nn


model = model.to(device)
optimizer = Adam(model.parameters(), lr=config.lr)
criterion = nn.BCEWithLogitsLoss()

for epoch in range(config.epochs):
    epoch_start_time = time.time()
    train_dataloader = DataLoader(
        train_data, batch_size=config.batch_size, shuffle=True, collate_fn=collate_data
    )
    valid_dataloader = DataLoader(
        test_data, batch_size=config.batch_size, shuffle=True, collate_fn=collate_data
    )
    train_loss = train_step(
        train_dataloader, model, epoch, criterion, optimizer, 500, valid_dataloader, config.lr
    )

    # End of training
    val_loss, acc, f1 = evaluate(model, valid_dataloader, criterion)
    print("-" * 89)
    print(
        "| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | "
        " acc {:5.2f}".format(epoch, (time.time() - epoch_start_time), val_loss, acc)
    )
    print(f"F1-score is {f1}")
    print("-" * 89)
    # Save the model if the validation loss is the best we've seen so far.
    if val_loss < best_val_loss:
        with open("/data/users/sgarg6/trained_models/gpt-reward/model.pt", "wb") as f:
            torch.save(model, f)
        best_val_loss = val_loss

500it [12:00,  1.54s/it]

| epoch   0 | loss  0.72



  0%|                                                         | 0/1069 [00:00<?, ?it/s][A
  0%|                                                 | 1/1069 [00:00<13:23,  1.33it/s][A
  0%|                                                 | 2/1069 [00:01<13:01,  1.36it/s][A
  0%|▏                                                | 3/1069 [00:02<11:30,  1.54it/s][A
  0%|▏                                                | 4/1069 [00:02<11:30,  1.54it/s][A
  0%|▏                                                | 5/1069 [00:03<11:30,  1.54it/s][A
  1%|▎                                                | 6/1069 [00:04<12:32,  1.41it/s][A
  1%|▎                                                | 7/1069 [00:04<12:08,  1.46it/s][A
  1%|▎                                                | 8/1069 [00:05<13:09,  1.34it/s][A
  1%|▍                                                | 9/1069 [00:06<12:29,  1.41it/s][A
  1%|▍                                               | 10/1069 [00:06<12:27,  1.42it/s][

  8%|████                                            | 90/1069 [01:02<12:13,  1.33it/s][A
  9%|████                                            | 91/1069 [01:02<11:59,  1.36it/s][A
  9%|████▏                                           | 92/1069 [01:03<12:26,  1.31it/s][A
  9%|████▏                                           | 93/1069 [01:04<13:04,  1.24it/s][A
  9%|████▏                                           | 94/1069 [01:05<12:40,  1.28it/s][A
  9%|████▎                                           | 95/1069 [01:06<12:08,  1.34it/s][A
  9%|████▎                                           | 96/1069 [01:06<11:49,  1.37it/s][A
  9%|████▎                                           | 97/1069 [01:07<11:11,  1.45it/s][A
  9%|████▍                                           | 98/1069 [01:07<10:07,  1.60it/s][A
  9%|████▍                                           | 99/1069 [01:08<10:41,  1.51it/s][A
  9%|████▍                                          | 100/1069 [01:09<11:15,  1.43it/s][A

 17%|███████▉                                       | 180/1069 [02:06<10:19,  1.44it/s][A
 17%|███████▉                                       | 181/1069 [02:07<10:38,  1.39it/s][A
 17%|████████                                       | 182/1069 [02:07<10:50,  1.36it/s][A
 17%|████████                                       | 183/1069 [02:08<10:47,  1.37it/s][A
 17%|████████                                       | 184/1069 [02:09<10:37,  1.39it/s][A
 17%|████████▏                                      | 185/1069 [02:09<10:28,  1.41it/s][A
 17%|████████▏                                      | 186/1069 [02:10<10:18,  1.43it/s][A
 17%|████████▏                                      | 187/1069 [02:11<10:19,  1.42it/s][A
 18%|████████▎                                      | 188/1069 [02:12<10:11,  1.44it/s][A
 18%|████████▎                                      | 189/1069 [02:12<10:05,  1.45it/s][A
 18%|████████▎                                      | 190/1069 [02:13<10:00,  1.46it/s][A

 25%|███████████▊                                   | 270/1069 [03:10<07:40,  1.73it/s][A
 25%|███████████▉                                   | 271/1069 [03:10<08:16,  1.61it/s][A
 25%|███████████▉                                   | 272/1069 [03:11<08:50,  1.50it/s][A
 26%|████████████                                   | 273/1069 [03:12<09:14,  1.44it/s][A
 26%|████████████                                   | 274/1069 [03:13<09:34,  1.38it/s][A
 26%|████████████                                   | 275/1069 [03:13<09:44,  1.36it/s][A
 26%|████████████▏                                  | 276/1069 [03:14<10:04,  1.31it/s][A
 26%|████████████▏                                  | 277/1069 [03:15<09:57,  1.33it/s][A
 26%|████████████▏                                  | 278/1069 [03:16<09:52,  1.34it/s][A
 26%|████████████▎                                  | 279/1069 [03:17<10:10,  1.29it/s][A
 26%|████████████▎                                  | 280/1069 [03:17<09:32,  1.38it/s][A

 34%|███████████████▊                               | 360/1069 [04:13<08:06,  1.46it/s][A
 34%|███████████████▊                               | 361/1069 [04:14<07:42,  1.53it/s][A
 34%|███████████████▉                               | 362/1069 [04:14<07:35,  1.55it/s][A
 34%|███████████████▉                               | 363/1069 [04:15<07:31,  1.56it/s][A
 34%|████████████████                               | 364/1069 [04:16<07:42,  1.52it/s][A
 34%|████████████████                               | 365/1069 [04:16<07:46,  1.51it/s][A
 34%|████████████████                               | 366/1069 [04:17<08:08,  1.44it/s][A
 34%|████████████████▏                              | 367/1069 [04:18<08:33,  1.37it/s][A
 34%|████████████████▏                              | 368/1069 [04:19<07:43,  1.51it/s][A
 35%|████████████████▏                              | 369/1069 [04:19<08:42,  1.34it/s][A
 35%|████████████████▎                              | 370/1069 [04:20<08:49,  1.32it/s][A

 42%|███████████████████▊                           | 450/1069 [05:16<07:28,  1.38it/s][A
 42%|███████████████████▊                           | 451/1069 [05:17<07:36,  1.35it/s][A
 42%|███████████████████▊                           | 452/1069 [05:18<07:35,  1.35it/s][A
 42%|███████████████████▉                           | 453/1069 [05:19<07:53,  1.30it/s][A
 42%|███████████████████▉                           | 454/1069 [05:19<07:43,  1.33it/s][A
 43%|████████████████████                           | 455/1069 [05:20<07:16,  1.41it/s][A
 43%|████████████████████                           | 456/1069 [05:21<07:23,  1.38it/s][A
 43%|████████████████████                           | 457/1069 [05:21<06:44,  1.51it/s][A
 43%|████████████████████▏                          | 458/1069 [05:22<06:38,  1.53it/s][A
 43%|████████████████████▏                          | 459/1069 [05:23<07:07,  1.43it/s][A
 43%|████████████████████▏                          | 460/1069 [05:24<07:15,  1.40it/s][A

 51%|███████████████████████▋                       | 540/1069 [06:20<05:44,  1.53it/s][A
 51%|███████████████████████▊                       | 541/1069 [06:21<05:53,  1.49it/s][A
 51%|███████████████████████▊                       | 542/1069 [06:22<06:11,  1.42it/s][A
 51%|███████████████████████▊                       | 543/1069 [06:23<06:18,  1.39it/s][A
 51%|███████████████████████▉                       | 544/1069 [06:24<06:48,  1.28it/s][A
 51%|███████████████████████▉                       | 545/1069 [06:24<06:57,  1.26it/s][A
 51%|████████████████████████                       | 546/1069 [06:25<06:51,  1.27it/s][A
 51%|████████████████████████                       | 547/1069 [06:26<06:50,  1.27it/s][A
 51%|████████████████████████                       | 548/1069 [06:27<06:52,  1.26it/s][A
 51%|████████████████████████▏                      | 549/1069 [06:28<06:44,  1.28it/s][A
 51%|████████████████████████▏                      | 550/1069 [06:28<06:19,  1.37it/s][A

 59%|███████████████████████████▋                   | 630/1069 [07:24<05:23,  1.36it/s][A
 59%|███████████████████████████▋                   | 631/1069 [07:25<05:27,  1.34it/s][A
 59%|███████████████████████████▊                   | 632/1069 [07:26<05:27,  1.34it/s][A
 59%|███████████████████████████▊                   | 633/1069 [07:27<05:27,  1.33it/s][A
 59%|███████████████████████████▊                   | 634/1069 [07:27<05:21,  1.35it/s][A
 59%|███████████████████████████▉                   | 635/1069 [07:28<05:21,  1.35it/s][A
 59%|███████████████████████████▉                   | 636/1069 [07:29<05:37,  1.28it/s][A
 60%|████████████████████████████                   | 637/1069 [07:30<05:24,  1.33it/s][A
 60%|████████████████████████████                   | 638/1069 [07:30<05:28,  1.31it/s][A
 60%|████████████████████████████                   | 639/1069 [07:31<05:26,  1.32it/s][A
 60%|████████████████████████████▏                  | 640/1069 [07:32<05:05,  1.40it/s][A

 67%|███████████████████████████████▋               | 720/1069 [08:27<03:52,  1.50it/s][A
 67%|███████████████████████████████▋               | 721/1069 [08:27<03:53,  1.49it/s][A
 68%|███████████████████████████████▋               | 722/1069 [08:28<04:12,  1.37it/s][A
 68%|███████████████████████████████▊               | 723/1069 [08:29<03:57,  1.45it/s][A
 68%|███████████████████████████████▊               | 724/1069 [08:30<03:58,  1.45it/s][A
 68%|███████████████████████████████▉               | 725/1069 [08:30<03:52,  1.48it/s][A
 68%|███████████████████████████████▉               | 726/1069 [08:31<04:03,  1.41it/s][A
 68%|███████████████████████████████▉               | 727/1069 [08:32<04:12,  1.35it/s][A
 68%|████████████████████████████████               | 728/1069 [08:32<04:05,  1.39it/s][A
 68%|████████████████████████████████               | 729/1069 [08:33<04:04,  1.39it/s][A
 68%|████████████████████████████████               | 730/1069 [08:34<04:30,  1.25it/s][A

 76%|███████████████████████████████████▌           | 810/1069 [09:30<02:50,  1.52it/s][A
 76%|███████████████████████████████████▋           | 811/1069 [09:31<03:00,  1.43it/s][A
 76%|███████████████████████████████████▋           | 812/1069 [09:32<03:06,  1.38it/s][A
 76%|███████████████████████████████████▋           | 813/1069 [09:33<03:04,  1.38it/s][A
 76%|███████████████████████████████████▊           | 814/1069 [09:33<03:03,  1.39it/s][A
 76%|███████████████████████████████████▊           | 815/1069 [09:34<03:01,  1.40it/s][A
 76%|███████████████████████████████████▉           | 816/1069 [09:35<02:49,  1.50it/s][A
 76%|███████████████████████████████████▉           | 817/1069 [09:35<02:52,  1.46it/s][A
 77%|███████████████████████████████████▉           | 818/1069 [09:36<03:04,  1.36it/s][A
 77%|████████████████████████████████████           | 819/1069 [09:37<03:03,  1.36it/s][A
 77%|████████████████████████████████████           | 820/1069 [09:38<03:02,  1.36it/s][A

 84%|███████████████████████████████████████▌       | 900/1069 [10:35<02:01,  1.39it/s][A
 84%|███████████████████████████████████████▌       | 901/1069 [10:35<01:59,  1.40it/s][A
 84%|███████████████████████████████████████▋       | 902/1069 [10:36<02:04,  1.34it/s][A
 84%|███████████████████████████████████████▋       | 903/1069 [10:37<02:00,  1.38it/s][A
 85%|███████████████████████████████████████▋       | 904/1069 [10:38<01:57,  1.40it/s][A
 85%|███████████████████████████████████████▊       | 905/1069 [10:38<01:58,  1.38it/s][A
 85%|███████████████████████████████████████▊       | 906/1069 [10:39<01:54,  1.42it/s][A
 85%|███████████████████████████████████████▉       | 907/1069 [10:40<01:47,  1.51it/s][A
 85%|███████████████████████████████████████▉       | 908/1069 [10:40<01:43,  1.55it/s][A
 85%|███████████████████████████████████████▉       | 909/1069 [10:41<01:36,  1.66it/s][A
 85%|████████████████████████████████████████       | 910/1069 [10:41<01:36,  1.65it/s][A

 93%|███████████████████████████████████████████▌   | 990/1069 [11:37<01:00,  1.31it/s][A
 93%|███████████████████████████████████████████▌   | 991/1069 [11:37<01:00,  1.30it/s][A
 93%|███████████████████████████████████████████▌   | 992/1069 [11:38<00:59,  1.29it/s][A
 93%|███████████████████████████████████████████▋   | 993/1069 [11:39<00:58,  1.30it/s][A
 93%|███████████████████████████████████████████▋   | 994/1069 [11:40<00:55,  1.35it/s][A
 93%|███████████████████████████████████████████▋   | 995/1069 [11:40<00:53,  1.39it/s][A
 93%|███████████████████████████████████████████▊   | 996/1069 [11:41<00:51,  1.43it/s][A
 93%|███████████████████████████████████████████▊   | 997/1069 [11:42<00:50,  1.44it/s][A
 93%|███████████████████████████████████████████▉   | 998/1069 [11:42<00:48,  1.47it/s][A
 93%|███████████████████████████████████████████▉   | 999/1069 [11:43<00:46,  1.50it/s][A
 94%|███████████████████████████████████████████   | 1000/1069 [11:44<00:46,  1.48it/s][A

| epoch   0 | loss  0.69



  0%|                                                         | 0/1069 [00:00<?, ?it/s][A
  0%|                                                 | 1/1069 [00:00<14:25,  1.23it/s][A
  0%|                                                 | 2/1069 [00:01<14:03,  1.27it/s][A
  0%|▏                                                | 3/1069 [00:02<12:34,  1.41it/s][A
  0%|▏                                                | 4/1069 [00:02<12:38,  1.40it/s][A
  0%|▏                                                | 5/1069 [00:03<11:40,  1.52it/s][A
  1%|▎                                                | 6/1069 [00:04<11:15,  1.57it/s][A
  1%|▎                                                | 7/1069 [00:04<11:08,  1.59it/s][A
  1%|▎                                                | 8/1069 [00:05<10:55,  1.62it/s][A
  1%|▍                                                | 9/1069 [00:05<10:55,  1.62it/s][A
  1%|▍                                               | 10/1069 [00:06<10:53,  1.62it/s][

  8%|████                                            | 90/1069 [01:01<11:20,  1.44it/s][A
  9%|████                                            | 91/1069 [01:02<11:14,  1.45it/s][A
  9%|████▏                                           | 92/1069 [01:03<11:20,  1.44it/s][A
  9%|████▏                                           | 93/1069 [01:03<11:59,  1.36it/s][A
  9%|████▏                                           | 94/1069 [01:04<11:57,  1.36it/s][A
  9%|████▎                                           | 95/1069 [01:05<12:01,  1.35it/s][A
  9%|████▎                                           | 96/1069 [01:06<11:34,  1.40it/s][A
  9%|████▎                                           | 97/1069 [01:06<11:11,  1.45it/s][A
  9%|████▍                                           | 98/1069 [01:07<10:57,  1.48it/s][A
  9%|████▍                                           | 99/1069 [01:08<10:56,  1.48it/s][A
  9%|████▍                                          | 100/1069 [01:08<10:58,  1.47it/s][A

 17%|███████▉                                       | 180/1069 [02:07<11:21,  1.30it/s][A
 17%|███████▉                                       | 181/1069 [02:08<11:50,  1.25it/s][A
 17%|████████                                       | 182/1069 [02:09<12:17,  1.20it/s][A
 17%|████████                                       | 183/1069 [02:09<11:42,  1.26it/s][A
 17%|████████                                       | 184/1069 [02:10<11:10,  1.32it/s][A
 17%|████████▏                                      | 185/1069 [02:11<10:35,  1.39it/s][A
 17%|████████▏                                      | 186/1069 [02:11<10:20,  1.42it/s][A
 17%|████████▏                                      | 187/1069 [02:12<10:05,  1.46it/s][A
 18%|████████▎                                      | 188/1069 [02:12<09:40,  1.52it/s][A
 18%|████████▎                                      | 189/1069 [02:13<09:21,  1.57it/s][A
 18%|████████▎                                      | 190/1069 [02:14<09:10,  1.60it/s][A

 25%|███████████▊                                   | 270/1069 [03:11<10:32,  1.26it/s][A
 25%|███████████▉                                   | 271/1069 [03:12<10:17,  1.29it/s][A
 25%|███████████▉                                   | 272/1069 [03:12<10:03,  1.32it/s][A
 26%|████████████                                   | 273/1069 [03:13<09:35,  1.38it/s][A
 26%|████████████                                   | 274/1069 [03:14<09:14,  1.43it/s][A
 26%|████████████                                   | 275/1069 [03:14<08:57,  1.48it/s][A
 26%|████████████▏                                  | 276/1069 [03:15<09:21,  1.41it/s][A
 26%|████████████▏                                  | 277/1069 [03:16<09:11,  1.44it/s][A
 26%|████████████▏                                  | 278/1069 [03:16<09:18,  1.42it/s][A
 26%|████████████▎                                  | 279/1069 [03:17<09:39,  1.36it/s][A
 26%|████████████▎                                  | 280/1069 [03:18<08:41,  1.51it/s][A

 34%|███████████████▊                               | 360/1069 [04:14<08:57,  1.32it/s][A
 34%|███████████████▊                               | 361/1069 [04:15<08:36,  1.37it/s][A
 34%|███████████████▉                               | 362/1069 [04:16<08:52,  1.33it/s][A
 34%|███████████████▉                               | 363/1069 [04:16<08:46,  1.34it/s][A
 34%|████████████████                               | 364/1069 [04:17<08:35,  1.37it/s][A
 34%|████████████████                               | 365/1069 [04:18<08:35,  1.37it/s][A
 34%|████████████████                               | 366/1069 [04:18<08:25,  1.39it/s][A
 34%|████████████████▏                              | 367/1069 [04:19<08:27,  1.38it/s][A
 34%|████████████████▏                              | 368/1069 [04:20<08:18,  1.41it/s][A
 35%|████████████████▏                              | 369/1069 [04:20<08:09,  1.43it/s][A
 35%|████████████████▎                              | 370/1069 [04:21<07:55,  1.47it/s][A

 42%|███████████████████▊                           | 450/1069 [05:18<06:46,  1.52it/s][A
 42%|███████████████████▊                           | 451/1069 [05:18<07:01,  1.47it/s][A
 42%|███████████████████▊                           | 452/1069 [05:19<07:11,  1.43it/s][A
 42%|███████████████████▉                           | 453/1069 [05:20<07:24,  1.39it/s][A
 42%|███████████████████▉                           | 454/1069 [05:20<07:08,  1.43it/s][A
 43%|████████████████████                           | 455/1069 [05:21<07:18,  1.40it/s][A
 43%|████████████████████                           | 456/1069 [05:22<07:50,  1.30it/s][A
 43%|████████████████████                           | 457/1069 [05:23<07:23,  1.38it/s][A
 43%|████████████████████▏                          | 458/1069 [05:23<07:05,  1.44it/s][A
 43%|████████████████████▏                          | 459/1069 [05:24<07:20,  1.38it/s][A
 43%|████████████████████▏                          | 460/1069 [05:25<07:34,  1.34it/s][A

In [None]:
val_loss, acc, f1 = evaluate(model, valid_dataloader, criterion)

In [None]:
wandb.summary['val_loss'] = val_loss
wandb.summary['val_acc'] = acc
wandb.summary['val_f1'] = f1

In [None]:
print(val_loss, acc, f1)

In [None]:
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt").to(device)

with torch.no_grad():
    logits = model(**inputs).logits

print(torch.sigmoiod(logits))
predicted_class_id = logits.argmax().item()
print(predicted_class_id)

In [None]:
model.push_to_hub("sugam11/gpt2-rlhf-reward")

In [None]:
wandb.finish()