In [1]:
from transformers import LlamaForSequenceClassification
import torch
from torch.nn.utils import clip_grad_value_
import torch.optim as optim
import numpy as np
import math
import random
import pickle
from itertools import zip_longest
from datasets import load_dataset

1. Data

In [2]:
# gen data
with open('../Data/PRM_data/gen_texts.pkl', 'rb') as file:
    gen_texts = pickle.load(file)
with open('../Data/PRM_data/gen_targets.pkl', 'rb') as file:
    gen_targets = pickle.load(file)
with open('../Data/PRM_data/gen_starts_ends.pkl', 'rb') as file:
    gen_starts_ends = pickle.load(file)

# sol data
with open('../Data/PRM_data/sol_texts.pkl', 'rb') as file:
    sol_texts = pickle.load(file)
with open('../Data/PRM_data/sol_starts_ends.pkl', 'rb') as file:
    sol_starts_ends = pickle.load(file)

# Math-Shepherd
dataset = load_dataset('../Data/Math-Shepherd')

In [3]:
def shuffle_lists(*args):
    combined = list(zip(*args))
    random.shuffle(combined)
    return list(zip(*combined))

def np2torch(input,addBatchDim=True):
    if addBatchDim:
        return torch.tensor(input,device='cuda')[None]
    else:
        return torch.tensor(input,device='cuda')

def from_shepherd(dataset):
    # yield token_id, index, target, data_source
    dataset = dataset.shuffle()
    for data in dataset['train']:
        yield np2torch(data['input_id']),np2torch(data['index'],False),np2torch(data['targets']).float(),0

def from_sol(texts,starts_ends,num_of_points=5):
    texts,starts_ends = shuffle_lists(texts,starts_ends)
    for text,(start,end) in zip(texts,starts_ends):
        if start>=end:# use entire sol when it is shorter than 10
            yield np2torch(text),end,torch.ones(1,device='cuda',dtype=torch.float32),1
        else: 
            index = np.random.randint(start,end,num_of_points)
            # targets = np.exp(-(end-index)/end) # discount
            yield np2torch(text),np2torch(index,False),\
                    torch.ones((1,num_of_points),device='cuda',dtype=torch.float32),1
        
def from_genData(texts,targets,starts_ends,num_of_points=5):
    texts,targets,starts_ends = shuffle_lists(texts,targets,starts_ends)
    for text,y,(start,end) in zip(texts,targets,starts_ends):
        if start>=end:# use entire sol when it is shorter than 10
            yield np2torch(text),end,y*torch.ones(1,device='cuda',dtype=torch.float32),2
        else:
            index = np.random.randint(start,end,num_of_points)
            # target = y * np.exp(-(end-index)/end) # discount
            yield np2torch(text),np2torch(index,False),\
                    y*torch.ones((1,num_of_points),device='cuda',dtype=torch.float32),2

2. Model

In [4]:
model = LlamaForSequenceClassification.from_pretrained('deepseek-ai/deepseek-math-7b-rl',\
                                                       num_labels=1,\
                                                       torch_dtype="auto",\
                                                       attn_implementation="flash_attention_2")

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at deepseek-ai/deepseek-math-7b-rl and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
for param in model.model.parameters():
    param.requires_grad = False
for param in model.score.parameters():
    param.requires_grad = True
model.score = model.score.float()
model = model.to('cuda')

3. Training head

In [6]:
epochs = 1
accumulation_steps = 64
verbose = 1024
lr = 6e-5
clip = 6e-3
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.score.parameters(),lr = lr)

In [7]:
model.train()
train_loss = [0,0,0]
count_loss = [0,0,0]
i = 0
for epoch in range(epochs):
    for data in zip(from_shepherd(dataset),\
                    from_sol(sol_texts,sol_starts_ends),\
                    from_genData(gen_texts,gen_targets,gen_starts_ends)):
        for d in data:
            # if d is None: continue # zip_longest will return None for shorter iterable
            text,index,target,source = d
            hidden_states = model.model(text)[0].float()
            logits = model.score(hidden_states)[:,index,0]
            loss = loss_fn(logits,target)
            loss.backward()
            
            train_loss[source] += loss.item()
            count_loss[source] += 1
            i += 1

            if (i + 1) % accumulation_steps == 0:
                clip_grad_value_(model.score.parameters(),clip)
                optimizer.step()
                optimizer.zero_grad()
            
            if (i + 1) % verbose == 0:
                print(f"iter: {i}, loss: {[l/c if c!=0 else 'N/A' for l,c in zip(train_loss,count_loss)]}")
                train_loss = [0,0,0]
                count_loss = [0,0,0]

iter: 1023, loss: [0.6983305301484475, 0.7101281848121598, 0.7084300920061352]
iter: 2047, loss: [0.6851908006863288, 0.6916047665031075, 0.7124841071643437]
iter: 3071, loss: [0.6872878611262593, 0.6627943027786344, 0.719804101500693]
iter: 4095, loss: [0.6924264471202303, 0.6356581202175611, 0.7236687314440633]
iter: 5119, loss: [0.6937833166958993, 0.6292938889820905, 0.7132304342261507]
iter: 6143, loss: [0.6900924494888776, 0.6158084587918388, 0.707914371119916]
iter: 7167, loss: [0.688401895009877, 0.6027550736655238, 0.709160310134553]
iter: 8191, loss: [0.6910998361152515, 0.6003385509563681, 0.6963855191584556]
iter: 9215, loss: [0.6799317163106633, 0.5954016733762116, 0.6894079299091943]
iter: 10239, loss: [0.6862649718337744, 0.5795238888508414, 0.6891365681301084]
iter: 11263, loss: [0.6809413582957976, 0.5622165842839351, 0.6912893049003791]
iter: 12287, loss: [0.6772415956094467, 0.5540770841272253, 0.6917170327255103]
iter: 13311, loss: [0.6718460233155583, 0.53999936152

In [8]:
model.save_pretrained("../Model/PRM")