In [2]:
import datetime
import logging
import os
import numpy as np
import torch
from importlib import reload

from models import dapm
from scripts.data_loader import *
from scripts.train_dapm import train
from utils.metrics import normalize_mat
from params import Param
from utils.logging_utils import *

import warnings
warnings.filterwarnings('ignore')


In [4]:

def dapm_main(param, **kwargs):

    """ define model name """ 
    model_name = param.generate_model_name()

    """ define prevoius model name for fine tuning """ 
    model_files = os.listdir(kwargs['model_dir'])
    previous_model_name = ''
    for f in os.listdir(kwargs['model_dir']):
        if f'{param.last_year}___#{param.last_month}#' in f and '.pkl' in f:
            previous_model_name = f
    if len(previous_model_name) == 0:
        return

    print(model_name)
    print(previous_model_name)
    dapm = torch.load(os.path.join(kwargs['model_dir'], previous_model_name)).to(kwargs['device'])

    kwargs['model_name'] = model_name
    kwargs['model_file'] = os.path.join(kwargs['model_dir'], model_name + '.pkl')
    kwargs['log_file'] = os.path.join(kwargs['log_dir'], model_name + '.log')
    kwargs['run_file'] = os.path.join(kwargs['run_dir'], model_name + '_run_{}'.format(datetime.datetime.now().strftime('%d%H%m')))
    
    data_dir = f'/home/yijun/notebooks/training_data/'
    data_obj = load_data(data_dir, param)
    train_loc, val_loc, test_loc = load_locations(kwargs['train_val_test'], param)
    
    data_obj.train_loc = train_loc
    data_obj.train_y = data_obj.gen_train_val_test_label(data_obj.label_mat, data_obj.train_loc)
    data_obj.val_loc = val_loc
    data_obj.val_y = data_obj.gen_train_val_test_label(data_obj.label_mat, data_obj.val_loc)
    data_obj.test_loc = test_loc
    data_obj.test_y = data_obj.gen_train_val_test_label(data_obj.label_mat, data_obj.test_loc)
    
    """ logging starts """
    start_logging(kwargs['log_file'], model_name)
    data_logging(data_obj)

    """ load DeepAP model """
    train(dapm, data_obj, param, **kwargs)
    
    """ logging ends """
    logging.info('{} ENDS.'.format(model_name))
    logging.shutdown()
    reload(logging)


In [5]:
"""
    define directory
"""

base_dir = f'data/los_angeles_500m_fine_tune_1234_tp1'
train_val_test_file = f'/home/yijun/notebooks/training_data/train_val_test_los_angeles_500m_fine_tune_1234.json'
device = torch.device("cuda:3" if torch.cuda.is_available() else 'cpu')  # the gpu device

""" load train, val, test locations """
f = open(train_val_test_file, 'r')
train_val_test = json.loads(f.read())

kwargs = {
    'model_dir': os.path.join(base_dir, 'models/'),
    'log_dir': os.path.join(base_dir, 'logs/'),
    'run_dir': os.path.join(base_dir, 'runs/'),
    'train_val_test': train_val_test,
    'device': device
}

for m in range(1, 13):
    param = Param([m], 2018, alpha=1, beta=0.1, gamma=5, sp_neighbor=1, model_type=['sp', 'ae', 'sc'])
    dapm_main(param, **kwargs)           


dapm___sp_ae_sc___los_angeles_500m_2018___#02#___6_00001_1___1_01_5_001___16_13
dapm___sp_ae_sc___los_angeles_500m_2018___#01#___6_00001_1___1_01_5___16_13.pkl
dapm___sp_ae_sc___los_angeles_500m_2018___#03#___6_00001_1___1_01_5_001___16_13
dapm___sp_ae_sc___los_angeles_500m_2018___#02#___6_00001_1___1_01_5_001___16_13.pkl
dapm___sp_ae_sc___los_angeles_500m_2018___#04#___6_00001_1___1_01_5_001___16_13
dapm___sp_ae_sc___los_angeles_500m_2018___#03#___6_00001_1___1_01_5_001___16_13.pkl


KeyboardInterrupt: 