# Scripts

> Scripts for training and processing the data

In [None]:
# | default_exp scripts

In [None]:
# | export

from pathlib import Path

import hydra
import lightning as pl
import torch
from hydra import compose, initialize
from lightning.pytorch.utilities import rank_zero_only
from omegaconf import DictConfig

from neuralresonator.data import generate_random_dataset
from neuralresonator.training import MultiShapeMultiMaterialLitModule

# Train

In [None]:
# | export

@rank_zero_only
def log_hyperparameters(
    object_dict: dict,
) -> None:
    """
    Log hyperparameters to all loggers.
    """

    hparams = {}
    cfg = object_dict["cfg"]
    model = object_dict["model"]
    trainer = object_dict["trainer"]

    hparams["model"] = cfg["model"]

    # save number of model parameters
    hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
    hparams["model/params/trainable"] = sum(
        p.numel() for p in model.parameters() if p.requires_grad
    )
    hparams["model/params/non_trainable"] = sum(
        p.numel() for p in model.parameters() if not p.requires_grad
    )

    hparams["datamodule"] = cfg["datamodule"]
    hparams["trainer"] = cfg["trainer"]

    # send hparams to all loggers
    trainer.logger.log_hyperparams(hparams)


@hydra.main(version_base=None, config_path="../configs", config_name="train")
def train(
    cfg: DictConfig,
):
    if cfg.get("seed"):
        pl.seed_everything(cfg.seed, workers=True)

    datamodule = hydra.utils.instantiate(cfg.datamodule)
    model = hydra.utils.instantiate(cfg.model)
    logger = hydra.utils.instantiate(cfg.logger)
    trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)

    object_dict = {
        "cfg": cfg,
        "datamodule": datamodule,
        "model": model,
        "logger": logger,
        "trainer": trainer,
    }

    if logger:
        log_hyperparameters(object_dict)

    trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
    trainer.test(model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))

    # print total number of batches, batch size and number of samples
    print(f"Total number of batches: {len(datamodule.test_dataloader())}")
    print(f"Batch size: {cfg.datamodule.batch_size}")
    print(
        "Number of samples:"
        f" {len(datamodule.test_dataloader()) * cfg.datamodule.batch_size}"
    )


In [None]:
# | eval: false

# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = 3

with initialize(version_base=None, config_path="../configs"):
    cfg = compose(
        config_name="train.yaml",
        return_hydra_config=True,
        overrides=[
            "trainer.max_epochs=1",
            "hydra.runtime.output_dir=outputs",
            "paths.output_dir=${hydra.runtime.output_dir}",
            "paths.work_dir=${hydra.runtime.cwd}",
            "seed=42",
            "logger=null",
            "++datamodule.train_index_map_path=data/index_map.csv",
            "++datamodule.val_index_map_path=data/index_map.csv",
            "++datamodule.test_index_map_path=data/index_map.csv",
        ],
    )
    train(cfg)

# Generate dataset

In [None]:
# | export

@hydra.main(version_base=None, config_path="../configs", config_name="generate_dataset")
def gen_dataset(
    cfg: DictConfig,
):
    if cfg.get("seed"):
        pl.seed_everything(cfg.seed, workers=True)

    paths: dict = {key: Path(val) for key, val in cfg.paths.items()}
    
    if not paths['train_data_dir'].exists():
        paths['train_data_dir'].mkdir(parents=True)
    if not paths['val_data_dir'].exists():
        paths['val_data_dir'].mkdir(parents=True)
    if not paths['test_data_dir'].exists():
        paths['test_data_dir'].mkdir(parents=True)

    print("Generating training data...")
    generate_random_dataset(
        n_modes=cfg.n_modes,
        n_vertices=cfg.n_vertices,
        n_refinements=cfg.n_refinements,
        data_dir=paths['train_data_dir'],
        n_shapes=cfg.n_train_shapes,
        n_materials=cfg.n_train_materials,
        materials=cfg.train_materials,
    )
    
    print("Generating validation data...")
    generate_random_dataset(
        n_modes=cfg.n_modes,
        n_vertices=cfg.n_vertices,
        n_refinements=cfg.n_refinements,
        data_dir=paths['val_data_dir'],
        n_shapes=cfg.n_val_shapes,
        n_materials=cfg.n_val_materials,
        materials=cfg.val_materials,
    )

    print("Generating test data...")
    generate_random_dataset(
        n_modes=cfg.n_modes,
        n_vertices=cfg.n_vertices,
        n_refinements=cfg.n_refinements,
        data_dir=paths['test_data_dir'],
        n_shapes=cfg.n_test_shapes,
        n_materials=cfg.n_test_materials,
        materials=cfg.test_materials,
    )


## Export

In [None]:
@hydra.main(version_base=None, config_path="../configs", config_name="export")
def export(
    cfg: DictConfig,
):

    # Load checkpoint
    model = MultiShapeMultiMaterialLitModule.load_from_checkpoint(cfg.ckpt_path)
    model.eval()

    # export encoder to torchscript
    script = torch.jit.script(model.encoder)
    torch.jit.save(script, cfg.encoder_path)

    # export coefficient model to torchscript
    script = torch.jit.script(model.model)
    torch.jit.save(script, cfg.coefficient_model_path)
