In [None]:
import pandas as pd
import torch
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import numpy as np
import karman
from karman.nn import *
import os
from torch.utils.data import Subset
from torch import nn
import argparse
from pyfiglet import Figlet
from termcolor import colored
from dataclasses import dataclass
from matplotlib import pyplot as plt
import imageio
from PIL import Image
import io
from IPython.display import Image
from astropy.time import Time
import astropy

In [None]:
time_and_position_based_features_to_calculate = [
    'all__day_of_year__[d]',
    'all__seconds_in_day__[s]',
    'all__sun_right_ascension__[rad]',
    'all__sun_declination__[rad]',
    'all__sidereal_time__[rad]',
    'tudelft_thermo__longitude__[deg]',
    'tudelft_thermo__latitude__[deg]'
    'tudelft_thermo__local_solar_time__[h]']

def create_datapoint(date, longitude, latitude, altitude):
    sample = {}
    date = pd.to_datetime(date)
    second_in_day = 3600*date.hour + 60*date.minute + date.second
    day_of_year = date.day_of_year
    t=Time(str(date),location=(f'{longitude}d', f'{latitude}d'))
    sunpos=astropy.coordinates.get_sun(t)
    sun_ra = sunpos.ra.rad
    sun_dec = sunpos.dec.rad
    side_real = t.sidereal_time('mean').rad
    lst = (second_in_day/3600.0 + longitude/15.)

    sample['all__seconds_in_day__[s]_sin'] = np.sin(2*np.pi*second_in_day/(24.0*60*60))
    sample['all__seconds_in_day__[s]_cos'] = np.cos(2*np.pi*second_in_day/(24.0*60*60))

    sample['all__day_of_year__[d]_sin'] = np.sin(2*np.pi*day_of_year/366.0)
    sample['all__day_of_year__[d]_cos'] = np.cos(2*np.pi*day_of_year/366.0)

    sample['all__sun_right_ascension__[rad]_sin'] = np.sin(sun_ra)
    sample['all__sun_right_ascension__[rad]_cos'] = np.cos(sun_ra)

    sample['all__sun_declination__[rad]_sin'] = np.sin(sun_dec)
    sample['all__sun_declination__[rad]_cos'] = np.cos(sun_dec)

    sample['all__sidereal_time__[rad]_sin'] = np.sin(side_real)
    sample['all__sidereal_time__[rad]_cos'] = np.cos(side_real)

    sample['tudelft_thermo__local_solar_time__[h]_sin'] = np.sin(2*np.pi*((lst + 12.0)/48))
    sample['tudelft_thermo__local_solar_time__[h]_cos'] = np.cos(2*np.pi*((lst + 12.0)/48))

    sample['tudelft_thermo__latitude__[deg]'] = float(latitude)

    sample['tudelft_thermo__longitude__[deg]_sin'] = np.sin(2*np.pi*((longitude + 180)/360.0))
    sample['tudelft_thermo__longitude__[deg]_cos'] = np.cos(2*np.pi*((longitude + 180)/360.0))

    sample['all__year__[y]'] = float(date.year)
    sample['tudelft_thermo__altitude__[m]'] = altitude
    return sample

In [None]:
model_path = '/home/jupyter/karman-project/output_directory/best_model_NoFism2DailyFeedForward_2022-09-05 15:10:48.827248_fold_1'
#Stick to this convention 'best_model' is the model of interest.
model_opt=torch.load(model_path)['opt']

In [None]:
dataset=karman.ThermosphericDensityDataset(
        directory='/home/jupyter/karman-project/data_directory',
        lag_minutes_omni=model_opt.lag_minutes_omni,
        lag_minutes_fism2_flare_stan_bands=model_opt.lag_fism2_minutes_flare_stan_bands,
        omni_resolution=model_opt.omni_resolution,
        fism2_flare_stan_bands_resolution=model_opt.fism2_flare_stan_bands_resolution,
        fism2_daily_stan_bands_resolution=model_opt.fism2_daily_stan_bands_resolution,
        features_to_exclude_thermo=model_opt.features_to_exclude_thermo.split(','),
        features_to_exclude_omni=model_opt.features_to_exclude_omni.split(','),
        features_to_exclude_fism2_flare_stan_bands=model_opt.features_to_exclude_fism2_flare_stan_bands.split(','),
        features_to_exclude_fism2_daily_stan_bands=model_opt.features_to_exclude_fism2_daily_stan_bands.split(','),
        create_cyclical_features=model_opt.cyclical_features,
        max_date='2022-06-01'
    )

In [None]:
latitude = 0
longitude = 0
altitude = 200_000
sample_data = create_datapoint(pd.to_datetime('2022-02-01'), longitude, latitude, altitude)

date = pd.to_datetime('2022-02-01')
location_nearest_date = (dataset.data_thermo['data']['all__dates_datetime__']-date).abs().argsort()[:1]
data = dataset.data_thermo['data'].iloc[location_nearest_date, :].copy()
for key, value in sample_data.items():
    data.at[location_nearest_date, key] = value

model_ready_input = dataset.data_thermo['scaler'].transform(data.drop(columns=dataset.features_to_exclude_thermo+dataset.cyclical_features))

In [None]:
set(dataset.data_thermo['data'].columns) - set(sample_data.keys())

In [None]:
dataset.features_to_exclude_thermo

In [None]:
dataset.data_thermo['data'].columns

In [None]:
(dataset.data_thermo['data'][dataset.data_thermo['data']['all__dates_datetime__'] == nearest_date]).columns

In [None]:
dataset.__getdate__(pd.to_datetime('2022-02-01'))

In [None]:
from matplotlib import pyplot as plt
date_start = '2022-01-01'
date_end = '2022-02-28'

for column in dataset.time_series_data['omni']['data'].columns:
    dataset.time_series_data['omni']['data'].loc[date_start:date_end].plot(y=[column])
    plt.xlabel('')
    plt.show()

dataset.time_series_data['omni']['data']['mag_bsn_pos'] = dataset.time_series_data['omni']['data']['omniweb__bsnz_gse__[Re]']**2 + \
    dataset.time_series_data['omni']['data']['omniweb__bsnx_gse__[Re]']**2 +\
    dataset.time_series_data['omni']['data']['omniweb__bsny_gse__[Re]']**2

dataset.time_series_data['omni']['data'].loc[date_start:date_end].plot(y=['mag_bsn_pos'])
plt.xlabel('')
plt.show()

for column in dataset.time_series_data['fism2_flare_stan_bands']['data'].columns:
    dataset.time_series_data['fism2_flare_stan_bands']['data'].loc[date_start:date_end].plot(y=[column])
    plt.xlabel('')
    plt.show()