In [1]:
from transformers import LlamaForSequenceClassification
import torch
from torch.nn.utils import clip_grad_value_
import torch.optim as optim
import numpy as np
import math
import random
import pickle
from itertools import zip_longest
from datasets import load_dataset,load_from_disk

#### 1. Data

In [2]:
# gen data
with open('../Data/PRM_data/gen_texts.pkl', 'rb') as file:
    gen_texts = pickle.load(file)
with open('../Data/PRM_data/gen_targets.pkl', 'rb') as file:
    gen_targets = pickle.load(file)
with open('../Data/PRM_data/gen_starts_ends.pkl', 'rb') as file:
    gen_starts_ends = pickle.load(file)

# sol data
with open('../Data/PRM_data/sol_texts.pkl', 'rb') as file:
    sol_texts = pickle.load(file)
with open('../Data/PRM_data/sol_starts_ends.pkl', 'rb') as file:
    sol_starts_ends = pickle.load(file)

# Math-Shepherd
dataset = load_dataset('../Data/Math-Shepherd')

# MMOS
dataset2 = load_from_disk('../Data/MMOS')

In [3]:
MAX_LEN = 1200
def shuffle_lists(*args):
    combined = list(zip(*args))
    random.shuffle(combined)
    return list(zip(*combined))

def np2torch(input,addBatchDim=True):
    if addBatchDim:
        return torch.tensor(input,device='cuda')[None]
    else:
        return torch.tensor(input,device='cuda')

def from_shepherd(dataset):
    # yield token_id, index, target, data_source
    dataset = dataset.shuffle()
    for data in dataset['train']:
        if len(data['index']) != len(data['targets']): continue
        if max(data['index']) > MAX_LEN:
            out = list(zip(*[(d,t) for d,t in zip(data['index'],data['targets']) if d<MAX_LEN]))
            if len(out) == 2:
                index,targets = out
            else:
                continue # out is [] -> all index > MAX_LEN
        else:
            index,targets = data['index'],data['targets']
        yield np2torch(data['input_id'][:max(index)+1]),np2torch(index,False),np2torch(targets).float(),0

def from_mmos(dataset,num_of_points=5):
    # yield token_id, index, target, data_source
    dataset = dataset.shuffle()
    for data in dataset['train']:
        text = data['input_id']
        start,end = data['starts_ends']
        end = min(end,MAX_LEN)
        if start>=end:# use entire sol when it is shorter than 10
            continue
        else: 
            index = np.random.randint(start,end,num_of_points)
            # targets = np.exp(-(end-index)/end) # discount
            yield np2torch(text[:max(index)+1]),np2torch(index,False),\
                    torch.ones((1,num_of_points),device='cuda',dtype=torch.float32),3

def from_sol(texts,starts_ends,num_of_points=5):
    texts,starts_ends = shuffle_lists(texts,starts_ends)
    for text,(start,end) in zip(texts,starts_ends):
        end = min(end,MAX_LEN)
        if start>=end:# use entire sol when it is shorter than 10
            continue
        else: 
            index = np.random.randint(start,end,num_of_points)
            # targets = np.exp(-(end-index)/end) # discount
            yield np2torch(text[:max(index)+1]),np2torch(index,False),\
                    torch.ones((1,num_of_points),device='cuda',dtype=torch.float32),1
        
def from_genData(texts,targets,starts_ends,num_of_points=5):
    texts,targets,starts_ends = shuffle_lists(texts,targets,starts_ends)
    for text,y,(start,end) in zip(texts,targets,starts_ends):
        end = min(end,MAX_LEN)
        if start>=end:# use entire sol when it is shorter than 10
            continue
        else:
            index = np.random.randint(start,end,num_of_points)
            # target = y * np.exp(-(end-index)/end) # discount
            yield np2torch(text[:max(index)+1]),np2torch(index,False),\
                    y*torch.ones((1,num_of_points),device='cuda',dtype=torch.float32),2

#### 2. Model

In [None]:
model = LlamaForSequenceClassification.from_pretrained('deepseek-ai/deepseek-math-7b-rl',\
                                                       num_labels=1,\
                                                       torch_dtype="auto",\
                                                       attn_implementation="flash_attention_2")

