## Introduction
In this notebook, we explain training steps. 

Here are general steps you need to follow:
1. Python. Make sure you have a python environment (conda or other) with libraries listed in requirements.txt
2. Data. The data preparation steps are provided in the DATA directory. Follow the prepare_data.ipynb notebook before continuing here. 

3. Commands. The rest of this notebook generates shell commands for starting the training process. We are running our code on a server with slurm for job management. You can change a little bit this notebook to adapt to your own environment. Providing GPU would make the running of these training jobs faster.

## Set the params
Define the desired settings and args...

In [55]:
# ##########
# General Args
# ##########
exp_name = 'train_all'
data_augment_method = 'full_aug' # ['no_aug', 'full_aug', 'sel_aug'] 
sampling_method = 'equal' #'concat' or 'equal'

exp_name_com = f'{exp_name}_{sampling_method}' 

cur_data_path = 'DATA/' #'../DATA_FINAL_ELEC/'
exp_output_dir = 'EVAL/'

# $ which python
conda_python_dir = '~/anaconda3/envs/forec/bin/python'

print(f'\n- experiment name: {exp_name_com}')
print(f'\t - data augmentation methods: {data_augment_method}')
print(f'\t - data sampling method: {sampling_method}')
print(f'\t - reading data: {cur_data_path}')
print(f'\t - writing evaluations: {exp_output_dir}')


# ##########
# Market selection
# all_markets = [ 'jp', 'in', 'de', 'fr', 'ca', 'mx', 'uk', 'us'] 
# ##########
# target_markets = ['de', 'fr']
# source_markets = ['uk'] #, 'us' for no_aug use 'xx'

target_markets = ['jp', 'in', 'de', 'fr', 'ca', 'mx', 'uk']
source_markets = ['jp', 'in', 'de', 'fr', 'ca', 'mx', 'uk', 'us'] #, 'us' for no_aug use 'xx'


print(f'-Working on below market pairs (target, augmenting with market):')
all_poss_pairs = []
for target_market in target_markets:
    for source_market in source_markets:
        if target_market==source_market:
            continue
        if data_augment_method=='no_aug':
            source_market='xx'
        all_poss_pairs.append((target_market, source_market))
        print(f'\t--> ({target_market}, {source_market})')
all_poss_pairs = list(set(all_poss_pairs))

# ##########
# Training Data fractions to use from each target and source 
# 1 means full data, and 2 means 1/2 of the training data to sample
# ##########
tgt_fractions = [1]
src_fractions = [1] #2, 3, 4, 5, 10

fractions = []
print('\n-Sampling below training data fractions:')
for tgt_fraction in tgt_fractions:
    for src_fraction in src_fractions:
        fractions.append((src_fraction, tgt_fraction))
        print(f'\t--> ({src_fraction}, {tgt_fraction})')
        



- experiment name: train_all_equal
	 - data augmentation methods: full_aug
	 - data sampling method: equal
	 - reading data: DATA/
	 - writing evaluations: EVAL/
-Working on below market pairs (target, augmenting with market):
	--> (jp, in)
	--> (jp, de)
	--> (jp, fr)
	--> (jp, ca)
	--> (jp, mx)
	--> (jp, uk)
	--> (jp, us)
	--> (in, jp)
	--> (in, de)
	--> (in, fr)
	--> (in, ca)
	--> (in, mx)
	--> (in, uk)
	--> (in, us)
	--> (de, jp)
	--> (de, in)
	--> (de, fr)
	--> (de, ca)
	--> (de, mx)
	--> (de, uk)
	--> (de, us)
	--> (fr, jp)
	--> (fr, in)
	--> (fr, de)
	--> (fr, ca)
	--> (fr, mx)
	--> (fr, uk)
	--> (fr, us)
	--> (ca, jp)
	--> (ca, in)
	--> (ca, de)
	--> (ca, fr)
	--> (ca, mx)
	--> (ca, uk)
	--> (ca, us)
	--> (mx, jp)
	--> (mx, in)
	--> (mx, de)
	--> (mx, fr)
	--> (mx, ca)
	--> (mx, uk)
	--> (mx, us)
	--> (uk, jp)
	--> (uk, in)
	--> (uk, de)
	--> (uk, fr)
	--> (uk, ca)
	--> (uk, mx)
	--> (uk, us)

-Sampling below training data fractions:
	--> (1, 1)


## Generate training shell commands
Before being able to run MAML and FOREC models, we need to train corrosponding NMF++ models (and for NMF, we need GMF++ and MLP++ models trained first). For this purpose,
- 'train_base.py' trains and evaluates GMF++, MLP++, and NMF++ models (if no_aug, falls into GMF, MLP, and NMF)
- 'train_maml.py' for cross-market scenarios, trains a NMF++ model with two markets using MAML 
- 'train_forec.py' for cross-market scenarios, trains a FOREC model with two marketheads for the two markets  

