# Expanding the Capabilities of SDO: Autocalibration
In this notebook, we demonstration the SDO Autocalibration model, as described in [Dos Santos et al., 2021](https://ui.adsabs.harvard.edu/abs/2021A%26A...648A..53D/abstract)

Luiz F. G. dos Santos, Souvik Bose, Valentina Salvatelli, Brad Neuberg, Mark C. M. Cheung, Miho Janvier, Meng Jin, Yarin Gal, Paul Boerner, Atılım Güneş Baydin

---

## Introduction

The main dataset used for the project is the SDO ML dataset (see Galvez et al., 2019). For this notebook, we utilize a dataset that is not corrected for degradation, and is available from [Zenodo](https://zenodo.org/record/4430801#.X_xiP-lKhmE)

---

## Table of Contents

The notebook is set out as follows:

1. Setting up the notebook
2. Reading and loading the SDO/AIA data
3. Autocalibration Inference <br>
    3a. Plotting <br>
    3b. Single-channel Model
3. Plotting the degradation curves
4. Downloading & Correcting AIA images
5. Discussion

## 1. Setting up the notebook


In [1]:
import datetime
import math

import astropy.time
import torch
import sunpy.map

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from sdo.datasets.degradation_sdo_dataset import DegradationSDO_Dataset
from sdo.pytorch_utilities import create_dataloader
from sdo.models.autocalibration_models import Autocalibration6, Autocalibration10
from astropy.visualization import time_support
from scipy.stats import ks_2samp
from scipy.spatial.distance import euclidean
from fastdtw import fastdtw

from astropy.visualization import ImageNormalize, SqrtStretch, time_support
from sunpy.net import Fido, attrs

# from aiapy.calibrate import correct_degradation
# from aiapy.calibrate.util import get_correction_table

ModuleNotFoundError: No module named 'torch'

In [None]:
data_basedir = "/media/paul/data/autocal_npz/data_small/"
data_inventory = "/media/paul/data/autocal_npz/small_inventory.pkl"
results_path = "/media/paul/data/autocal_npz/results/"
experiment_name = "luiz_exp_36_apodize"


instr = ["AIA", "AIA", "AIA", "AIA", "AIA", "AIA", "AIA"]
channels = ["0094", "0131", "0171", "0193", "0211", "0304", "0335"]
channels_names = [
    "$94~\AA$",
    "$131~\AA$",
    "$171~\AA$",
    "$193~\AA$",
    "$211~\AA$",
    "$304~\AA$",
    "$335~\AA$",
]

## 2. Reading and loading the SDO/AIA data

In [None]:
model_file = "/media/paul/data/autocal_npz/Models/luiz_exp_36_apodize_model.pth"

In [None]:
test_dataset = DegradationSDO_Dataset(
    data_basedir=data_basedir,
    data_inventory=data_inventory,
    instr=instr,
    channels=channels,
    yr_range=[2010, 2020],
    mnt_step=1,
    day_step=7,
    h_step=1,
    min_step=1,
    resolution=512,
    subsample=2,
    normalization=0,
    scaling=True,
    apodize=True,
    shuffle=False,
    holdout=False,
    test=True,
    test_ratio=1,
)

test_loader = create_dataloader(
    test_dataset, batch_size=24, num_dataloader_workers=4, shuffle=False, train=False
)

In [None]:
# this step could potentially take a while

model = Autocalibration6(input_shape=[7, 256, 256], output_dim=7)
model.load_state_dict(
    torch.load(model_file, map_location=torch.device("cuda"))
)  # need to use cuda if available.

## 3. Autocalibration Inference
### 3a. Multi-channel Model

In [None]:
final_predictions = []
with torch.no_grad():
    for batch_idx, (input_data, dates) in enumerate(test_loader):
        if batch_idx == 0:
            output = model(input_data)
            temp_degradation_multi = output
            dates_multi_channel_array = np.array(dates)
        else:
            output = model(input_data)
            temp_degradation_multi = torch.cat((temp_degradation_multi, output), 0)
            dates_multi_channel_array = np.append(
                dates_multi_channel_array, dates, axis=0
            )

degradation_multi_channel = temp_degradation_multi.detach().numpy()

### 3a. Single-channel Model

In [None]:
degradation_single_channel = []
dates_single_channel_array = []

model_file = [
    "1000_luiz_exp_33_0094_masked_model.pth",
    "1000_luiz_exp_33_0131_masked_model.pth",
    "1000_luiz_exp_33_0171_masked_model.pth",
    "1000_luiz_exp_33_0193_masked_model.pth",
    "1000_luiz_exp_33_0211_masked_model.pth",
    "1000_luiz_exp_33_0304_masked_model.pth",
    "1000_luiz_exp_33_0335_masked_model.pth",
]

for c in range(len(channels)):
    test_dataset = DegradationSDO_Dataset(
        data_basedir=data_basedir,
        data_inventory=data_inventory,
        instr=[instr[c]],
        channels=[channels[c]],
        yr_range=[2010, 2020],
        mnt_step=1,
        day_step=7,
        h_step=1,
        min_step=1,
        resolution=512,
        subsample=2,
        normalization=0,
        scaling=True,
        apodize=True,
        shuffle=False,
        holdout=False,
        test_ratio=1,
        test=True,
    )

    test_loader = create_dataloader(
        test_dataset,
        batch_size=128,
        num_dataloader_workers=8,
        shuffle=False,
        train=False,
    )

    model = Autocalibration6(input_shape=[1, 256, 256], output_dim=1)
    model.load_state_dict(
        torch.load(
            "/home/paul/Documents/SpaceML/sdoml_data/Models/" + model_file[c],
            map_location=torch.device("cpu"),
        )
    )

    final_predictions = []
    with torch.no_grad():
        for batch_idx, (input_data, dates) in enumerate(test_loader):
            if batch_idx == 0:
                output = model(input_data)
                temp_degradation_single = output
                temp_dates_single = np.array(dates)
            else:
                output = model(input_data)
                temp_degradation_single = torch.cat(
                    (temp_degradation_single, output), 0
                )
                temp_dates_single = np.append(temp_dates_single, dates, axis=0)

    temp_degradation_single = temp_degradation_single.detach().numpy()
    degradation_single_channel.append(temp_degradation_single)
    dates_single_channel_array.append(temp_dates_single)

#### Converting all times to astropy time.

In [None]:
eve_date = ["20140526", "20140526"]  # LasMAriah it date with EVE MEGS-A data.
last_training_date = ["20131231", "20131231"]  # Last date with EVE MEGS-A data.
xticks = [
    "20100101",
    "20110101",
    "20120101",
    "20130101",
    "20140101",
    "20150101",
    "20160101",
    "20170101",
    "20180101",
    "20190101",
    "20200101",
]

dates_multi_channel_str = list(
    map(
        "{:4d}{:02d}{:02d}{:02d}{:02d}".format,
        dates_multi_channel_array[:, 0],
        dates_multi_channel_array[:, 1],
        dates_multi_channel_array[:, 2],
        dates_multi_channel_array[:, 3],
        dates_multi_channel_array[:, 4],
    )
)

dates_multi = [
    datetime.datetime.strptime(i, "%Y%m%d%H%M").date() for i in dates_multi_channel_str
]
dates_multi = astropy.time.Time([astropy.time.Time(i.isoformat()) for i in dates_multi])

eve_line = [datetime.datetime.strptime(i, "%Y%m%d").date() for i in eve_date]
eve_line = astropy.time.Time([astropy.time.Time(i.isoformat()) for i in eve_line])

last_training_line = [
    datetime.datetime.strptime(i, "%Y%m%d").date() for i in last_training_date
]
last_training_line = astropy.time.Time(
    [astropy.time.Time(i.isoformat()) for i in last_training_line]
)

xticks = [datetime.datetime.strptime(i, "%Y%m%d").date() for i in xticks]
xticks = astropy.time.Time([astropy.time.Time(i.isoformat()) for i in xticks])

dates_single = []
for c in range(len(channels)):
    dates_single_channel_str = list(
        map(
            "{:4d}{:02d}{:02d}".format,
            dates_single_channel_array[c][:, 0],
            dates_single_channel_array[c][:, 1],
            dates_single_channel_array[c][:, 2],
        )
    )
    temp = [
        datetime.datetime.strptime(i, "%Y%m%d").date() for i in dates_single_channel_str
    ]  # x_values for 0094
    temp = astropy.time.Time([astropy.time.Time(i.isoformat()) for i in temp])
    dates_single.append(temp)

#### Defining two fuctions to use as a median filter

In [None]:
def date_median(dates):
    median = math.floor(np.median(dates.astype("int64")))
    result = np.datetime64(median, "ns")  # unit: nanosecond
    return result


def moving_median(data, dates, window):
    iti = int(np.floor(data.shape[0] / window))
    median = []
    date = []

    for i in range(iti - 1):
        median.append(np.median(data[(i * window) : ((i + 1) * window)]))
        date.append(date_median(dates[(i * window) : ((i + 1) * window)]))

    return np.array(median, dtype="float64"), pd.to_datetime(date)


def moving_std(data, dates, window):
    iti = int(np.floor(data.shape[0] / window))
    std = []

    for i in range(iti):
        std.append(np.std(data[(i * window) : ((i + 1) * window)]))

    return np.array(std, dtype="float64")

---

## 4. Plotting the Degradation Curve

In [None]:
pwd

In [None]:
v9dat = None

In [None]:
dates = []

In [None]:
# v9dat = pd.read_csv('/media/paul/data/autocal_npz/data_v8_ratios.csv', parse_dates=True,
#                    names=["DATE"] + channels, index_col="DATE",skiprows=1, header=0) #Reading the degradations obtained from AIApy for curve V9
v8dat = pd.read_csv(
    "/media/paul/data/autocal_npz/data_v8table.csv",
    parse_dates=True,
    names=["DATE"] + channels,
)  # Reading the degradations obtained from AIApy for curve V8

v8dat["DATE"].replace(
    "Z", ".000", regex=True, inplace=True
)  # Fixing some inconsistent data formating
v8dat["DATE"] = pd.to_datetime(v8dat["DATE"], infer_datetime_format=True)
v8dat.set_index("DATE", inplace=True)

time_support()  # Adding support of Astropy date to plot.
colors = ["Blue", "Orange", "Green", "Red", "Purple", "Brown", "Magenta"]

fig, ax = plt.subplots(4, 2, figsize=(17, 22), dpi=300)
fig.subplots_adjust(wspace=0.1, hspace=0.18)
ax = ax.ravel()

for c in range(len(channels)):
    # v9dat[channels[c]] = v9dat[channels[c]].astype(np.float64)
    v8dat[channels[c]] = v8dat[channels[c]].astype(np.float64)

    # v9dat[channels[c]] = v9dat[channels[c]]/v9dat[channels[c]][0]
    v8dat[channels[c]] = v8dat[channels[c]] / v8dat[channels[c]][0]
    # v9med_data, v9med_dates = moving_median(v9dat[channels[c]],v9dat.index,28)
    v8med_data, v8med_dates = moving_median(v8dat[channels[c]], v8dat.index, 28)

    multi_med_data, multi_med_dates = moving_median(
        degradation_multi_channel[:, c], pd.DataFrame(dates_multi.tt.datetime), 15
    )
    multi_std_data = moving_std(
        degradation_multi_channel[:, c], pd.DataFrame(dates_multi.tt.datetime), 15
    )
    single_med_data, single_med_dates = moving_median(
        degradation_single_channel[c], pd.DataFrame(dates_single[c].tt.datetime), 15
    )
    single_std_data = moving_std(
        degradation_single_channel[c], pd.DataFrame(dates_single[c].tt.datetime), 15
    )

    # euv_v9_times = astropy.time.Time([astropy.time.Time(i.isoformat()) for i in v9med_dates])
    euv_v8_times = astropy.time.Time(
        [astropy.time.Time(i.isoformat()) for i in v8med_dates]
    )
    multi_med_times = astropy.time.Time(
        [astropy.time.Time(i.isoformat()) for i in multi_med_dates]
    )
    single_med_times = astropy.time.Time(
        [astropy.time.Time(i.isoformat()) for i in single_med_dates]
    )

    ax[c].fill_between(
        dates_multi[0:209],
        1.5,
        0,
        color=colors[c],
        label="Training data period",
        alpha=0.05,
    )
    ax[c].plot(
        last_training_line, [0, 2], "--", color="black", label="Last Training date"
    )
    ax[c].plot(eve_line, [0, 2], "--", color="gray", label="Last EVE MEGS-A data")
    ax[c].plot(
        multi_med_times,
        multi_med_data,
        label=channels_names[c] + " - Multi-Channel",
        linewidth=2,
        color=colors[c],
    )
    ax[c].plot(
        single_med_times,
        single_med_data,
        "--",
        label=channels_names[c] + " - Single-Channel",
        linewidth=2,
        color=colors[c],
        alpha=0.7,
    )
    # ax[c].plot(euv_v9_times, v9med_data, color='black',\
    #           label=channels_names[c]+' - Degradation V9', linewidth=3, alpha=0.9)
    ax[c].plot(
        euv_v8_times,
        v8med_data,
        color="gray",
        label=channels_names[c] + " - Degradation V8",
        linewidth=3,
        alpha=0.9,
    )

    # up = [m + n for m, n in zip(v9med_data, v9med_data*0.28)]
    up_multi = [m + n for m, n in zip(multi_med_data, multi_std_data)]
    up_single = [m + n for m, n in zip(single_med_data, single_std_data)]
    # down = [m - n for m, n in zip(v9med_data,v9med_data*0.28)]
    down_multi = [m - n for m, n in zip(multi_med_data, multi_std_data)]
    down_single = [m - n for m, n in zip(single_med_data, single_std_data)]

    # ax[c].fill_between(euv_v9_times,up,down,color='gray', alpha=0.1)
    ax[c].fill_between(
        multi_med_times,
        up_multi,
        down_multi,
        color=colors[c],
        label="Standard Deviation",
        alpha=0.1,
    )
    ax[c].fill_between(
        single_med_times, up_single, down_single, color=colors[c], alpha=0.1
    )

    if c % 2 == 0:
        ax[c].set_ylabel("Degradation", fontsize=14)

    ax[c].set_xlabel("Time (UTC)", fontsize=14)
    plt.xticks(np.arange(0, 4), fontsize=14)
    plt.yticks(fontsize=14)
    ax[c].legend(fontsize=10)
    ax[c].set_ylim(0, 1.05)
    ax[c].set_xlim(multi_med_times[0] - 1, multi_med_times[-1])

ax[c + 1].set_visible(False)

**Figure 2**

In [None]:
autocal_table = []

for i in range(0, len(dates_multi_channel_array)):
    d = {
        **{"DATE": dates_multi.tt.datetime[i]},
        **{
            channels[n]: degradation_multi_channel[i][n]
            for n in range(len(degradation_multi_channel[i]))
        },
    }
    my_df.append(d)

autocal_table = pd.DataFrame(autocal_table)

## 5. Downloading & Correcting AIA images

First, we will obtain AIA images from 

In [None]:
# During testing Fido.search returned zero results when requesting 1 image a year for 10 years.
# This is now done in two chunks.

q_one = Fido.search(
    attrs.Time("2010-06-01T00:00:00", "2015-01-02T00:00:00"),
    attrs.Sample(1 * u.year),
    attrs.Instrument("AIA"),
    attrs.Wavelength(304 * u.angstrom),
)

q_two = Fido.search(
    attrs.Time("2016-06-01T00:00:00", "2020-01-01T00:00:00"),
    attrs.Sample(1 * u.year),
    attrs.Instrument("AIA"),
    attrs.Wavelength(304 * u.angstrom),
)

In [None]:
files_one = Fido.fetch(q_one)
files_two = Fido.fetch(q_two)
all_files = files_one + files_two

In [None]:
def correct_degradation(smap, table):
    """
    Correct degradation using time-step in the correction table that is closest to the observation date.
    """
    index = (
        table["DATE"]
        .sub(pd.to_datetime(smap.date.value, infer_datetime_format=True))
        .abs()
        .idxmin()
    )
    num = smap.meta["wavelnth"]
    return smap._new_instance(
        smap.data / table.iloc[index][f"{int(num):04}"], smap.meta
    )

In [None]:
maps = sunpy.map.Map(sorted(all_files))
maps_corrected = [correct_degradation(m, autocal_table) for m in maps]

In [None]:
# We set the image normalisation constant across the images
norm = ImageNormalize(vmin=0, vmax=4e2, stretch=SqrtStretch())

fig = plt.figure(figsize=(len(maps) * 3, 6))
plt.subplots_adjust(wspace=-0.2, hspace=0)

for i, (m, mc) in enumerate(zip(maps, maps_corrected)):
    ax = fig.add_subplot(2, len(maps), i + 1, projection=m)
    m.plot(axes=ax, norm=norm, annotate=False)
    ax.set_title(m.date.datetime.year)
    ax.coords[0].set_ticks_visible(False)
    ax.coords[0].set_ticklabel_visible(False)
    ax.coords[1].set_ticks_visible(False)
    ax.coords[1].set_ticklabel_visible(False)
    ax.set_aspect("equal")
    if i == 0:
        ax.set_ylabel("uncorrected")

    ax = fig.add_subplot(2, len(maps), i + 1 + len(maps), projection=mc)
    mc.plot(axes=ax, norm=norm, annotate=False)
    ax.coords[0].set_ticks_visible(False)
    ax.coords[0].set_ticklabel_visible(False)
    ax.coords[1].set_ticks_visible(False)
    ax.coords[1].set_ticklabel_visible(False)
    ax.set_aspect("equal")

**Figure 3:**

## 5. Discussion

---