In [1]:
import argparse
import collections
import torch
import numpy as np
import model.loss as module_loss
import model.metric as module_metric
from parse_config import ConfigParser
from utils.util import create_model, create_dataloader, create_trainer


In [2]:
# fix random seeds for reproducibility
SEED = 125
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)

In [3]:
args = argparse.ArgumentParser(description='Emotion Reasoning in Daily Life')
args.add_argument('-c', '--config', default='train.json', type=str,
                  help='config file path (default: None)')
args.add_argument('-r', '--resume', default=None, type=str,
                  help='path to latest checkpoint (default: None)')
args.add_argument('-d', '--device', default=None, type=str,
                  help='indices of GPUs to enable (default: all)')

# custom cli options to modify configuration from default values given in json file.
CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
options = [
    CustomArgs(['--lr', '--learning_rate'], type=float, target='optimizer;args;lr'),
    CustomArgs(['--bs', '--batch_size'], type=int, target='data_loader;args;batch_size')
]

In [4]:
config = ConfigParser.from_args(args, options)

jupyter


In [None]:
logger = config.get_logger('train')

# setup data_loader instances
data_loader = create_dataloader(config)
valid_data_loader = data_loader.split_validation()

Initializing VisualFeatureExtractor...
Initializing AudioFeatureExtracor...
Initializing TextFeatureExtractor...


  0%|▏                                                                                             | 12/6829 [00:00<01:05, 103.52it/s]

Initializing PersonalityFeatureExtractor...
vectorize features.....


 37%|██████████████████████████████████▏                                                         | 2536/6829 [00:20<00:29, 147.55it/s]

In [6]:

# build model architecture, then print to console
model = create_model(config)

model [AMER] was created


In [7]:
model

AMER()