In [5]:
for param in model.model.parameters():
    param.requires_grad = False
for param in model.score.parameters():
    param.requires_grad = True
model.score = model.score.float()
model = model.to('cuda')

#### 3. Training head

In [6]:
epochs = 1
accumulation_steps = 64
verbose = 1024
lr = 6e-5
clip = 6e-3
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.score.parameters(),lr = lr)

In [7]:
i = 0
for epoch in range(epochs):
    for data in zip(from_shepherd(dataset),\
                    from_sol(sol_texts,sol_starts_ends),\
                    from_genData(gen_texts,gen_targets,gen_starts_ends)):
        for d in data:
            # if d is None: continue # zip_longest will return None for shorter iterable
            text,index,target,source = d
            hidden_states = model.model(text)[0].float()
            logits = model.score(hidden_states)[:,index,0]
            loss = loss_fn(logits,target)
            loss.backward()
            
            train_loss[source] += loss.item()
            count_loss[source] += 1
            i += 1

            if (i + 1) % accumulation_steps == 0:
                clip_grad_value_(model.score.parameters(),clip)
                optimizer.step()
                optimizer.zero_grad()
            
            if (i + 1) % verbose == 0:
                print(f"iter: {i}, loss: {[l/c if c!=0 else 'N/A' for l,c in zip(train_loss,count_loss)]}")
                train_loss = [0,0,0]
                count_loss = [0,0,0]

iter: 1023, loss: [0.6983305301484475, 0.7101281848121598, 0.7084300920061352]
iter: 2047, loss: [0.6851908006863288, 0.6916047665031075, 0.7124841071643437]
iter: 3071, loss: [0.6872878611262593, 0.6627943027786344, 0.719804101500693]
iter: 4095, loss: [0.6924264471202303, 0.6356581202175611, 0.7236687314440633]
iter: 5119, loss: [0.6937833166958993, 0.6292938889820905, 0.7132304342261507]
iter: 6143, loss: [0.6900924494888776, 0.6158084587918388, 0.707914371119916]
iter: 7167, loss: [0.688401895009877, 0.6027550736655238, 0.709160310134553]
iter: 8191, loss: [0.6910998361152515, 0.6003385509563681, 0.6963855191584556]
iter: 9215, loss: [0.6799317163106633, 0.5954016733762116, 0.6894079299091943]
iter: 10239, loss: [0.6862649718337744, 0.5795238888508414, 0.6891365681301084]
iter: 11263, loss: [0.6809413582957976, 0.5622165842839351, 0.6912893049003791]
iter: 12287, loss: [0.6772415956094467, 0.5540770841272253, 0.6917170327255103]
iter: 13311, loss: [0.6718460233155583, 0.53999936152

In [8]:
model.save_pretrained("../Model/PRM")

#### 4. Fine-tune

In [4]:
epochs = 1
alpha_factor = 4.0
accumulation_steps = 64
verbose = 1024
lr = 6e-5
clip = 6e-3
topics_num = 4
weights=[0.5,0.2,0.2,0.1]

In [5]:
from transformers import LlamaForSequenceClassification,BitsAndBytesConfig,AutoConfig
import torch
from peft import (
    get_peft_model,
    PeftType,
    LoraConfig)

In [6]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)
model = LlamaForSequenceClassification.from_pretrained('../Model/PRM',\
                                                       num_labels=1,\
                                                       device_map="auto",
                                                       torch_dtype="auto",
                                                       quantization_config=quantization_config,
                                                       attn_implementation="flash_attention_2"
                                                       )
model.gradient_checkpointing_enable()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [7]:
peft_config = LoraConfig(r=8, # low rank 
                         lora_alpha = 16, # see below 
                         lora_dropout = 0.1, 
                         bias="none",#'none', 'all' or 'lora_only' 
                         target_modules = [ "q_proj", 
                                            "k_proj", 
                                            "v_proj", 
                                            "o_proj", 
                                            "gate_proj", 
                                            "up_proj", 
                                            "down_proj" 
                                        ] 
                        )
base_model = get_peft_model(model.model, peft_config)
base_model.gradient_checkpointing_enable()
# model.config.pad_token_id = tokenizer.pad_token_id
base_model.print_trainable_parameters()
model.score = model.score.float()
model.score.weight.requires_grad_(True);

