In [50]:

import os
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import loggers
from pytorch_lightning import Trainer
import numpy as np
import pandas as pd

from net import SimpleNet
from net_lstm import SimpleLSTM
from net_transformer import SimpleTimeSeriesTransformer

from datamodule import Datamodule
from model import FlightModel

from file_parsing_utils import create_csv_dict
from coordinate_transform import CoordinateEnum, helper_get_coordinate_system_based_on_enum
from iterate_flights import itterate_flights, build_features, flight_tensor_chunk_itterator
from model import iterative_path_predict
from coordinate_transform import  haversine
from iterate_flights import flightpath_iterator


from eval_metrics import * 

from glob import glob 
import copy
from collections import defaultdict

from folium_utils import create_folium_map
from folium_utils import get_map_image
import folium

from matplotlib import pyplot as plt
import matplotlib.image as mpimg


import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np

import hydra


from hydra.utils import instantiate
from omegaconf import OmegaConf

In [48]:
def load_model(experiment_dir, map_location = "cpu"):
    """
    Method to load in model from checkpoint given experiment dir
    Returns the model and the dictionary of the shape [Channels x Timesteps] the model will need as input
    """

    # Paths to config (determines params like the number of input rows and input channels) and checkpoint where weights are stored
    config_path = experiment_dir + "config/config.yaml"
    last_checkpoint = experiment_dir + "model_checkpoints/last.ckpt"

    # Get the config with all of it's parameters
    cfg = OmegaConf.load(config_path)
    model_cfg = cfg.model

    # Figure out dimensions of the input tensors to pass to the model 
    num_input_rows = cfg["datamodule"]["datamodule"]["num_input_rows_total"]
    num_input_coordinates = len(helper_get_coordinate_system_based_on_enum(CoordinateEnum[cfg["datamodule"]["datamodule"]['coordinate_system_enum']]))
    num_auxilary_inputs = len(cfg["datamodule"]["datamodule"]['auxiliary_input_channels'])
    num_input_channels = num_input_coordinates + num_auxilary_inputs
    print("num_input_channels, num_input_rows: ", num_input_channels, num_input_rows)

    desired_shape_dict = {"num_input_channels": num_input_channels, 
                     "num_input_rows": num_input_rows}

    
    # Get model specific config and instantiate the model 
    model_cfg = cfg.model.model
    model = instantiate(model_cfg)

    # Pass a dummy tensor to instantiate any lazy modules
    dummy_input_tensor = torch.randn(32, num_input_channels, num_input_rows)  # Adjust dimensions as needed
    _ = model(dummy_input_tensor)

    # Load the checkpoint
    checkpoint = torch.load(last_checkpoint, map_location=map_location)

    # Manually load the state_dict
    model.load_state_dict(checkpoint['state_dict'], strict=False)

    return model, desired_shape_dict

In [49]:
experiment_cnn_dir = "/Users/aleksandranikevich/Desktop/AircraftTrajectory/REPO/flight_pattern_of_life/models/TestHydra/Experiment_1/"
experiment_lstm_dir = "/Users/aleksandranikevich/Desktop/AircraftTrajectory/REPO/flight_pattern_of_life/models/TestHydra/Experiment_2/"
experiment_transformer_dir = "/Users/aleksandranikevich/Desktop/AircraftTrajectory/REPO/flight_pattern_of_life/models/TestHydra/Experiment_3/"

model_cnn, desired_shape_dict = load_model(experiment_cnn_dir, map_location = "cpu")
model_lstm, desired_shape_dict = load_model(experiment_lstm_dir, map_location = "cpu")
model_transformer, desired_shape_dict = load_model(experiment_transformer_dir, map_location = "cpu")



num_input_channels, num_input_rows:  2 500
Using Default model weights initialization


  checkpoint = torch.load(last_checkpoint, map_location=map_location)


num_input_channels, num_input_rows:  2 500
Using Default model weights initialization
num_input_channels, num_input_rows:  2 500
Using Default model weights initialization




In [None]:
# Example usage of model using dummy input tensor: 
model_output = model_cnn()

In [None]:
# eval_models(model, 
#         individual_flights_dir=cfg.datamodule.datamodule.all_flight_dataframes_dict,
#         coordinate_system_enum=cfg.coordinate_system.coordinate_system.coordinate_system_enum,
#         auxiliary_input_channels=cfg.coordinate_system.coordinate_system.auxiliary_input_channels,
#         auxiliary_output_channels=cfg.coordinate_system.coordinate_system.auxiliary_output_channels,
#         min_rows_input=cfg.datamodule.datamodule.min_rows_input,
#         num_input_rows_total=cfg.datamodule.datamodule.num_input_rows_total,
#         results_save_dir=results_dir,
#         desired_keys=eval_keys,
#         num_predict_steps=cfg.eval_config.eval_config.num_predict_steps,
#         break_after_index=cfg.eval_config.eval_config.break_after_index, 
#         dataset_wide_normalization_dict=dataset_wide_normalization_dict, 
#         )

