# Load Data and Modules

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
from preprocessing_utils import pickle_save, pickle_load

In [4]:
from sklearn.model_selection import train_test_split
from scipy.spatial.transform import Rotation as R

In [5]:
ellipse_cols = ['sat_id', 'ellipse_id', 'major', 'minor', 'loss', 'r00', 'r01', 'r02', 'r10', 'r11', 'r12', 'r20', 'r21', 'r22',
                't1', 't2', 't3']

In [6]:
data_sim = pd.read_csv('../data/data_sim.csv')
data_real = pd.read_csv('../data/data_real.csv')

In [7]:
data_sim = pd.DataFrame(data_sim.values, columns = ellipse_cols)
data_real = pd.DataFrame(data_real.values, columns = ellipse_cols)

In [8]:
data_sim.shape, data_real.shape

((38629, 17), (26790, 17))

# Axis Angle Representation

In [9]:
def convert_matrix(data):
    matrix = data.loc[:, 'r00':'r22'].values.reshape(-1, 3, 3)
    rotvec = np.array([R.from_matrix(m).as_rotvec() for m in matrix])
    phis = np.linalg.norm(rotvec,axis=1,keepdims=True)
    rotvec = rotvec/phis
    data['r1'], data['r2'], data['r3'] = np.transpose(rotvec)
    data['phi'] = phis[:,0]
    print(rotvec.shape)
    data.drop(['r00', 'r01', 'r02', 'r10', 'r11', 'r12', 'r20', 'r21', 'r22'], axis=1, inplace=True)

In [10]:
a = np.random.randn(2,3)
norm = np.linalg.norm(a,axis=1,keepdims=True)
np.sum((a/norm)**2,axis=1)
norm[:,0]

array([1.44069049, 1.29727874])

In [11]:
convert_matrix(data_sim)
convert_matrix(data_real)

(38629, 3)
(26790, 3)


# Train-Test Segregation

In [12]:
merged_data = data_sim.merge(data_real, how='outer', on=['sat_id', 'ellipse_id'], suffixes=['_sim', '_real'], indicator=True)
merge = merged_data['_merge']
merged_data.drop(['ellipse_id', 'loss_sim', 'loss_real'], axis=1, inplace=True)

In [13]:
train_data = merged_data[merged_data['_merge'] == 'both'].drop('_merge', axis=1).reset_index(drop=True)
test_data = merged_data[merged_data['_merge'] == 'left_only'].drop('_merge', axis=1).dropna(axis=1).reset_index(drop=True)

In [14]:
sat_ids_train = train_data['sat_id']
sat_ids_test = test_data['sat_id']

In [15]:
train_data.drop('sat_id', axis=1, inplace=True)
test_data.drop('sat_id', axis=1, inplace=True)

In [16]:
labels = train_data.loc[:, 'major_real':]
train_data = train_data.loc[:, :'phi_sim']

In [17]:
cols = ['major', 'minor', 't1', 't2', 't3', 'r1', 'r2', 'r3','phi']

In [18]:
train_data = pd.DataFrame(train_data.values, columns=cols)
test_data = pd.DataFrame(test_data.values, columns=cols)
labels = pd.DataFrame(labels.values, columns=cols)

In [19]:
train_data.shape, test_data.shape, labels.shape

((26790, 9), (11839, 9), (26790, 9))

# Train-CV Split per Satellite

In [20]:
split_ratio = 0.1

In [21]:
train = {}
cv = {}
labels_train = {}
labels_cv = {}

In [22]:
for sat_id in range(600):
    sat_data = train_data[sat_ids_train == sat_id]
    sat_labels = labels[sat_ids_train == sat_id]
    train_size = int(sat_data.shape[0] * (1 - split_ratio))
    train[sat_id] = sat_data.iloc[:train_size, :].reset_index(drop=True)
    cv[sat_id] = sat_data.iloc[train_size:, :].reset_index(drop=True)
    labels_train[sat_id] = sat_labels.iloc[:train_size, :].reset_index(drop=True)
    labels_cv[sat_id] = sat_labels.iloc[train_size:, :].reset_index(drop=True)

In [23]:
def relative_error(sat_id,field):
    delta = np.mean(np.abs(labels_train[sat_id][field]))
    den = np.mean(np.abs(train[sat_id][field]))
    return delta/den

# Network Architecture

In [24]:
class SatNet(nn.Module):
    
    def __init__(self, input_size, hidden_size):
        super(SatNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, input_size)
    
    def forward(self, x):
        a1 = torch.relu(self.fc1(x))
        a2 = self.fc2(a1)
        return a2+x

