# FKAT Dev Testing Notebook

**CLI equivalent:** `hatch run dev:train hf -- model_name=gpt2`

In [None]:
%load_ext autoreload
%autoreload 2

import logging
logging.basicConfig(level=logging.INFO, format="%(message)s")

In [None]:
from hydra import compose, initialize_config_dir
import os

config_dir = os.path.abspath("./conf")
overrides = ["model_name=gpt2", "trainer.max_steps=10"]

try:
    import deepspeed
    overrides.append("strategy=deepspeed")
except ImportError:
    pass

with initialize_config_dir(config_dir=config_dir, version_base="1.3"):
    cfg = compose(config_name="hf", overrides=overrides)

In [None]:
from fkat import initialize

s = initialize(cfg)

In [None]:
s.trainer.fit(s.model, datamodule=s.data)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

run_id = s.trainer.logger.run_id
metric_file = Path(f"mlruns/{s.trainer.logger.experiment_id}/{run_id}/metrics/val_loss")
data = np.loadtxt(metric_file, usecols=(2, 1), ndmin=2)  # step, value
plt.plot(data[:, 0], data[:, 1], marker="o")
plt.xlabel("Step")
plt.ylabel("Validation Loss")
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Open MLflow experiment in browser
import subprocess
import atexit
import time
from pathlib import Path

mlflow_uri = f"file://{Path.cwd()}/mlruns"
print(f"Starting MLflow UI at {mlflow_uri}")
mlflow_proc = subprocess.Popen(["mlflow", "ui", "--backend-store-uri", mlflow_uri])
atexit.register(lambda: mlflow_proc.kill())
time.sleep(5)
subprocess.run(["open", "http://localhost:5000"])

In [None]:
# Open VizTracer trace in browser
import subprocess
import gzip
import shutil
from pathlib import Path

traces = list(Path(f"mlruns/{s.trainer.logger.experiment_id}/{run_id}/artifacts/viztracer").rglob("*.html.gz"))
if traces:
    trace = traces[0]
    if trace.suffix == ".gz":
        extracted = trace.with_suffix("")
        with gzip.open(trace, "rb") as f_in:
            with open(extracted, "wb") as f_out:
                shutil.copyfileobj(f_in, f_out)
        trace = extracted
    subprocess.run(["open", str(trace)])

In [None]:
# Quick 2-step test: cfg.trainer.max_steps = 2; s = initialize(cfg); s.trainer.fit(s.model, datamodule=s.data)
# Inspect batch: batch = next(iter(s.data.train_dataloader())); print(batch.keys(), batch["input_ids"].shape)
# Debug: import ipdb; ipdb.set_trace()