# Settings

In [61]:
save_preprocessed_data = True
save_results = True

# Asteroid data
asteroid_location = '1AU' # Options: 1AU, 2.3AU
asteroid_irradiation_type = 12 # L_M = 12, I_H_M = 3

# GP model
cycles = 30 # Number of models used for evaluation
training_iterations = 150 # Default 150, optimal iteration count may change if model is modified

# Imports

In [62]:
import pandas as pd
import os
import numpy as np
import torch
import torch.nn.functional as F
import gpytorch
from math_functions import mean_squared_error, root_mean_squared_error, denoise_and_norm
import tqdm

%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Data Preprocessing

#### Sample Information

In [63]:
file_path = 'data/data_all_H_laser(august2).xlsx'
sample_info = pd.read_excel(file_path, sheet_name=0)

print(sample_info.dtypes, '\n')
sample_info

Sample name     object
type            object
Irradiation      int64
ET_1AU         float64
ET_2.3AU       float64
dtype: object 



Unnamed: 0,Sample name,type,Irradiation,ET_1AU,ET_2.3AU
0,PO_TXH_007,L_OL,10,4.227972e+08,1.358991e+09
1,PO_TXH_008,L_OL,10,8.455945e+08,2.717982e+09
2,PO_TXH_081_CP1,L_OL,10,4.227972e+08,1.358991e+09
3,PO_TXH_081_CP2,L_OL,10,8.455945e+08,2.717982e+09
4,RB_LE4OLV,L_OL,10,7.610350e+08,2.446184e+09
...,...,...,...,...,...
163,OC_LP_A2_2,I_H_M,3,2.361833e+02,1.249410e+03
164,OC_LP_A3_2,I_H_M,3,5.992057e+02,3.169798e+03
165,OC_LP_A1_3,I_H_M,3,9.491069e+01,5.020775e+02
166,OC_LP_A2_3,I_H_M,3,2.361833e+02,1.249410e+03


#### Load Wavelength and Reflectance

In [64]:
extract_path = 'data/reflectance_data'

too_large_first_wavelength = ['MJL_OLV_1', 'MJL_OLV_2', 'MJL_OLV_3']# These have too large first wavelength
discarded_samples = ['KC_OL_lm_12', 'KC_OL_lvn_12', 'RB_LE3CPX', 'KC_OPX_lvn_11', 'KC_OPX_lm_11', 'TJ_OPX_1'] # Discarded because exposure times were unrealisticly high
unwanted_samples = too_large_first_wavelength + discarded_samples

# Initialize empty dictionaries to hold data for W and R
w_data = {}
r_data = {}

# Iterate through the extracted files and process CSV files
for root, dirs, files in os.walk(extract_path):
    for file in files:
        if file.endswith('.csv'):
            file_path = os.path.join(root, file)
            df = pd.read_csv(file_path, header=None)

            # Extract the filename without extension to use as column header
            file_name = file.split('.')[0]
            if file_name in sample_info['Sample name'].values and file_name not in unwanted_samples:
                df.columns = df.iloc[1]
                df = df[2:].astype(float)

                wavelengths = df["W"].to_numpy() * 1000
                reflectances = df["R"].to_numpy()

                w_data[file_name] = wavelengths
                r_data[file_name] = reflectances

#### Preprocess Wavelength Data

In [65]:
max_first_wavelength = float('-inf')
for key, value in w_data.items():
    if value[0] > max_first_wavelength:
        max_first_wavelength = value[0]
print(max_first_wavelength)

min_last_wavelength = float('inf')
for key, value in w_data.items():
    if value[-1] < min_last_wavelength:
        min_last_wavelength = value[-1]
print(min_last_wavelength)

# Convert the dictionaries to DataFrames
w_df = pd.DataFrame(dict([(k, pd.Series(v)) for k, v in w_data.items()]))

w_df

510.46387599999997
2480.84157


Unnamed: 0,OC_LP_B2_1,KC_OPX_lm_9,OC_TXH_011_A60,RB_LE4OLV,OWN_OL4_EN1_1,MY_OL_30,XT_TXH_030_P2,KC_OPX_lm_8,OC_LP_B3_2,YY_OL_6,...,OC_LP_B2_2,OC_LVM_F1,AF_OL_010_1,KC_OPX_lvn_1,KC_OPX_lvn_10,OC_TXH_013_P05,KC_OL_hm_2,OC_LP_A3_1,RB_LE2CPX,OC_LP_A2_2
0,250.0,500.12,300.0,250.0,496.605913,255.855137,280.0,500.12,250.0,500.0,...,250.0,298.484286,253.848732,450.10,450.10,250.0,500.12,250.0,250.0,250.0
1,255.0,500.31,305.0,251.0,498.919256,258.242041,285.0,500.31,255.0,503.0,...,255.0,306.630801,255.609603,450.26,450.26,255.0,500.31,255.0,251.0,255.0
2,260.0,500.50,310.0,252.0,500.461484,261.023115,290.0,500.50,260.0,506.0,...,260.0,312.785660,257.873579,450.42,450.42,260.0,500.50,260.0,252.0,260.0
3,265.0,500.70,315.0,253.0,502.260751,263.815131,295.0,500.70,265.0,509.0,...,265.0,316.923165,260.389108,450.57,450.57,265.0,500.70,265.0,253.0,265.0
4,270.0,500.89,320.0,254.0,503.288903,266.194741,300.0,500.89,270.0,512.0,...,270.0,317.000261,262.149978,450.73,450.73,270.0,500.89,270.0,254.0,270.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4651,,,,,,,,,,,...,,,,2494.10,2494.10,,,,,
4652,,,,,,,,,,,...,,,,2495.30,2495.30,,,,,
4653,,,,,,,,,,,...,,,,2496.50,2496.50,,,,,
4654,,,,,,,,,,,...,,,,2497.70,2497.70,,,,,


