In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
from preprocessing_utils import process_sat_data, pickle_save, pickle_load

In [3]:
data_train = pd.read_csv('../data/train.csv')
data_test = pd.read_csv('../data/track1/test.csv')

In [4]:
train_data = data_train.copy(deep=True)
test_data = data_test.copy(deep=True)

In [5]:
process_sat_data(train_data, test_data, scale=10000)

In [6]:
import torch
import torch.nn as nn

In [7]:
from torch.utils.data import DataLoader
import torch.optim as optim

In [8]:
class SatNet(nn.Module):
    def __init__(self, name, omega, mean, diff):
        super(SatNet, self).__init__()
        self.name = name
        self.w = nn.Parameter(omega)
        self.off = nn.Parameter(mean)
        self.a1 = nn.Parameter(diff)
        self.a2 = nn.Parameter(diff)
        self.p1 = nn.Parameter(torch.randn(1))
        self.p2 = nn.Parameter(torch.randn(1))
    def forward(self, t):
        s1 = self.a1 * torch.cos(self.p1 + self.w * t)
        #s2 = self.a2 * torch.cos(self.p2 + self.w * t * 2)
        s = self.off + s1 #+ s2
        v1 = - self.a1 * self.w * torch.sin(self.w * t + self.p1)
        #v2 = - self.a2 * self.w * torch.sin(self.w * t + self.p2) * 2
        v = v1 #+ v2
        return [s, v]

In [9]:
def smape_loss(satellite_predicted_values,satellite_true_values):
    return torch.mean(torch.abs((satellite_predicted_values - satellite_true_values) 
        / (torch.abs(satellite_predicted_values) + torch.abs(satellite_true_values))))

In [10]:
def get_initial_values(ellipse_data):
    omega = 2 * np.pi / (torch.max(ellipse_data[:, 6]) - torch.min(ellipse_data[:, 6])) # wT = 2(pi)
    mean = torch.mean(ellipse_data[:,[0,1,2]], dim=0)
    amplitude_x = 0.5 * (torch.max(ellipse_data[:,0]) - torch.min(ellipse_data[:,0]))
    amplitude_y = 0.5 * (torch.max(ellipse_data[:,1]) - torch.min(ellipse_data[:,1]))
    amplitude_z = 0.5 * (torch.max(ellipse_data[:,2]) - torch.min(ellipse_data[:,2]))
    return [omega, mean, amplitude_x, amplitude_y, amplitude_z]

In [11]:
def print_status(itr, ellipse_data, loss_x, loss_y, loss_z, pred):
    x_pred, y_pred, z_pred, vx_pred, vy_pred, vz_pred = pred
    smape_x = smape_loss(x_pred, ellipse_data[:,0])
    smape_y = smape_loss(y_pred, ellipse_data[:,1])
    smape_z = smape_loss(z_pred, ellipse_data[:,2])
    smape_vx = smape_loss(vx_pred, ellipse_data[:,3])
    smape_vy = smape_loss(vy_pred, ellipse_data[:,4])
    smape_vz = smape_loss(vz_pred, ellipse_data[:,5])
    print('Iteration: ', itr)
    print('Loss: {} (X) | {} (Y) | {} (Z)'.format(loss_x.item(), loss_y.item(), loss_z.item()))
    print('SMAPE (Position): {} (X) | {} (Y) | {} (Z)'.format(smape_x, smape_y, smape_z))
    print('SMAPE (Velocity): {} (X) | {} (Y) | {} (Z)'.format(smape_vx, smape_vy, smape_vz))

In [12]:
df_cols = ['sat_id', 'ellipse_id', 'smape_x', 'smape_y', 'smape_z',
           'a1_x', 'a2_x', 'p1_x', 'p2_x', 'w_x', 'off_x',
           'a1_y', 'a2_y', 'p1_y', 'p2_y', 'w_y', 'off_y',
           'a1_z', 'a2_z', 'p1_z', 'p2_z', 'w_z', 'off_z']

In [13]:
def save_weights(sat_nets, sat_id, ellipse_id, smape_x, smape_y, smape_z, weights):
    row = [sat_id, ellipse_id, smape_x, smape_y, smape_z]
    for axis in ['x', 'y', 'z']:
        model = sat_nets[sat_id][ellipse_id][axis]
        params = [model.a1, model.a2, model.p1, model.p2, model.w, model.off]
        params = [item.item() for item in params]
        row += params
    weights = np.concatenate([weights.values, [row]], axis=0)
    weights = pd.DataFrame(weights, columns=df_cols)
    weights.to_excel('../data/weights.xlsx', index=False)
    return weights

