In [12]:
# ==================================================================================================
# import packages
from pathlib import Path
import torch as th
from torch.utils.data import DataLoader
import time
from tqdm import tqdm # Instantly make your loops show a smart progress meter
import os
import math
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary
from PIL import Image
import uuid
import pandas as pd
from datetime import datetime
import warnings
warnings.filterwarnings("ignore")

# Import other scripts here 
from drought_impact_dataset import *
from drought_impact_sampler import *
# Import a config file for training 
from utils.jura_config_rec_2020_test import * #train_dataset_params, train_sampler_params, test_dataset_params, test_sampler_params, sim_params, model_params
from torch.utils.data import DataLoader
from utils.utils_pixel import *
from model import *

# MLFlow
import mlflow
import mlflow.sklearn
import logging
logging.basicConfig(level=logging.WARN)
logger = logging.getLogger(__name__)

#th.manual_seed(0)
#random.seed(0) 

In [13]:
start_time = time.time()
# ==================================================================================================
# Get Dataset details from configuration file
s2_path =  Path(train_dataset_params['s2_path'])
era_path =  Path(train_dataset_params['era_path'])
dem_path =   Path(train_dataset_params['dem_path'])
env_path =  Path(train_dataset_params['env_path'])
ts_delta =  train_dataset_params['ts_delta']
ts_len =  train_dataset_params['ts_len']
len_preds =  train_dataset_params['len_preds']
focus_time_train = train_dataset_params['focus_time']
focus_list_train = train_dataset_params['focus_list']
ratio = train_dataset_params['ratio']
data_file_extension = train_dataset_params['data_file_extension'] 
feature_set = train_dataset_params['feature_set']
remove_bands = train_dataset_params['remove_bands']
agg_funct_dict = train_dataset_params['agg_funct_dict']
multiple_labels = train_dataset_params['multiple_labels']
"""
focus_time_val = val_dataset_params['focus_time']
focus_list_val = val_dataset_params['focus_list']
"""
focus_list_test = test_dataset_params['focus_list']
focus_time_test = test_dataset_params['focus_time']
#focus_time_test2 = test2_dataset_params['focus_time']

end_time = time.time()
print(f'Took {end_time-start_time} seconds')


Took 0.00034165382385253906 seconds


In [14]:
start_time = time.time()
# ==================================================================================================
# Get Sampler details from configuration file
batch_size_tr =  train_sampler_params['batch_size'] 
n_batch = sim_params['n_batches']
sampler_size = train_sampler_params["size"]
sampler_replacement =  train_sampler_params['replacement'] 
mask_dir =  train_sampler_params['mask_dir'] 
sampler_set_seed =  train_sampler_params['set_seed'] 
sampler_roi_train = train_sampler_params['roi']
static_dir = train_sampler_params['static_dir'] 
mask_threshold = 0.5 #train_sampler_params["mask_threshold"]

sampler_roi_test = test_sampler_params['roi']
sampler_length_te = test_sampler_params['length'] 
#sampler_roi_val = val_sampler_params['roi']
#sampler_length_val = val_sampler_params['length'] 

end_time = time.time()
print(f'Took {end_time-start_time} seconds')

Took 0.0001773834228515625 seconds


### Create samples
Split temporally, all samples within same region\
If split spatially: define in dataset creation or in sampler ROI

In [15]:
# REMOVED ENV PATH and DEM FOR NOW 

start_time = time.time()
train_ds = DroughtImpactDataset(s2_path=s2_path, era_path=era_path, env_path=env_path, dem_path=dem_path, focus_list=focus_list_train,
                          focus_time=[focus_time_train], ts_delta=ts_delta, ts_len=ts_len, ratio=ratio, len_preds=len_preds, feature_set=feature_set, agg_funct_dict=agg_funct_dict, multiple_labels=multiple_labels,
                               correct_ndvi=None)
end_time = time.time()
print(f'Took {end_time-start_time} seconds')

