In [None]:
import sys
sys.path.append('..')

In [None]:
import os
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = '1'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
import time
from gerumo.data.dataset import describe_dataset
from gerumo.data.generators import build_generator
from gerumo.utils.engine import (
    setup_cfg, setup_environment, setup_experiment, setup_model,
    build_dataset, build_callbacks, build_metrics, build_optimizer, build_loss, load_model
)
from gerumo.models.base import build_model
from gerumo.visualization.metrics import training_history


class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
args = dotdict()

## Select configuration

In [None]:
args['config_file'] = '/home/ir-riqu1/rds/rds-iris-ip007/ir-riqu1/outputs/best models/20-04-01_onioncnn_ftt_lst_23_epochs_dallcut1000_lr_0.02_f2_sgd_clr_with_momentum_classification_20220420_183240/config.yml'
args['opts'] = [
    'SOLVER.BATCH_SIZE', 10,
    'DATASETS.TRAIN.EVENTS','/home/ir-riqu1/rds/rds-iris-ip007/ir-riqu1/Prod5-parquets/output_tiny_splitted/train_events.parquet',
    'DATASETS.TRAIN.TELESCOPES','/home/ir-riqu1/rds/rds-iris-ip007/ir-riqu1/Prod5-parquets/output_tiny_splitted/train_telescopes.parquet',
    'DATASETS.VALIDATION.EVENTS', '/home/ir-riqu1/rds/rds-iris-ip007/ir-riqu1/Prod5-parquets/output_tiny_splitted/train_events.parquet',
    'DATASETS.VALIDATION.TELESCOPES', '/home/ir-riqu1/rds/rds-iris-ip007/ir-riqu1/Prod5-parquets/output_tiny_splitted/train_telescopes.parquet'
]
args['weights'] = '/home/ir-riqu1/rds/rds-iris-ip007/ir-riqu1/outputs/best models/20-04-01_onioncnn_ftt_lst_23_epochs_dallcut1000_lr_0.02_f2_sgd_clr_with_momentum_classification_20220420_183240/weights/model.11-0.18.h5'

## Setup

In [None]:
cfg = setup_cfg(args)
cfg.defrost()
cfg.OUTPUT_DIR = os.path.join(cfg.OUTPUT_DIR,'pseudo_label')
cfg.MODEL.WEIGHTS=args['weights']
cfg.freeze()
output_dir = setup_experiment(cfg, training=True)
logger = setup_environment(cfg)

## Load Datasets

In [None]:
train_dataset = build_dataset(cfg, 'train')
train_dataset = train_dataset[train_dataset.type=='LST']
train_dataset = train_dataset.head(100)
describe_dataset(train_dataset, logger,
                save_to=output_dir / "train_description.txt")
validation_dataset = build_dataset(cfg, 'validation')
validation_dataset=validation_dataset[validation_dataset.type=='LST']
validation_dataset = validation_dataset.head(100)
describe_dataset(validation_dataset, logger,
                save_to=output_dir / "validation_description.txt")

## Build generators

In [None]:
train_generator = build_generator(cfg, train_dataset)
validation_generator = build_generator(cfg, validation_dataset)

## Build model

In [None]:
input_shape = train_generator.get_input_shape()
model = build_model(cfg, input_shape)
model = load_model(model, train_generator, args['weights'])

## Build training tools

In [None]:
callbacks = build_callbacks(cfg)
metrics = build_metrics(cfg)
optimizer = build_optimizer(cfg, len(train_generator))
loss = build_loss(cfg)

## Compile model

In [None]:
model = setup_model(
    model, train_generator, optimizer, loss, metrics
)

## Start training

In [None]:
train_generator.fit_mode()
validation_generator.fit_mode()
model.fit_mode()

In [None]:
from gerumo.models.pseudo import pseudo_train_model
model=pseudo_train_model(model)

In [None]:
start_time = time.time()
history = model.fit(
    train_generator,
    epochs= 5,
    verbose=2,
    validation_data=validation_generator,
    validation_steps=len(validation_generator),
    callbacks=callbacks,
    use_multiprocessing=False,
    workers=1,
    max_queue_size=20,
)
training_time = (time.time() - start_time)/60.0

In [None]:
logger.info(f'Training time: {training_time:.3f} [min]')

In [None]:
training_history(history, training_time, cfg.EXPERIMENT_NAME)

In [None]:
training_history(history, training_time, cfg.EXPERIMENT_NAME, ylog=True)