In [18]:
def train(train_data):
    sat_nets = {}
    weights = pd.DataFrame([[-1] * len(df_cols)], columns=df_cols)
    weights.to_excel('../data/weights.xlsx', index=False)
    for sat_id in range(6,7):
        
        print('------- Satellite ID:', sat_id, '------- ')
        cols = ['x_sim', 'y_sim', 'z_sim', 'Vx_sim', 'Vy_sim', 'Vz_sim', 'epoch']
        sat_data = train_data[train_data['sat_id'] == sat_id].loc[:, cols].values
        print('\n## Number of Ellipses: ', int(sat_data.shape[0]/24)+1, '\n')
        sat_nets[sat_id] = {}
        data = DataLoader(sat_data, batch_size=24) # 24 points per ellipse
        data_iters = iter(data)
        
        for ellipse_id, ellipse_data in enumerate(data_iters):
            
            print('\n**** Ellipse ID:', ellipse_id, '****\n')
            ellipse_data = ellipse_data.float()
            if (ellipse_id == 0):
                omega, mean, amplitude_x, amplitude_y, amplitude_z = get_initial_values(ellipse_data)
                sat_nets[sat_id][ellipse_id] = {
                'x': SatNet('x', omega, mean[0], amplitude_x),
                'y': SatNet('y', omega, mean[1], amplitude_y),
                'z': SatNet('z', omega, mean[2], amplitude_z)
                }
            else:
                prev = sat_nets[sat_id][ellipse_id-1]
                sat_nets[sat_id][ellipse_id] = {
                'x': SatNet('x', prev['x'].w, prev['x'].off, prev['x'].a1),
                'y': SatNet('y', prev['y'].w, prev['y'].off, prev['y'].a1),
                'z': SatNet('z', prev['z'].w, prev['z'].off, prev['z'].a1)
                }
            
            
            optim_x = optim.Adam(sat_nets[sat_id][ellipse_id]['x'].parameters(), 0.001, (0.9, 0.999))
            optim_y = optim.Adam(sat_nets[sat_id][ellipse_id]['y'].parameters(), 0.001, (0.9, 0.999))
            optim_z = optim.Adam(sat_nets[sat_id][ellipse_id]['z'].parameters(), 0.001, (0.9, 0.999))
            
            for itr in range(9000):
                optim_x.zero_grad()
                optim_y.zero_grad()
                optim_z.zero_grad()
                
                time = ellipse_data[:,6].float()
                x_pred, vx_pred = sat_nets[sat_id][ellipse_id]['x'](time)
                y_pred, vy_pred = sat_nets[sat_id][ellipse_id]['y'](time)
                z_pred, vz_pred = sat_nets[sat_id][ellipse_id]['z'](time)
                
                loss_sx = smape_loss(x_pred, ellipse_data[:,0])
                loss_sy = smape_loss(y_pred, ellipse_data[:,1])
                loss_sz = smape_loss(z_pred, ellipse_data[:,2])
                loss_vx = smape_loss(vx_pred, ellipse_data[:,3])
                loss_vy = smape_loss(vy_pred, ellipse_data[:,4])
                loss_vz = smape_loss(vz_pred, ellipse_data[:,5])
                loss_x = loss_sx + loss_vx
                loss_y = loss_sy + loss_vy
                loss_z = loss_sz + loss_vz
                
                loss_x.backward()
                loss_y.backward()
                loss_z.backward()
                
                optim_x.step()
                optim_y.step()
                optim_z.step()
                
                if(itr % 1000 == 0):
                    pred = [x_pred, y_pred, z_pred, vx_pred, vy_pred, vz_pred]
                    print_status(itr, ellipse_data, loss_x, loss_y, loss_z, pred)
                    
            weights = save_weights(sat_nets, sat_id, ellipse_id, loss_sx, loss_sy, loss_sz, weights)
    
    return sat_nets

In [20]:
sat_nets = train(train_data)

------- Satellite ID: 6 ------- 

## Number of Ellipses:  18 


**** Ellipse ID: 0 ****

Iteration:  0
Loss: 1.5567963123321533 (X) | 0.6417148113250732 (Y) | 0.7999927997589111 (Z)
SMAPE (Position): 0.7784578204154968 (X) | 0.27821651101112366 (Y) | 0.38843047618865967 (Z)
SMAPE (Velocity): 0.7783384919166565 (X) | 0.363498330116272 (Y) | 0.41156235337257385 (Z)
Iteration:  1000
Loss: 0.3723428249359131 (X) | 0.32150688767433167 (Y) | 0.21926578879356384 (Z)
SMAPE (Position): 0.10616769641637802 (X) | 0.0871230885386467 (Y) | 0.07226164638996124 (Z)
SMAPE (Velocity): 0.26617512106895447 (X) | 0.23438380658626556 (Y) | 0.1470041424036026 (Z)
Iteration:  2000
Loss: 0.3219643831253052 (X) | 0.3155679702758789 (Y) | 0.19887295365333557 (Z)
SMAPE (Position): 0.07385324686765671 (X) | 0.08680790662765503 (Y) | 0.06268690526485443 (Z)
SMAPE (Velocity): 0.24811114370822906 (X) | 0.22876004874706268 (Y) | 0.13618604838848114 (Z)
Iteration:  3000
Loss: 0.31669002771377563 (X) | 0.31244072318077