In [25]:
def train_satnet(sat_data, satnet_params, iterations, lr,l2 = 1e-8):
    train, cv, labels_train, labels_cv = [torch.from_numpy(data).float() for data in sat_data]
    input_size, hidden_size = satnet_params
    model = SatNet(input_size, hidden_size)
    optim = torch.optim.Adam(model.parameters(), lr,weight_decay = l2)
    
    for itr in range(iterations):
        optim.zero_grad()
        pred = model(train)
        pred_cv = model(cv)

        axis_pred =  pred[:,:2]
        trans_pred = pred[:,2:5]
        rot_pred = pred[:,5:8]/(1e-5+torch.sqrt(torch.sum(pred[:,5:8]**2,dim=1,keepdim=True)))
        phi_pred = torch.fmod(pred[:,8],2*np.pi)

        axis_act =  labels_train[:,:2]
        trans_act = labels_train[:,2:5]
        rot_act = labels_train[:,5:8]/(1e-5+torch.sqrt(torch.sum(labels_train[:,5:8]**2,dim=1,keepdim=True)))
        phi_act = torch.fmod(labels_train[:,8],2*np.pi)

        loss = torch.mean((axis_pred-axis_act)**2) + torch.mean((trans_pred-trans_act)**2) + torch.mean((rot_pred-rot_act)**2) + torch.mean((phi_act-phi_pred)**2)
        

        loss.backward()
        optim.step()
        if(itr % 1000 == 0):
            
            axis_pred =  pred_cv[:,:2]
            trans_pred = pred_cv[:,2:5]
            rot_pred = pred_cv[:,5:8]/(1e-5+torch.sqrt(torch.sum(pred_cv[:,5:8]**2,dim=1,keepdim=True)))
            phi_pred = torch.fmod(pred_cv[:,8],2*np.pi)

            axis_act =  labels_cv[:,:2]
            trans_act = labels_cv[:,2:5]
            rot_act = labels_cv[:,5:8]/(1e-5+torch.sqrt(torch.sum(labels_cv[:,5:8]**2,dim=1,keepdim=True)))
            phi_act = torch.fmod(labels_cv[:,8],2*np.pi)

            axis_loss = torch.mean((axis_pred-axis_act)**2)
            trans_loss = torch.mean((trans_pred-trans_act)**2) 
            rot_loss = torch.mean((rot_pred-rot_act)**2) 
            phi_loss = torch.mean((phi_act-phi_pred)**2) 

            cv_loss = axis_loss + trans_loss + rot_loss + phi_loss

            print('Iteration: {} | Loss (Train): {} | Loss (CV): {}'.format(itr, loss.item(), cv_loss.item()))
            print('CV - axis {} trans {} rot {} phi {}'.format(axis_loss.item(),trans_loss.item(),rot_loss.item(),phi_loss.item()))
   
    return model

In [26]:
input_size = train_data.shape[1]
hidden_size = 60
iterations = 10000
lr = 0.0001
satnet_params = [input_size, hidden_size]

In [27]:
models = {}
sat_id = np.random.randint(600)
print('\nSatellite ID:', sat_id,'\n')
sat_data = [train[sat_id].values, cv[sat_id].values, labels_train[sat_id].values, labels_cv[sat_id].values]
models[sat_id] = train_satnet(sat_data, satnet_params, iterations, lr,l2 = 1e-5)


Satellite ID: 464 

Iteration: 0 | Loss (Train): 4.730100154876709 | Loss (CV): 0.24831189215183258
CV - axis 0.11112047731876373 trans 0.09274056553840637 rot 0.044450268149375916 phi 5.794747153231583e-07
Iteration: 1000 | Loss (Train): 0.5197770595550537 | Loss (CV): 4.0059814453125
CV - axis 8.068723400356248e-05 trans 9.703503019409254e-05 rot 0.6182366609573364 phi 3.3875670433044434
Iteration: 2000 | Loss (Train): 0.2726769745349884 | Loss (CV): 4.7191267013549805
CV - axis 7.742835441604257e-05 trans 9.679208596935496e-05 rot 0.6327710151672363 phi 4.086181640625
Iteration: 3000 | Loss (Train): 0.27124497294425964 | Loss (CV): 4.755521297454834
CV - axis 5.227794827078469e-05 trans 6.796960951760411e-05 rot 0.628968358039856 phi 4.1264328956604
Iteration: 4000 | Loss (Train): 0.2707412540912628 | Loss (CV): 4.715027809143066
CV - axis 1.2558839443954639e-05 trans 1.9669761968543753e-05 rot 0.6165617108345032 phi 4.098433971405029
Iteration: 5000 | Loss (Train): 0.2744807302951

In [28]:
satnets = pickle_load('../data/satnets.pickle')

In [44]:
def cv_score(sat_id):
    sat_cv = torch.from_numpy(cv[sat_id].values).float()
    pred_cv = satnets[sat_id](sat_cv)
    mse = pred_cv - torch.from_numpy(labels_cv[sat_id].values).float()
    mse = torch.mean(mse ** 2).item()
    return mse

In [51]:
scores = []
for sat_id, _ in satnets.items():
    score = cv_score(sat_id)
    scores.append(score)
    print(sat_id, score)

0 0.02037404477596283
1 0.5222344398498535
2 0.22667211294174194
3 0.07419899106025696
4 0.10231594741344452
5 0.10988763719797134
6 0.1876893788576126
7 0.39337870478630066
8 0.1258256733417511
9 0.13783220946788788
10 0.000905217370018363
11 0.21004505455493927
12 0.10607798397541046
13 0.06882281601428986
14 0.2971198856830597
15 0.09913583844900131
16 0.14530642330646515
17 0.32698193192481995
18 0.009965124540030956
19 0.028068194165825844
20 0.14388258755207062
21 0.23603878915309906
22 0.10253661125898361
23 0.10837914794683456
24 0.010680575855076313
25 0.10977118462324142
26 1.1370787620544434
27 0.11047343164682388
28 0.02859400399029255
29 0.02570461481809616
30 0.05909551680088043
31 0.14499376714229584
32 0.05390481278300285
33 0.10786513984203339
34 0.06427127122879028
35 0.8254964351654053
36 0.11031884700059891
37 0.15440674126148224
38 0.09727097302675247
39 0.015676429495215416
40 0.10112927854061127
41 0.14696034789085388
42 0.17483873665332794
43 0.13476671278476715

In [55]:
np.mean(scores)

0.2526441873541539