In [54]:
def eval_models(model, 
                individual_flights_dir, 
                coordinate_system_enum, 
                auxiliary_input_channels, 
                auxiliary_output_channels, 
                min_rows_input, 
                num_input_rows_total, 
                results_save_dir, 
                desired_keys = None, 
                num_predict_steps = 10, 
                break_after_index = 3, 
                dataset_wide_normalization_dict = None, 
                ):

    # Eval over only desired flightpath keys
    flight_dfs = create_csv_dict(individual_flights_dir)
    if desired_keys is None:
        desired_keys = ['146014', '180338']
        eval_flights_dfs = {key: flight_dfs[key] for key in desired_keys if key in flight_dfs}
    else:
        eval_flights_dfs = {key[0]: {key[1]: flight_dfs[key[0]][key[1]]} for key in desired_keys}

    flight_dictionary_pre_loaded = False

    coordinate_system = helper_get_coordinate_system_based_on_enum(coordinate_system_enum)
    len_coordinate_system  = len(coordinate_system)
    desired_input_features = coordinate_system  + auxiliary_input_channels
    desired_output_features = coordinate_system  + auxiliary_output_channels

    flightpath_iter = flightpath_iterator(flights_dict=eval_flights_dfs, 
                            flight_dictionary_pre_loaded=flight_dictionary_pre_loaded, 
                            desired_input_features=desired_input_features, 
                            desired_output_features=desired_output_features, 
                            min_rows_input=min_rows_input, 
                            num_input_rows_total=num_input_rows_total, 
                            num_output_rows=num_predict_steps, 
                            len_coordinate_system=len_coordinate_system, 
                            shuffle_flights = False, 
                            shuffle_chunks = False, 
                            bool_yield_meta_flightpath = True, 
                            force_new_flightpath_every_val_step = False, 
                            dataset_wide_normalization_dict = dataset_wide_normalization_dict,
                            )
    
    # Evaluate the model on each of the desired flightpahts, N steps at a time (feeding the model M prior steps aircraft took)
    eval_dict, flightpaths = eval_over_flightpaths(model, flightpath_iter, num_predict_steps, break_after_index = break_after_index)
    # Get metrics 
    dict_overall_eval, dict_overall_eval_arrays = get_final_eval_metrics_all_models(eval_dict, flightpaths_dict=flightpaths)
    # Get averaged metrics for overall performance measure
    final_metrics_dict, final_array_metrics_dict = get_overall_error_metrics(dict_overall_eval, dict_overall_eval_arrays)
    # create figure of the average errors based on the model by different metric
    fig = plot_metrics_from_dict(final_array_metrics_dict)


    # save the figure
    os.makedirs(results_save_dir, exist_ok=True)
    save_path = os.path.join(results_save_dir, "metrics_plot.png")
    fig.savefig(save_path, format='png') 

    # Save the eval metrics (both arrays and final scores)
    save_as_pickle(final_metrics_dict, "final_metrics_dict.pkl", results_save_dir)
    save_as_pickle(final_array_metrics_dict, "final_array_metrics_dict.pkl", results_save_dir)

In [55]:
experiment_dir = "/Users/aleksandranikevich/Desktop/AircraftTrajectory/REPO/flight_pattern_of_life/models/TestHydra/Experiment_1/"

config_path = experiment_dir + "config/config.yaml"
last_checkpoint = experiment_dir + "model_checkpoints/last.ckpt"

# Get the config with all of it's parameters
cfg = OmegaConf.load(config_path)

In [56]:
# Get flightpath keys we will be using to evaluate the model / make sure those keys are not trained on
if cfg.eval_config.eval_config.desired_keys is None:
    eval_keys = get_eval_keys(cfg.datamodule.datamodule.all_flight_dataframes_dict)
else:
    eval_keys = cfg.eval_config.eval_config.desired_keys
print("Eval keys: ", eval_keys)

Eval keys:  [('178224', 'N90K')]


In [57]:
results_dir = "temp_results/"
dataset_wide_normalization_dict = None

eval_models(model_cnn, 
        individual_flights_dir=cfg.datamodule.datamodule.all_flight_dataframes_dict,
        coordinate_system_enum=cfg.coordinate_system.coordinate_system.coordinate_system_enum,
        auxiliary_input_channels=cfg.coordinate_system.coordinate_system.auxiliary_input_channels,
        auxiliary_output_channels=cfg.coordinate_system.coordinate_system.auxiliary_output_channels,
        min_rows_input=cfg.datamodule.datamodule.min_rows_input,
        num_input_rows_total=cfg.datamodule.datamodule.num_input_rows_total,
        results_save_dir=results_dir,
        desired_keys=eval_keys,
        num_predict_steps=cfg.eval_config.eval_config.num_predict_steps,
        break_after_index=cfg.eval_config.eval_config.break_after_index, 
        dataset_wide_normalization_dict=dataset_wide_normalization_dict, 
        )