#### Preprocess Reflectance Data

In [66]:
r_df = pd.DataFrame(dict([(k, pd.Series(v)) for k, v in r_data.items()]))

r_df

Unnamed: 0,OC_LP_B2_1,KC_OPX_lm_9,OC_TXH_011_A60,RB_LE4OLV,OWN_OL4_EN1_1,MY_OL_30,XT_TXH_030_P2,KC_OPX_lm_8,OC_LP_B3_2,YY_OL_6,...,OC_LP_B2_2,OC_LVM_F1,AF_OL_010_1,KC_OPX_lvn_1,KC_OPX_lvn_10,OC_TXH_013_P05,KC_OL_hm_2,OC_LP_A3_1,RB_LE2CPX,OC_LP_A2_2
0,0.091267,0.17514,0.08142,0.102687,0.222230,0.078070,0.19662,0.18813,0.080964,0.229140,...,0.087153,0.071975,0.019167,0.29841,0.10807,0.087242,0.66890,0.073795,0.139433,0.068925
1,0.095063,0.15834,0.08278,0.102185,0.222707,0.079495,0.14896,0.18181,0.084610,0.226592,...,0.090949,0.073561,0.019169,0.29947,0.10395,0.088632,0.66179,0.075704,0.138631,0.070946
2,0.099236,0.15316,0.08263,0.102471,0.223281,0.080919,0.13586,0.16909,0.088477,0.226860,...,0.095122,0.075679,0.019171,0.29177,0.10738,0.088888,0.66537,0.077820,0.138315,0.073034
3,0.103076,0.15552,0.08353,0.102007,0.223663,0.083058,0.13219,0.16247,0.092523,0.227072,...,0.098962,0.077799,0.019173,0.28526,0.11490,0.087427,0.66944,0.079724,0.138342,0.074928
4,0.107064,0.15903,0.08438,0.101547,0.224141,0.084007,0.13455,0.16843,0.096120,0.226811,...,0.102950,0.079391,0.019174,0.28488,0.13047,0.088377,0.66752,0.081352,0.137929,0.076659
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4651,,,,,,,,,,,...,,,,0.54237,0.43920,,,,,
4652,,,,,,,,,,,...,,,,0.54303,0.43950,,,,,
4653,,,,,,,,,,,...,,,,0.54355,0.43975,,,,,
4654,,,,,,,,,,,...,,,,0.54356,0.43966,,,,,


#### Interpolate and Normalize Data

In [67]:
use_davids_denoising_and_normalization = True
use_own_normalization = False
# number_of_wavelengths = 51

# interpolation_wavelengths = np.linspace(541., 2424., number_of_wavelengths)
# interpolation_wavelengths = np.linspace(550., 2250., number_of_wavelengths)
interpolation_wavelengths = np.arange(550., 2250., 35)
interpolation_dict = {}

for i in range(len(w_df.columns)):
	wavelengths = w_df.iloc[:, i].to_numpy()
	unique_wavelengths, unique_indices = np.unique(wavelengths, return_index=True)

	reflectances = r_df.iloc[:, i].to_numpy()
	unique_reflectances = reflectances[unique_indices]

	interpolation = np.interp(interpolation_wavelengths, unique_wavelengths, unique_reflectances)

	if use_davids_denoising_and_normalization:
		denoised_normalized_data = denoise_and_norm(interpolation, interpolation_wavelengths, denoising=True, normalising=True)
		denoised_normalized_data = denoised_normalized_data.flatten()
		interpolation_dict[w_df.columns[i]] = denoised_normalized_data
		normalized_interpolation = interpolation / interpolation[0]
		# print(np.mean((denoised_normalized_data - normalized_interpolation) ** 2)) # Similarity of own and Davids methods
		# print(np.array_equal(denoised_normalized_data, normalized_interpolation))
	elif use_own_normalization:
		normalized_interpolation = interpolation / interpolation[0]
		interpolation_dict[w_df.columns[i]] = normalized_interpolation
	else:
		interpolation_dict[w_df.columns[i]] = interpolation

interpolation_df = pd.DataFrame(interpolation_dict, index=interpolation_wavelengths)

interpolation_df

