In [1]:
from transformers import LlamaForSequenceClassification
import torch
from torch.nn.utils import clip_grad_value_
import torch.optim as optim
import numpy as np
import math
import random
import pickle
from itertools import zip_longest
from datasets import load_dataset,load_from_disk

#### 1. Data

In [2]:
# gen data
with open('../Data/PRM_data/gen_texts.pkl', 'rb') as file:
    gen_texts = pickle.load(file)
with open('../Data/PRM_data/gen_targets.pkl', 'rb') as file:
    gen_targets = pickle.load(file)
with open('../Data/PRM_data/gen_starts_ends.pkl', 'rb') as file:
    gen_starts_ends = pickle.load(file)

# sol data
with open('../Data/PRM_data/sol_texts.pkl', 'rb') as file:
    sol_texts = pickle.load(file)
with open('../Data/PRM_data/sol_starts_ends.pkl', 'rb') as file:
    sol_starts_ends = pickle.load(file)

# Math-Shepherd
dataset = load_dataset('../Data/Math-Shepherd')

# MMOS
dataset2 = load_from_disk('../Data/MMOS')

In [3]:
MAX_LEN = 1200
def shuffle_lists(*args):
    combined = list(zip(*args))
    random.shuffle(combined)
    return list(zip(*combined))

def np2torch(input,addBatchDim=True):
    if addBatchDim:
        return torch.tensor(input,device='cuda')[None]
    else:
        return torch.tensor(input,device='cuda')

def from_shepherd(dataset):
    # yield token_id, index, target, data_source
    dataset = dataset.shuffle()
    for data in dataset['train']:
        if len(data['index']) != len(data['targets']): continue
        if max(data['index']) > MAX_LEN:
            index,targets = zip(*[(d,t) for d,t in zip(data['index'],data['targets']) if d<MAX_LEN])
        else:
            index,targets = data['index'],data['targets']
        yield np2torch(data['input_id'][:max(index)+1]),np2torch(index,False),np2torch(targets).float(),0

def from_mmos(dataset,num_of_points=5):
    # yield token_id, index, target, data_source
    dataset = dataset.shuffle()
    for data in dataset['train']:
        text = data['input_id']
        start,end = data['starts_ends']
        end = min(end,MAX_LEN)
        if start>=end:# use entire sol when it is shorter than 10
            continue
        else: 
            index = np.random.randint(start,end,num_of_points)
            # targets = np.exp(-(end-index)/end) # discount
            yield np2torch(text[:max(index)+1]),np2torch(index,False),\
                    torch.ones((1,num_of_points),device='cuda',dtype=torch.float32),3

def from_sol(texts,starts_ends,num_of_points=5):
    texts,starts_ends = shuffle_lists(texts,starts_ends)
    for text,(start,end) in zip(texts,starts_ends):
        end = min(end,MAX_LEN)
        if start>=end:# use entire sol when it is shorter than 10
            continue
        else: 
            index = np.random.randint(start,end,num_of_points)
            # targets = np.exp(-(end-index)/end) # discount
            yield np2torch(text[:max(index)+1]),np2torch(index,False),\
                    torch.ones((1,num_of_points),device='cuda',dtype=torch.float32),1
        
def from_genData(texts,targets,starts_ends,num_of_points=5):
    texts,targets,starts_ends = shuffle_lists(texts,targets,starts_ends)
    for text,y,(start,end) in zip(texts,targets,starts_ends):
        end = min(end,MAX_LEN)
        if start>=end:# use entire sol when it is shorter than 10
            continue
        else:
            index = np.random.randint(start,end,num_of_points)
            # target = y * np.exp(-(end-index)/end) # discount
            yield np2torch(text[:max(index)+1]),np2torch(index,False),\
                    y*torch.ones((1,num_of_points),device='cuda',dtype=torch.float32),2

#### 2. Model

In [4]:
model = LlamaForSequenceClassification.from_pretrained('deepseek-ai/deepseek-math-7b-rl',\
                                                       num_labels=1,\
                                                       torch_dtype="auto",\
                                                       attn_implementation="flash_attention_2")

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at deepseek-ai/deepseek-math-7b-rl and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
for param in model.model.parameters():
    param.requires_grad = False
for param in model.score.parameters():
    param.requires_grad = True
