In [None]:
import torch
from dataset import RolloutTensorDataset
from dataloader import prepare_dataloaders
from models import MLP, ChemicalTimeStepper, RolloutModel
from train import do_training
from config import default_config

In [None]:
# options
dtype = torch.float32
loss_config, stepper_config, net_config, data_config, loader_config, train_config = default_config()

In [None]:
# input and output files
data_path = "../../../data/concentrations/"              # path to the concentration data for training and testing
output_file = './models/NN_test.pt'

In [None]:
# initialize objects
in_features = len(data_config['species'])  # concentrations
out_features = stepper_config['stoichiometry_matrix'].shape[0] if stepper_config['learn_rates'] else in_features

dataset = RolloutTensorDataset(data_path, dtype=dtype, **data_config)
train_loader, test_loader = prepare_dataloaders(dataset, **loader_config)
net = MLP(in_features, out_features, dtype=dtype, **net_config)
stepper = ChemicalTimeStepper(net, dtype=dtype, **stepper_config)
model = RolloutModel(stepper, data_config['trajectory_length'])

train_loss, mean_test_loss = do_training(model, train_loader, test_loader, **train_config, **loss_config)

if output_file is not None:
    torch.save(net.state_dict(), output_file)
