In [1]:
import torch
import pandas as pd
from functools import partial
from pathlib import Path
from datasets import Dataset, interleave_datasets
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
from transformers import Trainer, TrainingArguments, AutoTokenizer, DataCollatorForLanguageModeling
from transformers.models.llama.modeling_llama import LlamaForCausalLM

from evals import DatasetArgs, evaluate, evaluate_FinMoE, load_eval_dataset
from FinMoE import FinMoE, FinMoEConfig
from utils import get_dataset_args

assert torch.cuda.is_available(), "CUDA not available"
device = torch.device("cuda")

seed = 42
torch.manual_seed(seed)

MAX_LENGTH = 512
dataset_cache_path = Path(r"D:/datasets/general-3-tasks")

model_id = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

## Dataset setup

In [2]:
hub_basepath = Path(r"C:\Users\samba\.cache\huggingface\hub")
args = get_dataset_args(tokenizer, hub_basepath)

## Load and Preprocess Dataset

In [3]:
def train_preprocess_causal(args: DatasetArgs, dataset_id: str, example: dict):
    # Create prompt and target text
    prompt_args = [example[key] for key in args.prompt_args[dataset_id]]
    prompt = args.prompt_templates[dataset_id].format(*prompt_args)

    target = args.id2labels[dataset_id][example["label"]]
    full_text = prompt + target

    # tokenize text
    tokenized = tokenizer(full_text,
                          truncation=True,
                          padding="max_length",
                          max_length=MAX_LENGTH)
    
    # add padding tokens
    prompt_tokenized = tokenizer(prompt,
                              truncation=True,
                              max_length=MAX_LENGTH)
    prompt_length = len(prompt_tokenized["input_ids"])

    labels = tokenized["input_ids"].copy()
    labels[:prompt_length] = [-100] * prompt_length
    tokenized["labels"] = labels
    return tokenized

def train_preprocess_tokenclass(args: DatasetArgs, dataset_id: str, example: dict):
    # Create prompt and target text
    prompt_args = [example[key] for key in args.prompt_args[dataset_id]]
    prompt = args.prompt_templates[dataset_id].format(*prompt_args)

    # tokenize text
    tokenized = tokenizer(prompt,
                          truncation=True,
                          padding="max_length",
                          max_length=MAX_LENGTH)

    # tokenize and index label
    target = args.id2labels[dataset_id][example["label"]]
    token_target = tokenizer.encode(target, add_special_tokens=False)[0]
    label = args.token_list.index(token_target)
    tokenized["labels"] = label

    return tokenized

In [4]:
nrows_list = [3876, 3876, 3876]
dataset_list = []
for i, (dataset_id, dataset_path) in enumerate(args.paths.items()):
    train_subset = pd.read_csv(dataset_path / "train.csv",
                                delimiter=args.del_mapping[dataset_id],
                                names=args.names_mapping[dataset_id],
                                nrows=nrows_list[i])

    preprocess_func = partial(train_preprocess_tokenclass, args, dataset_id)
    dataset_list.append(Dataset
                        .from_pandas(train_subset)
                        .map(preprocess_func, batched=False)
                        .remove_columns(args.columns[dataset_id]))

n_datasets = len(dataset_list)
train_dataset = interleave_datasets(dataset_list, 
                                    probabilities=[1/n_datasets]*n_datasets,
                                    seed=seed)

Map:   0%|          | 0/3876 [00:00<?, ? examples/s]

Map:   0%|          | 0/3876 [00:00<?, ? examples/s]

Map:   0%|          | 0/3876 [00:00<?, ? examples/s]

In [5]:
save_path = dataset_cache_path.with_stem("finmoe-tokenclass_balanced-len512")
train_dataset.save_to_disk(save_path)

Saving the dataset (0/1 shards):   0%|          | 0/11489 [00:00<?, ? examples/s]

In [3]:
load_path = dataset_cache_path.with_stem("finmoe-tokenclass_balanced-len512")
train_dataset = Dataset.load_from_disk(load_path)

## Training

### Peft Model trainer

In [None]:
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=16,
    lora_dropout=0.01,
    target_modules=["q_proj", "v_proj"]
)

model_id = "meta-llama/Llama-3.2-1B"
base_model = LlamaForCausalLM.from_pretrained(model_id, torch_dtype="float16")

peft_model = get_peft_model(base_model, peft_config)
peft_model.print_trainable_parameters()

In [None]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=False
)

out_dir = Path(rf"D:/models/general-Llama-3_2-1B")
training_args = TrainingArguments(
    output_dir=str(out_dir),
    num_train_epochs=6,
    per_device_train_batch_size=2,
    learning_rate=1e-3,
    weight_decay=0.01,
    warmup_steps=1000,
    save_strategy="epoch",
    do_train=True,
)

trainer = Trainer(
    model=peft_model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=data_collator,
)

### FinMoE trainer

In [4]:
ckpt_base = Path(r"D:/models")
expert_order = ["FPB", "Headline", "Topics"]
expert_ckpt_names = {"FPB": "checkpoint-best",
                     "Headline": "checkpoint-best",
                     "Topics": "checkpoint-best"}

