In [1]:
import pandas as pd
import torch
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import numpy as np
import karman
from karman.nn import *
import os
from torch.utils.data import Subset
from torch import nn
import argparse
from pyfiglet import Figlet
from termcolor import colored
from dataclasses import dataclass
from matplotlib import pyplot as plt
import imageio
from PIL import Image
import io
from IPython.display import Image
from astropy.time import Time
import astropy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
time_and_position_based_features_to_calculate = [
    'all__day_of_year__[d]',
    'all__seconds_in_day__[s]',
    'all__sun_right_ascension__[rad]',
    'all__sun_declination__[rad]',
    'all__sidereal_time__[rad]',
    'tudelft_thermo__longitude__[deg]',
    'tudelft_thermo__latitude__[deg]'
    'tudelft_thermo__local_solar_time__[h]']

def create_datapoint(date, longitude, latitude, altitude):
    sample = {}
    date = pd.to_datetime(date)
    second_in_day = 3600*date.hour + 60*date.minute + date.second
    day_of_year = date.day_of_year
    t=Time(str(date),location=(f'{longitude}d', f'{latitude}d'))
    sunpos=astropy.coordinates.get_sun(t)
    sun_ra = sunpos.ra.rad
    sun_dec = sunpos.dec.rad
    side_real = t.sidereal_time('mean').rad
    lst = (second_in_day/3600.0 + longitude/15.)

    sample['all__seconds_in_day__[s]_sin'] = np.sin(2*np.pi*second_in_day/(24.0*60*60))
    sample['all__seconds_in_day__[s]_cos'] = np.cos(2*np.pi*second_in_day/(24.0*60*60))

    sample['all__day_of_year__[d]_sin'] = np.sin(2*np.pi*day_of_year/366.0)
    sample['all__day_of_year__[d]_cos'] = np.cos(2*np.pi*day_of_year/366.0)

    sample['all__sun_right_ascension__[rad]_sin'] = np.sin(sun_ra)
    sample['all__sun_right_ascension__[rad]_cos'] = np.cos(sun_ra)

    sample['all__sun_declination__[rad]_sin'] = np.sin(sun_dec)
    sample['all__sun_declination__[rad]_cos'] = np.cos(sun_dec)

    sample['all__sidereal_time__[rad]_sin'] = np.sin(side_real)
    sample['all__sidereal_time__[rad]_cos'] = np.cos(side_real)

    sample['tudelft_thermo__local_solar_time__[h]_sin'] = np.sin(2*np.pi*((lst + 12.0)/48))
    sample['tudelft_thermo__local_solar_time__[h]_cos'] = np.cos(2*np.pi*((lst + 12.0)/48))

    sample['tudelft_thermo__latitude__[deg]'] = float(latitude)

    sample['tudelft_thermo__longitude__[deg]_sin'] = np.sin(2*np.pi*((longitude + 180)/360.0))
    sample['tudelft_thermo__longitude__[deg]_cos'] = np.cos(2*np.pi*((longitude + 180)/360.0))

    sample['all__year__[y]'] = float(date.year)
    sample['tudelft_thermo__altitude__[m]'] = altitude
    return sample

In [5]:
model_path = '/home/jupyter/karman-project/data_directory/run_flare_lag_2880/best_model_NoFism2DailyFeedForward_2022-09-08 20:38:18.491403_fold_2_seed_0'
#Stick to this convention 'best_model' is the model of interest.
model_opt=torch.load(model_path)['opt']

In [8]:
model_dataset = karman.ThermosphericDensityDataset(
        directory='/home/jupyter/karman-project/data_directory',
        lag_minutes_omni=model_opt.lag_minutes_omni,
        lag_minutes_fism2_flare_stan_bands=model_opt.lag_fism2_minutes_flare_stan_bands,
        omni_resolution=model_opt.omni_resolution,
        fism2_flare_stan_bands_resolution=model_opt.fism2_flare_stan_bands_resolution,
        fism2_daily_stan_bands_resolution=model_opt.fism2_daily_stan_bands_resolution,
        features_to_exclude_thermo=model_opt.features_to_exclude_thermo.split(','),
        features_to_exclude_omni=model_opt.features_to_exclude_omni.split(','),
        features_to_exclude_fism2_flare_stan_bands=model_opt.features_to_exclude_fism2_flare_stan_bands.split(','),
        features_to_exclude_fism2_daily_stan_bands=model_opt.features_to_exclude_fism2_daily_stan_bands.split(','),
        create_cyclical_features=model_opt.cyclical_features,
    )
