# Load Data and Modules

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
from sklearn.model_selection import train_test_split
from scipy.spatial.transform import Rotation as R

In [4]:
ellipse_cols = ['sat_id', 'ellipse_id', 'major', 'minor', 'loss', 'r00', 'r01', 'r02', 'r10', 'r11', 'r12', 'r20', 'r21', 'r22',
                't1', 't2', 't3']

In [5]:
data_sim = pd.read_csv('../data/data_sim.csv')
data_real = pd.read_csv('../data/data_real.csv')

In [6]:
data_sim = pd.DataFrame(data_sim.values, columns = ellipse_cols)
data_real = pd.DataFrame(data_real.values, columns = ellipse_cols)

In [7]:
data_sim.shape, data_real.shape

((38629, 17), (26790, 17))

# Axis Angle Representation

In [8]:
def convert_matrix(data):
    matrix = data.loc[:, 'r00':'r22'].values.reshape(-1, 3, 3)
    rotvec = np.array([R.from_matrix(m).as_rotvec() for m in matrix])
    data['r1'], data['r2'], data['r3'] = np.transpose(rotvec)
    data.drop(['r00', 'r01', 'r02', 'r10', 'r11', 'r12', 'r20', 'r21', 'r22'], axis=1, inplace=True)

In [9]:
convert_matrix(data_sim)
convert_matrix(data_real)

# Train-Test Segregation

In [10]:
merged_data = data_sim.merge(data_real, how='outer', on=['sat_id', 'ellipse_id'], suffixes=['_sim', '_real'], indicator=True)
merge = merged_data['_merge']
merged_data.drop(['ellipse_id', 'loss_sim', 'loss_real'], axis=1, inplace=True)

In [11]:
train_data = merged_data[merged_data['_merge'] == 'both'].drop('_merge', axis=1).reset_index(drop=True)
test_data = merged_data[merged_data['_merge'] == 'left_only'].drop('_merge', axis=1).dropna(axis=1).reset_index(drop=True)

In [12]:
sat_ids_train = train_data['sat_id']
sat_ids_test = test_data['sat_id']

In [13]:
train_data.drop('sat_id', axis=1, inplace=True)
test_data.drop('sat_id', axis=1, inplace=True)

In [14]:
labels = train_data.loc[:, 'major_real':]
train_data = train_data.loc[:, :'r3_sim']

In [15]:
cols = ['major', 'minor', 't1', 't2', 't3', 'r1', 'r2', 'r3']

In [16]:
train_data = pd.DataFrame(train_data.values, columns=cols)
test_data = pd.DataFrame(test_data.values, columns=cols)
labels = pd.DataFrame(labels.values, columns=cols)

In [17]:
labels = labels - train_data

In [18]:
train_data.shape, test_data.shape, labels.shape

((26790, 8), (11839, 8), (26790, 8))

# Train-CV Split per Satellite

In [47]:
split_ratio = 0.1

In [48]:
train = {}
cv = {}
labels_train = {}
labels_cv = {}

In [49]:
for sat_id in range(600):
    sat_data = train_data[sat_ids_train == sat_id]
    sat_labels = labels[sat_ids_train == sat_id]
    train_size = int(sat_data.shape[0] * (1 - split_ratio))
    train[sat_id] = sat_data.iloc[:train_size, :].reset_index(drop=True)
    cv[sat_id] = sat_data.iloc[train_size:, :].reset_index(drop=True)
    labels_train[sat_id] = sat_labels.iloc[:train_size, :].reset_index(drop=True)
    labels_cv[sat_id] = sat_labels.iloc[train_size:, :].reset_index(drop=True)

# Network Architecture

In [160]:
class SatNet(nn.Module):
    
    def __init__(self, input_size, hidden_size):
        super(SatNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, input_size)

    def forward(self, x):
        a1 = self.fc1(x)
        a2 = self.fc2(a1)
        return a2