trainable params: 18,739,200 || all params: 6,509,674,496 || trainable%: 0.287866928085493


In [8]:
import random
from torch.autograd import Function
import torch.nn as nn
import torch

def sample_from_iterables(weights,*iterables):
    while True:
        iterable = random.choices(iterables, weights=weights, k=1)[0]
        try:
            yield next(iterable)
        except StopIteration:
            break

class GradientReversalFunction(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha  # Store alpha in the context
        return x

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output * -ctx.alpha, None  # Use stored alpha, return None for alpha's grad

class GradientReversalLayer(nn.Module):
    def forward(self, x, alpha):
        return GradientReversalFunction.apply(x, alpha)

class revLinear(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(revLinear, self).__init__()
        self.layers = nn.Sequential(nn.Linear(input_dim, input_dim),nn.CELU(),nn.Linear(input_dim, output_dim))
        self.grad_rev = GradientReversalLayer()

    def forward(self, x, alpha):
        return self.layers(self.grad_rev(x, alpha))

In [9]:
# topic_model = revLinear(model.score.weight.shape[1],topics_num).to('cuda').float()

In [10]:
base_params = [param for param in base_model.parameters() if param.requires_grad]
trainable_params =  base_params + \
                    list(model.score.parameters())
                    # list(topic_model.parameters())
optimizer = torch.optim.Adam(trainable_params,lr = lr)

In [11]:
loss_fn = torch.nn.BCEWithLogitsLoss()
loss_topic = torch.nn.CrossEntropyLoss()
asym =lambda x: x if x<0 else 2*x
# sigmoid = lambda x: 1/(1+np.exp(-x))
softplue = lambda x:np.log(1 + np.exp(x))

train_loss = 0
for epoch in range(epochs):
    for i,(text,index,target,source) in enumerate(from_shepherd(dataset)):
        target_topics = source * torch.ones(target.shape[1],dtype=torch.long,device='cuda') # l
        hidden_states = base_model(text)[0][:,index].float() # b,l,d
        logits = model.score(hidden_states)[:,:,0] # b,l
        loss = loss_fn(logits,target)
        loss.backward()
        train_loss += loss.item()

        if (i + 1) % accumulation_steps == 0:
            # clip_grad_value_(trainable_params,clip)
            clip_grad_value_(base_params,clip)
            optimizer.step()
            optimizer.zero_grad()

        if (i + 1) % verbose == 0:
            print(f"iter: {i}, \n train loss: {train_loss/verbose}")
            train_loss = 0
            
        torch.cuda.empty_cache()

iter: 1023, 
 train loss: 0.6580159974400885
iter: 2047, 
 train loss: 0.6415111924288794
iter: 3071, 
 train loss: 0.5934078045393107
iter: 4095, 
 train loss: 0.5148563318325614
iter: 5119, 
 train loss: 0.4851639556000009
iter: 6143, 
 train loss: 0.4959308831639646
iter: 7167, 
 train loss: 0.44096636618996854
iter: 8191, 
 train loss: 0.43119324900726497
iter: 9215, 
 train loss: 0.42783852473348816
iter: 10239, 
 train loss: 0.41793171603058
iter: 11263, 
 train loss: 0.38764809867461736
iter: 12287, 
 train loss: 0.38431894192444815
iter: 13311, 
 train loss: 0.3792582213463902
iter: 14335, 
 train loss: 0.3853120165790642
iter: 15359, 
 train loss: 0.3791149910425702
iter: 16383, 
 train loss: 0.39037251335412293
iter: 17407, 
 train loss: 0.3745948937043977
iter: 18431, 
 train loss: 0.3699627665919252
iter: 19455, 
 train loss: 0.37997911753427616
iter: 20479, 
 train loss: 0.36232350547061287
iter: 21503, 
 train loss: 0.3599839823541515
iter: 22527, 
 train loss: 0.35588806

In [12]:
torch.save(model.score.state_dict(), '../Model/model_score.pth')
peft_model_id = "../Model/PRM_LORA"
base_model.save_pretrained(peft_model_id)