model_scaler = model_dataset.data_thermo['scaler']

Loading Omni.
Loading FISM2 Flare Stan bands.
Loading FISM2 Daily Stan bands.
Creating thermospheric density dataset
Creating cyclical features

Finished Creating dataset.


In [9]:
#have to do it like this to keep the correct scaler from the model run.
space_x_dataset = karman.ThermosphericDensityDataset(
        directory='/home/jupyter/karman-project/data_directory',
        lag_minutes_omni=model_opt.lag_minutes_omni,
        lag_minutes_fism2_flare_stan_bands=model_opt.lag_fism2_minutes_flare_stan_bands,
        omni_resolution=model_opt.omni_resolution,
        fism2_flare_stan_bands_resolution=model_opt.fism2_flare_stan_bands_resolution,
        fism2_daily_stan_bands_resolution=model_opt.fism2_daily_stan_bands_resolution,
        features_to_exclude_thermo=model_opt.features_to_exclude_thermo.split(','),
        features_to_exclude_omni=model_opt.features_to_exclude_omni.split(','),
        features_to_exclude_fism2_flare_stan_bands=model_opt.features_to_exclude_fism2_flare_stan_bands.split(','),
        features_to_exclude_fism2_daily_stan_bands=model_opt.features_to_exclude_fism2_daily_stan_bands.split(','),
        create_cyclical_features=model_opt.cyclical_features,
        max_date=pd.to_datetime('2022-06-01')
    )

Loading Omni.
Loading FISM2 Flare Stan bands.
Loading FISM2 Daily Stan bands.
Creating thermospheric density dataset
Creating cyclical features

Finished Creating dataset.


In [12]:
space_x_dataset.time_series_data['fism2_flare_stan_bands']['data'].head()

Unnamed: 0_level_0,fism2_flare_stan_bands__E0_05_0_4__[photons/cm**2/s],fism2_flare_stan_bands__E0_4_0_8__[photons/cm**2/s],fism2_flare_stan_bands__E0_8_1_8__[photons/cm**2/s],fism2_flare_stan_bands__E1_8_3_2__[photons/cm**2/s],fism2_flare_stan_bands__E3_2_7_0__[photons/cm**2/s],fism2_flare_stan_bands__E7_0_15_5__[photons/cm**2/s],fism2_flare_stan_bands__E15_5_22_4__[photons/cm**2/s],fism2_flare_stan_bands__E22_4_29_0__[photons/cm**2/s],fism2_flare_stan_bands__E29_0_32_0__[photons/cm**2/s],fism2_flare_stan_bands__E32_0_54_0__[photons/cm**2/s],...,fism2_flare_stan_bands__E79_8_91_3_low__[photons/cm**2/s],fism2_flare_stan_bands__E79_8_91_3_med__[photons/cm**2/s],fism2_flare_stan_bands__E79_8_91_3_high__[photons/cm**2/s],fism2_flare_stan_bands__E91_3_97_5_low__[photons/cm**2/s],fism2_flare_stan_bands__E91_3_97_5_med__[photons/cm**2/s],fism2_flare_stan_bands__E91_3_97_5_high__[photons/cm**2/s],fism2_flare_stan_bands__E97_5_98_7__[photons/cm**2/s],fism2_flare_stan_bands__E98_7_102_7__[photons/cm**2/s],fism2_flare_stan_bands__E102_7_105_0__[photons/cm**2/s],fism2_flare_stan_bands__E105_0_121_0__[photons/cm**2/s]
all__dates_datetime__,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2003-10-13 00:00:00,9537.197224,268857.862111,34153630.0,85796970.0,615234700.0,1564234000.0,7814994000.0,5780004000.0,8797233000.0,9218649000.0,...,3020788000.0,9110898000.0,4389259000.0,910333800.0,2260406000.0,1753422000.0,7223733000.0,6597644000.0,7610754000.0,31641820000.0
2003-10-13 00:30:00,9447.646047,271250.815862,34371180.0,86235900.0,617842600.0,1569459000.0,7839831000.0,5804002000.0,8814387000.0,9246465000.0,...,3026229000.0,9126623000.0,4396311000.0,911503700.0,2263505000.0,1756247000.0,7229472000.0,6607150000.0,7620247000.0,31689430000.0
2003-10-13 01:00:00,9196.903,269397.509711,34333890.0,86169310.0,617540300.0,1569036000.0,7837834000.0,5802141000.0,8813119000.0,9244285000.0,...,3025817000.0,9125481000.0,4395785000.0,911423000.0,2263291000.0,1756057000.0,7229077000.0,6606495000.0,7619585000.0,31685540000.0
2003-10-13 01:30:00,9376.005269,269981.550397,34305540.0,86106880.0,617247400.0,1568620000.0,7835873000.0,5800322000.0,8811895000.0,9242157000.0,...,3025416000.0,9124375000.0,4395276000.0,911345700.0,2263085000.0,1755875000.0,7228700000.0,6605868000.0,7618949000.0,31681820000.0
2003-10-13 02:00:00,10011.8187,273154.274082,34286620.0,86048730.0,616963400.0,1568210000.0,7833942000.0,5798538000.0,8810710000.0,9240074000.0,...,3025024000.0,9123301000.0,4394779000.0,911271200.0,2262887000.0,1755701000.0,7228336000.0,6605264000.0,7618336000.0,31678530000.0


