In [None]:
import pandas as pd
import torch
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import numpy as np
import karman
from karman.nn import *
import os
from torch.utils.data import Subset
from torch import nn
import argparse
from pyfiglet import Figlet
from termcolor import colored
from dataclasses import dataclass
from matplotlib import pyplot as plt
import imageio
from PIL import Image
import io
from IPython.display import Image
from astropy.time import Time


In [None]:
@dataclass
class Opt():
    x: 1


In [None]:
print('Karman Model Evaluation')

f = Figlet(font='5lineoblique')
print(colored(f.renderText('KARMAN'), 'red'))

f = Figlet(font='digital')
print(colored(f.renderText("Thermospheric  density  calculations"), 'blue'))
print(colored(f'Version {karman.__version__}\n','blue'))
opt = Opt(x=1)
opt.model_name = 'no_lag'
opt.model_folder = '/home/jupyter/karman-project/experiment_no_lag'
opt.batch_size = 512
opt.num_workers = 30
opt.data_directory = '/home/jupyter/karman-project/data_directory'

#Stick to this convention 'best_model' is the model of interest.
model_path = [os.path.join(opt.model_folder, f) for f in os.listdir(opt.model_folder) if 'best_model' in f][0]

print('Loading Data')
model_opt=torch.load(model_path)['opt']

dataset = karman.ThermosphericDensityDataset(
    directory=opt.data_directory,
    exclude_omni=model_opt.exclude_omni,
    exclude_fism2_daily=model_opt.exclude_fism2_daily,
    exclude_fism2_flare=model_opt.exclude_fism2_flare,
    lag_minutes_omni=model_opt.lag_minutes_omni,
    lag_days_fism2_daily=model_opt.lag_days_fism2_daily,
    lag_minutes_fism2_flare=model_opt.lag_minutes_fism2_flare,
    wavelength_bands_to_skip=model_opt.wavelength_bands_to_skip,
    omniweb_downsampling_ratio=model_opt.omniweb_downsampling_ratio,
    features_to_exclude_omni=model_opt.features_to_exclude_omni,
    features_to_exclude_thermo=model_opt.features_to_exclude_thermo,
    features_to_exclude_fism2_flare=model_opt.features_to_exclude_fism2_flare,
    features_to_exclude_fism2_daily=model_opt.features_to_exclude_fism2_daily
)

print('Loading Model')

if model_opt.model == 'FeedForwardDensityPredictor':
    # Will only use an FFNN with just the thermo static features data
    model = FeedForwardDensityPredictor(
                        num_features=dataset.data_thermo_matrix.shape[1]
                        )
elif model_opt.model=='Fism2FlareDensityPredictor':
    if model_opt.exclude_fism2_daily and model_opt.exclude_omni:
        model=Fism2FlareDensityPredictor(
                        input_size_thermo=dataset.data_thermo_matrix.shape[1],
                        input_size_fism2_flare=dataset.fism2_flare_irradiance_matrix.shape[1],
                        output_size_fism2_flare=20
                        )
    else:
        raise RuntimeError(f"exclude_fism2_daily and exclude_omni are not set to True; while model chosen is {model_opt.model}")
elif model_opt.model=='Fism2DailyDensityPredictor':
    if model_opt.exclude_fism2_flare and model_opt.exclude_omni:
        model=Fism2DailyDensityPredictor(
                        input_size_thermo=dataset.data_thermo_matrix.shape[1],
                        input_size_fism2_daily=dataset.fism2_daily_irradiance_matrix.shape[1],
                        output_size_fism2_daily=20
                        )
    else:
        raise RuntimeError(f"exclude_fism2_flare and exclude_omni are not set to True; while model chosen is {model_opt.model}")
elif model_opt.model=='OmniDensityPredictor':
    if model_opt.exclude_fism2_daily and model_opt.exclude_fism2_flare:
        model=OmniDensityPredictor(
                        input_size_thermo=dataset.data_thermo_matrix.shape[1],
                        input_size_omni=dataset.data_omni_matrix.shape[1],
                        output_size_omni=20
                        )
    else:
        raise RuntimeError(f"exclude_fism2_daily and exclude_fism2_flare are not set to True; while model chosen is {model_opt.model}")
elif model_opt.model == 'FullFeatureDensityPredictor':
    if model_opt.exclude_omni==False and model_opt.exclude_fism2_flare==False and model_opt.exclude_fism2_daily==False:
        model = FullFeatureDensityPredictor(
                        input_size_thermo=dataset.data_thermo_matrix.shape[1],
                        input_size_fism2_flare=dataset.fism2_flare_irradiance_matrix.shape[1],
                        input_size_fism2_daily=dataset.fism2_daily_irradiance_matrix.shape[1],
                        input_size_omni=dataset.data_omni_matrix.shape[1],
                        output_size_fism2_flare=20,
                        output_size_fism2_daily=20,
                        output_size_omni=20
                        )
