In [None]:
import sys
sys.path.append('..')

In [None]:
import time
from gerumo.data.dataset import describe_dataset
from gerumo.data.generators import build_generator
from gerumo.utils.engine import (
    setup_cfg, setup_environment, setup_experiment, setup_model,
    build_dataset, build_callbacks, build_metrics, build_optimizer, build_loss
)
from gerumo.models.base import build_model
from gerumo.visualization.metrics import training_history


class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
args = dotdict()

## Select configuration

In [None]:
args['config_file'] = '/home/asuka/projects/gerumo2/config/cnn_regression.yml'
args['opts'] = []

## Setup

In [None]:
cfg = setup_cfg(args)
output_dir = setup_experiment(cfg)
logger = setup_environment(cfg)

## Load Datasets

In [None]:
train_dataset = build_dataset(cfg, 'train')
describe_dataset(train_dataset, logger,
                save_to=output_dir / "train_description.txt")
validation_dataset = build_dataset(cfg, 'validation')
describe_dataset(validation_dataset, logger,
                save_to=output_dir / "validation_description.txt")

## Build generators

In [None]:
train_generator = build_generator(cfg, train_dataset)
validation_generator = build_generator(cfg, validation_dataset)

## Build model

In [None]:
input_shape = train_generator.get_input_shape()
model = build_model(cfg, input_shape)

## Build training tools

In [None]:
callbacks = build_callbacks(cfg)
metrics = build_metrics(cfg)
optimizer = build_optimizer(cfg)
loss = build_loss(cfg)

## Compile model

In [None]:
model = setup_model(
    model, train_generator, optimizer, loss, metrics
)

## Start training

In [None]:
train_generator.fit_mode()
validation_generator.fit_mode()
model.fit_mode()

In [None]:
start_time = time.time()
history = model.fit(
    train_generator,
    epochs=cfg.SOLVER.EPOCHS,
    verbose=1,
    validation_data=validation_generator,
    validation_steps=len(validation_generator),
    callbacks=callbacks,
    use_multiprocessing=False,
    workers=1,
    max_queue_size=20,
)
training_time = (time.time() - start_time)/60.0

In [None]:
logger.info(f"Training time: {training_time:.3f} [min]")

In [None]:
training_history(history, training_time, cfg.EXPERIMENT_NAME)

In [None]:
training_history(history, training_time, cfg.EXPERIMENT_NAME, ylog=True)