In [161]:
def train_satnet(sat_data, satnet_params, iterations, lr):
    train, cv, labels_train, labels_cv = [torch.from_numpy(data).float() for data in sat_data]
    input_size, hidden_size = satnet_params
    model = SatNet(input_size, hidden_size)
    optim = torch.optim.Adam(model.parameters(), lr)
    
    for itr in range(iterations):
        optim.zero_grad()
        pred = model(train)
        pred_cv = model(cv)
        loss = torch.mean((pred - labels_train) ** 2)
        loss_cv = torch.mean((pred_cv - labels_cv) ** 2)
        loss.backward()
        optim.step()
        if(itr % 1000 == 0):
            print('Iteration: {} | Loss (Train): {} | Loss (CV): {}'.format(itr, loss.item(), loss_cv.item()))
        
    return model

In [162]:
input_size = train_data.shape[1]
hidden_size = 20
iterations = 10000
lr = 0.0001
satnet_params = [input_size, hidden_size]

In [164]:
models = {}
for sat_id in range(10):
    print('\nSatellited ID:', sat_id,'\n')
    sat_data = [train[sat_id].values, cv[sat_id].values, labels_train[sat_id].values, labels_cv[sat_id].values]
    models[sat_id] = train_satnet(sat_data, satnet_params, iterations, lr)


Satellited ID: 0 

Iteration: 0 | Loss (Train): 0.9590092301368713 | Loss (CV): 0.7717316746711731
Iteration: 1000 | Loss (Train): 0.01771736703813076 | Loss (CV): 0.0043374584056437016
Iteration: 2000 | Loss (Train): 0.006953138392418623 | Loss (CV): 0.00025927380193024874
Iteration: 3000 | Loss (Train): 0.006901136599481106 | Loss (CV): 0.000193530140677467
Iteration: 4000 | Loss (Train): 0.006901047192513943 | Loss (CV): 0.00019347229681443423
Iteration: 5000 | Loss (Train): 0.006900903768837452 | Loss (CV): 0.00019368350331205875
Iteration: 6000 | Loss (Train): 0.006900662090629339 | Loss (CV): 0.00019403624173719436
Iteration: 7000 | Loss (Train): 0.006900261156260967 | Loss (CV): 0.0001946228148881346
Iteration: 8000 | Loss (Train): 0.006899595260620117 | Loss (CV): 0.00019560314831323922
Iteration: 9000 | Loss (Train): 0.0068984865210950375 | Loss (CV): 0.00019724918820429593

Satellited ID: 1 

Iteration: 0 | Loss (Train): 0.3217967450618744 | Loss (CV): 1.0191320180892944
Ite

In [166]:
labels

Unnamed: 0,major,minor,t1,t2,t3,r1,r2,r3
0,0.000395,-0.000241,0.001483,-0.000020,0.001934,0.002871,-0.001600,0.002137
1,0.000370,-0.000283,0.001507,-0.000028,0.001948,0.002899,-0.001616,0.002159
2,0.000051,-0.001444,0.001518,-0.000064,0.001934,0.002843,-0.001584,0.002126
3,-0.000035,-0.001737,0.001535,-0.000038,0.001939,0.002726,-0.001516,0.002053
4,0.000311,-0.000268,0.001525,-0.000077,0.001902,0.002605,-0.001445,0.001976
5,0.000282,-0.000328,0.001524,-0.000041,0.001899,0.002456,-0.001358,0.001879
6,0.000299,-0.000259,0.001498,-0.000102,0.001855,-1.923189,-0.263089,1.927239
7,0.000791,0.001535,0.001461,-0.000213,0.001792,-1.923094,-0.262922,1.926859
8,0.000628,0.000837,0.001469,-0.000160,0.001828,-1.922951,-0.262782,1.926539
9,0.000553,0.000555,0.001476,-0.000129,0.001860,-1.922803,-0.262651,1.926253