Working on flightpath index number: 0, and flightpath is None: False
Working on flightpath index number: 1, and flightpath is None: False
Working on flightpath index number: 2, and flightpath is None: False
Working on flightpath index number: 3, and flightpath is None: False
Working on flightpath index number: 4, and flightpath is None: False
broken because flight chunk index  >= break after index


IndexError: list index out of range

In [58]:
# DEBUG EVAL 

In [77]:
model = model_cnn
individual_flights_dir=cfg.datamodule.datamodule.all_flight_dataframes_dict
coordinate_system_enum=cfg.coordinate_system.coordinate_system.coordinate_system_enum
auxiliary_input_channels=cfg.coordinate_system.coordinate_system.auxiliary_input_channels
auxiliary_output_channels=cfg.coordinate_system.coordinate_system.auxiliary_output_channels
min_rows_input=cfg.datamodule.datamodule.min_rows_input
num_input_rows_total=cfg.datamodule.datamodule.num_input_rows_total
results_save_dir=results_dir
desired_keys=eval_keys
num_predict_steps=cfg.eval_config.eval_config.num_predict_steps
break_after_index=cfg.eval_config.eval_config.break_after_index
dataset_wide_normalization_dict = None




In [78]:
# Eval over only desired flightpath keys
flight_dfs = create_csv_dict(individual_flights_dir)
if desired_keys is None:
    desired_keys = ['146014', '180338']
    eval_flights_dfs = {key: flight_dfs[key] for key in desired_keys if key in flight_dfs}
else:
    eval_flights_dfs = {key[0]: {key[1]: flight_dfs[key[0]][key[1]]} for key in desired_keys}

flight_dictionary_pre_loaded = False

coordinate_system = helper_get_coordinate_system_based_on_enum(coordinate_system_enum)
len_coordinate_system  = len(coordinate_system)
desired_input_features = coordinate_system  + auxiliary_input_channels
desired_output_features = coordinate_system  + auxiliary_output_channels

flightpath_iter = flightpath_iterator(flights_dict=eval_flights_dfs, 
                        flight_dictionary_pre_loaded=flight_dictionary_pre_loaded, 
                        desired_input_features=desired_input_features, 
                        desired_output_features=desired_output_features, 
                        min_rows_input=min_rows_input, 
                        num_input_rows_total=num_input_rows_total, 
                        num_output_rows=num_predict_steps, 
                        len_coordinate_system=len_coordinate_system, 
                        shuffle_flights = False, 
                        shuffle_chunks = False, 
                        bool_yield_meta_flightpath = True, 
                        force_new_flightpath_every_val_step = False, 
                        dataset_wide_normalization_dict = dataset_wide_normalization_dict,
                        )

# Evaluate the model on each of the desired flightpahts, N steps at a time (feeding the model M prior steps aircraft took)
eval_dict, flightpaths = eval_over_flightpaths(model, flightpath_iter, num_predict_steps, break_after_index = break_after_index)
# Get metrics 
dict_overall_eval, dict_overall_eval_arrays = get_final_eval_metrics_all_models(eval_dict, flightpaths_dict=flightpaths)
# Get averaged metrics for overall performance measure

Working on flightpath index number: 0, and flightpath is None: False
Working on flightpath index number: 1, and flightpath is None: False
Working on flightpath index number: 2, and flightpath is None: False
Working on flightpath index number: 3, and flightpath is None: False
Working on flightpath index number: 4, and flightpath is None: False
broken because flight chunk index  >= break after index