Unnamed: 0,OC_LP_B2_1,KC_OPX_lm_9,OC_TXH_011_A60,RB_LE4OLV,OWN_OL4_EN1_1,MY_OL_30,XT_TXH_030_P2,KC_OPX_lm_8,OC_LP_B3_2,YY_OL_6,...,OC_LP_B2_2,OC_LVM_F1,AF_OL_010_1,KC_OPX_lvn_1,KC_OPX_lvn_10,OC_TXH_013_P05,KC_OL_hm_2,OC_LP_A3_1,RB_LE2CPX,OC_LP_A2_2
550.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
585.0,1.026282,1.10189,1.039515,1.049587,1.038579,1.065196,1.028713,1.099574,1.029855,1.017196,...,1.027862,1.024853,1.016492,1.049467,1.08888,1.01804,1.009957,1.09077,1.036895,1.108395
620.0,1.042327,1.1834,1.069736,1.075839,1.076703,1.111911,1.031733,1.187837,1.045787,1.042254,...,1.042187,1.039262,1.019992,1.084846,1.166127,1.03424,1.005253,1.153725,1.065563,1.179331
655.0,1.05704,1.256333,1.099095,1.115246,1.111283,1.147167,1.048842,1.26457,1.06263,1.059391,...,1.057037,1.068499,1.028611,1.122279,1.231926,1.052626,1.00673,1.19465,1.094772,1.222508
690.0,1.066844,1.331937,1.113682,1.150254,1.14289,1.188282,1.074813,1.336385,1.073493,1.07551,...,1.06658,1.080303,1.030342,1.166682,1.293279,1.069449,1.008841,1.222008,1.121448,1.250131
725.0,1.070148,1.380208,1.114729,1.14795,1.168941,1.198847,1.080595,1.384343,1.078289,1.088378,...,1.069777,1.090136,1.014097,1.20235,1.33046,1.080164,0.978823,1.237932,1.165286,1.264677
760.0,1.058103,1.344052,1.096756,1.126082,1.189365,1.190178,1.03491,1.352643,1.067717,1.101679,...,1.057763,1.069218,0.983496,1.183759,1.29929,1.083778,0.928125,1.234658,1.193742,1.257656
795.0,1.026276,1.182157,1.048993,1.098423,1.195381,1.173645,0.911866,1.198078,1.037863,1.10779,...,1.026208,0.999457,0.946838,1.07454,1.153612,1.071056,0.87456,1.21104,1.237238,1.22888
830.0,0.977283,0.965048,0.985474,1.111019,1.195647,1.150912,0.756063,0.980659,0.990698,1.119908,...,0.97829,0.911064,0.935879,0.892865,0.958637,1.05004,0.835862,1.1666,1.275647,1.18108
865.0,0.926507,0.777401,0.923801,1.101294,1.189886,1.159474,0.626428,0.799926,0.933654,1.125979,...,0.921832,0.80648,0.931011,0.745857,0.781465,1.022005,0.812657,1.109838,1.282047,1.140348


#### Combine Sample Info and Reflectance Data

In [68]:
reflectance_data_transposed = interpolation_df.T
reflectance_data_transposed.reset_index(inplace=True)
reflectance_data_transposed.rename(columns={'index': 'Sample name'}, inplace=True)

# Merge sample_info with the transposed reflectance_data
merged_data = pd.merge(sample_info, reflectance_data_transposed, on='Sample name')
merged_data

Unnamed: 0,Sample name,type,Irradiation,ET_1AU,ET_2.3AU,550.0,585.0,620.0,655.0,690.0,...,1915.0,1950.0,1985.0,2020.0,2055.0,2090.0,2125.0,2160.0,2195.0,2230.0
0,PO_TXH_007,L_OL,10,4.227972e+08,1.358991e+09,1.0,1.007881,0.985896,0.983607,0.980605,...,1.188147,1.191285,1.193287,1.193245,1.193972,1.195630,1.199715,1.197659,1.198879,1.201743
1,PO_TXH_008,L_OL,10,8.455945e+08,2.717982e+09,1.0,1.009062,0.991024,0.990374,0.988972,...,1.212832,1.215029,1.216749,1.216171,1.217804,1.219437,1.223788,1.222993,1.223918,1.226317
2,PO_TXH_081_CP1,L_OL,10,4.227972e+08,1.358991e+09,1.0,1.025258,1.028209,1.040079,1.043011,...,1.311502,1.310743,1.312417,1.318335,1.320112,1.319991,1.322872,1.320888,1.326771,1.322527
3,PO_TXH_081_CP2,L_OL,10,8.455945e+08,2.717982e+09,1.0,1.045553,1.057447,1.081275,1.079556,...,1.514210,1.508384,1.517911,1.524061,1.528268,1.528207,1.534357,1.533932,1.532657,1.535004
4,RB_LE4OLV,L_OL,10,7.610350e+08,2.446184e+09,1.0,1.049587,1.075839,1.115246,1.150254,...,2.195548,2.253296,2.285227,2.320041,2.349138,2.356762,2.362727,2.390682,2.418659,2.426168
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
163,OC_LP_A2_2,I_H_M,3,2.361833e+02,1.249410e+03,1.0,1.108395,1.179331,1.222508,1.250131,...,1.011697,1.010107,1.011611,1.012312,1.012525,1.015213,1.016467,1.009209,1.006724,1.007714
164,OC_LP_A3_2,I_H_M,3,5.992057e+02,3.169798e+03,1.0,1.099653,1.167046,1.210675,1.240013,...,1.049090,1.048354,1.050910,1.052026,1.053355,1.056345,1.058137,1.049916,1.048932,1.051390
165,OC_LP_A1_3,I_H_M,3,9.491069e+01,5.020775e+02,1.0,1.101308,1.165847,1.203071,1.224616,...,0.951818,0.949864,0.950533,0.950766,0.949719,0.951637,0.947837,0.940130,0.939281,0.939781
166,OC_LP_A2_3,I_H_M,3,2.361833e+02,1.249410e+03,1.0,1.104020,1.173051,1.214460,1.239397,...,0.998843,0.996978,0.998891,0.999081,0.998434,1.001088,0.996495,0.988218,0.988232,0.989602


#### Save the Data

In [69]:
if save_preprocessed_data:
	merged_data.to_csv('data/combined_data_H_L_550-2230_35interval.csv', index=False)

#### Prepare the Data for the Model

In [70]:
data = pd.read_csv('data/combined_data_H_L_550-2230_35interval.csv')
data['ET_1AU'].describe()

count    1.680000e+02
mean     7.023619e+08
std      9.858265e+08
min      4.373764e-02
25%      4.791702e+07
50%      4.227972e+08
75%      8.625063e+08
max      5.637296e+09
Name: ET_1AU, dtype: float64

In [71]:
data['ET_2.3AU'].describe()

count    1.680000e+02
mean     2.257592e+09
std      3.168728e+09
min      2.313721e-01
25%      1.540190e+08
50%      1.358991e+09
75%      2.772342e+09
max      1.811988e+10
Name: ET_2.3AU, dtype: float64

