# Load Data and Modules

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [32]:
from preprocessing_utils import pickle_save, pickle_load

In [3]:
from sklearn.model_selection import train_test_split
from scipy.spatial.transform import Rotation as R

In [4]:
ellipse_cols = ['sat_id', 'ellipse_id', 'major', 'minor', 'loss', 'r00', 'r01', 'r02', 'r10', 'r11', 'r12', 'r20', 'r21', 'r22',
                't1', 't2', 't3']

In [5]:
data_sim = pd.read_csv('../data/data_sim.csv')
data_real = pd.read_csv('../data/data_real.csv')

In [6]:
data_sim = pd.DataFrame(data_sim.values, columns = ellipse_cols)
data_real = pd.DataFrame(data_real.values, columns = ellipse_cols)

In [7]:
data_sim.shape, data_real.shape

((38629, 17), (26790, 17))

# Axis Angle Representation

In [8]:
def convert_matrix(data):
    matrix = data.loc[:, 'r00':'r22'].values.reshape(-1, 3, 3)
    rotvec = np.array([R.from_matrix(m).as_rotvec() for m in matrix])
    phis = np.linalg.norm(rotvec,axis=1,keepdims=True)
    rotvec = rotvec/phis
    data['r1'], data['r2'], data['r3'] = np.transpose(rotvec)
    data['phi'] = phis[:,0]
    print(rotvec.shape)
    data.drop(['r00', 'r01', 'r02', 'r10', 'r11', 'r12', 'r20', 'r21', 'r22'], axis=1, inplace=True)

In [9]:
a = np.random.randn(2,3)
norm = np.linalg.norm(a,axis=1,keepdims=True)
np.sum((a/norm)**2,axis=1)
norm[:,0]

array([0.76247935, 0.38325686])

In [10]:
convert_matrix(data_sim)
convert_matrix(data_real)

(38629, 3)
(26790, 3)


# Train-Test Segregation

In [11]:
merged_data = data_sim.merge(data_real, how='outer', on=['sat_id', 'ellipse_id'], suffixes=['_sim', '_real'], indicator=True)
merge = merged_data['_merge']
merged_data.drop(['ellipse_id', 'loss_sim', 'loss_real'], axis=1, inplace=True)

In [12]:
train_data = merged_data[merged_data['_merge'] == 'both'].drop('_merge', axis=1).reset_index(drop=True)
test_data = merged_data[merged_data['_merge'] == 'left_only'].drop('_merge', axis=1).dropna(axis=1).reset_index(drop=True)

In [14]:
sat_ids_train = train_data['sat_id']
sat_ids_test = test_data['sat_id']

In [15]:
train_data.drop('sat_id', axis=1, inplace=True)
test_data.drop('sat_id', axis=1, inplace=True)

In [17]:
labels = train_data.loc[:, 'major_real':]
train_data = train_data.loc[:, :'phi_sim']

In [18]:
cols = ['major', 'minor', 't1', 't2', 't3', 'r1', 'r2', 'r3','phi']

In [19]:
train_data = pd.DataFrame(train_data.values, columns=cols)
test_data = pd.DataFrame(test_data.values, columns=cols)
labels = pd.DataFrame(labels.values, columns=cols)

In [20]:
train_data.shape, test_data.shape, labels.shape

((26790, 9), (11839, 9), (26790, 9))

# Train-CV Split per Satellite

In [21]:
split_ratio = 0.1

In [22]:
train = {}
cv = {}
labels_train = {}
labels_cv = {}

In [23]:
for sat_id in range(600):
    sat_data = train_data[sat_ids_train == sat_id]
    sat_labels = labels[sat_ids_train == sat_id]
    train_size = int(sat_data.shape[0] * (1 - split_ratio))
    train[sat_id] = sat_data.iloc[:train_size, :].reset_index(drop=True)
    cv[sat_id] = sat_data.iloc[train_size:, :].reset_index(drop=True)
    labels_train[sat_id] = sat_labels.iloc[:train_size, :].reset_index(drop=True)
    labels_cv[sat_id] = sat_labels.iloc[train_size:, :].reset_index(drop=True)

In [24]:
def relative_error(sat_id,field):
    delta = np.mean(np.abs(labels_train[sat_id][field]))
    den = np.mean(np.abs(train[sat_id][field]))
    return delta/den

# Network Architecture

In [25]:
class SatNet(nn.Module):
    
    def __init__(self, input_size, hidden_size):
        super(SatNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, input_size)
    
    def forward(self, x):
        a1 = torch.relu(self.fc1(x))
        a2 = self.fc2(a1)
        return a2+x