In [89]:
def get_final_eval_metrics_all_models(eval_dict, flightpaths_dict):
    """
    *Currently assumes Lat-Long coordinates
    """


    paths_we_evaled_keys_list = list(eval_dict.keys())
    models_and_other = list(eval_dict[paths_we_evaled_keys_list[0]].keys()) # Ex: dict_keys(['prediction_model_0', 'chunk_index', 'ground_truth'])
    all_model_keys = [model_name for model_name in models_and_other if "prediction_model_" in model_name]

    dict_overall_eval = {}
    dict_overall_eval_arrays = {}
    for flightpath_key in paths_we_evaled_keys_list:

        flightpath_dataframe = flightpaths_dict[flightpath_key]
        flightpath_dataframe.columns = [col.capitalize() if col.lower() in ['latitude', 'longitude'] else col for col in flightpath_dataframe.columns]
        flightpath_np = flightpath_dataframe[["Latitude", "Longitude"]].to_numpy()
        flightpath_np = np.squeeze(flightpath_np)

        ground_truth_list = eval_dict[flightpath_key]["ground_truth"]
        num_samples = len(ground_truth_list)

        #dict_disance_overall = {model_key: 0.0 for model_key in all_model_keys}
        dict_errors = {model_key: defaultdict(float) for model_key in all_model_keys}
        dict_errors_arrs = {model_key: defaultdict(list) for model_key in all_model_keys}

        for sample_index in range(num_samples):
            for model_key in all_model_keys:
                ground_truth_array = eval_dict[flightpath_key]["ground_truth"][sample_index]
                ground_truth_array_shape = ground_truth_array.shape
                model_prediction_array = eval_dict[flightpath_key][model_key][sample_index]
                ground_truth_array = np.squeeze(ground_truth_array)
                model_prediction_array = np.squeeze(model_prediction_array)[:, :ground_truth_array_shape[-1]] # TODO FORCE THE SAME NUMBER OF PREDICICTIONS AS GROUND TRUTH ARRAY

                # Sometimes there are no more samples in the ground truth array, we must then ignore the corresponding predictions
                if ground_truth_array.shape != model_prediction_array.shape:
                    print("CONTINUE")
                    print(ground_truth_array.shape, model_prediction_array.shape, "\n")
                    continue

                error_distance_arr = haversine(ground_truth_array[0], ground_truth_array[1], model_prediction_array[0], model_prediction_array[1])

                error_distance_total = np.sum(error_distance_arr)
                dict_errors[model_key]["error_distance_overall"] += error_distance_total
                dict_errors_arrs[model_key]["error_distance_arr"].append(error_distance_arr)


                # Normalized errors
                chunk_idx = eval_dict[flightpath_key]["chunk_index"][sample_index] - 1
                current_position = flightpath_np[chunk_idx]
                current_lat = current_position[0]
                current_long = current_position[1]
                distance_traveled_arr = haversine(current_lat, current_long, ground_truth_array[0], ground_truth_array[1])
                
                normalized_error_arr = error_distance_arr / distance_traveled_arr
                # sometimes the distance traveled to the next timestep is insignificant, we want to filter NaNs and Infs that are caused by this
                normalized_error_arr = np.nan_to_num(normalized_error_arr, nan=0.0, posinf=0.0, neginf=0.0)

                if np.isinf(normalized_error_arr).any() or np.isnan(normalized_error_arr).any():
                    print('\n\n')
                    print("current_lat, current_long, ground_truth_array[0], ground_truth_array[1]: ", current_lat, current_long, ground_truth_array[0], ground_truth_array[1])
                    print("distance_traveled_arr: ", distance_traveled_arr)
                    print("normalized_error_arr: ", normalized_error_arr)

                error_distance_overall_normalized = np.sum(normalized_error_arr)
                dict_errors[model_key]["error_distance_overall_normalized"] += error_distance_overall_normalized
                dict_errors_arrs[model_key]["normalized_error_arr"].append(normalized_error_arr)

                _, num_predictions = ground_truth_array.shape
                dict_errors[model_key]["num_predictions_total"] += num_predictions

                mse_error_arr = 0.5*((ground_truth_array[0] - current_lat)**2.0 + (ground_truth_array[1] - current_long)**2.0)
                dict_errors[model_key]["mse_error_arr"] += np.sum(mse_error_arr)
                dict_errors_arrs[model_key]["mse_error_arr"].append(mse_error_arr)

                

        ###dict_errors_arrs = {key: np.stack(arr_list, axis=0) for key, arr_list in dict_errors_arrs.items()}

        dict_overall_eval[flightpath_key] = dict_errors
        dict_overall_eval_arrays[flightpath_key] = dict_errors_arrs

    
    return dict_overall_eval, dict_overall_eval_arrays


In [90]:
dict_overall_eval, dict_overall_eval_arrays = get_final_eval_metrics_all_models(eval_dict, flightpaths_dict=flightpaths)

In [91]:
dict_overall_eval

{('178224',
  'N90K'): {'prediction_model_0': defaultdict(float,
              {'error_distance_overall': 60.21800422668457,
               'error_distance_overall_normalized': 18.138869285583496,
               'num_predictions_total': 40.0,
               'mse_error_arr': 0.026358262170106173})}}

In [83]:
len(eval_dict[('178224','N90K')]['prediction_model_0'])


4

In [86]:
eval_dict