In [None]:
irradiation_types = {
    1: "I_H_OL",
    2: "I_H_PX",
    3: "I_H_M",
    4: "I_Ar_OL",
    5: "I_Ar_PX",
    6: "I_Ar_M",
    7: "I_He_OL",
    8: "I_He_PX",
    9: "I_He_M",
    10: "L_OL",
    11: "L_PX",
    12: "L_M"
}

In [72]:
counts = data['Irradiation'].value_counts()
print(counts)

if asteroid_location == '1AU':
	data['ET_1AU'] = np.log10(data['ET_1AU']+1) # Scale the target to log10
elif asteroid_location == '2.3AU':
	data['ET_2.3AU'] = np.log10(data['ET_2.3AU']+1)
else:
	raise ValueError

'''
Irradiation type combinations
Hydrogen + laser on only olivine = 1, 10
Hydrogen + laser on all minerals = 1, 2, 3, 10, 11, 12
All = 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12
'''

selected_irradiations = [1, 2, 3, 10, 11, 12]

num_of_data_rows = data.shape[0]
print(num_of_data_rows)

data

10    58
12    46
11    23
3     18
1     13
2     10
Name: Irradiation, dtype: int64
168


Unnamed: 0,Sample name,type,Irradiation,ET_1AU,ET_2.3AU,550.0,585.0,620.0,655.0,690.0,...,1915.0,1950.0,1985.0,2020.0,2055.0,2090.0,2125.0,2160.0,2195.0,2230.0
0,PO_TXH_007,L_OL,10,8.626132,1.358991e+09,1.0,1.007881,0.985896,0.983607,0.980605,...,1.188147,1.191285,1.193287,1.193245,1.193972,1.195630,1.199715,1.197659,1.198879,1.201743
1,PO_TXH_008,L_OL,10,8.927162,2.717982e+09,1.0,1.009062,0.991024,0.990374,0.988972,...,1.212832,1.215029,1.216749,1.216171,1.217804,1.219437,1.223788,1.222993,1.223918,1.226317
2,PO_TXH_081_CP1,L_OL,10,8.626132,1.358991e+09,1.0,1.025258,1.028209,1.040079,1.043011,...,1.311502,1.310743,1.312417,1.318335,1.320112,1.319991,1.322872,1.320888,1.326771,1.322527
3,PO_TXH_081_CP2,L_OL,10,8.927162,2.717982e+09,1.0,1.045553,1.057447,1.081275,1.079556,...,1.514210,1.508384,1.517911,1.524061,1.528268,1.528207,1.534357,1.533932,1.532657,1.535004
4,RB_LE4OLV,L_OL,10,8.881405,2.446184e+09,1.0,1.049587,1.075839,1.115246,1.150254,...,2.195548,2.253296,2.285227,2.320041,2.349138,2.356762,2.362727,2.390682,2.418659,2.426168
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
163,OC_LP_A2_2,I_H_M,3,2.375084,1.249410e+03,1.0,1.108395,1.179331,1.222508,1.250131,...,1.011697,1.010107,1.011611,1.012312,1.012525,1.015213,1.016467,1.009209,1.006724,1.007714
164,OC_LP_A3_2,I_H_M,3,2.778300,3.169798e+03,1.0,1.099653,1.167046,1.210675,1.240013,...,1.049090,1.048354,1.050910,1.052026,1.053355,1.056345,1.058137,1.049916,1.048932,1.051390
165,OC_LP_A1_3,I_H_M,3,1.981867,5.020775e+02,1.0,1.101308,1.165847,1.203071,1.224616,...,0.951818,0.949864,0.950533,0.950766,0.949719,0.951637,0.947837,0.940130,0.939281,0.939781
166,OC_LP_A2_3,I_H_M,3,2.375084,1.249410e+03,1.0,1.104020,1.173051,1.214460,1.239397,...,0.998843,0.996978,0.998891,0.999081,0.998434,1.001088,0.996495,0.988218,0.988232,0.989602


In [73]:
X_task_sets = [] # list[ np[task 1 X], np[task 2 X], [...], ...]
y_task_sets = [] # list[ np[task 1 y], np[task 2 y], [...], ...]

# Save indices for RMSE calculations. This keeps track of the asteroid IDs for GP model tasks
sample_id_task_sets = [] # list[ np[task 1 indices], np[task 2 indices], [...], ...]
irradiation_type_to_train_i_task = {}

# Separate different irradiation types into tasks
for i, task in enumerate(selected_irradiations):
	task_set = data[data['Irradiation'].isin([task])].copy()
	if len(task_set)==0:
		continue
	task_indices = task_set.index.to_numpy()
	sample_id_task_sets.append(task_indices)
	if asteroid_location == '1AU':
		task_set.drop(['Sample name', 'type', 'Irradiation', 'ET_2.3AU'], axis=1, inplace=True)
		X = task_set.drop('ET_1AU', axis=1).values
		y = task_set['ET_1AU'].values
	else:
		task_set.drop(['Sample name', 'type', 'Irradiation', 'ET_1AU'], axis=1, inplace=True)
		X = task_set.drop('ET_2.3AU', axis=1).values
		y = task_set['ET_2.3AU'].values
	X_task_sets.append(X)
	y_task_sets.append(y)

	irradiation_type_to_train_i_task[task] = i

# Number of data for different tasks
number_of_total_data_points = 0
for task in y_task_sets:
	number_of_total_data_points += len(task)
	print(len(task))
print(f'Total {number_of_total_data_points}')

13
10
18
58
23
46
Total 168


In [74]:
data_dim = X.shape[-1] # Number of considered wavelengths
print(data_dim)

