### WIP

In [None]:
import os
import random
from dataclasses import dataclass
from typing import Tuple, Optional
import numpy as np
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model

In [None]:
# plot style from IKT215, dpi set to 200
def set_mpl_params(dpi: int = 200, figsize: Tuple[int, int] = (9, 6), grid: bool = True, font_size: int = 12, font_family: str = 'serif') -> None:
    mpl.rcParams['figure.dpi'] = dpi
    mpl.rcParams['figure.figsize'] = figsize
    mpl.rcParams['axes.grid'] = grid
    mpl.rcParams.update({'font.size': font_size})
    mpl.rcParams['font.family'] = font_family

In [None]:
def seed_everything(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

In [None]:
@dataclass
class Config:
    model_name: str = "meta-llama/Llama-3.2-1B"
    dataset_name: str = "yahma/alpaca-cleaned"
    train_size: int = 10000
    val_size: int = 2000
    test_size: int = 2000
    batch_size: int = 4
    grad_accum_steps: int = 4
    epochs: int = 3
    lr: float = 0.0002
    weight_decay: float = 0.01
    warmup_ratio: float = 0.1
    seed: int = 42
    lora_rank: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    output_dir: Path = Path("outputs/checkpoints")
    max_length: int = 512
    use_fp16: bool = True
    report_to: Optional[str] = "none"

In [None]:
cfg = Config()
seed_everything(cfg.seed)
set_mpl_params()
cfg.output_dir.mkdir(parents = True, exist_ok = True)
cfg