"""
start_time = time.time()
val_ds = DroughtImpactDataset(s2_path=s2_path, era_path=era_path, env_path=env_path, dem_path=dem_path, focus_list=focus_list_val,
                          focus_time=[focus_time_val], ts_delta=ts_delta, ts_len=ts_len, ratio=ratio, len_preds=len_preds, feature_set=feature_set, agg_funct_dict=agg_funct_dict, multiple_labels=multiple_labels)
end_time = time.time()
print(f'Took {end_time-start_time} seconds')
"""
start_time = time.time()
test_ds = DroughtImpactDataset(s2_path=s2_path, era_path=era_path, env_path=env_path, dem_path=dem_path, focus_list=focus_list_test,
                          focus_time=[focus_time_test], ts_delta=ts_delta, ts_len=ts_len, ratio=ratio, len_preds=len_preds, feature_set=feature_set, agg_funct_dict=agg_funct_dict, multiple_labels=multiple_labels)
end_time = time.time()
print(f'Took {end_time-start_time} seconds')


Took 0.12427043914794922 seconds
Took 0.04250454902648926 seconds


In [16]:
start_time = time.time()
train_sampler = DroughtImpactSampler(train_ds, size=sampler_size, length=batch_size_tr*n_batch, replacement=sampler_replacement, 
                                     mask_dir=mask_dir, roi=sampler_roi_train, set_seed=sampler_set_seed,
                                         mask_threshold=mask_threshold, static_dir=static_dir) 
end_time = time.time()
print(f'Took {end_time-start_time} seconds')
"""
start_time = time.time()
val_sampler = DroughtImpactSampler(val_ds, size=sampler_size, length=sampler_length_val, replacement=sampler_replacement, 
                                     mask_dir=mask_dir, set_seed=sampler_set_seed, static_dir=static_dir)
end_time = time.time()
print(f'Took {end_time-start_time} seconds')
"""
start_time = time.time()
test_sampler = DroughtImpactSampler(test_ds, size=sampler_size, length=sampler_length_te, replacement=sampler_replacement, 
                                     mask_dir=mask_dir, set_seed=sampler_set_seed, static_dir=static_dir)
end_time = time.time()
print(f'Took {end_time-start_time} seconds')


Took 0.6423163414001465 seconds
Took 0.3484668731689453 seconds


In [17]:
len(test_ds.all_loc_dates)

9

In [18]:
test_ds.all_loc_dates