49


#### Load and Process Asteroid Data

In [75]:
asteroid_metadata = pd.read_excel('data/asteroid_spectra-denoised-norm.xlsx', sheet_name='metadata')
asteroid_metadata.rename(columns={0: 'asteroid number', 1: 'taxonomy class'}, inplace=True)
asteroid_metadata.drop([2, 3, 4, 5, 6, 7], axis=1, inplace=True)
asteroid_metadata

Unnamed: 0,asteroid number,taxonomy class
0,1,C
1,2,B
2,8,Sw
3,10,C
4,13,Ch
...,...,...
586,14402,Xk
587,52768,Xk
588,54789,Xe
589,137170,Xk


The reflectance data has to be from 550 to 2230 nm wavelengths in 35 nm intervals.

In [76]:
asteroid_reflectance_data = pd.read_excel('data/asteroid_spectra-denoised-norm.xlsx', sheet_name='spectra')
asteroid_reflectance_wavelengths = pd.read_excel('data/asteroid_spectra-denoised-norm.xlsx', sheet_name='wavelengths')
asteroid_reflectance_data.columns = asteroid_reflectance_wavelengths[0]
asteroid_reflectance_data = asteroid_reflectance_data.iloc[:, 20:]
columns_to_keep = asteroid_reflectance_data.columns[::7]
asteroid_reflectance_data = asteroid_reflectance_data[columns_to_keep]
asteroid_reflectance_data = asteroid_reflectance_data.iloc[:, :49]
asteroid_reflectance_data

Unnamed: 0,550,585,620,655,690,725,760,795,830,865,...,1915,1950,1985,2020,2055,2090,2125,2160,2195,2230
0,1,1.006235,1.009625,1.012375,1.015883,1.019225,1.020920,1.019588,1.015045,1.008113,...,0.994535,0.995703,0.996647,0.997415,0.998057,0.998614,0.999115,0.999581,1.000035,1.000498
1,1,1.000946,1.000984,1.001122,1.000911,0.999782,0.997165,0.992498,0.985533,0.976664,...,0.860797,0.859856,0.858868,0.857766,0.856483,0.854955,0.853115,0.850903,0.848352,0.845606
2,1,1.056721,1.095177,1.130304,1.172215,1.207296,1.230174,1.233637,1.218064,1.195585,...,1.489804,1.486921,1.487053,1.491015,1.498196,1.507392,1.517392,1.526998,1.535565,1.543370
3,1,1.003665,1.005185,1.004877,1.003058,1.000080,0.996407,0.992530,0.988935,0.985965,...,1.143252,1.149032,1.154817,1.160722,1.166808,1.172950,1.178972,1.184700,1.189955,1.194562
4,1,0.997972,0.988595,0.979806,0.978641,0.985151,0.995229,1.004703,1.011116,1.014903,...,1.108138,1.108106,1.107679,1.106798,1.105512,1.103941,1.102205,1.100422,1.098713,1.097198
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
586,1,1.000215,1.003417,1.007521,1.009133,1.006300,0.999965,0.991527,0.982187,0.972925,...,1.141740,1.147074,1.149855,1.149072,1.145504,1.140993,1.137408,1.136621,1.140503,1.150315
587,1,1.017018,1.034954,1.052161,1.066625,1.076336,1.080078,1.079382,1.076576,1.073988,...,1.153789,1.160979,1.166787,1.169584,1.167741,1.160230,1.149082,1.137606,1.129125,1.126964
588,1,1.035785,1.069090,1.099624,1.131759,1.167938,1.202763,1.228519,1.241679,1.245785,...,1.379756,1.380304,1.381194,1.382253,1.383279,1.384069,1.384420,1.384131,1.382997,1.380816
589,1,1.006542,1.017187,1.030825,1.043771,1.060991,1.067755,1.059743,1.054216,1.043250,...,1.049537,1.049700,1.050227,1.050987,1.051847,1.052678,1.053347,1.053724,1.053710,1.053317


In [77]:
asteroid_data = pd.concat([asteroid_metadata, asteroid_reflectance_data], axis=1)
asteroid_data

Unnamed: 0,asteroid number,taxonomy class,550,585,620,655,690,725,760,795,...,1915,1950,1985,2020,2055,2090,2125,2160,2195,2230
0,1,C,1,1.006235,1.009625,1.012375,1.015883,1.019225,1.020920,1.019588,...,0.994535,0.995703,0.996647,0.997415,0.998057,0.998614,0.999115,0.999581,1.000035,1.000498
1,2,B,1,1.000946,1.000984,1.001122,1.000911,0.999782,0.997165,0.992498,...,0.860797,0.859856,0.858868,0.857766,0.856483,0.854955,0.853115,0.850903,0.848352,0.845606
2,8,Sw,1,1.056721,1.095177,1.130304,1.172215,1.207296,1.230174,1.233637,...,1.489804,1.486921,1.487053,1.491015,1.498196,1.507392,1.517392,1.526998,1.535565,1.543370
3,10,C,1,1.003665,1.005185,1.004877,1.003058,1.000080,0.996407,0.992530,...,1.143252,1.149032,1.154817,1.160722,1.166808,1.172950,1.178972,1.184700,1.189955,1.194562
4,13,Ch,1,0.997972,0.988595,0.979806,0.978641,0.985151,0.995229,1.004703,...,1.108138,1.108106,1.107679,1.106798,1.105512,1.103941,1.102205,1.100422,1.098713,1.097198
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
586,14402,Xk,1,1.000215,1.003417,1.007521,1.009133,1.006300,0.999965,0.991527,...,1.141740,1.147074,1.149855,1.149072,1.145504,1.140993,1.137408,1.136621,1.140503,1.150315
587,52768,Xk,1,1.017018,1.034954,1.052161,1.066625,1.076336,1.080078,1.079382,...,1.153789,1.160979,1.166787,1.169584,1.167741,1.160230,1.149082,1.137606,1.129125,1.126964
588,54789,Xe,1,1.035785,1.069090,1.099624,1.131759,1.167938,1.202763,1.228519,...,1.379756,1.380304,1.381194,1.382253,1.383279,1.384069,1.384420,1.384131,1.382997,1.380816
589,137170,Xk,1,1.006542,1.017187,1.030825,1.043771,1.060991,1.067755,1.059743,...,1.049537,1.049700,1.050227,1.050987,1.051847,1.052678,1.053347,1.053724,1.053710,1.053317


