In [None]:
import sys
sys.path.append('..')

In [None]:
import os
import time
import pickle

from gerumo.data.dataset import describe_dataset
from gerumo.data.generators import build_generator
from gerumo.utils.engine import (
    setup_cfg, setup_environment, setup_experiment, build_dataset, build_metrics
)
from gerumo.models.base import build_model


class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
args = dotdict()

## Select configuration

In [None]:
#args['config_file'] = '/home/asuka/projects/gerumo2/config/rf_regression.yml'
args['config_file'] = '/home/asuka/projects/gerumo2/config/rf_classification2.yml'
args['opts'] = []

## Setup

In [None]:
cfg = setup_cfg(args)
output_dir = setup_experiment(cfg)
logger = setup_environment(cfg)

## Load Datasets

In [None]:
train_dataset = build_dataset(cfg, 'train')
describe_dataset(train_dataset, logger,
                save_to=output_dir / "train_description.txt")
validation_dataset = build_dataset(cfg, 'validation')
describe_dataset(validation_dataset, logger,
                save_to=output_dir / "validation_description.txt")

## Build generators

In [None]:
train_generator = build_generator(cfg, train_dataset)
validation_generator = build_generator(cfg, validation_dataset)

In [None]:
print('Training batches:', len(train_generator))
print('Validation batches:', len(validation_generator))

## Build model

In [None]:
input_shape = train_generator.get_input_shape()
model = build_model(cfg, input_shape)

## Training model

In [None]:
inputs_batch, outputs_batch = train_generator.get_batch()
start_time = time.time()
model.fit(inputs_batch, outputs_batch)
training_time = (time.time() - start_time)/60.0

In [None]:
logger.info(f"Training time: {training_time:.3f} [min]")

## Validation

In [None]:
from gerumo.utils.structures import Event
metrics = build_metrics(cfg, standalone=True)

In [None]:
start_time = time.time()
inputs_batch, outputs_batch = validation_generator.get_batch()
predicted_batch = model(inputs_batch)
validation_time  = (time.time() - start_time)/60.0

In [None]:
logger.info(f"Validation time: {validation_time:.3f} [min]")
for name, metric in metrics.items():
    logger.info(f"{name}:\t{metric(predicted_batch, Event.list_to_tensor(outputs_batch)):.2f}")

## Save Model

In [None]:
with open(os.path.join(output_dir, 'model.pkl'), 'wb') as f:
    pickle.dump(model, f)
logger.info("Saved 'model.pkl'")