### Step 1: Install necesscary packages

In [None]:
!pip install matplotlib
!pip install torch numpy transformers datasets tiktoken wandb tqdm

### Step 2: Package imports and configuration

In [2]:
import sys
import os
sys.path.append(os.path.abspath("..")) 
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import pickle
from model import GPT, GPTConfig
import random
from tqdm import tqdm
import time
import json
import matplotlib.pyplot as plt
# Configuration
beta = 0.5
device = 'cuda' if torch.cuda.is_available() else 'cpu'
base_lr = 1e-4
epochs = 5
batch_size = 64
max_length =64
num_samples = 1
max_new_tokens = 200
temperature = 0.8
top_k = 200
# tokenizer
with open("../sft/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
def encode(s): return [stoi[c] for c in s]
def decode(l): return ''.join([itos[i] for i in l])

### Step 3: Define helper functions

In [3]:
def compute_logprob(input_ids):
    inputs = input_ids[:, :-1]
    targets = input_ids[:, 1:]
    logits, _ = gpt(inputs, full_seq=True)
    B, T, V = logits.size()
    logits_flat = logits.reshape(-1, V)
    targets_flat = targets.reshape(-1)
    loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=0, reduction='none')
    loss = loss.reshape(B, T)
    attention_mask = (targets != 0).float()
    loss = (loss * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
    return -loss 

def pad_or_truncate(seq, max_length):
    return seq[-max_length:] if len(seq) > max_length else seq + [0] * (max_length - len(seq))

def get_batches(lines, batch_size):
    random.shuffle(lines)
    #for l in lines:
    #    print(l[1])
    for i in range(0, len(lines), batch_size):
        batch = lines[i:i+batch_size]
        if len(batch) < batch_size:
            continue
        neg_inputs = [pad_or_truncate(encode(p['negative'] + '\n\n\n\n'), max_length) for p in batch]
        pos_inputs = [pad_or_truncate(encode(p['positive'] + '\n\n\n\n'), max_length) for p in batch]
        neg_tensor = torch.tensor(neg_inputs, dtype=torch.long, device=device)
        pos_tensor = torch.tensor(pos_inputs, dtype=torch.long, device=device)
        yield neg_tensor, pos_tensor

### Step 4: Load the pretrained NanoGPT model

In [None]:
ckpt = torch.load("../sft/gpt.pt", map_location=device)
gptconf = GPTConfig(**ckpt['model_args'])
gpt = GPT(gptconf)
state_dict = ckpt['model']
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
gpt.to(device).train()

### Step 5: Load Data (**students are required to complete this part!**)

In [7]:
# Load data from ./data/pos_neg_pairs.json

### Step 6: Build the optimizer and scheduler (**students are required to complete this part!**)

In [8]:
# recommend to use the AdamW optimizer 

### Step 7: Begin training (**students are required to complete this part!**)

In [None]:
total_steps = len(lines) // batch_size
for epoch in range(epochs):
    pbar = tqdm(get_batches(lines, batch_size))
    for step, (neg_tensor,pos_tensor) in enumerate(pbar):
        ###########################################################
        # Please complete the training code here!
        ###########################################################
    ckpt_path = f"./dpo.pt"
    torch.save({
        "model_state_dict": gpt.state_dict(),
        "model_args": ckpt['model_args'],
    }, ckpt_path)
    print(f"Saved checkpoint to {ckpt_path}")

### Step 8: Begin testing (**students are required to complete this part!**)

In [None]:
gpt.eval()
test_set = ["88+7=?", "x-18=21,x=?", "x/10=6,x=?", "54/1=?", "24+48=?", "11+23=?", "64+13=?", "24+46=?", "95-42=?", "81-58=?", "81-25=?", "x+14=36,x=?", "x/5=20,x=?", "x-1=52,x=?", "66+27=?", "62-33=?", "x-7=62,x=?", "39-21=?", "x-35=29,x=?", "x-89=7,x=?", "92-29=?", "x-15=84,x=?", "26/x=1,x=?", "82-47=?", "x-30=42,x=?", "76-17=?", "40+17=?", "30-5=?", "x+3=8,x=?", "15+59=?", "4+78=?", "36+46=?", "x/25=2,x=?", "13+47=?", "82-43=?", "x-53=16,x=?", "x-73=27,x=?", "x+19=82,x=?", "14*5=?", "68-42=?", "x-57=41,x=?", "12/x=1,x=?", "73-64=?", "73-34=?", "78+22=?", "x+48=61,x=?", "58+37=?", "48-5=?", "x-78=2,x=?", "24/x=3,x=?", "13+30=?", "x+60=64,x=?", "x/10=10,x=?", "x+22=70,x=?", "24/2=?", "68-52=?", "x-29=27,x=?", "x+55=82,x=?", "x+64=71,x=?", "94-84=?", "x-23=30,x=?", "51+16=?", "44+30=?", "57-54=?", "x+70=89,x=?", "22+21=?", "4+11=?", "24+16=?", "x-31=23,x=?", "82-11=?", "x+19=31,x=?", "94-72=?", "16+82=?", "x+37=53,x=?", "69-64=?", "1+4=?", "14*1=?", "x-5=71,x=?", "94-73=?", "34-32=?", "x-67=6,x=?", "x-18=73,x=?", "48+52=?", "45+38=?", "x+83=94,x=?", "30+39=?", "x-26=50,x=?", "97-28=?", "1*x=4,x=?", "x-70=3,x=?", "78/2=?", "77+4=?", "33-29=?", "13+79=?", "87+9=?", "82-36=?", "x-47=19,x=?", "93-93=?", "x-12=39,x=?", "8-7=?"]
with torch.no_grad():
    for prompt in test_set: 
        prompt_ids = encode(prompt)
        #x = (torch.tensor(prompt_ids, dtype=torch.long, device=device)[None, ...])
        #y = gpt.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
        #print(decode(y[0].squeeze().tolist()))
        ###########################################################
        # Please complete the test code here!
        ###########################################################