In [56]:
command_dict = {}
for tgt_market, src_market in all_poss_pairs:
    for tgt_frac, src_fra in fractions:
        cur_cmd_dict = {}
        cur_exp_name = f'{exp_name_com}_{tgt_market}_{src_market}_{data_augment_method}_ftgt{tgt_frac}_fsrc{src_fra}'
        
        # 'train_base.py'
        py_file_main = 'train_base.py'
        cur_exp_out_file = f'{exp_output_dir}base-{cur_exp_name}.json'
        pre_set_args = {
            "--data_dir %s"%(cur_data_path),
            "--tgt_market %s"%(tgt_market),
            "--aug_src_market %s"%(src_market),
            "--exp_name %s"%(cur_exp_name),
            "--exp_output %s"%(cur_exp_out_file),

            "--num_epoch %i"%(25),  
            "--batch_size %i"%(1024),
            "--cuda "

            "--data_augment_method %s"%(data_augment_method),
            "--data_sampling_method %s"%(sampling_method),

            "--tgt_fraction %i"%(tgt_frac),  
            "--src_fraction %i"%(src_fra),  
        }
        myargumets = ' '.join(pre_set_args)
        command_pieces = [conda_python_dir, py_file_main, myargumets]
        final_cmd = ' '.join(command_pieces)
        cur_cmd_dict['base'] = final_cmd
        
        if data_augment_method=='no_aug':
            command_dict[cur_exp_name] = cur_cmd_dict
            continue
        
        # 'train_maml.py'
        py_file_main = 'train_maml.py'
        fast_lr_tune = '0.1'
        shots = 20 #512, 200, 100, 50, 20
        cur_exp_out_file = f'{exp_output_dir}maml-{cur_exp_name}_shots{shots}.json'
        pre_set_args = {
            "--data_dir %s"%(cur_data_path),
            "--tgt_market %s"%(tgt_market),
            "--aug_src_market %s"%(src_market),
            "--exp_name %s"%(cur_exp_name),
            "--exp_output %s"%(cur_exp_out_file),

            "--num_epoch %i"%(25),  
            "--batch_size %i"%(shots),
            "--cuda "

            "--data_sampling_method %s"%(sampling_method),
            "--fast_lr %s"%(fast_lr_tune),
            "--tgt_fraction %i"%(tgt_frac),  
            "--src_fraction %i"%(src_fra),  
        }
        myargumets = ' '.join(pre_set_args)
        command_pieces = [conda_python_dir, py_file_main, myargumets]
        final_cmd = ' '.join(command_pieces)
        cur_cmd_dict['maml'] = final_cmd
        
        # 'train_forec.py'
        py_file_main = 'train_forec.py'
        cur_exp_out_file = f'{exp_output_dir}forec-{cur_exp_name}_shots{shots}.json'
        pre_set_args = {
            "--data_dir %s"%(cur_data_path),
            "--tgt_market %s"%(tgt_market),
            "--aug_src_market %s"%(src_market),
            "--exp_name %s"%(cur_exp_name),
            "--exp_output %s"%(cur_exp_out_file),

            "--num_epoch %i"%(25),  
            "--batch_size %i"%(shots),
            "--cuda "

            "--data_sampling_method %s"%(sampling_method),
            "--fast_lr %s"%(fast_lr_tune),
            "--tgt_fraction %i"%(tgt_frac),  
            "--src_fraction %i"%(src_fra),  
        }
        myargumets = ' '.join(pre_set_args)
        command_pieces = [conda_python_dir, py_file_main, myargumets]
        final_cmd = ' '.join(command_pieces)
        cur_cmd_dict['forec'] = final_cmd
        
        command_dict[cur_exp_name] = cur_cmd_dict
        
print(f'Generated {len(command_dict)} experiments:')
for k, v in command_dict.items():
    print(f'{k}')
    print(f'\t{list(v.keys())}')

Generated 49 experiments:
train_all_equal_in_jp_full_aug_ftgt1_fsrc1
	['base', 'maml', 'forec']
train_all_equal_fr_de_full_aug_ftgt1_fsrc1
	['base', 'maml', 'forec']
train_all_equal_uk_fr_full_aug_ftgt1_fsrc1
	['base', 'maml', 'forec']
train_all_equal_de_ca_full_aug_ftgt1_fsrc1
	['base', 'maml', 'forec']
train_all_equal_mx_us_full_aug_ftgt1_fsrc1
	['base', 'maml', 'forec']
train_all_equal_ca_us_full_aug_ftgt1_fsrc1
	['base', 'maml', 'forec']
train_all_equal_ca_uk_full_aug_ftgt1_fsrc1
	['base', 'maml', 'forec']
train_all_equal_in_us_full_aug_ftgt1_fsrc1
	['base', 'maml', 'forec']
train_all_equal_jp_de_full_aug_ftgt1_fsrc1
	['base', 'maml', 'forec']
train_all_equal_uk_jp_full_aug_ftgt1_fsrc1
	['base', 'maml', 'forec']
train_all_equal_mx_fr_full_aug_ftgt1_fsrc1
	['base', 'maml', 'forec']
train_all_equal_mx_de_full_aug_ftgt1_fsrc1
	['base', 'maml', 'forec']
train_all_equal_uk_ca_full_aug_ftgt1_fsrc1
	['base', 'maml', 'forec']
