In [1]:
import os
import sys
sys.path.append("/Users/shashanks./Downloads/Installations/ddn/")
import warnings
warnings.filterwarnings('ignore')

import torch
import numpy as np
import scipy.special
import torch.nn as nn
import matplotlib.pyplot as plt

from scipy.linalg import block_diag
from torch.utils.data import Dataset, DataLoader
from bernstein import bernstein_coeff_order10_new
from ddn.pytorch.node import AbstractDeclarativeNode

#### CUDA Initializations

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

Using cpu device


#### Initializations

In [3]:
t_fin = 4.8
num = 12

tot_time = np.linspace(0.0, t_fin, num)
tot_time_copy = tot_time.reshape(num, 1)
P, Pdot, Pddot = bernstein_coeff_order10_new(10, tot_time_copy[0], tot_time_copy[-1], tot_time_copy)
nvar = np.shape(P)[1]

In [4]:
A_eq_mat = np.vstack((P[0], Pdot[0], Pddot[0], P[-1], Pdot[-1], Pddot[-1]))
A_eq_np = block_diag(A_eq_mat, A_eq_mat)
Q_np = 10 * block_diag(np.dot(Pddot.T, Pddot), np.dot(Pddot.T, Pddot))
q_np = np.zeros(2 * nvar)

#### QPNode

In [5]:
class QPNode(AbstractDeclarativeNode):
    def __init__(self, Q_np, q_np, A_eq_np, rho=1.0, nvar=22, maxiter=1000):
        super().__init__()
        self.rho = rho
        self.nvar = nvar
        self.maxiter = maxiter
        self.Q = torch.tensor(Q_np, dtype=torch.double).to(device)
        self.q = torch.tensor(q_np, dtype=torch.double).to(device)
        self.A = torch.tensor(A_eq_np, dtype=torch.double).to(device)
    
    def objective(self, b, lamda, y):
        """
        b: (B x 12)
        lamda: (B x 22)
        y: (B x 22)
        """
        lamda = lamda.transpose(0, 1)
        y = y.transpose(0, 1)
        cost_mat = self.rho * torch.matmul(self.A.T, self.A) + self.Q
        lincost_mat = -self.rho * torch.matmul(b, self.A).T + self.q.view(-1, 1) - lamda
        f = 0.5 * torch.diag(torch.matmul(y.T, torch.matmul(cost_mat, y))) + torch.diag(torch.matmul(lincost_mat.T, y))
        return f
    
    def compute_augmented_lagrangian(self, b, lamda):
        """
        b: (12,)
        lamda: (22,)
        """
        cost_mat = self.rho * torch.matmul(self.A.T, self.A) + self.Q
        lincost_mat = -self.rho * torch.matmul(b, self.A).T + self.q - lamda
        lincost_mat = lincost_mat.view(-1, 1)
        sol, _ = torch.solve(lincost_mat, -cost_mat)
        sol = sol.view(-1)
        res = torch.matmul(self.A, sol) - b
        return sol, res
    
    def optimize(self, b, lamda):
        sol, res = self.compute_augmented_lagrangian(b, lamda)
        for i in range(0, self.maxiter):
            sol, res = self.compute_augmented_lagrangian(b, lamda)
            lamda -= self.rho * torch.matmul(self.A.T, res)
        return sol
    
    def solve(self, b, lamda):
        batch_size, _ = b.size()
        y = torch.zeros(batch_size, 22, dtype=torch.double).to(device)
        for i in range(batch_size):
            b_cur = b[i]
            lamda_cur = lamda[i]
            sol = self.optimize(b_cur, lamda_cur)
            y[i, :] = sol
        return y, None

#### PyTorch Declarative Function

In [6]:
class QPFunction(torch.autograd.Function):
    """Generic declarative autograd function.
    Defines the forward and backward functions. Saves all inputs and outputs,
    which may be memory-inefficient for the specific problem.
    
    Assumptions:
    * All inputs are PyTorch tensors
    * All inputs have a single batch dimension (b, ...)
    """
    @staticmethod
    def forward(ctx, problem, *inputs):
        output, solve_ctx = torch.no_grad()(problem.solve)(*inputs)
        ctx.save_for_backward(output, *inputs)
        ctx.problem = problem
        ctx.solve_ctx = solve_ctx
        return output.clone()

    @staticmethod
    def backward(ctx, grad_output):
        output, *inputs = ctx.saved_tensors
        problem = ctx.problem
        solve_ctx = ctx.solve_ctx
        output.requires_grad = True
        inputs = tuple(inputs)
        grad_inputs = problem.gradient(*inputs, y=output, v=grad_output,
            ctx=solve_ctx)
        return (None, *grad_inputs)

#### PyTorch Declarative Layer

