### Step 1: Install necesscary packages

In [1]:
!pip install matplotlib
!pip install torch numpy transformers datasets tiktoken wandb tqdm

/bin/bash: line 1: pip: command not found
/bin/bash: line 1: pip: command not found


### Step 2: Package imports and configuration

In [2]:
import sys
import os
sys.path.append(os.path.abspath("/home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math")) 
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import pickle
from model import GPT, GPTConfig
import random
from tqdm import tqdm
import time
import json
import matplotlib.pyplot as plt
# Configuration
beta = 0.5
device = 'cuda' if torch.cuda.is_available() else 'cpu'
base_lr = 1e-4
epochs = 40
batch_size = 1024
max_length =64
num_samples = 1
max_new_tokens = 128
temperature = 0.8
top_k = 200
checkpoint_path = "/home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/"
os.makedirs(checkpoint_path, exist_ok=True)
# tokenizer
with open("/home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/sft/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
def encode(s): return [stoi[c] for c in s]
def decode(l): return ''.join([itos[i] for i in l])

### Step 3: Define helper functions

In [3]:
def compute_logprob(input_ids):
    inputs = input_ids[:, :-1]
    targets = input_ids[:, 1:]
    logits, _ = gpt(inputs, full_seq=True)
    B, T, V = logits.size()
    logits_flat = logits.reshape(-1, V)
    targets_flat = targets.reshape(-1)
    loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=0, reduction='none')
    loss = loss.reshape(B, T)
    attention_mask = (targets != 0).float()
    loss = (loss * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
    return -loss 

def pad_or_truncate(seq, max_length):
    return seq[-max_length:] if len(seq) > max_length else seq + [0] * (max_length - len(seq))

def get_batches(lines, batch_size):
    random.shuffle(lines)
    #for l in lines:
    #    print(l[1])
    for i in range(0, len(lines), batch_size):
        batch = lines[i:i+batch_size]
        if len(batch) < batch_size:
            continue
        neg_inputs = [pad_or_truncate(encode(p['negative'] + '\n\n\n\n'), max_length) for p in batch]
        pos_inputs = [pad_or_truncate(encode(p['positive'] + '\n\n\n\n'), max_length) for p in batch]
        neg_tensor = torch.tensor(neg_inputs, dtype=torch.long, device=device)
        pos_tensor = torch.tensor(pos_inputs, dtype=torch.long, device=device)
        yield neg_tensor, pos_tensor

### Step 4: Load the pretrained NanoGPT model

In [4]:
ckpt = torch.load("/home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/sft/gpt.pt", map_location=device)
gptconf = GPTConfig(**ckpt['model_args'])
gpt = GPT(gptconf)
state_dict = ckpt['model']
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
gpt.to(device).train()

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(74, 348)
    (wpe): Embedding(256, 348)
    (drop): Dropout(p=0.2, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=348, out_features=1044, bias=False)
          (c_proj): Linear(in_features=348, out_features=348, bias=False)
          (attn_dropout): Dropout(p=0.2, inplace=False)
          (resid_dropout): Dropout(p=0.2, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=348, out_features=1392, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1392, out_features=348, bias=False)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=348, out_features=74, bias=False)
)

### Step 5: Load Data (**students are required to complete this part!**)

In [5]:
# Load data from ./data/pos_neg_pairs.json
with open("/home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/data/dpo_data.json", "r") as f:
    lines = [json.loads(line) for line in f]
print(len(lines))

400004


### Step 6: Build the optimizer and scheduler (**students are required to complete this part!**)

In [6]:
# recommend to use the AdamW optimizer 
total_steps = len(lines) // batch_size
optimizer = torch.optim.AdamW(gpt.parameters(), lr=base_lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps)

### Step 7: Begin training

We will train the model with the dpo algorithm, i tested some epochs and found that the loss converges and the model perform well at around 35-40 epochs, so we will train the model for 40 epochs

In [7]:
from tqdm import tqdm