In [72]:
latitude = 0
longitude = 0
altitude = 200_000
sample_data = create_datapoint(pd.to_datetime('2022-02-01'), longitude, latitude, altitude)

date = pd.to_datetime('2022-02-01')
location_nearest_date = (space_x_dataset.data_thermo['data']['all__dates_datetime__']-date).abs().argsort()[:1]
data = space_x_dataset.data_thermo['data'].iloc[location_nearest_date, :].drop(columns=model_opt.features_to_exclude_thermo.split(',')).copy()
for key, value in sample_data.items():
    data.at[location_nearest_date, key] = value

fism2_date_lag = date - pd.Timedelta(minutes=model_opt.lag_fism2_minutes_flare_stan_bands)
fism2_flare_data = space_x_dataset.time_series_data['fism2_flare_stan_bands']['data'].loc[fism2_date_lag:date, :].copy()

omni_date_lag = date - pd.Timedelta(minutes=model_opt.lag_minutes_omni)
omni_data = space_x_dataset.time_series_data['omni']['data'].loc[omni_date_lag:date, :].copy()

print(len(model_dataset.features_to_exclude_thermo))
print(len(model_dataset.cyclical_features))
print(len(data.columns))
thermo_features = model_dataset.data_thermo['scaler'].transform(data)
fism2_flare_features = model_dataset.time_series_data['fism2_flare_stan_bands']['scaler'].transform(fism2_flare_data)
omni_features = model_dataset.time_series_data['omni']['scaler'].transform(omni_data)

batch = {}
batch['omni'] = torch.FloatTensor(omni_features).unsqueeze(0)
batch['fism2_flare_stan_bands'] = torch.FloatTensor(fism2_flare_features).unsqueeze(0)
batch['instantaneous_features'] = torch.FloatTensor(thermo_features.flatten()).unsqueeze(0)

20
7
33


  f"X has feature names, but {self.__class__.__name__} was fitted without"
  f"X has feature names, but {self.__class__.__name__} was fitted without"
  f"X has feature names, but {self.__class__.__name__} was fitted without"


In [75]:
if model_opt.model == 'NoFism2DailyFeedForward':
    model = NoFism2DailyFeedForward(
        dropout=model_opt.dropout,
        hidden_size=model_opt.hidden_size,
        out_features=model_opt.out_features
        ).to(dtype=torch.float32)

    state_dict = torch.load(os.path.join(model_path))['state_dict']
    #Sanitize state_dict key names
    for key in list(state_dict.keys()):
        if key.startswith('module'):
        # Model was saved as dataparallel model
            # Remove 'module.' from start of key
            state_dict[key[7:]] = state_dict.pop(key)
        else:
            continue
    with torch.no_grad():
        model.forward(batch)
    model.load_state_dict(state_dict)
    print('Loaded state dict')