[['46.907_7.137_47.407_7.637',
  ['2020-05-04',
   '2020-05-14',
   '2020-05-24',
   '2020-06-03',
   '2020-06-13',
   '2020-06-23',
   '2020-07-03',
   '2020-07-13',
   '2020-07-23'],
  [],
  ['2020-08-02', '2020-08-12', '2020-08-22']],
 ['46.907_7.137_47.407_7.637',
  ['2020-05-09',
   '2020-05-19',
   '2020-05-29',
   '2020-06-08',
   '2020-06-18',
   '2020-06-28',
   '2020-07-08',
   '2020-07-18',
   '2020-07-28'],
  [],
  ['2020-08-07', '2020-08-17', '2020-08-27']],
 ['46.907_7.137_47.407_7.637',
  ['2020-05-14',
   '2020-05-24',
   '2020-06-03',
   '2020-06-13',
   '2020-06-23',
   '2020-07-03',
   '2020-07-13',
   '2020-07-23',
   '2020-08-02'],
  [],
  ['2020-08-12', '2020-08-22', '2020-09-01']],
 ['46.907_7.137_47.407_7.637',
  ['2020-05-19',
   '2020-05-29',
   '2020-06-08',
   '2020-06-18',
   '2020-06-28',
   '2020-07-08',
   '2020-07-18',
   '2020-07-28',
   '2020-08-07'],
  [],
  ['2020-08-17', '2020-08-27', '2020-09-06']],
 ['46.907_7.137_47.407_7.637',
  ['2020-05-24',


In [11]:
with open('all_loc_dates_val.pkl', 'wb') as f:
    pickle.dump(val_ds.all_loc_dates, f)

In [19]:
with open('all_loc_dates_test.pkl', 'wb') as f:
    pickle.dump(test_ds.all_loc_dates, f)

In [9]:
with open('full_date_range.pkl', 'wb') as f:
    pickle.dump(test_ds.full_date_range, f)

## Dataloader 

In [8]:
start_time = time.time()
# Call the dataloaders
train_dl = DataLoader(test_ds, sampler=test_sampler)  #, num_workers=2)
#val_dl = DataLoader(val_ds, sampler=val_sampler) 
test_dl = DataLoader(test_ds, sampler=test_sampler) 
end_time = time.time()
print(f'Took {end_time-start_time} seconds')

Took 0.00040268898010253906 seconds


# Timing the data generation

In [9]:
# OPTIMIZED 
start = time.time()
test_samples = list(test_dl.__iter__())
end = time.time()
print(f'Iterating through dataloader: {end-start} sec')

Iterating through dataloader: 38.15724492073059 sec


In [None]:
idx = 8
img = train_samples[idx][0]
label = train_samples[idx][1]
train_samples[idx][2]

In [None]:
img.shape

In [None]:
label[:,:,train_ds.feature_set['NDVI'],:,:]

In [None]:
train_ds.all_loc_dates[idx]

In [None]:
img[:,:,train_ds.feature_set['CP'],:,:] 

In [None]:
cp_data = th.cat([img[:,:,train_ds.feature_set['CP'],:,:] ,label[:,:,train_ds.feature_set['CP'],:,:]], axis=1)
b2_data = th.cat([img[:,:,train_ds.feature_set['B2'],:,:] ,label[:,:,train_ds.feature_set['B2'],:,:]], axis=1)
b8_data = th.cat([img[:,:,train_ds.feature_set['B8'],:,:] ,label[:,:,train_ds.feature_set['B8'],:,:]], axis=1)
ndvi_data = th.cat([img[:,:,train_ds.feature_set['NDVI'],:,:] ,label[:,:,train_ds.feature_set['NDVI'],:,:]], axis=1)

In [None]:
b2_data <0.1

In [None]:
b8_data >0.15

In [None]:
from scipy.interpolate import interp1d
from scipy import signal

def correct_noisy_ndvi(img, label, dataset, b2_thresh, b8_thresh, ts_len, len_preds):
    """
    Impute NDVI through linear interpolation if blue band (B2) > b2_thresh and infrared band (B8) < b8_thresh.
    Use both img and label for interpolation.
    
    Author: Selene
    :param img: data tensor
    :param label: label tensor
    :param dataset: dataset
    :param b2_thresh: threhsold for filtering with blue band (0.1 typically)
    :param b8_thresh: threhsold for filtering with infrared band (0.15 typically)
    :param ts_len: number of timestamps in data tensor
    :param len_preds: number of timestamps in label tensor
    """
    
    print(img.shape)
    img_tensor = img.clone().detach() 
    label_tensor = label.clone().detach() 
    
    # Get band data for image and label
    cp_data = th.cat([img_tensor[:,:,dataset.feature_set['CP'],:,:] ,label_tensor[:,:,dataset.feature_set['CP'],:,:]], axis=1)
    b2_data = th.cat([img_tensor[:,:,dataset.feature_set['B2'],:,:] ,label_tensor[:,:,dataset.feature_set['B2'],:,:]], axis=1)
    b8_data = th.cat([img_tensor[:,:,dataset.feature_set['B8'],:,:] ,label_tensor[:,:,dataset.feature_set['B8'],:,:]], axis=1)
    ndvi_data = th.cat([img_tensor[:,:,dataset.feature_set['NDVI'],:,:] ,label_tensor[:,:,dataset.feature_set['NDVI'],:,:]], axis=1)

    # Filter: find NDVI that needs to be replaced
    to_rep = (cp_data>0) #((b2_data>b2_thresh) | (b8_data<b8_thresh))
    to_rep = to_rep.squeeze(0).squeeze(1).squeeze(1) # make it 1D
    
    if to_rep.sum() < ts_len+len_preds-1: # There needs to be at least 2 points for inteprolation
        print('cp')
        # Linear interpolation 
        x = np.arange(ndvi_data.shape[1])[~to_rep] # get the x values that are valid
        ndvi_vec = ndvi_data.squeeze(0).squeeze(1).squeeze(1) # make it 1D
        f = interp1d(x, ndvi_vec[~to_rep], fill_value="extrapolate")
        interpolated = f(np.arange(ts_len+len_preds))
        
        # Make sure its between 0 and 1
        interpolated[interpolated < 0] = 0
        interpolated[interpolated > 1] = 1
            
        # Replace NDVI
        img_tensor[:,:,dataset.feature_set['NDVI'],:,:] = th.from_numpy(interpolated[:ts_len]).unsqueeze(0).unsqueeze(2).unsqueeze(2)
        label_tensor[:,:,dataset.feature_set['NDVI'],:,:] = th.from_numpy(interpolated[ts_len:]).unsqueeze(0).unsqueeze(2).unsqueeze(2)
        
        # Correct B2 and B8
        img_tensor[:,to_rep[:ts_len],dataset.feature_set['B2'],:,:] = b2_thresh
        label_tensor[:,to_rep[ts_len:],dataset.feature_set['B2'],:,:] = b2_thresh
        img_tensor[:,to_rep[:ts_len],dataset.feature_set['B8'],:,:] = b8_thresh
        label_tensor[:,to_rep[ts_len:],dataset.feature_set['B8'],:,:] = b8_thresh
        # Set CP to 0 for all values
        img_tensor[:,:,dataset.feature_set['CP'],:,:] = 0
        label_tensor[:,:,dataset.feature_set['CP'],:,:] = 0
    
    # If everything gets dropped
    else:
        to_rep = ((b2_data>b2_thresh) | (b8_data<b8_thresh))
        to_rep = to_rep.squeeze(0).squeeze(1).squeeze(1) # make it 1D
        print('band', to_rep.sum())
        # Linear interpolation 
        x = np.arange(ndvi_data.shape[1])[~to_rep] # get the x values that are valid
        ndvi_vec = ndvi_data.squeeze(0).squeeze(1).squeeze(1) # make it 1D
        f = interp1d(x, ndvi_vec[~to_rep], fill_value="extrapolate")
        interpolated = f(np.arange(ts_len+len_preds))
        
        # Make sure its between 0 and 1
        interpolated[interpolated < 0] = 0
        interpolated[interpolated > 1] = 1
            
        # Replace NDVI
        img_tensor[:,:,dataset.feature_set['NDVI'],:,:] = th.from_numpy(interpolated[:ts_len]).unsqueeze(0).unsqueeze(2).unsqueeze(2)
        label_tensor[:,:,dataset.feature_set['NDVI'],:,:] = th.from_numpy(interpolated[ts_len:]).unsqueeze(0).unsqueeze(2).unsqueeze(2)

        # Set CP to 0 for all values
        img_tensor[:,:,dataset.feature_set['CP'],:,:] = 0
        label_tensor[:,:,dataset.feature_set['CP'],:,:] = 0
    
    
    return img_tensor, label_tensor

In [None]:
img2, lab2 = correct_noisy_ndvi(img, label, train_ds, 0.10, 0.15, 15, 3)

In [None]:
# Plot before and after correction

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

# Get band data for image and label
cp_data = th.cat([img[:,:,train_ds.feature_set['CP'],:,:] ,label[:,:,train_ds.feature_set['CP'],:,:]], axis=1).squeeze(0).squeeze(1).squeeze(1)
b2_data = th.cat([img[:,:,train_ds.feature_set['B2'],:,:] ,label[:,:,train_ds.feature_set['B2'],:,:]], axis=1).squeeze(0).squeeze(1).squeeze(1)
b8_data = th.cat([img[:,:,train_ds.feature_set['B8'],:,:] ,label[:,:,train_ds.feature_set['B8'],:,:]], axis=1).squeeze(0).squeeze(1).squeeze(1)
ndvi_data = th.cat([img[:,:,train_ds.feature_set['NDVI'],:,:] ,label[:,:,train_ds.feature_set['NDVI'],:,:]], axis=1).squeeze(0).squeeze(1).squeeze(1)

# Get band data for image and label
cp_data2 = th.cat([img2[:,:,train_ds.feature_set['CP'],:,:] ,lab2[:,:,train_ds.feature_set['CP'],:,:]], axis=1).squeeze(0).squeeze(1).squeeze(1)
b2_data2 = th.cat([img2[:,:,train_ds.feature_set['B2'],:,:] ,lab2[:,:,train_ds.feature_set['B2'],:,:]], axis=1).squeeze(0).squeeze(1).squeeze(1)
b8_data2 = th.cat([img2[:,:,train_ds.feature_set['B8'],:,:] ,lab2[:,:,train_ds.feature_set['B8'],:,:]], axis=1).squeeze(0).squeeze(1).squeeze(1)
ndvi_data2 = th.cat([img2[:,:,train_ds.feature_set['NDVI'],:,:] ,lab2[:,:,train_ds.feature_set['NDVI'],:,:]], axis=1).squeeze(0).squeeze(1).squeeze(1)
   

fig, axs = plt.subplots(1, 1, figsize=(10, 3))
plt.suptitle('Raw data')
sns.lineplot(ax=axs, x=np.arange(18), y=ndvi_data, label='NDVI')
sns.lineplot(ax=axs, x=np.arange(18), y=cp_data/100, label='CP')
sns.lineplot(ax=axs, x=np.arange(18), y=b2_data, label='B2')
sns.lineplot(ax=axs, x=np.arange(18), y=b8_data, label='B8')
sns.despine(top=True, right=True)
plt.tight_layout()


fig, axs = plt.subplots(1, 1, figsize=(10, 3))
plt.suptitle('Correctd data')
sns.lineplot(ax=axs, x=np.arange(18), y=ndvi_data2, label='NDVI')
sns.lineplot(ax=axs, x=np.arange(18), y=cp_data2/100, label='CP')
sns.lineplot(ax=axs, x=np.arange(18), y=b2_data2, label='B2')
sns.lineplot(ax=axs, x=np.arange(18), y=b8_data2, label='B8')
sns.despine(top=True, right=True)
plt.tight_layout()

In [None]:
ndvi_data

In [None]:
ndvi_data2

# Timing the data loading from memory

In [None]:
start = time.time()
x,y = load_batch(batch_size = 10, batch_nbr = 0, sample_type = 'pixel_data', split='train', exp='arch')
end = time.time()
print(f'Loading a batch: {end-start} sec')

### Stats on train set

In [None]:
start_time = time.time()

# ==================================================================================================
# Do statistics on the training set to normalize the entire dataset

# Temporal bands
bands = []
n_temp = train_ds.bands_s2 + train_ds.bands_era
for i in list(train_ds.feature_set.keys())[:n_temp]:
    bands.append(train_ds.feature_set[i])
tmp_bands_vals, tmp_band_means_or_mins, tmp_band_stds_or_maxs = dataset_stats(train_dl, bands=bands, temporal=True, norm_method=sim_params["norm_method"])

# Static bands
bands = []
for i in list(train_ds.feature_set.keys())[n_temp:]:
    bands.append(train_ds.feature_set[i])
stat_bands_vals, stat_band_means_or_mins, stat_band_stds_or_maxs = dataset_stats(train_dl, bands=bands, temporal=False, norm_method=sim_params["norm_method"])

# First temporal, then static
all_band_vals = list(tmp_bands_vals) + list(stat_bands_vals)
all_band_means_or_mins = list(tmp_band_means_or_mins) + list(stat_band_means_or_mins)
all_band_stds_or_maxs = list(tmp_band_stds_or_maxs) + list(stat_band_stds_or_maxs)

Model_hyperparameters_dict['mean_or_min_intensity_training_set']=all_band_means_or_mins
Model_hyperparameters_dict['std_or_max_intensity_training_set']=all_band_stds_or_maxs

end_time = time.time()
print(f'Took {end_time-start_time} seconds')

In [None]:
# ==================================================================================================
# Perform normalization and prepare dataloaders
start_time = time.time()
train_ds_norm = DroughtImpactDataset(s2_path=s2_path, era_path=era_path, dem_path=dem_path, env_path=env_path, focus_list=focus_list_train,
                          focus_time=[focus_time_train], ts_delta=ts_delta, ts_len=ts_len, ratio=ratio, len_preds=len_preds, feature_set=feature_set, agg_funct_dict=agg_funct_dict, norm_stats=[all_band_means_or_mins, all_band_stds_or_maxs], norm_method=sim_params["norm_method"], multiple_labels=multiple_labels)

val_ds_norm = DroughtImpactDataset(s2_path=s2_path, era_path=era_path, dem_path=dem_path, env_path=env_path,  focus_list=focus_list_val,
                          focus_time=[focus_time_val], ts_delta=ts_delta, ts_len=ts_len, ratio=ratio, len_preds=len_preds, feature_set=feature_set, agg_funct_dict=agg_funct_dict, norm_stats=[all_band_means_or_mins, all_band_stds_or_maxs], norm_method=sim_params["norm_method"], multiple_labels=multiple_labels)

test_ds_norm = DroughtImpactDataset(s2_path=s2_path, era_path=era_path, dem_path=dem_path, env_path=env_path,  focus_list=focus_list_test,
                          focus_time=[focus_time_test1,focus_time_test2], ts_delta=ts_delta, ts_len=ts_len, ratio=ratio, len_preds=len_preds, feature_set=feature_set, agg_funct_dict=agg_funct_dict, norm_stats=[all_band_means_or_mins, all_band_stds_or_maxs], norm_method=sim_params["norm_method"], multiple_labels=multiple_labels)

train_sampler_norm = DroughtImpactSampler(train_ds_norm, size=sampler_size, length=batch_size_tr*n_batch, replacement=sampler_replacement, 
                                     mask_dir=mask_dir, set_seed=sampler_set_seed, static_dir=static_dir)

val_sampler_norm = DroughtImpactSampler(val_ds_norm, size=sampler_size, length=batch_size_tr*n_batch, replacement=sampler_replacement, 
                                     mask_dir=mask_dir, set_seed=sampler_set_seed, static_dir=static_dir)

test_sampler_norm = DroughtImpactSampler(test_ds_norm, size=sampler_size, length=sampler_length_te, replacement=sampler_replacement, 
                                     mask_dir=mask_dir, set_seed=sampler_set_seed, static_dir=static_dir)


# Call the dataloaders
train_dl_norm = DataLoader(train_ds_norm, sampler=train_sampler_norm) 
val_dl_norm = DataLoader(val_ds_norm, sampler=val_sampler_norm) 
test_dl_norm = DataLoader(test_ds_norm, sampler=test_sampler_norm) 
end_time = time.time()
print(f'Took {end_time-start_time} seconds')

start_time = time.time()
# Get info on each set 
## Maybe not efficient to go through samples...
train_samples = list(train_dl_norm.__iter__())
called_train_samples = [x[2] for x in train_samples]
val_samples = list(val_dl_norm.__iter__())
called_val_samples = [x[2] for x in val_samples]
test_samples = list(test_dl_norm.__iter__())
called_test_samples = [x[2] for x in test_samples]

# Save info to experiment
Experiment_dict['train_samples']=called_train_samples
Experiment_dict['val_samples']=called_val_samples
Experiment_dict['test_samples']=called_test_samples


end_time = time.time()
print(f'Took {end_time-start_time} seconds')

In [None]:
train_samples[4][0].isnan().any()

# Model Training

In [None]:
start_time = time.time()

# Create folder where checkpoints for model will be saved
method = sim_params["method"] # direct vs oneshot

checkpoint_folder = f'checkpoints/{method}_{sim_params["learning_rate"]}_{ model_params["num_layers"]}_{model_params["hidden_dim"]}/'
#checkpoint_folder = f'checkpoints/{dt_string}/'
if not os.path.exists(checkpoint_folder):
    os.mkdir(checkpoint_folder)
    
end_time = time.time()
print(f'Took {end_time-start_time} seconds')

In [None]:
dt_string = 'debug'

In [None]:
checkpoint_file_prefix = 'checkpoint_'+dt_string.split(' ')[0].replace('/', '_')+'_e'
checkpoints = [file for file in os.listdir(checkpoint_folder) if file.startswith(checkpoint_file_prefix) and 'best' not in file]
sorted_checkpoints = sorted(checkpoints, key=get_ckpt_epoch_batch)
# Get last checkpoint (latest epoch)
checkpoint_file = sorted_checkpoints[-1] if len(sorted_checkpoints)!=0 else None

In [None]:
checkpoint_file

In [None]:
start_time = time.time()

if checkpoint_file is not None:
    checkpoint = th.load(checkpoint_folder+checkpoint_file)
    start_epoch = checkpoint['epoch']
    start_batch = checkpoint['batch']
    epoch_loss = checkpoint['epoch_loss']
    optimizer = checkpoint['optimizer']
    dt_string = checkpoint['experiment_name']
    mlflow_run_id = checkpoint['mlflow_run_id']
    
    hidden_dim = model_params["hidden_dim"]
    num_layers = model_params["num_layers"]
    output_dim = model_params["output_dim"]
    
    if method == 'dir': #direct
        model = LSTM_oneshot(input_dim=len(train_ds.feature_set)-len(remove_bands), hidden_dim=hidden_dim, num_layers=num_layers, output_dim=output_dim)
    if method == 'rec': #recursive
        model = LSTM_recursive(input_dim=len(train_ds.feature_set)-len(remove_bands), hidden_dim=hidden_dim, num_layers=num_layers, num_steps=sim_params["num_steps"])
    criterion = select_loss_function(sim_params['loss_function'])
    
    model.load_state_dict(checkpoint['state_dict'])
    print("=> loaded checkpoint '{}' (trained for {} epochs)".format(checkpoint_file, checkpoint['epoch']+1))
    
    # Get existing MLflow experiment
    experiment = mlflow.get_experiment_by_name(dt_string)
    
    # Get run that was just created and its ID to use when tracking
    client = mlflow.tracking.MlflowClient() # Create a MlflowClient object
    runs = client.search_runs(experiment.experiment_id)
    mlflow_run_id = [r.info.run_id for r in runs if r.info.run_name==f'train_{method}'][0]
    mlflow_run_id_val = [r.info.run_id for r in runs if r.info.run_name==f'val_{method}'][0]
    

if checkpoint_file is None:
    start_epoch = 0
    start_batch = 0
    epoch_loss = 0
    hidden_dim = model_params["hidden_dim"]
    num_layers = model_params["num_layers"]
    output_dim = model_params["output_dim"]
    lr = sim_params["learning_rate"] # learning rate
    
    if method == 'dir': #direct
        model = LSTM_oneshot(input_dim=len(train_ds.feature_set)-len(remove_bands), hidden_dim=hidden_dim, num_layers=num_layers, output_dim=output_dim)
    if method == 'rec': #recursive
        model = LSTM_recursive(input_dim=len(train_ds.feature_set)-len(remove_bands), hidden_dim=hidden_dim, num_layers=num_layers, num_steps=sim_params["num_steps"])
    criterion = select_loss_function(sim_params['loss_function'])
    optimizer = select_optimizer(sim_params["optimizer"], model.parameters(), sim_params["learning_rate"], sim_params["momentum"])

    #summary(model, (len(train_ds.feature_set)-len(remove_bands), 1, 1))
    
    # Create new MLflow experiment
    now = datetime.now()
    #dt_string = now.strftime("%d/%m/%Y")+f'_{sim_params["learning_rate"]}_{model_params["num_layers"]}_{model_params["hidden_dim"]}'
    dt_string = 'debug'
    mlflow.create_experiment(name=dt_string) 
    experiment = mlflow.get_experiment_by_name(dt_string)
    
    with mlflow.start_run(experiment_id = experiment.experiment_id, run_name=f'train_{method}'):
        mlflow.log_param("n_samples training", len(train_dl))
        mlflow.log_param("batch_size training", batch_size_tr)
    
    
    with mlflow.start_run(experiment_id = experiment.experiment_id, run_name=f'val_{method}'):
        mlflow.log_param(f"n_samples val", len(val_dl))
        
    # Get run that was just created and its ID to use when tracking
    client = mlflow.tracking.MlflowClient() # Create a MlflowClient object
    runs = client.search_runs(experiment.experiment_id)
    mlflow_run_id = [r.info.run_id for r in runs if r.info.run_name==f'train_{method}'][0]
    mlflow_run_id_val = [r.info.run_id for r in runs if r.info.run_name==f'val_{method}'][0]
    
end_time = time.time()
print(f'Took {end_time-start_time} seconds')

In [None]:
start_time = time.time()

total_tr_loss = 0
best_loss = np.inf
model.train()

for ix_epoch in tqdm(range(1)): #sim_params["num_epochs"])
    if ix_epoch<start_epoch:
        continue

    print(f"Epoch {ix_epoch}\n---------")

    # Train
    epoch_loss = train_model(method=method, data_loader=train_dl_norm, model=model, epoch=ix_epoch, loss_function=criterion, optimizer=optimizer, 
                             batch_size=batch_size_tr, n_batch=n_batch,
                             n_timesteps_in=ts_len, n_timesteps_out=len_preds, n_feats_in=len(train_ds_norm.feature_set)-len(remove_bands), n_feats_out=output_dim, 
                             remove_band=train_dataset_params["remove_bands"], feature_set=train_ds_norm.feature_set, 
                             experiment=experiment, checkpoint_folder=checkpoint_folder, dt_string=dt_string, start_batch=start_batch, client=client, run_id=mlflow_run_id, epoch_loss=epoch_loss)

    total_tr_loss += epoch_loss
    
    end_time = time.time()
    print(f'Took {end_time-start_time} seconds')    

    # Validate
    total_val_loss, avg_val_loss = test_model(method=method, data_loader=val_dl_norm, model=model, loss_function=criterion, 
                                              n_timesteps_in=ts_len, n_timesteps_out=len_preds, n_feats_in=len(train_ds_norm.feature_set)-len(remove_bands), 
                                              n_feats_out=output_dim, remove_band=train_dataset_params["remove_bands"], feature_set=train_ds_norm.feature_set, 
                                              experiment=experiment, split='val', client=client, run_id=mlflow_run_id_val, checkpoint_folder=checkpoint_folder)



    best_loss = compare_model_for_checkpoint(total_val_loss, best_loss, model, ix_epoch, checkpoint_folder+'checkpoint_'+dt_string.split(' ')[0].replace('/', '_')+f'_e{ix_epoch}_b{n_batch}_best.pth.tar') 

with mlflow.start_run(experiment_id = experiment.experiment_id, run_name='trained'):
    mlflow.sklearn.log_model(model, "model")



In [None]:
 ########################################################################
# TEST MODEL
with mlflow.start_run(experiment_id = experiment.experiment_id, run_name=f'test_{method}'):
        mlflow.log_param(f"n_samples val", len(val_dl_norm))
        
# Get run that was just created and its ID to use when tracking
client = mlflow.tracking.MlflowClient() # Create a MlflowClient object
runs = client.search_runs(experiment.experiment_id)
mlflow_run_id_test = [r.info.run_id for r in runs if r.info.run_name==f'test_{method}'][0]

    
total_test_loss, avg_test_loss = test_model(method=method, data_loader=test_dl_norm, model=model, loss_function=criterion, 
                                            n_timesteps_in=ts_len, n_timesteps_out=len_preds, n_feats_in=len(train_ds.feature_set)-len(remove_bands), n_feats_out=output_dim, 
                                            remove_band=train_dataset_params["remove_bands"], feature_set=test_ds_norm.feature_set, 
                                            experiment=experiment, split='test', client=client, run_id=mlflow_run_id_test)

In [None]:
mlflow.end_run(mlflow_run_id)
mlflow.end_run(mlflow_run_id_val)
mlflow.end_run(mlflow_run_id_test)

In [None]:
Model_hyperparameters_dict['loss_function']=sim_params['loss_function']
Model_hyperparameters_dict['model_architecture']=model
Model_hyperparameters_dict['learning_rate']=sim_params['learning_rate']
Model_hyperparameters_dict['optimizer']=sim_params['optimizer']
Model_hyperparameters_dict['num_epochs']=sim_params['num_epochs']
Model_hyperparameters_dict['batch_size_train']=train_sampler_params['batch_size']

# Save Statistics
Experiment_dict['Model_hyperparameters']=Model_hyperparameters_dict
pd.to_pickle(Experiment_dict,save_path+'/Experiment_dict_'+Universally_unique_identifier.urn[9:]+'.pkl')