In [8]:
total_steps = len(lines) // batch_size
for epoch in range(epochs):
    pbar = tqdm(get_batches(lines, batch_size), total=total_steps, desc=f"Training")
    for step, (neg_tensor,pos_tensor) in enumerate(pbar):
        ###########################################################
        # Please complete the training code here!
        # Examples: 
        # ...
        # neg_logprob
        # pos_logprob 
        # loss = -F.logsigmoid((pos_logprob - neg_logprob) / beta).mean() - pos_logprob.mean() * 0.1 
        # ...
        ###########################################################
        optimizer.zero_grad()
        neg_logprob = compute_logprob(neg_tensor)
        pos_logprob = compute_logprob(pos_tensor)
        loss = -F.logsigmoid((pos_logprob - neg_logprob) / beta).mean() - pos_logprob.mean() * 0.1 
        loss.backward()
        optimizer.step()
        scheduler.step()
        pbar.update(1)
        if step % 1000 == 0: 
            print(f"Step {step}, Loss: {loss.item()}")
    ckpt_path = checkpoint_path + f"dpo_{epoch}.pt"
    os.makedirs(checkpoint_path, exist_ok=True)
    torch.save({
        "model": gpt.state_dict(),
        "model_args": ckpt['model_args'],
    }, ckpt_path)
    print(f"Saved checkpoint to {ckpt_path}")

Training:   0%|▎                                                                                                                                             | 1/390 [00:00<04:20,  1.50it/s]

Step 0, Loss: 12.401473045349121


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.49it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_0.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:10,  5.53it/s]

Step 0, Loss: 0.09065369516611099


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_1.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:08,  5.63it/s]

Step 0, Loss: 0.06962287425994873


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_2.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:09,  5.61it/s]

Step 0, Loss: 0.06278470903635025


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_3.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:09,  5.62it/s]

Step 0, Loss: 0.05694141611456871


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_4.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:09,  5.62it/s]

Step 0, Loss: 0.05249735340476036


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_5.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:09,  5.62it/s]

Step 0, Loss: 0.047523580491542816


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_6.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:11,  5.42it/s]

Step 0, Loss: 0.04132597893476486


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_7.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:09,  5.62it/s]

Step 0, Loss: 0.03280108422040939


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_8.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:08,  5.67it/s]

Step 0, Loss: 0.027747027575969696


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_9.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:09,  5.62it/s]

Step 0, Loss: 0.02624540962278843


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_10.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:08,  5.63it/s]

Step 0, Loss: 0.02489774487912655


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_11.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:09,  5.62it/s]

Step 0, Loss: 0.023890415206551552


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_12.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:09,  5.59it/s]

Step 0, Loss: 0.02332920953631401


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_13.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:08,  5.62it/s]

Step 0, Loss: 0.022431712597608566


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_14.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:08,  5.62it/s]

Step 0, Loss: 0.021962221711874008


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_15.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:08,  5.66it/s]

Step 0, Loss: 0.021205637603998184


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_16.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:08,  5.64it/s]

Step 0, Loss: 0.020295152440667152


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_17.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:08,  5.63it/s]

Step 0, Loss: 0.019863320514559746


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_18.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:09,  5.59it/s]

Step 0, Loss: 0.018462279811501503


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_19.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:08,  5.67it/s]

Step 0, Loss: 0.01797167956829071


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_20.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:08,  5.65it/s]

Step 0, Loss: 0.017795376479625702


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_21.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:08,  5.65it/s]

Step 0, Loss: 0.01766950450837612


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_22.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:10,  5.49it/s]

Step 0, Loss: 0.016936419531702995


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_23.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:09,  5.61it/s]

Step 0, Loss: 0.01691877655684948


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_24.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:08,  5.65it/s]

Step 0, Loss: 0.01661328412592411


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_25.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:10,  5.53it/s]

Step 0, Loss: 0.016455724835395813


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_26.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:08,  5.69it/s]

Step 0, Loss: 0.01627330668270588


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_27.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:08,  5.66it/s]

Step 0, Loss: 0.016254279762506485


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_28.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:08,  5.65it/s]

Step 0, Loss: 0.015972616150975227


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_29.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:08,  5.67it/s]

Step 0, Loss: 0.016330713406205177


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_30.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:08,  5.68it/s]

Step 0, Loss: 0.01598365232348442


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_31.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:08,  5.66it/s]

Step 0, Loss: 0.01626153476536274


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_32.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:08,  5.64it/s]

Step 0, Loss: 0.016123415902256966


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_33.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:08,  5.71it/s]

Step 0, Loss: 0.01607584021985531


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_34.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:09,  5.61it/s]

Step 0, Loss: 0.01586838625371456


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_35.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:08,  5.69it/s]

Step 0, Loss: 0.016065992414951324


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_36.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:10,  5.53it/s]

Step 0, Loss: 0.015671860426664352


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_37.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:09,  5.61it/s]