In [78]:
asteroid_data['taxonomy class'].unique()

array(['C', 'B', 'Sw', 'Ch', 'X', 'S', 'Sqw', 'Cgh', 'T', 'Sr', 'D', 'L',
       'A', 'K', 'Sq', 'Vw', 'Q', 'Sv', 'V', 'Svw', 'U', 'Cg', 'Cb', 'O',
       'R', 'Qw', 'Sa', 'Sq:', 'Srw', 'Xk', 'Xc', 'Xe', 'Xn'],
      dtype=object)

In [79]:
# Select the wanted asteroid types
asteroid_types = ['Sw', 'S', 'Sqw', 'Sr', 'Sa', 'Srw', 'Q', 'V', 'Sq', 'Sv', 'Svw', 'Qw', 'Sq:']
asteroid_data = asteroid_data[asteroid_data['taxonomy class'].isin(asteroid_types)]
asteroid_data.reset_index(drop=True, inplace=True)
asteroid_data

Unnamed: 0,asteroid number,taxonomy class,550,585,620,655,690,725,760,795,...,1915,1950,1985,2020,2055,2090,2125,2160,2195,2230
0,8,Sw,1,1.056721,1.095177,1.130304,1.172215,1.207296,1.230174,1.233637,...,1.489804,1.486921,1.487053,1.491015,1.498196,1.507392,1.517392,1.526998,1.535565,1.543370
1,26,S,1,1.035074,1.071880,1.113490,1.150716,1.170832,1.165139,1.129637,...,1.286201,1.282915,1.281733,1.283009,1.287099,1.294130,1.302498,1.309571,1.313036,1.313275
2,27,S,1,1.043694,1.083981,1.124005,1.165305,1.189076,1.184025,1.158200,...,1.359991,1.358534,1.359297,1.361496,1.364300,1.367716,1.372451,1.378826,1.385693,1.391547
3,28,S,1,1.042823,1.082409,1.123155,1.160838,1.185129,1.186426,1.166564,...,1.254068,1.253994,1.255971,1.259443,1.263851,1.268656,1.273508,1.278225,1.282629,1.286544
4,29,S,1,1.028375,1.064617,1.099786,1.129347,1.152713,1.157703,1.141715,...,1.343379,1.343638,1.344793,1.347242,1.350773,1.354924,1.359253,1.363445,1.367240,1.370466
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
374,4,V,1,1.015984,1.031595,1.050125,1.069011,1.085567,1.083458,1.026807,...,0.896678,0.890320,0.889835,0.896239,0.910410,0.931447,0.956959,0.984520,1.012169,1.038345
375,1468,V,1,1.093883,1.181398,1.265464,1.349296,1.427508,1.470123,1.441391,...,1.581080,1.554699,1.543300,1.555390,1.586917,1.630635,1.684444,1.750545,1.828159,1.908830
376,1904,V,1,1.058941,1.101019,1.142001,1.194562,1.224986,1.193965,1.115869,...,1.054894,1.060022,1.073943,1.094997,1.121523,1.151999,1.185131,1.219655,1.254309,1.288179
377,1929,V,1,1.084892,1.144267,1.203031,1.271043,1.323662,1.304500,1.177044,...,1.146871,1.132937,1.127551,1.141166,1.172563,1.210858,1.250985,1.298490,1.358068,1.423017


In [80]:
asteroid_predictions = asteroid_data[['asteroid number']].copy()
asteroid_predictions

Unnamed: 0,asteroid number
0,8
1,26
2,27
3,28
4,29
...,...
374,4
375,1468
376,1904
377,1929


In [81]:
asteroid_data = asteroid_data.drop(['asteroid number', 'taxonomy class'], axis=1)
asteroid_data

