## Train

### Setting

In [None]:
import os
import sys
import numpy as np

from scripts.utils.tools import generate_maplist_from_dataset
from scripts.data.data_prep_ticks import load_maps_ticks
from scripts.models.classifier import train_classifier, NOTE_CLASSIFIER_MODEL_PATH as NOTE_CLS_MODEL_PATH
from scripts.data.gan_data_prep import build_gan_real_data
from scripts.models.gan import train_gan
from datetime import datetime

NB_DIR = os.path.abspath(os.getcwd())
ROOT_DIR = NB_DIR
ARCAEA_DIR = NB_DIR

if ROOT_DIR not in sys.path:
    sys.path.append(ROOT_DIR)

ACCEL_DEVICE = "auto"

### Generate Maplist

In [None]:
maplist_path = generate_maplist_from_dataset(root_dir=ARCAEA_DIR, dataset_subdir="dataset", maplist_name="maplist_aff.txt")
maplist_path

### Tensorboard

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs/classifier --port 6006
%tensorboard --logdir logs/gan --port 6007

print("classifier: tensorboard --logdir logs/classifier --port 6006")
print("gan       : tensorboard --logdir logs/gan --port 6007")

### Train Classifier

In [None]:
load_maps_ticks()

cls_log_root = os.path.join(ROOT_DIR, "logs", "classifier")
os.makedirs(cls_log_root, exist_ok=True)

cls_run_id = datetime.now().strftime("%Y%m%d-%H%M%S")
cls_log_dir = os.path.join(cls_log_root, cls_run_id)

train_classifier(
    model_path=NOTE_CLS_MODEL_PATH,
    epochs=1,
    batch_size=64,
    device=ACCEL_DEVICE,
    log_dir=cls_log_dir,
)
print("Classifier TensorBoard logs:", cls_log_dir)

### Train GAN

In [None]:
gan_npz_path = build_gan_real_data()

with np.load(gan_npz_path) as ds:
    real_data = ds["real_data"]
    real_mask = ds.get("real_mask", None)

gan_log_root = os.path.join(ROOT_DIR, "logs", "gan")
os.makedirs(gan_log_root, exist_ok=True)

gan_run_id = datetime.now().strftime("%Y%m%d-%H%M%S")
gan_log_dir = os.path.join(gan_log_root, gan_run_id)

gen, dis = train_gan(
    real_data,
    epochs=1000,
    batch_size=32,
    latent_dim=64,
    device="auto",
    real_mask=real_mask,
    log_dir=gan_log_dir,
)
print("GAN TensorBoard logs:", gan_log_dir)