# FinMoE Training Notebook
This notebook contains the code used to train FinMoE.</br></br>
Developed by Samuel Barnett</br>
Supervised by Dr JingJing Deng</br>
Submitted as part of the degree of MEng Computer Science to the Board of Examiners in the Department of Computer Sciences, Durham University

In [None]:
import torch
from pathlib import Path
from datasets import Dataset
from huggingface_hub import constants as hub_c
from transformers import Trainer, TrainingArguments, AutoTokenizer, DataCollatorForLanguageModeling

from evals import evaluate_FinMoE, load_eval_dataset
from FinMoE import FinMoE, FinMoEConfig
from utils import DatasetArgs, get_dataset_args, load_train_datasets

assert torch.cuda.is_available(), "CUDA not available"
device = torch.device("cuda")

seed = 42
torch.manual_seed(seed)

ckpt_base = Path(r"D:/models")  # Path to directory containing expert ckpts
dataset_cache_path = Path(r"D:/datasets/FinMoE-train")

model_id = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

args = get_dataset_args(tokenizer, Path(hub_c.HF_HUB_CACHE))

model_name = "FinMoE-final-top3-fast" # new FinMoE model will be saved to this checkpoint in `ckpt_base`

## Load and Preprocess Dataset
Two preprocessing functions are provided for the two loss functions supported: ForCausalLM, and ForTokenClassification

In [2]:
def train_preprocess_causal(args: DatasetArgs, dataset_id: str, example: dict):
    """Provided for training with ForCausalLM loss"""
    # Create prompt and target text
    prompt_args = [example[key] for key in args.prompt_args[dataset_id]]
    prompt = args.prompt_templates[dataset_id].format(*prompt_args)

    target = args.id2labels[dataset_id][example["label"]]

    return tokenizer(prompt + target, truncation=False)


def train_preprocess_tokenclass(args: DatasetArgs, dataset_id: str, example: dict):
    """Preprocessing function: applies prompt template to training sample and tokenizes prompt & label"""
    # Create prompt and target text
    prompt_args = [example[key] for key in args.prompt_args[dataset_id]]
    prompt = args.prompt_templates[dataset_id].format(*prompt_args)

    # tokenize text
    tokenized = tokenizer(prompt, truncation=False)

    # tokenize and index label
    target = args.id2labels[dataset_id][example["label"]]
    token_target = tokenizer.encode(target, add_special_tokens=False)[0]
    label = args.token_list.index(token_target)
    tokenized["labels"] = label

    return tokenized

In [None]:
nrows_list = [3876, 3876, 3876] # ensure each dataset is equally represented in training dataset
train_dataset = load_train_datasets(args, train_preprocess_tokenclass, nrows_list)

### Save Dataset

In [None]:
save_path = dataset_cache_path.with_stem("finmoe-tokenclass_balanced-len256")
train_dataset.save_to_disk(save_path)

### Load Dataset

In [5]:
load_path = dataset_cache_path.with_stem("finmoe-tokenclass_balanced-len256")
train_dataset = Dataset.load_from_disk(load_path)

# Training

In [None]:
## MODIFY THESE CKPT NAMES
expert_ckpt_names = {"FPB": "checkpoint-best",
                     "Headline": "checkpoint-best",
                     "Topics": "checkpoint-best"}

## note: str() wraps path as Path objects are not json serializable
# modify expert path names if necessary
expert_ckpts = [str(ckpt_base / f"expert-Llama-3_2-1B-{expert_name}" / expert_ckpt_names[expert_name])
                for expert_name in args.expert_order]

loss_type = "ForTokenClassification" # "ForCausalLM"

finMoE_config = FinMoEConfig(
    loss_type=loss_type,
    num_labels=len(args.token_list),

    topk="top3",
    g_net_id="FastGating",
    expert_ckpts=expert_ckpts,
    token_list=args.token_list,
)

data_collator = None
if loss_type == "ForCausalLM":
    ## use data_collator when training with "ForCausalLM" loss
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm=False
    )

In [None]:
finMoE_model = FinMoE(finMoE_config).to(device)
print("Memory allocated:", torch.cuda.memory_allocated())
print("Trainable params:")
for name, params in finMoE_model.named_parameters():
    if params.requires_grad:
        print(name, params.shape)

In [None]:
training_args = TrainingArguments(
    output_dir=str(ckpt_base / model_name),
    num_train_epochs=5,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    learning_rate=1e-3,
    weight_decay=0.01,
    warmup_steps=128,
    logging_steps=32,
    save_strategy="epoch",
)

trainer = Trainer(
    model=finMoE_model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=data_collator,
)

trainer.train()

# Eval

When evaluating a model on a task, modify the `dataset_id` variable with the dataset name you want to evaluate over

In [None]:
dataset_id = "Topics"
testset = load_eval_dataset(tokenizer, dataset_id, args)

### Evaluate FinMoE checkpoint

In [None]:
ckpt_path = ckpt_base / model_name / "checkpoint-3590"
finMoE_model = FinMoE.load_pretrained(ckpt_path).to(device).eval()

In [None]:
results = evaluate_FinMoE(finMoE_model, tokenizer,
                          testset,
                          args.token_opts[dataset_id])
print(results)