Unnamed: 0,550,585,620,655,690,725,760,795,830,865,...,1915,1950,1985,2020,2055,2090,2125,2160,2195,2230
0,1,1.056721,1.095177,1.130304,1.172215,1.207296,1.230174,1.233637,1.218064,1.195585,...,1.489804,1.486921,1.487053,1.491015,1.498196,1.507392,1.517392,1.526998,1.535565,1.543370
1,1,1.035074,1.071880,1.113490,1.150716,1.170832,1.165139,1.129637,1.078341,1.044297,...,1.286201,1.282915,1.281733,1.283009,1.287099,1.294130,1.302498,1.309571,1.313036,1.313275
2,1,1.043694,1.083981,1.124005,1.165305,1.189076,1.184025,1.158200,1.123083,1.091185,...,1.359991,1.358534,1.359297,1.361496,1.364300,1.367716,1.372451,1.378826,1.385693,1.391547
3,1,1.042823,1.082409,1.123155,1.160838,1.185129,1.186426,1.166564,1.137006,1.109706,...,1.254068,1.253994,1.255971,1.259443,1.263851,1.268656,1.273508,1.278225,1.282629,1.286544
4,1,1.028375,1.064617,1.099786,1.129347,1.152713,1.157703,1.141715,1.120613,1.105336,...,1.343379,1.343638,1.344793,1.347242,1.350773,1.354924,1.359253,1.363445,1.367240,1.370466
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
374,1,1.015984,1.031595,1.050125,1.069011,1.085567,1.083458,1.026807,0.924796,0.824631,...,0.896678,0.890320,0.889835,0.896239,0.910410,0.931447,0.956959,0.984520,1.012169,1.038345
375,1,1.093883,1.181398,1.265464,1.349296,1.427508,1.470123,1.441391,1.321847,1.167093,...,1.581080,1.554699,1.543300,1.555390,1.586917,1.630635,1.684444,1.750545,1.828159,1.908830
376,1,1.058941,1.101019,1.142001,1.194562,1.224986,1.193965,1.115869,1.026662,0.950155,...,1.054894,1.060022,1.073943,1.094997,1.121523,1.151999,1.185131,1.219655,1.254309,1.288179
377,1,1.084892,1.144267,1.203031,1.271043,1.323662,1.304500,1.177044,1.022877,0.895823,...,1.146871,1.132937,1.127551,1.141166,1.172563,1.210858,1.250985,1.298490,1.358068,1.423017


# GP Model

#### Feature Extractor

