In [1]:
import json
import random
from functions import *
import torch
import numpy as np
from torch.nn.utils import clip_grad_value_
import math
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    BitsAndBytesConfig, 
    # AutoConfig,
)
from peft import get_peft_model

""" Data """
files = ['../Data/OlympiadBench_Dataset/data/outputs.json','../Data/AMC/outputs.json','../Data/MATH/outputs.json']
texts = []
for file in files:
    with open(file, 'r', encoding='utf-8') as f:
        # Load the list from the JSON file
        texts.extend(json.load(f))
from transformers import AutoTokenizer
MODEL_PATH = "deepseek-ai/deepseek-math-7b-rl"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
texts = tokenizer.batch_encode_plus(texts,return_attention_mask=False,add_special_tokens=True,\
                                    truncation=True, max_length=4096)['input_ids']

""" Model """
epochs = 1
accumulation_steps = 64
verbose = 1024
lr = 6e-5
clip = 6e-3
alpha = 0.05
quantization_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    device_map="auto",
    torch_dtype="auto",
    trust_remote_code=True,
    quantization_config=quantization_config,
    attn_implementation="flash_attention_2"
)
config_dict = random_peft_config()
peft_config = config_map[config_dict['config_type']](**config_dict['config_kwargs'])
model = get_peft_model(model, peft_config)
model.config.pad_token_id = tokenizer.pad_token_id
# model.print_trainable_parameters()
a,b = model.get_nb_trainable_parameters()
trainable = a/b
trainable_params = [param for param in model.parameters() if param.requires_grad]
optimizer = torch.optim.AdamW(trainable_params,lr = lr)
# optimizer = torch.optim.SGD(trainable_params,lr=lr)

loss_fn = torch.nn.CrossEntropyLoss()
performance = []
logit_offset = config_dict['config_kwargs'].get('num_virtual_tokens',0) if config_dict['config_type'] == 'PromptEncoderConfig' else 0
for epoch in range(epochs):
    random.shuffle(texts)
    model.train()
    train_loss = 0
    train_last = 0
    skip = 0
    tot_skip = 0
    # for llm, batchsize = 1 still gives 100 GPU util
    for i,input_ids in enumerate(texts):
        # train
        input_ids = sample_consecutive_chunk(input_ids,1200)
        input_ids = torch.tensor(input_ids).to('cuda')[None]
        outs = model(input_ids).logits
        if torch.any(torch.isnan(outs)):
            skip += 1
            continue
        loss = loss_fn(outs[0,logit_offset:-1],input_ids[0,1:])
        if math.isinf(loss.item()) or math.isnan(loss.item()):
            skip += 1
            continue

        loss.backward()
        train_loss += loss.item()
        # print(i,train_loss)
        if (i + 1) % accumulation_steps == 0:
            clip_grad_value_(trainable_params,clip)
            optimizer.step()
            optimizer.zero_grad()

        # eval    
        if (i + 1) % verbose == 0:
            temp = (train_loss-train_last)/(verbose-skip)
            print(f"epoch {epoch} iter {i}: train loss {temp}")
            performance.append((epoch,i,temp))
            train_last = train_loss
            tot_skip += skip
            skip = 0

# Save model/config/performance
peft_model_id = create_next_model_folder()
model.save_pretrained(peft_model_id)
with open(peft_model_id+'/config_dict.json', 'w') as f:
    json.dump(config_dict, f)
with open(peft_model_id+'/performance.json', 'w') as f:
    json.dump({'loss_hist':performance,'trainable_pct':trainable}, f)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

epoch 0 iter 1023: train loss 1.2250459004571894
epoch 0 iter 2047: train loss 1.097773865389172
epoch 0 iter 3071: train loss 0.970116772732581
epoch 0 iter 4095: train loss 0.9048753228416899
epoch 0 iter 5119: train loss 0.8852845648143557
epoch 0 iter 6143: train loss 0.8505273452901747
epoch 0 iter 7167: train loss 0.8277944275614573
epoch 0 iter 8191: train loss 0.8189215795209748
epoch 0 iter 9215: train loss 0.8168886959756492
epoch 0 iter 10239: train loss 0.8177528268715832
epoch 0 iter 11263: train loss 0.7985045031964546
epoch 0 iter 12287: train loss 0.8106034819138586
epoch 0 iter 13311: train loss 0.7889349279430462
epoch 0 iter 14335: train loss 0.7809478161652805
epoch 0 iter 15359: train loss 0.780915419170924
epoch 0 iter 16383: train loss 0.8007599455813761
epoch 0 iter 17407: train loss 0.7841280383363483
epoch 0 iter 18431: train loss 0.7871896027645562
epoch 0 iter 19455: train loss 0.7729098980780691
epoch 0 iter 20479: train loss 0.7760596466396237
epoch 0 iter