In [None]:
import random
import torch
import wandb
import argparse
from xturing.engines.quant_utils.lrec import (
	get_c4,
	prepare_models,
	train_model
)
wandb.login()

In [None]:
lrec_config = {
    "base_model": "decapoda-research/llama-7b-hf",
    "intq_checkpoint": "llama7b-2bit-128g.pt",
    "wbits": 2,
    "groupsize": 128,
    "lora_alpha": 128,
    "lora_r": 32,
    "lora_dropout": 0.05,
    "lora_target_modules": [
        "q_proj",
        "v_proj",
        "k_proj",
        "o_proj",
        "up_proj",
        "down_proj",
        "gate_proj",
    ],
    "n_samples": 2048,
    "lr": 3e-4,
    "batch_size": 1,
    "num_epochs": 3,
    "kl_weight": 1.0,
    "ce_weight": 200.0,
    "trainable_kl_weight": False,
    "trainable_ce_weight": False,
    "weight_decay": 1e-5,
    "save_freq": 1,
    "intra_save_freq": 200,
    "seed": 0,
    "seqlen": 2048,
    "cache": False,
    "train_cache_dir": "./train_cache/",
    "val_cache_dir": "./val_cache/",
    "ckpt_dir": "./ckpts/",
    "save_dir": "./save/",
}

In [None]:
args = argparse.Namespace(**lrec_config)
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

model, fp_model = prepare_models(args)
trainloader, valenc = get_c4(
    args.base_model, args.seqlen, args.n_samples, args.batch_size, args.seed
)
train_model(args, model, fp_model, trainloader, valenc)