train_all_equal_jp_fr_full_aug_ftgt1_fsrc1
	['base', 'maml', 'fo

## Write commands into .sh bash scripts

In [57]:
import os
sh_files = 'scripts/'
sh_logs = os.path.join(sh_files,'logs')
checkpoints_dir = 'checkpoints/'
if not os.path.exists(sh_logs):
    os.mkdir(sh_files)
    os.mkdir(sh_logs)
if not os.path.exists(exp_output_dir):
    os.mkdir(exp_output_dir)
if not os.path.exists(checkpoints_dir):
    os.mkdir(checkpoints_dir)

gpu_num = 1
gpu_type = '1080ti-long' #'titanx-short', 'm40-short'

master_file = open(os.path.join(sh_files,'master.sh'), 'w')

for cur_exp_name, v in command_dict.items():

    bash_file_name = f'{cur_exp_name}-run.sh'
    bash_file = open(os.path.join(sh_files,bash_file_name), 'w')
    cur_log_file = os.path.join('logs', f'{cur_exp_name}.out')
    
    bash_file.write('#!/bin/sh'+'\n')
    bash_file.write('#SBATCH --partition=%s'%(gpu_type) + '\n')
    bash_file.write('#SBATCH --ntasks=%s'%(1) + '\n')
    bash_file.write('#SBATCH --gres=gpu:%s'%(str(gpu_num)) + '\n')
    bash_file.write('#SBATCH --mem=%iG'%(50*gpu_num) + '\n')
    bash_file.write('#SBATCH --output=%s'%(cur_log_file) + '\n')

    bash_file.write('\ncd ..\n')
    if 'base' in v:
        bash_file.write(v['base'] + '\n\n')
    if 'maml' in v:
        bash_file.write(v['maml'] + '\n\n')
    if 'forec' in v:
        bash_file.write(v['forec'] + '\n\n')

    bash_file.close()
    master_file.write(f'sbatch {bash_file_name}\n')
    print(cur_exp_name + ' bash is created!')

master_file.close()

train_all_equal_in_jp_full_aug_ftgt1_fsrc1 bash is created!
train_all_equal_fr_de_full_aug_ftgt1_fsrc1 bash is created!
train_all_equal_uk_fr_full_aug_ftgt1_fsrc1 bash is created!
train_all_equal_de_ca_full_aug_ftgt1_fsrc1 bash is created!
train_all_equal_mx_us_full_aug_ftgt1_fsrc1 bash is created!
train_all_equal_ca_us_full_aug_ftgt1_fsrc1 bash is created!
train_all_equal_ca_uk_full_aug_ftgt1_fsrc1 bash is created!
train_all_equal_in_us_full_aug_ftgt1_fsrc1 bash is created!
train_all_equal_jp_de_full_aug_ftgt1_fsrc1 bash is created!
train_all_equal_uk_jp_full_aug_ftgt1_fsrc1 bash is created!
train_all_equal_mx_fr_full_aug_ftgt1_fsrc1 bash is created!
train_all_equal_mx_de_full_aug_ftgt1_fsrc1 bash is created!
train_all_equal_uk_ca_full_aug_ftgt1_fsrc1 bash is created!
train_all_equal_jp_fr_full_aug_ftgt1_fsrc1 bash is created!
train_all_equal_fr_jp_full_aug_ftgt1_fsrc1 bash is created!
train_all_equal_mx_ca_full_aug_ftgt1_fsrc1 bash is created!
train_all_equal_uk_mx_full_aug_ftgt1_fsr

In [58]:
!cd scripts/ && chmod +x *.sh && ./master.sh

Submitted batch job 8377817
Submitted batch job 8377818
Submitted batch job 8377819
Submitted batch job 8377820
Submitted batch job 8377821
Submitted batch job 8377822
Submitted batch job 8377823
Submitted batch job 8377824
Submitted batch job 8377825
Submitted batch job 8377826
Submitted batch job 8377827
Submitted batch job 8377828
Submitted batch job 8377829
Submitted batch job 8377830
Submitted batch job 8377831
Submitted batch job 8377832
Submitted batch job 8377833
Submitted batch job 8377834
Submitted batch job 8377835
Submitted batch job 8377836
Submitted batch job 8377837
Submitted batch job 8377838
Submitted batch job 8377839
Submitted batch job 8377840
Submitted batch job 8377841
Submitted batch job 8377842
Submitted batch job 8377843
Submitted batch job 8377844
Submitted batch job 8377845
Submitted batch job 8377846
Submitted batch job 8377847
Submitted batch job 8377848
Submitted batch job 8377849
Submitted batch job 8377850
Submitted batch job 8377851
Submitted batch job 

In [63]:
!squeue -u bonab

             JOBID PARTITION     NAME     USER ST       TIME  NODES NODELIST(REASON)
           8377841 1080ti-lo train_al    bonab  R    8:57:51      1 node120
           8377856 1080ti-lo train_al    bonab  R    8:57:51      1 node117
           8376528  m40-long     bash    bonab  R 2-00:42:20      1 node013


In [54]:
!scancel 8377816