model.score = model.score.float()
model = model.to('cuda')

#### 3. Training head

In [6]:
epochs = 1
accumulation_steps = 64
verbose = 1024
lr = 6e-5
clip = 6e-3
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.score.parameters(),lr = lr)

In [7]:
i = 0
for epoch in range(epochs):
    for data in zip(from_shepherd(dataset),\
                    from_sol(sol_texts,sol_starts_ends),\
                    from_genData(gen_texts,gen_targets,gen_starts_ends)):
        for d in data:
            # if d is None: continue # zip_longest will return None for shorter iterable
            text,index,target,source = d
            hidden_states = model.model(text)[0].float()
            logits = model.score(hidden_states)[:,index,0]
            loss = loss_fn(logits,target)
            loss.backward()
            
            train_loss[source] += loss.item()
            count_loss[source] += 1
            i += 1

            if (i + 1) % accumulation_steps == 0:
                clip_grad_value_(model.score.parameters(),clip)
                optimizer.step()
                optimizer.zero_grad()
            
            if (i + 1) % verbose == 0:
                print(f"iter: {i}, loss: {[l/c if c!=0 else 'N/A' for l,c in zip(train_loss,count_loss)]}")
                train_loss = [0,0,0]
                count_loss = [0,0,0]

iter: 1023, loss: [0.6983305301484475, 0.7101281848121598, 0.7084300920061352]
iter: 2047, loss: [0.6851908006863288, 0.6916047665031075, 0.7124841071643437]
iter: 3071, loss: [0.6872878611262593, 0.6627943027786344, 0.719804101500693]
iter: 4095, loss: [0.6924264471202303, 0.6356581202175611, 0.7236687314440633]
iter: 5119, loss: [0.6937833166958993, 0.6292938889820905, 0.7132304342261507]
iter: 6143, loss: [0.6900924494888776, 0.6158084587918388, 0.707914371119916]
iter: 7167, loss: [0.688401895009877, 0.6027550736655238, 0.709160310134553]
iter: 8191, loss: [0.6910998361152515, 0.6003385509563681, 0.6963855191584556]
iter: 9215, loss: [0.6799317163106633, 0.5954016733762116, 0.6894079299091943]
iter: 10239, loss: [0.6862649718337744, 0.5795238888508414, 0.6891365681301084]
iter: 11263, loss: [0.6809413582957976, 0.5622165842839351, 0.6912893049003791]
iter: 12287, loss: [0.6772415956094467, 0.5540770841272253, 0.6917170327255103]
iter: 13311, loss: [0.6718460233155583, 0.53999936152

In [8]:
model.save_pretrained("../Model/PRM")

#### 4. Fine-tune

In [4]:
epochs = 6
alpha = 0.25
accumulation_steps = 64
verbose = 1024
lr = 6e-5
clip = 6e-3
topics_num = 4
weights=[0.4,0.2,0.2,0.2]

In [5]:
from transformers import LlamaForSequenceClassification,BitsAndBytesConfig,AutoConfig
import torch
from peft import (
    get_peft_model,
    PeftType,
    LoraConfig)

In [6]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)
model = LlamaForSequenceClassification.from_pretrained('../Model/PRM',\
                                                       num_labels=1,\
                                                       device_map="auto",
                                                       torch_dtype="auto",
                                                       quantization_config=quantization_config,
                                                       attn_implementation="flash_attention_2"
                                                       )
model.gradient_checkpointing_enable()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [7]:
peft_config = LoraConfig(r=8, # low rank 
                         lora_alpha = 16, # see below 
                         lora_dropout = 0.1, 
                         bias="none",#'none', 'all' or 'lora_only' 
                         target_modules = [ "q_proj", 
                                            "k_proj", 
                                            "v_proj", 
                                            "o_proj", 
                                            "gate_proj", 
                                            "up_proj", 
                                            "down_proj" 
                                        ] 
                        )
base_model = get_peft_model(model.model, peft_config)
base_model.gradient_checkpointing_enable()
# model.config.pad_token_id = tokenizer.pad_token_id
base_model.print_trainable_parameters()
model.score = model.score.float()
model.score.weight.requires_grad_(True);

trainable params: 18,739,200 || all params: 6,509,674,496 || trainable%: 0.287866928085493