In [27]:
def train_satnet(sat_data, satnet_params, iterations, lr,l2 = 1e-8):
    train, cv, labels_train, labels_cv = [torch.from_numpy(data).float() for data in sat_data]
    input_size, hidden_size = satnet_params
    model = SatNet(input_size, hidden_size)
    optim = torch.optim.Adam(model.parameters(), lr,weight_decay = l2)
    
    for itr in range(iterations):
        optim.zero_grad()
        pred = model(train)
        pred_cv = model(cv)

        axis_pred =  pred[:,:2]
        trans_pred = pred[:,2:5]
        rot_pred = pred[:,5:8]/(1e-5+torch.sqrt(torch.sum(pred[:,5:8]**2,dim=1,keepdim=True)))
        phi_pred = torch.fmod(pred[:,8],2*np.pi)

        axis_act =  labels_train[:,:2]
        trans_act = labels_train[:,2:5]
        rot_act = labels_train[:,5:8]/(1e-5+torch.sqrt(torch.sum(labels_train[:,5:8]**2,dim=1,keepdim=True)))
        phi_act = torch.fmod(labels_train[:,8],2*np.pi)

        loss = torch.mean((axis_pred-axis_act)**2) + torch.mean((trans_pred-trans_act)**2) + torch.mean((rot_pred-rot_act)**2) + torch.mean((phi_act-phi_pred)**2)
        

        loss.backward()
        optim.step()
        if(itr % 1000 == 0):
            
            axis_pred =  pred_cv[:,:2]
            trans_pred = pred_cv[:,2:5]
            rot_pred = pred_cv[:,5:8]/(1e-5+torch.sqrt(torch.sum(pred_cv[:,5:8]**2,dim=1,keepdim=True)))
            phi_pred = torch.fmod(pred_cv[:,8],2*np.pi)

            axis_act =  labels_cv[:,:2]
            trans_act = labels_cv[:,2:5]
            rot_act = labels_cv[:,5:8]/(1e-5+torch.sqrt(torch.sum(labels_cv[:,5:8]**2,dim=1,keepdim=True)))
            phi_act = torch.fmod(labels_cv[:,8],2*np.pi)

            axis_loss = torch.mean((axis_pred-axis_act)**2)
            trans_loss = torch.mean((trans_pred-trans_act)**2) 
            rot_loss = torch.mean((rot_pred-rot_act)**2) 
            phi_loss = torch.mean((phi_act-phi_pred)**2) 

            cv_loss = axis_loss + trans_loss + rot_loss + phi_loss

            print('Iteration: {} | Loss (Train): {} | Loss (CV): {}'.format(itr, loss.item(), cv_loss.item()))
            print('CV - axis {} trans {} rot {} phi {}'.format(axis_loss.item(),trans_loss.item(),rot_loss.item(),phi_loss.item()))
   
    return model

In [28]:
input_size = train_data.shape[1]
hidden_size = 60
iterations = 10000
lr = 0.0001
satnet_params = [input_size, hidden_size]

In [30]:
models = {}
sat_id = np.random.randint(600)
print('\nSatellite ID:', sat_id,'\n')
sat_data = [train[sat_id].values, cv[sat_id].values, labels_train[sat_id].values, labels_cv[sat_id].values]
models[sat_id] = train_satnet(sat_data, satnet_params, iterations, lr,l2 = 1e-5)


Satellite ID: 255 

Iteration: 0 | Loss (Train): 10.313447952270508 | Loss (CV): 6.7493510246276855
CV - axis 0.017890410497784615 trans 0.167100191116333 rot 1.120578408241272 phi 5.443781852722168
Iteration: 1000 | Loss (Train): 0.4488983154296875 | Loss (CV): 1.8658528327941895
CV - axis 0.0009981930488720536 trans 0.008395507000386715 rot 0.67507004737854 phi 1.1813890933990479
Iteration: 2000 | Loss (Train): 0.38977351784706116 | Loss (CV): 1.6959383487701416
CV - axis 0.001066126162186265 trans 0.00833609513938427 rot 0.4094248414039612 phi 1.277111291885376
Iteration: 3000 | Loss (Train): 0.36874112486839294 | Loss (CV): 1.6165990829467773
CV - axis 0.0007534809992648661 trans 0.007567178923636675 rot 0.31479009985923767 phi 1.2934882640838623
Iteration: 4000 | Loss (Train): 0.3444545865058899 | Loss (CV): 1.499233603477478
CV - axis 0.0005369294667616487 trans 0.008473901078104973 rot 0.2133978307247162 phi 1.276824951171875
Iteration: 5000 | Loss (Train): 0.33494389057159424 

In [33]:
models = {}
for sat_id in range(600):
    print('\nSatellite ID:', sat_id,'\n')
    sat_data = [train[sat_id].values, cv[sat_id].values, labels_train[sat_id].values, labels_cv[sat_id].values]
    models[sat_id] = train_satnet(sat_data, satnet_params, iterations, lr, l2=1e-5)
    pickle_save(models, '../data/satnets.pickle')


Satellite ID: 0 

Iteration: 0 | Loss (Train): 1.322561502456665 | Loss (CV): 0.17919306457042694
CV - axis 0.004270268138498068 trans 0.1131247952580452 rot 0.060729119926691055 phi 0.0010688744951039553
Iteration: 1000 | Loss (Train): 0.08173535019159317 | Loss (CV): 0.0010363210458308458
CV - axis 1.414160942658782e-05 trans 3.418833500745677e-07 rot 8.513706416124478e-05 phi 0.0009367004968225956
Iteration: 2000 | Loss (Train): 0.07466565817594528 | Loss (CV): 0.00010920071508735418
CV - axis 1.475288809160702e-05 trans 3.203429912446154e-07 rot 8.268362580565736e-05 phi 1.1443858056736644e-05
Iteration: 3000 | Loss (Train): 0.07466395199298859 | Loss (CV): 0.00011188250209670514
CV - axis 1.5631547285011038e-05 trans 3.226311946491478e-07 rot 8.349534618901089e-05 phi 1.2432983567123301e-05
Iteration: 4000 | Loss (Train): 0.07465741038322449 | Loss (CV): 0.00012218212941661477
CV - axis 1.9527760741766542e-05 trans 3.167228044276271e-07 rot 8.859350782586262e-05 phi 1.37441411425