## note: str() wraps path as Path objects are not json serializable
expert_ckpts = [str(ckpt_base / f"expert-Llama-3_2-1B-{expert_name}" / expert_ckpt_names[expert_name])
                for expert_name in expert_order]

finMoE_config = FinMoEConfig(
    # loss_type="ForCausalLM",
    loss_type="ForTokenClassification",
    num_labels=len(args.token_list),

    expert_ckpts=expert_ckpts,
    gating_guassian=0.2,
    token_list=args.token_list,
)

finMoE_model = FinMoE(finMoE_config).to(device)
print("Memory allocated:", torch.cuda.memory_allocated())
print("Trainable params:")
for name, params in finMoE_model.named_parameters():
    if params.requires_grad:
        print(name, params.shape)


## --- use data_collator when training with "ForCausalLM" loss
# data_collator = DataCollatorForLanguageModeling(
#     tokenizer=tokenizer, mlm=False
# )

out_dir = Path(rf"D:/models/FinMoE-top3")
training_args = TrainingArguments(
    output_dir=str(out_dir),
    num_train_epochs=5,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    learning_rate=1e-3,
    weight_decay=0.01,
    warmup_steps=512,
    logging_steps=32,
    save_strategy="epoch",
    do_train=True,
)

trainer = Trainer(
    model=finMoE_model,
    args=training_args,
    train_dataset=train_dataset,
    # data_collator=data_collator,
)

Memory allocated: 4953514496
Trainable params:
gate.w_gate.weight torch.Size([3, 2048])


### Train

In [5]:
trainer.train()

  0%|          | 0/3590 [00:00<?, ?it/s]

{'loss': 9.6858, 'grad_norm': 195.14013671875, 'learning_rate': 6.25e-05, 'epoch': 0.04}
{'loss': 1.8656, 'grad_norm': 41.51145553588867, 'learning_rate': 0.000125, 'epoch': 0.09}
{'loss': 1.3241, 'grad_norm': 3.312025547027588, 'learning_rate': 0.0001875, 'epoch': 0.13}
{'loss': 1.4981, 'grad_norm': 4.386941432952881, 'learning_rate': 0.00025, 'epoch': 0.18}
{'loss': 1.4151, 'grad_norm': 27.556354522705078, 'learning_rate': 0.0003125, 'epoch': 0.22}
{'loss': 1.1529, 'grad_norm': 22.39943504333496, 'learning_rate': 0.000375, 'epoch': 0.27}
{'loss': 1.6776, 'grad_norm': 0.8499090671539307, 'learning_rate': 0.0004375, 'epoch': 0.31}
{'loss': 1.282, 'grad_norm': 19.473766326904297, 'learning_rate': 0.0005, 'epoch': 0.36}
{'loss': 1.4018, 'grad_norm': 36.04220962524414, 'learning_rate': 0.0005625000000000001, 'epoch': 0.4}
{'loss': 1.6611, 'grad_norm': 70.80497741699219, 'learning_rate': 0.000625, 'epoch': 0.45}
{'loss': 1.4659, 'grad_norm': 21.681621551513672, 'learning_rate': 0.0006875, 

TrainOutput(global_step=3590, training_loss=1.6109846061982815, metrics={'train_runtime': 26347.1514, 'train_samples_per_second': 2.18, 'train_steps_per_second': 0.136, 'total_flos': 1.721692468543488e+17, 'train_loss': 1.6109846061982815, 'epoch': 4.99956480111411})

# Eval

In [10]:
dataset_id = "Headline"
testset = load_eval_dataset(tokenizer, dataset_id, args)

Loading Headline dataset from AdaptLLM/finance-tasks


### Load an expert checkpoint

In [4]:
ckpt_path = Path(r"D:/models/general-Llama-3_2-1B") / "checkpoint-best"

model_id = "meta-llama/Llama-3.2-1B"
base_model = LlamaForCausalLM.from_pretrained(model_id, torch_dtype="float16") #.to(device)
expert_model = PeftModel.from_pretrained(base_model, ckpt_path, torch_dtype="float16").eval().to(device)

In [None]:
results = evaluate(expert_model, tokenizer,
                   testset,
                   guidance=True,
                   token_opts=args.token_opts[dataset_id])
print(results)

### Load a FinMoE checkpoint

In [4]:
ckpt_path = Path(r"D:/models/FinMoE-top3") / "checkpoint-3590"
finMoE_model = FinMoE.load_pretrained(ckpt_path).to(device).eval()

In [11]:
results = evaluate_FinMoE(finMoE_model, tokenizer,
                          testset,
                          args.token_opts[dataset_id])
print(results)

Routing experts: 100%|██████████| 20547/20547 [16:39<00:00, 20.56it/s]
Evaluating | 88.35: 100%|██████████| 20547/20547 [20:15<00:00, 16.91it/s]

{'accuracy': 0.8834866403854578, 'n_correct': 18153, 'n_total': 20547}



