In [2]:
import numpy as np
import pandas as pd
import os
import csv
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn import preprocessing
import time
from geopy.distance import distance, lonlat

In [3]:
class GRU_Handler:
    """ A helper class to train, test and diagnose the GRU"""

    def __init__(self, model, loss_fn, optimizer, scheduler):
        self.model = model
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.train_losses = []
        self.val_losses = []
        
    def train(
        self,
        train_loader,
        val_loader=None,
        batch_size=100,
        n_epochs=15,
        input_dim=57,
        seq_len=89,
        print_step=100
    ):
        self.train_losses = []
        self.val_losses = []
        self.distances = []
        
        start_time = time.time()
        
        for epoch in range(n_epochs):
            train_loss = 0
            batch = 0
    
            for data, target in train_loader:
                batch += 1
                data = data.view(-1, seq_len, input_dim).requires_grad_().cpu()
                target = target.cpu()
                
                # Clear gradients w.r.t. parameters
                self.optimizer.zero_grad()
        
                # Forward pass to get outputs
                y_pred = self.predict(data)
                
                # Calculate Loss
                # Convert y_pred and target to 2D tensors
                y_pred = y_pred.view(-1, 2)
                target = target.view(-1, 2)

                loss = self.loss_fn(y_pred, target)
                train_loss += loss.item()
                
                # Getting gradients w.r.t. parameters
                loss.backward()
        
                # Updating parameters
                self.optimizer.step()
            
            train_loss /= batch
#             self.scheduler.step()
            self.train_losses.append(train_loss)

            self.validation(val_loader)
            self.scheduler.step(self.val_losses[-1])

            if epoch % print_step == 0:
                mean_distance_error = self.distance_error()
                self.distances.append(mean_distance_error)
                elapsed = time.time() - start_time
                start_time = time.time()
                print(
                    "Epoch %d Train loss: %.10f. Validation loss: %.10f. Elapsed time: %.5fs. Distance Error: %.5f km"
                    % (epoch + 1, train_loss, self.val_losses[-1], elapsed, mean_distance_error)
                )

    def predict(self, x_data):
        y_pred = self.model(x_data)
        return y_pred

    def validation(self, val_loader):
        if val_loader is None:
            return
        with torch.no_grad():
            val_loss = 0
            batch = 0
            for data, target in val_loader:
                batch += 1
                y_pred = self.model(data)
                loss = self.loss_fn(y_pred, target)
                val_loss += loss.item()
            val_loss /= batch
            self.val_losses.append(val_loss)

    def evaluate(self, test_loader, seq_len=89, feed_forward=False):
        with torch.no_grad():
            test_loss = 0
            batch = 0
            actual, predicted = [], []
            for data, target in test_loader:
                batch += 1
                y_pred = self.model(data, feed_forward)
                loss = self.loss_fn(y_pred, target)
                test_loss += loss.item()
                y_pred = y_pred.view(-1, 2)
                target = target.view(-1, 2)
                y_pred, target = scaler.unnormalize(y_pred.cpu().detach().numpy()), scaler.unnormalize(target.cpu().detach().numpy())
                actual.append(target)
                predicted.append(y_pred)
    
            test_loss /= batch
            actual = np.concatenate(actual, axis=0)
            predicted = np.concatenate(predicted, axis=0)
            actual /= 10
            predicted /= 10
            actual3d = np.reshape(actual, (-1, seq_len, 2))
            predicted3d = np.reshape(predicted, (-1, seq_len, 2))
            return actual3d, predicted3d, test_loss

    def distance_error(self):
        with torch.no_grad():
            mean_error = 0
            batch = 0
            for data, target in train_loader:
                batch += 1
                y_pred = self.model(data, feed_forward=True)
                
                y_pred = y_pred.view(-1, 2)
                target = target.view(-1, 2)
                
                target_copy = target.detach().clone().cpu().numpy()
                y_pred_copy = y_pred.detach().clone().cpu().numpy()

                mask = (target_copy[:,0] != 0)
                target_copy = target_copy[mask]
                y_pred_copy = y_pred_copy[mask]

                target_copy, y_pred_copy = scaler.unnormalize(target_copy), scaler.unnormalize(y_pred_copy)

                target_copy /= 10
                y_pred_copy /= 10

                a = pd.DataFrame(target_copy)
                b = pd.DataFrame(y_pred_copy, columns=[2, 3])
                c = pd.concat([a, b], axis=1)
                c['distance'] = c.apply(lambda x : distance(lonlat(x.loc[1], x.loc[0]), lonlat(x.loc[3], x.loc[2])).km, axis=1)

                mean_error += np.mean(c['distance'])
            
            mean_error /= batch
            
            return mean_error
            
    def plot_losses(self):
        plt.plot(self.train_losses, label="Training loss")
        plt.plot(self.val_losses, label="Validation loss")
        plt.legend()
        plt.title("Losses for Training and Validation")
        
    def save_model(self, filename):
        # Save the Model
        torch.save(model.state_dict(), filename)