defaultdict(<function eval_metrics.eval_over_flightpaths.<locals>.<lambda>()>,
            {('178224',
              'N90K'): defaultdict(list,
                         {'prediction_model_0': [array([[[ 39.65982 ,  39.654663,  39.65362 ,  39.649616,  39.6442  ,
                                     39.637028,  39.629818,  39.622623,  39.612156,  39.603725,
                                     39.599773,  39.594986,  39.594105,  39.59047 ,  39.585526,
                                     39.578964,  39.57236 ,  39.565838,  39.55624 ,  39.54863 ,
                                     39.54466 ,  39.539772,  39.538574,  39.5345  ,  39.529087,
                                     39.521942,  39.514656,  39.50735 ,  39.496857,  39.488194,
                                     39.48445 ,  39.4797  ,  39.47827 ,  39.47396 ,  39.468414,
                                     39.461098,  39.453598,  39.445896,  39.43515 ,  39.425953,
                                     39.42208 ,  39.4171  ,  39.41

In [85]:
flightpaths[('178224','N90K')]

Unnamed: 0.1,Unnamed: 0,fltKey,CID,UAID,Time,Latitude,Longitude,Altitude,PointSource,RecTypeCat,Significance,GroundSpeed,FlightCourse
0,45391384,178224,421,N90K,1715896498,41.44937,-90.50767,7.0,AIG200,1,1,134,91
1,45391385,178224,421,N90K,1715896503,41.44933,-90.50351,7.0,AIG200,1,10,134,91
2,45391386,178224,421,N90K,1715896508,41.44941,-90.49936,8.0,AIG200,1,10,133,89
3,45391387,178224,421,N90K,1715896513,41.44937,-90.49503,10.0,AIG200,1,1,139,91
4,45391388,178224,421,N90K,1715896518,41.44912,-90.49070,11.0,AIG200,1,10,141,94
...,...,...,...,...,...,...,...,...,...,...,...,...,...
1173,45392557,178224,421,N90K,1715903092,37.64125,-97.42731,14.0,AIG200,1,2,88,20
1174,45392558,178224,421,N90K,1715903092,37.64170,-97.42706,14.0,AIG200,1,1,274,24
1175,45392559,178224,421,N90K,1715903097,37.64361,-97.42583,14.0,TH_FIXM,1,9,101,27
1176,45392560,178224,421,N90K,1715903108,37.64806,-97.42389,14.0,TH_FIXM,1,10,92,19


In [79]:
dict_overall_eval

{('178224', 'N90K'): {'prediction_model_0': defaultdict(float, {})}}

In [80]:
dict_overall_eval_arrays

{('178224', 'N90K'): {'prediction_model_0': defaultdict(list, {})}}

In [22]:
from omegaconf import OmegaConf
experiment_dir = "/Users/aleksandranikevich/Desktop/AircraftTrajectory/REPO/flight_pattern_of_life/models/TestHydra/Experiment_1/"
config_path = experiment_dir + "config/config.yaml"
last_checkpoint = experiment_dir + "model_checkpoints/last.ckpt"

cfg = OmegaConf.load(config_path)
model_cfg = cfg.model


In [38]:

# Create Dummy input tensor for any Lazy instantiated components 
num_input_rows = cfg["datamodule"]["datamodule"]["num_input_rows_total"]
num_input_coordinates = len(helper_get_coordinate_system_based_on_enum(CoordinateEnum[cfg["datamodule"]["datamodule"]['coordinate_system_enum']]))
num_auxilary_inputs = len(cfg["datamodule"]["datamodule"]['auxiliary_input_channels'])
num_input_channels = num_input_coordinates + num_auxilary_inputs

print("num_input_channels, num_input_rows: ", num_input_channels, num_input_rows)
dummy_input_tensor = torch.rand([1, num_input_channels, num_input_rows])
print(dummy_input_tensor.shape)

num_input_channels, num_input_rows:  2 500
torch.Size([1, 2, 500])


In [44]:

# Ensure 'model_cfg' is correctly accessed
model_cfg = cfg.model.model

# Instantiate the model using Hydra's instantiate method
model = instantiate(model_cfg)

# Pass a dummy tensor to instantiate any lazy modules
dummy_input_tensor = torch.randn(32, num_input_channels, num_input_rows)  # Adjust dimensions as needed
_ = model(dummy_input_tensor)

# Load the checkpoint
checkpoint = torch.load(last_checkpoint, map_location='cpu')

# Manually load the state_dict
model.load_state_dict(checkpoint['state_dict'], strict=False)

print(model)


Using Default model weights initialization
FlightModel(
  (model): SimpleNet(
    (first_conv_block): BasicBlock(
      (conv): Conv1d(2, 2, kernel_size=(9,), stride=(1,), padding=same, dilation=(3,), groups=2, padding_mode=replicate)
      (pw_conv): Conv1d(2, 64, kernel_size=(1,), stride=(1,), padding=same, dilation=(3,), padding_mode=replicate)
      (act): PReLU(num_parameters=1)
      (norm): InstanceNorm1d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    )
    (residual_blocks): Sequential(
      (0): ResidualBlock(
        (block1): BasicBlock(
          (conv): Conv1d(64, 64, kernel_size=(9,), stride=(1,), padding=same, dilation=(3,), groups=64, padding_mode=replicate)
          (pw_conv): Conv1d(64, 64, kernel_size=(1,), stride=(1,), padding=same, dilation=(3,), padding_mode=replicate)
          (act): PReLU(num_parameters=1)
          (norm): InstanceNorm1d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        )
        (block2

  checkpoint = torch.load(last_checkpoint, map_location='cpu')


In [39]:
net = instantiate(model_cfg["model"])
_ = net(dummy_input_tensor)                     # To instantiate any lazy modules
model = FlightModel.load_from_checkpoint(
    checkpoint_path=last_checkpoint,
    map_location='cpu',
    model=net, 
    coordinate_system_enum=model_cfg['model']['coordinate_system_enum'], 
    loss_fn=model_cfg['model']['loss_fn'], 
    optimizer=model_cfg['model']['optimizer']
)

print(model)


Using Default model weights initialization
Initializing model with weights distribution 0.0 and standard deviation 0.0001


RuntimeError: Error(s) in loading state_dict for FlightModel:
	Missing key(s) in state_dict: "model.model.first_conv_block.conv.weight", "model.model.first_conv_block.conv.bias", "model.model.first_conv_block.pw_conv.weight", "model.model.first_conv_block.pw_conv.bias", "model.model.first_conv_block.act.weight", "model.model.residual_blocks.0.block1.conv.weight", "model.model.residual_blocks.0.block1.conv.bias", "model.model.residual_blocks.0.block1.pw_conv.weight", "model.model.residual_blocks.0.block1.pw_conv.bias", "model.model.residual_blocks.0.block1.act.weight", "model.model.residual_blocks.0.block2.conv.weight", "model.model.residual_blocks.0.block2.conv.bias", "model.model.residual_blocks.0.block2.pw_conv.weight", "model.model.residual_blocks.0.block2.pw_conv.bias", "model.model.residual_blocks.0.block2.act.weight", "model.model.residual_blocks.1.block1.conv.weight", "model.model.residual_blocks.1.block1.conv.bias", "model.model.residual_blocks.1.block1.pw_conv.weight", "model.model.residual_blocks.1.block1.pw_conv.bias", "model.model.residual_blocks.1.block1.act.weight", "model.model.residual_blocks.1.block2.conv.weight", "model.model.residual_blocks.1.block2.conv.bias", "model.model.residual_blocks.1.block2.pw_conv.weight", "model.model.residual_blocks.1.block2.pw_conv.bias", "model.model.residual_blocks.1.block2.act.weight", "model.model.residual_blocks.2.block1.conv.weight", "model.model.residual_blocks.2.block1.conv.bias", "model.model.residual_blocks.2.block1.pw_conv.weight", "model.model.residual_blocks.2.block1.pw_conv.bias", "model.model.residual_blocks.2.block1.act.weight", "model.model.residual_blocks.2.block2.conv.weight", "model.model.residual_blocks.2.block2.conv.bias", "model.model.residual_blocks.2.block2.pw_conv.weight", "model.model.residual_blocks.2.block2.pw_conv.bias", "model.model.residual_blocks.2.block2.act.weight", "model.model.residual_blocks.3.block1.conv.weight", "model.model.residual_blocks.3.block1.conv.bias", "model.model.residual_blocks.3.block1.pw_conv.weight", "model.model.residual_blocks.3.block1.pw_conv.bias", "model.model.residual_blocks.3.block1.act.weight", "model.model.residual_blocks.3.block2.conv.weight", "model.model.residual_blocks.3.block2.conv.bias", "model.model.residual_blocks.3.block2.pw_conv.weight", "model.model.residual_blocks.3.block2.pw_conv.bias", "model.model.residual_blocks.3.block2.act.weight", "model.model.residual_blocks.4.block1.conv.weight", "model.model.residual_blocks.4.block1.conv.bias", "model.model.residual_blocks.4.block1.pw_conv.weight", "model.model.residual_blocks.4.block1.pw_conv.bias", "model.model.residual_blocks.4.block1.act.weight", "model.model.residual_blocks.4.block2.conv.weight", "model.model.residual_blocks.4.block2.conv.bias", "model.model.residual_blocks.4.block2.pw_conv.weight", "model.model.residual_blocks.4.block2.pw_conv.bias", "model.model.residual_blocks.4.block2.act.weight", "model.model.residual_blocks.5.block1.conv.weight", "model.model.residual_blocks.5.block1.conv.bias", "model.model.residual_blocks.5.block1.pw_conv.weight", "model.model.residual_blocks.5.block1.pw_conv.bias", "model.model.residual_blocks.5.block1.act.weight", "model.model.residual_blocks.5.block2.conv.weight", "model.model.residual_blocks.5.block2.conv.bias", "model.model.residual_blocks.5.block2.pw_conv.weight", "model.model.residual_blocks.5.block2.pw_conv.bias", "model.model.residual_blocks.5.block2.act.weight", "model.model.residual_blocks.6.block1.conv.weight", "model.model.residual_blocks.6.block1.conv.bias", "model.model.residual_blocks.6.block1.pw_conv.weight", "model.model.residual_blocks.6.block1.pw_conv.bias", "model.model.residual_blocks.6.block1.act.weight", "model.model.residual_blocks.6.block2.conv.weight", "model.model.residual_blocks.6.block2.conv.bias", "model.model.residual_blocks.6.block2.pw_conv.weight", "model.model.residual_blocks.6.block2.pw_conv.bias", "model.model.residual_blocks.6.block2.act.weight", "model.model.residual_blocks.7.block1.conv.weight", "model.model.residual_blocks.7.block1.conv.bias", "model.model.residual_blocks.7.block1.pw_conv.weight", "model.model.residual_blocks.7.block1.pw_conv.bias", "model.model.residual_blocks.7.block1.act.weight", "model.model.residual_blocks.7.block2.conv.weight", "model.model.residual_blocks.7.block2.conv.bias", "model.model.residual_blocks.7.block2.pw_conv.weight", "model.model.residual_blocks.7.block2.pw_conv.bias", "model.model.residual_blocks.7.block2.act.weight", "model.model.last_conv_block.conv.weight", "model.model.last_conv_block.conv.bias", "model.model.last_conv_block.pw_conv.weight", "model.model.last_conv_block.pw_conv.bias", "model.model.last_conv_block.act.weight", "model.model.one_more_linear.weight", "model.model.one_more_linear.bias", "model.model.act_one_more_linear.weight", "model.model.last_linear.weight", "model.model.last_linear.bias", "model.model.act_final.weight". 
	Unexpected key(s) in state_dict: "model.first_conv_block.conv.weight", "model.first_conv_block.conv.bias", "model.first_conv_block.pw_conv.weight", "model.first_conv_block.pw_conv.bias", "model.first_conv_block.act.weight", "model.residual_blocks.0.block1.conv.weight", "model.residual_blocks.0.block1.conv.bias", "model.residual_blocks.0.block1.pw_conv.weight", "model.residual_blocks.0.block1.pw_conv.bias", "model.residual_blocks.0.block1.act.weight", "model.residual_blocks.0.block2.conv.weight", "model.residual_blocks.0.block2.conv.bias", "model.residual_blocks.0.block2.pw_conv.weight", "model.residual_blocks.0.block2.pw_conv.bias", "model.residual_blocks.0.block2.act.weight", "model.residual_blocks.1.block1.conv.weight", "model.residual_blocks.1.block1.conv.bias", "model.residual_blocks.1.block1.pw_conv.weight", "model.residual_blocks.1.block1.pw_conv.bias", "model.residual_blocks.1.block1.act.weight", "model.residual_blocks.1.block2.conv.weight", "model.residual_blocks.1.block2.conv.bias", "model.residual_blocks.1.block2.pw_conv.weight", "model.residual_blocks.1.block2.pw_conv.bias", "model.residual_blocks.1.block2.act.weight", "model.residual_blocks.2.block1.conv.weight", "model.residual_blocks.2.block1.conv.bias", "model.residual_blocks.2.block1.pw_conv.weight", "model.residual_blocks.2.block1.pw_conv.bias", "model.residual_blocks.2.block1.act.weight", "model.residual_blocks.2.block2.conv.weight", "model.residual_blocks.2.block2.conv.bias", "model.residual_blocks.2.block2.pw_conv.weight", "model.residual_blocks.2.block2.pw_conv.bias", "model.residual_blocks.2.block2.act.weight", "model.residual_blocks.3.block1.conv.weight", "model.residual_blocks.3.block1.conv.bias", "model.residual_blocks.3.block1.pw_conv.weight", "model.residual_blocks.3.block1.pw_conv.bias", "model.residual_blocks.3.block1.act.weight", "model.residual_blocks.3.block2.conv.weight", "model.residual_blocks.3.block2.conv.bias", "model.residual_blocks.3.block2.pw_conv.weight", "model.residual_blocks.3.block2.pw_conv.bias", "model.residual_blocks.3.block2.act.weight", "model.residual_blocks.4.block1.conv.weight", "model.residual_blocks.4.block1.conv.bias", "model.residual_blocks.4.block1.pw_conv.weight", "model.residual_blocks.4.block1.pw_conv.bias", "model.residual_blocks.4.block1.act.weight", "model.residual_blocks.4.block2.conv.weight", "model.residual_blocks.4.block2.conv.bias", "model.residual_blocks.4.block2.pw_conv.weight", "model.residual_blocks.4.block2.pw_conv.bias", "model.residual_blocks.4.block2.act.weight", "model.residual_blocks.5.block1.conv.weight", "model.residual_blocks.5.block1.conv.bias", "model.residual_blocks.5.block1.pw_conv.weight", "model.residual_blocks.5.block1.pw_conv.bias", "model.residual_blocks.5.block1.act.weight", "model.residual_blocks.5.block2.conv.weight", "model.residual_blocks.5.block2.conv.bias", "model.residual_blocks.5.block2.pw_conv.weight", "model.residual_blocks.5.block2.pw_conv.bias", "model.residual_blocks.5.block2.act.weight", "model.residual_blocks.6.block1.conv.weight", "model.residual_blocks.6.block1.conv.bias", "model.residual_blocks.6.block1.pw_conv.weight", "model.residual_blocks.6.block1.pw_conv.bias", "model.residual_blocks.6.block1.act.weight", "model.residual_blocks.6.block2.conv.weight", "model.residual_blocks.6.block2.conv.bias", "model.residual_blocks.6.block2.pw_conv.weight", "model.residual_blocks.6.block2.pw_conv.bias", "model.residual_blocks.6.block2.act.weight", "model.residual_blocks.7.block1.conv.weight", "model.residual_blocks.7.block1.conv.bias", "model.residual_blocks.7.block1.pw_conv.weight", "model.residual_blocks.7.block1.pw_conv.bias", "model.residual_blocks.7.block1.act.weight", "model.residual_blocks.7.block2.conv.weight", "model.residual_blocks.7.block2.conv.bias", "model.residual_blocks.7.block2.pw_conv.weight", "model.residual_blocks.7.block2.pw_conv.bias", "model.residual_blocks.7.block2.act.weight", "model.last_conv_block.conv.weight", "model.last_conv_block.conv.bias", "model.last_conv_block.pw_conv.weight", "model.last_conv_block.pw_conv.bias", "model.last_conv_block.act.weight", "model.one_more_linear.weight", "model.one_more_linear.bias", "model.act_one_more_linear.weight", "model.last_linear.weight", "model.last_linear.bias", "model.act_final.weight". 

In [43]:
from hydra.utils import instantiate
from omegaconf import OmegaConf


# Ensure 'model_cfg' is correctly accessed
model_cfg = cfg.model.model

# Instantiate the model using Hydra's instantiate method
model = instantiate(model_cfg)

# Pass a dummy tensor to instantiate any lazy modules
dummy_input_tensor = torch.randn(32, 2, 100)  # Adjust dimensions as needed
_ = model(dummy_input_tensor)

# Load the checkpoint
checkpoint = torch.load(last_checkpoint, map_location='cpu')

# Manually load the state_dict
model.load_state_dict(checkpoint['state_dict'], strict=False)

print(model)


Using Default model weights initialization


  checkpoint = torch.load(last_checkpoint, map_location='cpu')


RuntimeError: Error(s) in loading state_dict for FlightModel:
	size mismatch for model.one_more_linear.weight: copying a param with shape torch.Size([100, 1000]) from checkpoint, the shape in current model is torch.Size([100, 200]).

In [42]:
model_cfg

{'model': {'_target_': 'model.FlightModel', 'model': {'_target_': 'net.SimpleNet', 'in_channels': 2, 'out_channels': 2, 'intermediate_channels': 64, 'num_res_blocks': 8, 'num_output_rows': 10, 'dilation': 3, 'kernel_size': 9, 'norm_type': 'instance', 'stride': 1, 'bias': True, 'normalize_inside_of_network': False}, 'coordinate_system_enum': 'LatLongCoordinates', 'loss_fn': 'mse', 'optimizer': None, 'max_num_val_maps': 10, 'n_future_timesteps': 10, 'mean': None, 'std': None, 'learning_rate': 0.0001}}

In [40]:
model = instantiate(model_cfg)
checkpoint = torch.load(last_checkpoint, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'], strict=False)

model

Using Default model weights initialization


  checkpoint = torch.load(last_checkpoint, map_location='cpu')


ConfigAttributeError: Missing key load_state_dict
    full_key: model.load_state_dict
    object_type=dict

In [None]:
import torch

# Instantiate the model components from the config
net = instantiate(model_cfg["model"])
_ = net(dummy_input_tensor)  # To instantiate any lazy modules

# Manually load the checkpoint
checkpoint = torch.load(last_checkpoint, map_location='cpu')

# Instantiate the FlightModel
model = FlightModel(
    model=net, 
    coordinate_system_enum=model_cfg['model']['coordinate_system_enum'], 
    loss_fn=model_cfg['model']['loss_fn'], 
    optimizer=model_cfg['model']['optimizer']
)

# Load the state_dict
model.load_state_dict(checkpoint['state_dict'], strict=False)

print(model)


In [14]:

net = instantiate(cfg.model)


Using Default model weights initialization


In [15]:
net

{'model': FlightModel(
  (model): SimpleNet(
    (first_conv_block): BasicBlock(
      (conv): Conv1d(2, 2, kernel_size=(9,), stride=(1,), padding=same, dilation=(3,), groups=2, padding_mode=replicate)
      (pw_conv): Conv1d(2, 64, kernel_size=(1,), stride=(1,), padding=same, dilation=(3,), padding_mode=replicate)
      (act): PReLU(num_parameters=1)
      (norm): InstanceNorm1d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    )
    (residual_blocks): Sequential(
      (0): ResidualBlock(
        (block1): BasicBlock(
          (conv): Conv1d(64, 64, kernel_size=(9,), stride=(1,), padding=same, dilation=(3,), groups=64, padding_mode=replicate)
          (pw_conv): Conv1d(64, 64, kernel_size=(1,), stride=(1,), padding=same, dilation=(3,), padding_mode=replicate)
          (act): PReLU(num_parameters=1)
          (norm): InstanceNorm1d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        )
        (block2): BasicBlock(
          (conv): 

In [None]:
# Extract model configuration and remove _target_
model_cfg = cfg['model']['model'] 

flight_model_instantiated = omegaconf.instantiate(model_cfg, _recursive_=True)
_ = flight_model_instantiated(example_input)