In [8]:
import random
from torch.autograd import Function
import torch.nn as nn
import torch

def sample_from_iterables(weights,*iterables):
    while True:
        iterable = random.choices(iterables, weights=weights, k=1)[0]
        try:
            yield next(iterable)
        except StopIteration:
            break

class GradientReversalFunction(Function):
    @staticmethod
    def forward(ctx, x):
        return x
    @staticmethod
    def backward(ctx, grad_output):
        return (grad_output * -alpha,)

class GradientReversalLayer(nn.Module):
    def forward(self, x):
        return GradientReversalFunction.apply(x)

class revLinear(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(revLinear, self).__init__()
        self.layers = nn.Linear(input_dim, output_dim)
        self.grad_rev = GradientReversalLayer()
    
    def forward(self, x):
        return self.layers(self.grad_rev(x))

In [9]:
topic_model = revLinear(model.score.weight.shape[1],topics_num).to('cuda').float()

In [10]:
trainable_params = [param for param in base_model.parameters() if param.requires_grad] + \
                    list(model.score.parameters()) + \
                    list(topic_model.parameters())
optimizer = torch.optim.AdamW(trainable_params,lr = lr)

In [11]:
loss_fn = torch.nn.BCEWithLogitsLoss()
loss_topic = torch.nn.CrossEntropyLoss()

train_loss = [0,0,0,0]
topic_loss = [0,0,0,0]
count_loss = [0,0,0,0]

for epoch in range(epochs):
    iterables = [from_shepherd(dataset),\
                 from_sol(sol_texts,sol_starts_ends),
                 from_genData(gen_texts,gen_targets,gen_starts_ends),\
                 from_mmos(dataset2)]
    for i,(text,index,target,source) in enumerate(sample_from_iterables(weights, *iterables)):
        target_topics = source * torch.ones(target.shape[1],dtype=torch.long,device='cuda') # l
        hidden_states = base_model(text)[0][:,index].float() # b,l,d
        logits = model.score(hidden_states)[:,:,0] # b,l
        logits_topics = topic_model(hidden_states)[0] # l,C
        loss1 = loss_fn(logits,target)
        loss2 = loss_topic(logits_topics,target_topics)
        loss = loss1 + loss2
        loss.backward()

        train_loss[source] += loss1.item()
        topic_loss[source] += loss2.item()
        count_loss[source] += 1

        if (i + 1) % accumulation_steps == 0:
            clip_grad_value_(trainable_params,clip)
            optimizer.step()
            optimizer.zero_grad()

        if (i + 1) % verbose == 0:
            print(f"iter: {i}, \n train loss: {[l/c if c!=0 else 'N/A' for l,c in zip(train_loss,count_loss)]}\n topic loss: {[l/c if c!=0 else 'N/A' for l,c in zip(topic_loss,count_loss)]}")
            train_loss = [0,0,0,0]
            topic_loss = [0,0,0,0]
            count_loss = [0,0,0,0]
            
        torch.cuda.empty_cache()

iter: 1023, 
 train loss: [0.65817857821389, 0.4526890612127793, 0.6755780264099628, 0.8026162947927202]
 topic loss: [1.1416445566285955, 1.4191593548760342, 1.4433136268814593, 1.3883098409289405]
iter: 2047, 
 train loss: [0.6752799573816766, 0.3233648593723774, 0.828306495863507, 0.5195725660575063]
 topic loss: [0.7325012481410826, 1.4819346618652345, 1.508425621359561, 1.5019488604445206]
iter: 3071, 
 train loss: [0.6463596872468986, 0.27168412066461667, 0.7976466647038857, 0.20245894052439334]
 topic loss: [0.7092168517041915, 1.4934290129424148, 1.5519538608690102, 1.7865507271901475]
iter: 4095, 
 train loss: [0.6185741501853353, 0.2168847916088163, 0.6422313251264432, 0.022611184270704906]
 topic loss: [1.1617714015771623, 1.4944955602355068, 1.5228635397256982, 2.7048021383833087]
iter: 5119, 
 train loss: [0.6008901754535999, 0.1519686096976177, 0.47133520008188434, 0.005004356264208372]
 topic loss: [1.7472596999049792, 1.5136407796275673, 1.4632545999721087, 5.4540791034

KeyboardInterrupt: 