Step 0, Loss: 0.01592426374554634


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_38.pt


Training:   1%|▋                                                                                                                                             | 2/390 [00:00<01:08,  5.65it/s]

Step 0, Loss: 0.015961483120918274


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [01:26<00:00,  4.50it/s]


Saved checkpoint to /home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_39.pt


### Step 8: Begin testing (**students are required to complete this part!**)

Test with some of the given examples, we will test on our test set after this

In [10]:
# Load the fine-tuned model
ckpt_path = "/home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_39.pt"
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
gpt = GPT(gptconf).cuda()
try:
    state_dict = checkpoint['model']
except:
    state_dict = checkpoint['model_state_dict']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
# Test
gpt.eval()
test_set = ["17+19=?", "3*17=?", "72/4=?", "72-x=34,x=?", "x*11=44,x=?", "3*17=?", "72/4=?", "72-x=34,x=?", "99/11=?", "15/5"]
answers = ["36", "51", "18", "38", "4", "51", "18", "38", "9", "3"]
correct = 0
incorrect = 0
with torch.no_grad():
    for i, prompt in enumerate(test_set): 
        prompt_ids = encode(prompt)
        ###########################################################
        # Please complete the test code here!
        # ...
        # gpt.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
        # ...
        ###########################################################
        x = torch.tensor(prompt_ids, dtype=torch.long, device=device).unsqueeze(0)
        # Generate output
        generated_ids, _ = gpt.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
        # Decode and print the result
        output = decode(generated_ids[0].tolist())
        print(output)

        answer = output.strip(".").split()[-1]
        print(answer, answers[i])
        if int(answer) == int(answers[i]):
            correct += 1
        else:
            incorrect += 1


print(f"Correct: {correct}, Incorrect: {incorrect}, Accuracy: {correct/(correct+incorrect)}")



17+19=? The answer is 36 because 17+19 equals 36.
36 36
3*17=? The answer is 51 because 3*17 equals 51.
51 51
72/4=? The answer is 18 because 72/4 equals 18.
18 18
72-x=34,x=? The answer is 38 because 72-34 equals 38.
38 38
x*11=44,x=? The answer is 4 because 44/11 equals 4.
4 4
3*17=? The answer is 51 because 3*17 equals 51.
51 51
72/4=? The answer is 18 because 72/4 equals 18.
18 18
72-x=34,x=? The answer is 38 because 72-34 equals 38.
38 38
99/11=? The answer is 9 because 99/11 equals 9.
9 9
15/5=? The answer is 3 because 15/5 equals 3.
3 3
Correct: 10, Incorrect: 0, Accuracy: 1.0


# testing on a larger test set

We will test on a larger test set to make sure the model perform well for different operations and number range in [-100, 100]

In [16]:
with open("/home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/data/dpo_test_data.json", "r") as f:
    lines = f.readlines()
test_data = [json.loads(line) for line in lines]
print(len(test_data))
print(test_data[0])

404
{'question': '49+51=?', 'answer': 100}


In [17]:
# Load the fine-tuned model
ckpt_path = "/home/users/ntu/cong045/scratch/school/sc3000/NanoGPT-Math/checkpoints/dpo/dpo_39.pt"
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
gpt = GPT(gptconf).cuda()
try:
    state_dict = checkpoint['model']
except:
    state_dict = checkpoint['model_state_dict']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
# Test
gpt.eval()
correct = 0
incorrect = 0
with torch.no_grad():
    for i, prompt in enumerate(test_data): 
        prompt_ids = encode(prompt["question"])
        ###########################################################
        # Please complete the test code here!
        # ...
        # gpt.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
        # ...
        ###########################################################
        x = torch.tensor(prompt_ids, dtype=torch.long, device=device).unsqueeze(0)
        # Generate output
        generated_ids, _ = gpt.generate(x, max_new_tokens, temperature=0.2, top_k=top_k)
        # Decode and print the result
        output = decode(generated_ids[0].tolist())
        # print(output)
        answer = int(output.strip(".").split()[-1])
        if int(answer) == int(prompt["answer"]):
            correct += 1
        else:
            incorrect += 1


print(f"Correct: {correct}, Incorrect: {incorrect}, Accuracy: {correct/(correct+incorrect)}")



Correct: 404, Incorrect: 0, Accuracy: 1.0


The model got perfect performance on the test set! So it seems like that our DPO training worked well!