In [7]:
class DeclarativeLayer(torch.nn.Module):
    """Generic declarative layer.
    
    Assumptions:
    * All inputs are PyTorch tensors
    * All inputs have a single batch dimension (b, ...)
    Usage:
        problem = <derived class of *DeclarativeNode>
        declarative_layer = DeclarativeLayer(problem)
        y = declarative_layer(x1, x2, ...)
    """
    def __init__(self, problem):
        super(DeclarativeLayer, self).__init__()
        self.problem = problem
        
    def forward(self, *inputs):
        return QPFunction.apply(self.problem, *inputs)

#### TrajNet

In [8]:
class TrajNet(nn.Module):
    def __init__(self, opt_layer, P, input_size=16, hidden_size=64, output_size=12, nvar=11, t_obs=8):
        super(TrajNet, self).__init__()
        self.nvar = nvar
        self.t_obs = t_obs
        self.P = torch.tensor(P, dtype=torch.double).to(device)
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, output_size)
        self.opt_layer = opt_layer
        self.activation = nn.ReLU()
        self.mask = torch.tensor([[1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0]], dtype=torch.double).to(device)
    
    def forward(self, x, b):
        batch_size, _ = x.size()
        out = self.activation(self.linear1(x))
        b_pred = self.linear2(out)
        b_gen = self.mask * b + (1 - self.mask) * b_pred
        
        # Run optimization
        lamda = torch.zeros(batch_size, 2 * self.nvar, dtype=torch.double).to(device)
        sol = self.opt_layer(b_gen, lamda)
        
        # Compute final trajectory
        x_pred = torch.matmul(self.P, sol[:, :self.nvar].transpose(0, 1))
        y_pred = torch.matmul(self.P, sol[:, self.nvar:].transpose(0, 1))
        
        x_pred = x_pred.transpose(0, 1)
        y_pred = y_pred.transpose(0, 1)
        out = torch.cat([x_pred, y_pred], dim=1)
        return out

#### Trajectory Data loader

In [9]:
class TrajectoryDataset(Dataset):
    def __init__(self, root_dir, t_obs=8, dt=0.4):
        self.root_dir = root_dir
        self.t_obs = t_obs
        self.dt = dt
    
    def __len__(self):
        return len(os.listdir(self.root_dir))
    
    def get_vel(self, pos):
        return (pos[-1] - pos[-2]) / self.dt
    
    def get_acc(self, vel):
        return (vel[-1] - vel[-2]) / self.dt
    
    def __getitem__(self, idx):
        file_name = "{}.npy".format(idx)
        file_path = os.path.join(self.root_dir, file_name)
        
        data = np.load(file_path, allow_pickle=True).item()
        x_traj = data['x_traj']
        y_traj = data['y_traj']
        
        x_inp = x_traj[:self.t_obs]
        y_inp = y_traj[:self.t_obs]
        x_fut = x_traj[self.t_obs:]
        y_fut = y_traj[self.t_obs:]
        
        vx_beg = (x_inp[-1] - x_inp[-2]) / self.dt
        vy_beg = (y_inp[-1] - y_inp[-2]) / self.dt
        
        vx_beg_prev = (x_inp[-2] - x_inp[-3]) / self.dt
        vy_beg_prev = (y_inp[-2] - y_inp[-3]) / self.dt
        
        ax_beg = (vx_beg - vx_beg_prev) / self.dt
        ay_beg = (vy_beg - vy_beg_prev) / self.dt

        traj_inp = np.dstack((x_inp, y_inp)).flatten()
        traj_out = np.hstack((x_fut, y_fut)).flatten()
        b_inp = np.array([x_inp[-1], vx_beg, ax_beg, 0, 0, 0, y_inp[-1], vy_beg, ay_beg, 0, 0, 0])
        return torch.tensor(traj_inp), torch.tensor(traj_out), torch.tensor(b_inp)

In [10]:
train_dataset = TrajectoryDataset("../datasets/data/", 8)
train_loader = DataLoader(train_dataset, batch_size=20, shuffle=True, num_workers=0)

In [11]:
test_dataset = TrajectoryDataset("../datasets1/data/", 8)
test_loader = DataLoader(test_dataset, batch_size=20, shuffle=True, num_workers=0)

In [12]:
for batch_num, data in enumerate(train_loader):
    traj_inp, traj_out, b_inp = data
    print(traj_inp.size())
    break

torch.Size([20, 16])


#### Model Definition

In [13]:
problem = QPNode(Q_np, q_np, A_eq_np)
qp_layer = DeclarativeLayer(problem)

model = TrajNet(qp_layer, P)
model = model.double()
model = model.to(device)

#### Training

In [14]:
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

In [15]:
epoch_train_loss = []
num_epochs = 50