with torch.no_grad():
    output = model(batch)

model_dataset.unscale_density(output)

Loaded state dict




tensor([[1.6872e-10]])

In [79]:
from tqdm import tqdm
densities = []

for date in tqdm(list(pd.date_range(start='2022-01-30', end='2022-02-05', freq='1440T'))):
    latitude = 0
    longitude = 0
    altitude = 200_000
    sample_data = create_datapoint(date, longitude, latitude, altitude)
    location_nearest_date = (space_x_dataset.data_thermo['data']['all__dates_datetime__']-date).abs().argsort()[:1]
    data = space_x_dataset.data_thermo['data'].iloc[location_nearest_date, :].drop(columns=model_opt.features_to_exclude_thermo.split(',')).copy()
    for key, value in sample_data.items():
        data.at[location_nearest_date, key] = value

    fism2_date_lag = date - pd.Timedelta(minutes=model_opt.lag_fism2_minutes_flare_stan_bands)
    fism2_flare_data = space_x_dataset.time_series_data['fism2_flare_stan_bands']['data'].loc[fism2_date_lag:date, :].copy()

    omni_date_lag = date - pd.Timedelta(minutes=model_opt.lag_minutes_omni)
    omni_data = space_x_dataset.time_series_data['omni']['data'].loc[omni_date_lag:date, :].copy()

    thermo_features = model_dataset.data_thermo['scaler'].transform(data)
    fism2_flare_features = model_dataset.time_series_data['fism2_flare_stan_bands']['scaler'].transform(fism2_flare_data)
    omni_features = model_dataset.time_series_data['omni']['scaler'].transform(omni_data)

    batch = {}
    batch['omni'] = torch.FloatTensor(omni_features).unsqueeze(0)
    batch['fism2_flare_stan_bands'] = torch.FloatTensor(fism2_flare_features).unsqueeze(0)
    batch['instantaneous_features'] = torch.FloatTensor(thermo_features.flatten()).unsqueeze(0)
    with torch.no_grad():
        output = model(batch)
    densities.append(model_dataset.unscale_density(output))

  f"X has feature names, but {self.__class__.__name__} was fitted without"
  f"X has feature names, but {self.__class__.__name__} was fitted without"
  f"X has feature names, but {self.__class__.__name__} was fitted without"
  f"X has feature names, but {self.__class__.__name__} was fitted without"
  f"X has feature names, but {self.__class__.__name__} was fitted without"
  f"X has feature names, but {self.__class__.__name__} was fitted without"
  f"X has feature names, but {self.__class__.__name__} was fitted without"
  f"X has feature names, but {self.__class__.__name__} was fitted without"
  f"X has feature names, but {self.__class__.__name__} was fitted without"
  f"X has feature names, but {self.__class__.__name__} was fitted without"
  f"X has feature names, but {self.__class__.__name__} was fitted without"
  f"X has feature names, but {self.__class__.__name__} was fitted without"
  f"X has feature names, but {self.__class__.__name__} was fitted without"
  f"X has feature names, 

In [None]:
from matplotlib import pyplot as plt
date_start = '2022-01-01'
date_end = '2022-02-28'

for column in dataset.time_series_data['omni']['data'].columns:
    dataset.time_series_data['omni']['data'].loc[date_start:date_end].plot(y=[column])
    plt.xlabel('')
    plt.show()

dataset.time_series_data['omni']['data']['mag_bsn_pos'] = dataset.time_series_data['omni']['data']['omniweb__bsnz_gse__[Re]']**2 + \
    dataset.time_series_data['omni']['data']['omniweb__bsnx_gse__[Re]']**2 +\
    dataset.time_series_data['omni']['data']['omniweb__bsny_gse__[Re]']**2

dataset.time_series_data['omni']['data'].loc[date_start:date_end].plot(y=['mag_bsn_pos'])
plt.xlabel('')
plt.show()

for column in dataset.time_series_data['fism2_flare_stan_bands']['data'].columns:
    dataset.time_series_data['fism2_flare_stan_bands']['data'].loc[date_start:date_end].plot(y=[column])
    plt.xlabel('')
    plt.show()