state_dict = torch.load(os.path.join(model_path))['state_dict']
#Sanitize state_dict key names
for key in list(state_dict.keys()):
    if key.startswith('module'):
    # Model was saved as dataparallel model
        # Remove 'module.' from start of key
        state_dict[key[7:]] = state_dict.pop(key)
    else:
        continue
model.load_state_dict(state_dict)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


In [None]:
with open(os.path.join(opt.data_directory, "test_indices.txt"), 'r') as f:
    test_indices = [int(line.rstrip()) for line in f]

In [None]:
batch = dataset.__getitem__(test_indices[0])

In [None]:
batch['static_features']


In [None]:
dataset.data_thermo.loc[dataset.index_list[test_indices[0]]]

In [None]:
dataset.data_thermo['all__seconds_in_day__[s]'].max()

In [None]:
year = 2009
day = 109
second_in_day = 46935
lon = -36.9556
lat = 60
date = pd.to_datetime(f'{year}/{str(day).zfill(3)}', format='%Y/%j') + pd.Timedelta(seconds=second_in_day)
t=Time(str(date),location=(str(float(lon))+'d',str(float(lat))+'d'))
sunpos=astropy.coordinates.get_sun(t)
print(sunpos.ra.rad)
print(sunpos.dec.rad)
print(t.sidereal_time('mean').rad)


In [None]:
asc_min = dataset.data_thermo['all__sun_right_ascension__[rad]'].min()
asc_max = dataset.data_thermo['all__sun_right_ascension__[rad]'].max()

decl_min = dataset.data_thermo['all__sun_declination__[rad]'].min()
decl_max = dataset.data_thermo['all__sun_declination__[rad]'].max()

sidereal_min = dataset.data_thermo['all__sidereal_time__[rad]'].min()
sidereal_max = dataset.data_thermo['all__sidereal_time__[rad]'].max()

In [None]:
for key in batch.keys():
    batch[key] = batch[key].unsqueeze(0)

In [None]:
import astropy
images = []
year=2009
for day_of_year in range(109,110):
    for time_of_day in tqdm(range(0, (24*60*60), 1200)):
        thermo_map = np.zeros((60, 120))
        date = pd.to_datetime(f'{year}/{str(day_of_year).zfill(3)}', format='%Y/%j') + pd.Timedelta(seconds=time_of_day)
        str_date = str(date)
        for lon in range(-180,180,3):
            # doesnt appear these are dependent on latitude so I can leave them on the same lat as they take forever
            t=Time(str_date,location=(f'{lon}d', f'0d'))
            sunpos=astropy.coordinates.get_sun(t)
            sun_ra=sunpos.ra.rad
            sun_dec=sunpos.dec.rad
            side_real = t.sidereal_time('mean').rad
            new_batch = {}
            new_batch['static_features'] = torch.cat(60*[batch['static_features']])
            for key in batch.keys():
                if key != 'static_features':
                    new_batch[key] = torch.cat(60*[batch[key]])
            for i, lat in enumerate(range(-90,90,3)):
                # Maybe weheer there is more training point
                new_batch['static_features'][i,3] = 0.5
                new_batch['static_features'][i,2] = time_of_day/(24*60*60)
                new_batch['static_features'][i,4] = float((lon +180)/360)
                new_batch['static_features'][i,5] = float((lat +90)/360)
                lst=(((time_of_day)/3600.) + (lon)/15.)/(36)
                new_batch['static_features'][i,6] = lst
                new_batch['static_features'][i,0] = day_of_year/366.

                new_batch['static_features'][i,17] = (sun_ra - asc_min)/(asc_max-asc_min)
                new_batch['static_features'][i,18] = (sun_dec - decl_min)/(decl_max-decl_min)
                new_batch['static_features'][i,19] = (side_real - sidereal_min)/(sidereal_max-sidereal_min)
            with torch.no_grad():
                output = model.forward(new_batch)
            thermo_map[:, int((lon+180)/3)] = np.squeeze(output.detach().cpu().numpy())
        images.append(thermo_map)

for i, image in enumerate(images):
    plt.imshow(image, vmin=0.75, vmax=0.81)
    plt.savefig('/home/jupyter/'+str(i)+'.png')

with imageio.get_writer('day.gif', mode='I') as writer:
    for filename in ['/home/jupyter/'+str(i)+'.png' for i in range(len(images))]:
        image = imageio.imread(filename)
        writer.append_data(image)
Image(filename="day.gif")

In [None]:
for image in images:
    print(image.min(), image.max())