for epoch in range(num_epochs):
    train_loss = []
    for batch_num, data in enumerate(train_loader):
        traj_inp, traj_out, b_inp = data
        traj_inp = traj_inp.to(device)
        traj_out = traj_out.to(device)
        b_inp = b_inp.to(device)

        out = model(traj_inp, b_inp)
        loss = criterion(out, traj_out)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss.append(loss.item())
        if batch_num % 10 == 0:
            print("Epoch: {}, Batch: {}, Loss: {}".format(epoch, batch_num, loss.item()))
    
    mean_loss = np.mean(train_loss)
    epoch_train_loss.append(mean_loss)
    print("Epoch: {}, Mean Loss: {}".format(epoch, mean_loss))
    print("-"*100)

Epoch: 0, Batch: 0, Loss: 33.013079039495516
Epoch: 0, Batch: 10, Loss: 29.159764343020683
Epoch: 0, Batch: 20, Loss: 23.020107455320634
Epoch: 0, Batch: 30, Loss: 15.766522517892419
Epoch: 0, Batch: 40, Loss: 11.178599180999472
Epoch: 0, Mean Loss: 19.5356222082199
----------------------------------------------------------------------------------------------------
Epoch: 1, Batch: 0, Loss: 18.448720579747395
Epoch: 1, Batch: 10, Loss: 14.640683673131361
Epoch: 1, Batch: 20, Loss: 12.927431889030574
Epoch: 1, Batch: 30, Loss: 6.135589737376947
Epoch: 1, Batch: 40, Loss: 5.089817243891978
Epoch: 1, Mean Loss: 8.627628542286864
----------------------------------------------------------------------------------------------------
Epoch: 2, Batch: 0, Loss: 3.894518983964221
Epoch: 2, Batch: 10, Loss: 6.02254978380591
Epoch: 2, Batch: 20, Loss: 6.378049243983864
Epoch: 2, Batch: 30, Loss: 11.613679196017696
Epoch: 2, Batch: 40, Loss: 7.266596082304288
Epoch: 2, Mean Loss: 6.173086442179863
--

KeyboardInterrupt: 

#### Testing Code

In [17]:
def plot_traj(i, traj_inp, traj_out, traj_pred):
    traj_inp = traj_inp.numpy()
    traj_out = traj_out.numpy()
    traj_pred = traj_pred.numpy()
    
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    ax.scatter(traj_inp[::2], traj_inp[1::2], label='Inp traj')
    ax.scatter(traj_out[:12], traj_out[12:], label='GT')
    ax.scatter(traj_pred[:12], traj_pred[12:], label='Pred')
    ax.legend()
    ax.set_xlim([-20, 20])
    ax.set_ylim([-20, 20])
    plt.savefig('./results_fc_v2/{}.png'.format(i))
    plt.close()

In [18]:
with torch.no_grad():
    cnt = 0
    test_loss = []
    for batch_num, data in enumerate(test_loader):
        traj_inp, traj_out, b_inp = data
        traj_inp = traj_inp.to(device)
        traj_out = traj_out.to(device)
        b_inp = b_inp.to(device)

        out = model(traj_inp, b_inp)
        loss = criterion(out, traj_out)
        
        test_loss.append(loss.item())
        print("Batch: {}, Loss: {}".format(batch_num, loss.item()))
        
        for i in range(traj_inp.size()[0]):
            plot_traj(cnt, traj_inp[i], traj_out[i], out[i])
            cnt += 1

mean_loss = np.mean(test_loss)
print("Epoch Mean Test Loss: {}".format(mean_loss))

Batch: 0, Loss: 0.7403452611831383
Batch: 1, Loss: 0.8878560839394235
Batch: 2, Loss: 0.7715245435679279
Batch: 3, Loss: 0.7239928365878544
Batch: 4, Loss: 0.5795446747985719
Batch: 5, Loss: 0.6835180643143536
Batch: 6, Loss: 0.7633167415103669
Batch: 7, Loss: 0.7227271313081662
Batch: 8, Loss: 0.6911066513488902
Batch: 9, Loss: 0.7167628329901647
Batch: 10, Loss: 0.6758446407676194
Batch: 11, Loss: 0.8188903861222745
Batch: 12, Loss: 0.8621790973896707
Batch: 13, Loss: 0.755437868409265
Batch: 14, Loss: 0.7828438698031521
Batch: 15, Loss: 0.9055853502940451
Batch: 16, Loss: 0.6136821268468322
Batch: 17, Loss: 0.6767238409357347
Batch: 18, Loss: 0.9784208928677302
Batch: 19, Loss: 0.9911032979739381
Batch: 20, Loss: 1.0969428945865822
Batch: 21, Loss: 0.6040111059618366
Batch: 22, Loss: 0.7733043556507085
Batch: 23, Loss: 0.7934813396594625
Batch: 24, Loss: 0.7874994377472642
Batch: 25, Loss: 0.8985541685982532
Batch: 26, Loss: 0.6652912159243821
Batch: 27, Loss: 0.8470976246177196
Bat