In [None]:
from pathlib import Path

from vectormesh.data.cache import VectorCache

assets = Path("../artefacts")
trainpath = next(
    assets.glob("aktes*/")
)  # change this if you dont want the first folder
tag = trainpath.name
cache = VectorCache.load(path=trainpath)
train = cache.select(range(1024))
valid = cache.select(range(1024, 2048))
column_name = "legal_dutch"

In [None]:
from vectormesh.data import OneHot

onehot = OneHot(num_classes=32, label_col="labels", target_col="onehot")
train_oh = train.map(onehot)
valid_oh = valid.map(onehot)

In [None]:
from torch.utils.data import DataLoader

from vectormesh.components import FixedPadding
from vectormesh.data import Collate

collate_fn = Collate(
    embedding_col="legal_dutch",
    target_col="onehot",
    padder=FixedPadding(max_chunks=30),
)

trainloader = DataLoader(train_oh, batch_size=32, shuffle=True, collate_fn=collate_fn)
validloader = DataLoader(valid_oh, batch_size=32, shuffle=False, collate_fn=collate_fn)

For Mixture-of-Experts (MoE), see the paper in the `references/` folder

In [None]:
from vectormesh.components import MeanAggregator, NeuralNet, Serial
from vectormesh.components.gating import MoE

moe = MoE(
    experts=[
        NeuralNet(hidden_size=768, out_size=32),
        NeuralNet(hidden_size=768, out_size=32),
        NeuralNet(hidden_size=768, out_size=32),
        NeuralNet(hidden_size=768, out_size=32),
    ],
    hidden_size=768,
    out_size=32,
    top_k=2,
)

In [None]:
pipeline = Serial([MeanAggregator(), moe])

Obviously, this can be improved (see notebook 2_design for inspiration).

In [None]:
import torch
import torch.optim as optim
from mltrainer import ReportTypes, Trainer, TrainerSettings

from vectormesh.components.metrics import F1Score
from vectormesh.data.vectorizers import detect_device

device = detect_device()
print(f"Using device: {device}")

log_dir = Path("demo").absolute()

settings = TrainerSettings(
    epochs=10,
    metrics=[F1Score()],
    logdir=log_dir,
    train_steps=len(trainloader),
    valid_steps=len(trainloader),
    reporttypes=[ReportTypes.TENSORBOARD, ReportTypes.TOML],
)

loss_fn = torch.nn.BCEWithLogitsLoss()

trainer = Trainer(
    model=pipeline,
    settings=settings,
    loss_fn=loss_fn,
    optimizer=optim.Adam,
    traindataloader=trainloader,
    validdataloader=trainloader,
    scheduler=optim.lr_scheduler.ReduceLROnPlateau,
    device=device,
)

In [None]:
trainer.loop()

In [None]:
import shutil

shutil.rmtree("tmp/", ignore_errors=True)
shutil.rmtree("logs/", ignore_errors=True)
shutil.rmtree("demo/", ignore_errors=True)