In [83]:
class FeatureExtractor(torch.nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.conv1 = torch.nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3, padding='same') # Convolutional layer
        self.pool = torch.nn.MaxPool1d(kernel_size=2, stride=2) # reduces the dimensionality
        self.fc1 = torch.nn.Linear((data_dim // 2) * 16, data_dim) # Fully connected layer, reduces dimensionality

    def forward(self, x):
        x = x.unsqueeze(1)  # Add a channel dimension: (batch_size, 1, data_dim)
        x = F.relu(self.conv1(x))  # Apply convolution, batch norm, and ReLU
        x = self.pool(x)  # Apply max pooling
        x = x.view(x.size(0), -1)  # Flatten the tensor: (batch_size, 16 * (data_dim // 2))
        x = self.fc1(x)  # Apply linear layer: (batch_size, data_dim)
        return x

#### Model Training

In [84]:
column_names = [f'pred {i+1}' for i in range(cycles)]
results = pd.DataFrame(columns=column_names) # Dataframe for results

In [85]:
for cycle in range(cycles):
	print(f'Cycle {cycle+1}/{cycles}')

	X_train_list = []
	y_train_list = []
	train_i_task_list = []
	
	# Process data for each task
	for task in range(len(X_task_sets)):

		X_train = X_task_sets[task]
		y_train= y_task_sets[task]

		X_train = torch.tensor(X_train, dtype=torch.float32)
		y_train = torch.tensor(y_train, dtype=torch.float32)

		X_train_list.append(X_train)
		y_train_list.append(y_train)
		train_i_task_list.append(torch.full((X_train.shape[0],1), dtype=torch.long, fill_value=task)) # using irradiation type as task id

	full_X_train = torch.cat(X_train_list)
	full_train_i = torch.cat(train_i_task_list)
	full_y_train = torch.cat(y_train_list)

	X_test = asteroid_data.values
	X_test = torch.tensor(X_test, dtype=torch.float32)
	test_i_tasks = torch.full((X_test.shape[0],1), dtype=torch.long, fill_value=irradiation_type_to_train_i_task[asteroid_irradiation_type])

	feature_extractor = FeatureExtractor()

	# The model itself
	class MultitaskGPModel(gpytorch.models.ExactGP):
		def __init__(self, train_x, train_y, likelihood):
			super(MultitaskGPModel, self).__init__(train_x, train_y, likelihood)
			self.mean_module = gpytorch.means.ConstantMean() # Mean function
			self.covar_module = gpytorch.kernels.keops.MaternKernel(nu=0.5) # Covariance function
			self.task_covar_module = gpytorch.kernels.IndexKernel(num_tasks=len(X_task_sets), rank=3) # Task covariance function
			self.feature_extractor = feature_extractor
			self.scale_to_bounds = gpytorch.utils.grid.ScaleToBounds(-1., 1.) # Scaling/normalization

		def forward(self,x,i):
			extracted_features = self.feature_extractor(x)

			# Ensure extracted features are scaled appropriately
			extracted_features = self.scale_to_bounds(extracted_features)

			mean_x = self.mean_module(extracted_features)

			# Get input-input covariance
			covar_x = self.covar_module(extracted_features)
			# Get task-task covariance
			covar_i = self.task_covar_module(i)
			# Multiply the two together to get the covariance we want
			covar = covar_x.mul(covar_i)

			return gpytorch.distributions.MultivariateNormal(mean_x, covar)

		
	likelihood = gpytorch.likelihoods.GaussianLikelihood()
	model = MultitaskGPModel((full_X_train, full_train_i), full_y_train, likelihood) # full_train_i tells the model which task each data point belongs

	model.train()
	likelihood.train()

	optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

	# "Loss" for GPs - the marginal log likelihood
	mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

	# Training loop
	counter = -1
	iterator = tqdm.notebook.tqdm(range(training_iterations))
	for i in iterator:
		counter += 1
		with gpytorch.settings.cholesky_max_tries(5):
			optimizer.zero_grad()
			output = model(full_X_train, full_train_i)
			loss = -mll(output, full_y_train)
			loss.backward()
			iterator.set_postfix(loss=loss.item())
			optimizer.step()

	# Evaluation
	model.eval()
	likelihood.eval()
	with torch.no_grad(), gpytorch.settings.use_toeplitz(False), gpytorch.settings.fast_pred_var():
		y_preds = likelihood(model(X_test, test_i_tasks))
		mean = y_preds.mean # The means are used as prediction values
		lower, upper = y_preds.confidence_region() # GP model's own confidence bounds (not using these at the moment)

	mean = mean.numpy()
	lower = lower.numpy()
	upper = upper.numpy()

	# Append the results
	results[column_names[cycle]] = mean

	unscaled_mean = 10 ** mean

Cycle 1/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 2/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 3/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 4/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 5/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 6/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 7/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 8/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 9/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 10/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 11/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 12/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 13/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 14/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 15/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 16/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 17/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 18/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 19/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 20/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 21/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 22/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 23/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 24/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 25/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 26/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 27/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 28/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 29/30


  0%|          | 0/150 [00:00<?, ?it/s]

Cycle 30/30


  0%|          | 0/150 [00:00<?, ?it/s]

#### Process the Results

In [86]:
results

Unnamed: 0,pred 1,pred 2,pred 3,pred 4,pred 5,pred 6,pred 7,pred 8,pred 9,pred 10,...,pred 21,pred 22,pred 23,pred 24,pred 25,pred 26,pred 27,pred 28,pred 29,pred 30
0,8.998983,9.130715,9.149955,9.150545,8.992460,9.116740,9.167773,9.161362,9.436657,9.078914,...,9.200199,9.089815,8.999155,9.117628,9.012719,9.128922,9.138183,9.036764,9.050103,9.086860
1,9.071829,9.013955,8.961753,9.040102,8.943403,8.976604,9.033328,9.008644,9.125957,8.956416,...,9.071217,9.076895,8.858221,9.157667,8.937730,9.146424,9.203490,9.097677,9.047692,9.180457
2,9.054144,9.036370,9.028281,9.098055,8.974089,8.946139,9.109622,9.020195,9.088909,9.009974,...,9.128071,9.046930,8.941339,9.181471,8.957822,9.150468,9.197356,9.111990,9.055963,9.154517
3,8.933554,9.066430,9.039145,8.898623,9.200254,9.046626,8.966159,8.993408,9.183300,9.018855,...,8.994694,8.986944,8.830553,8.929060,8.986413,9.032975,8.908477,8.865591,8.968865,8.948867
4,8.991339,9.134393,9.084800,9.122774,9.138472,9.135028,9.133761,9.188347,9.230145,9.098933,...,9.166737,9.098940,8.950968,9.139692,9.007756,9.196946,9.143432,9.037741,9.058343,9.153472
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
374,8.812582,8.196694,9.021842,8.792269,8.812254,8.777904,8.364261,8.469711,8.927990,8.752131,...,8.253101,8.716188,8.942484,8.078932,8.661622,8.100297,8.260528,8.502523,8.790978,8.514991
375,9.269583,9.142327,9.431663,9.164003,9.140928,9.226298,9.117450,8.645668,9.279369,9.212337,...,9.064335,9.091059,8.558619,8.611464,8.738858,8.575944,8.164733,8.766256,9.285607,8.381598
376,9.210302,9.090096,9.261954,8.966097,8.837721,9.111552,9.016284,8.866341,9.330456,9.035213,...,9.094944,9.245298,8.977007,8.925032,8.900933,8.838761,9.074035,8.929037,9.075952,9.139968
377,9.354666,9.141442,9.507651,9.117037,9.126844,9.296603,9.094653,8.722229,9.595776,9.241420,...,9.276096,9.434797,8.963686,8.593519,8.526929,8.948274,8.750854,8.903585,9.069268,8.884475


In [87]:
# The results are converted into the original scale
original_scale_results = 10 ** results
original_scale_average_results = np.mean(original_scale_results, axis=1)
original_scale_average_results

0      1.302505e+09
1      1.155035e+09
2      1.191469e+09
3      9.916480e+08
4      1.313671e+09
           ...     
374    4.469937e+08
375    1.089572e+09
376    1.185564e+09
377    1.432658e+09
378    1.289398e+09
Length: 379, dtype: float32

In [88]:
combined_results = pd.DataFrame(columns=['average_pred', 'standard_deviation'])
combined_results['average_pred'] = original_scale_average_results
combined_results['standard_deviation'] = original_scale_results.std(axis=1)
combined_results

Unnamed: 0,average_pred,standard_deviation
0,1.302505e+09,320610176.0
1,1.155035e+09,257334080.0
2,1.191469e+09,227266320.0
3,9.916480e+08,213998128.0
4,1.313671e+09,195218064.0
...,...,...
374,4.469937e+08,294660960.0
375,1.089572e+09,727312128.0
376,1.185564e+09,399982624.0
377,1.432658e+09,881835328.0


In [89]:
asteroid_predictions = pd.concat([asteroid_predictions, combined_results], axis=1)
asteroid_predictions

Unnamed: 0,asteroid number,average_pred,standard_deviation
0,8,1.302505e+09,320610176.0
1,26,1.155035e+09,257334080.0
2,27,1.191469e+09,227266320.0
3,28,9.916480e+08,213998128.0
4,29,1.313671e+09,195218064.0
...,...,...,...
374,4,4.469937e+08,294660960.0
375,1468,1.089572e+09,727312128.0
376,1904,1.185564e+09,399982624.0
377,1929,1.432658e+09,881835328.0


In [90]:
if save_results:
	if asteroid_location == '1AU':
		asteroid_predictions.to_excel(f'results/GP_asteroid_predictions_at_1AU_as_irradiation{asteroid_irradiation_type}.xlsx', index=False)
	else:
		asteroid_predictions.to_excel(f'results/GP_asteroid_predictions_at_2.3AU_as_irradiation{asteroid_irradiation_type}.xlsx', index=False)