In [1]:
import os
import sys
sys.path.append("/Users/shashanks./Downloads/Installations/ddn/")
import warnings
warnings.filterwarnings('ignore')

import torch
import numpy as np
import scipy.special
import torch.nn as nn
import matplotlib.pyplot as plt

from ddn.pytorch.node import *
from scipy.linalg import block_diag
from torch.utils.data import Dataset, DataLoader
from bernstein import bernstein_coeff_order10_new

#### CUDA Initializations

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

Using cpu device


#### Initializations

In [3]:
t_fin = 8.0
num = 20

tot_time = np.linspace(0.0, t_fin, num)
tot_time_copy = tot_time.reshape(num, 1)
P, Pdot, Pddot = bernstein_coeff_order10_new(10, tot_time_copy[0], tot_time_copy[-1], tot_time_copy)
nvar = np.shape(P)[1]

In [28]:
P.shape

(20, 11)

In [4]:
A_eq_mat = np.vstack((P[0], Pdot[0], Pddot[0], P[-1], Pdot[-1], Pddot[-1]))
A_eq_np = block_diag(A_eq_mat, A_eq_mat)
Q_np = 10 * block_diag(np.dot(Pddot.T, Pddot), np.dot(Pddot.T, Pddot))
q_np = np.zeros(2 * nvar)

#### QPNode

In [5]:
class QPNode(AbstractDeclarativeNode):
    def __init__(self, Q_np, q_np, A_eq_np, rho=1.0, nvar=22, maxiter=1000):
        super().__init__()
        self.rho = rho
        self.nvar = nvar
        self.maxiter = maxiter
        self.Q = torch.tensor(Q_np, dtype=torch.double).to(device)
        self.q = torch.tensor(q_np, dtype=torch.double).to(device)
        self.A = torch.tensor(A_eq_np, dtype=torch.double).to(device)
    
    def objective(self, b, lamda, y):
        """
        b: (B x 12)
        lamda: (B x 22)
        y: (B x 22)
        """
        lamda = lamda.transpose(0, 1)
        y = y.transpose(0, 1)
        cost_mat = self.rho * torch.matmul(self.A.T, self.A) + self.Q
        lincost_mat = -self.rho * torch.matmul(b, self.A).T + self.q.view(-1, 1) - lamda
        f = 0.5 * torch.diag(torch.matmul(y.T, torch.matmul(cost_mat, y))) + torch.diag(torch.matmul(lincost_mat.T, y))
        return f
    
    def compute_augmented_lagrangian(self, b, lamda):
        """
        b: (12,)
        lamda: (22,)
        """
        cost_mat = self.rho * torch.matmul(self.A.T, self.A) + self.Q
        lincost_mat = -self.rho * torch.matmul(b, self.A).T + self.q - lamda
        lincost_mat = lincost_mat.view(-1, 1)
        sol, _ = torch.solve(lincost_mat, -cost_mat)
        sol = sol.view(-1)
        res = torch.matmul(self.A, sol) - b
        return sol, res
    
    def optimize(self, b, lamda):
        sol, res = self.compute_augmented_lagrangian(b, lamda)
        for i in range(0, self.maxiter):
            sol, res = self.compute_augmented_lagrangian(b, lamda)
            lamda -= self.rho * torch.matmul(self.A.T, res)
        return sol
    
    def solve(self, b, lamda):
        batch_size, _ = b.size()
        y = torch.zeros(batch_size, 22, dtype=torch.double).to(device)
        for i in range(batch_size):
            b_cur = b[i]
            lamda_cur = lamda[i]
            sol = self.optimize(b_cur, lamda_cur)
            y[i, :] = sol
        return y, None

#### PyTorch Declarative Function

In [6]:
class QPFunction(torch.autograd.Function):
    """Generic declarative autograd function.
    Defines the forward and backward functions. Saves all inputs and outputs,
    which may be memory-inefficient for the specific problem.
    
    Assumptions:
    * All inputs are PyTorch tensors
    * All inputs have a single batch dimension (b, ...)
    """
    @staticmethod
    def forward(ctx, problem, *inputs):
        output, solve_ctx = torch.no_grad()(problem.solve)(*inputs)
        ctx.save_for_backward(output, *inputs)
        ctx.problem = problem
        ctx.solve_ctx = solve_ctx
        return output.clone()

    @staticmethod
    def backward(ctx, grad_output):
        output, *inputs = ctx.saved_tensors
        problem = ctx.problem
        solve_ctx = ctx.solve_ctx
        output.requires_grad = True
        inputs = tuple(inputs)
        grad_inputs = problem.gradient(*inputs, y=output, v=grad_output,
            ctx=solve_ctx)
        return (None, *grad_inputs)

#### PyTorch Declarative Layer

In [7]:
class DeclarativeLayer(torch.nn.Module):
    """Generic declarative layer.
    
    Assumptions:
    * All inputs are PyTorch tensors
    * All inputs have a single batch dimension (b, ...)
    Usage:
        problem = <derived class of *DeclarativeNode>
        declarative_layer = DeclarativeLayer(problem)
        y = declarative_layer(x1, x2, ...)
    """
    def __init__(self, problem):
        super(DeclarativeLayer, self).__init__()
        self.problem = problem
        
    def forward(self, *inputs):
        return QPFunction.apply(self.problem, *inputs)

#### TrajNet

In [19]:
class TrajNet(nn.Module):
    def __init__(self, opt_layer, P, input_size=16, hidden_size=64, output_size=12, nvar=11, t_obs=8):
        super(TrajNet, self).__init__()
        self.nvar = nvar
        self.t_obs = t_obs
        self.P = torch.tensor(P, dtype=torch.double).to(device)
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, output_size)
        self.opt_layer = opt_layer
        self.activation = nn.ReLU()
        self.mask = torch.tensor([[1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0]], dtype=torch.double).to(device)
    
    def forward(self, x, b):
        batch_size, _ = x.size()
        out = self.activation(self.linear1(x))
        b_pred = self.linear2(out)
        b_gen = self.mask * b + (1 - self.mask) * b_pred
        
        # Run optimization
        lamda = torch.zeros(batch_size, 2 * self.nvar, dtype=torch.double).to(device)
        sol = self.opt_layer(b_gen, lamda)
        print(sol.size())
        print(torch.matmul(self.P, sol[:, :self.nvar].transpose(0, 1)).size())
        
        # Compute final trajectory
        x_pred = torch.matmul(self.P, sol[:, :self.nvar].transpose(0, 1))[self.t_obs:]
        y_pred = torch.matmul(self.P, sol[:, self.nvar:].transpose(0, 1))[self.t_obs:]
        print(x_pred.size(), y_pred.size())
        
        x_pred = x_pred.transpose(0, 1)
        y_pred = y_pred.transpose(0, 1)
        out = torch.cat([x_pred, y_pred], dim=1)
        return out

#### Trajectory Data loader

In [20]:
class TrajectoryDataset(Dataset):
    def __init__(self, root_dir, t_obs=8):
        self.root_dir = root_dir
        self.t_obs = t_obs
    
    def __len__(self):
        return len(os.listdir(self.root_dir))
    
    def __getitem__(self, idx):
        file_name = "{}.npy".format(idx)
        file_path = os.path.join(self.root_dir, file_name)
        
        data = np.load(file_path, allow_pickle=True).item()
        x_traj = data['x_traj']
        y_traj = data['y_traj']
        
        x_inp = x_traj[:self.t_obs]
        y_inp = y_traj[:self.t_obs]
        x_fut = x_traj[self.t_obs:]
        y_fut = y_traj[self.t_obs:]

        traj_inp = np.dstack((x_inp, y_inp)).flatten()
        traj_out = np.hstack((x_fut, y_fut)).flatten()
        b_inp = np.array([data['x_init'], data['vx_init'], data['ax_init'], 0, 0, 0, data['y_init'], data['vy_init'], data['ay_init'], 0, 0, 0])
        
        return torch.tensor(traj_inp), torch.tensor(traj_out), torch.tensor(b_inp)

In [21]:
train_dataset = TrajectoryDataset("../datasets/data/", 8)
train_loader = DataLoader(train_dataset, batch_size=20, shuffle=True, num_workers=0)

In [22]:
test_dataset = TrajectoryDataset("../datasets1/data/", 8)
test_loader = DataLoader(test_dataset, batch_size=20, shuffle=True, num_workers=0)

In [23]:
for batch_num, data in enumerate(train_loader):
    traj_inp, traj_out, b_inp = data
    print(traj_inp.size())
    break

torch.Size([20, 16])


In [24]:
traj_inp[:, -1]

tensor([-0.9509,  6.0606, -1.4452,  3.6275, -3.2600, 11.8367,  5.2798, 14.3161,
        -4.7147,  5.9726,  1.6495, 15.0284, -3.4917, -4.2722,  0.8660, -0.8652,
        -6.8108,  1.2075, -0.9170,  3.1887], dtype=torch.float64)

In [25]:
b_inp

tensor([[-1.0155e+01,  1.6676e+00,  1.2164e-01,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  1.3217e+00,  5.4098e-01,  7.9870e-01,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-6.1901e+00,  3.9988e-01,  2.4077e-01,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  1.1493e+00,  1.8195e+00,  6.6076e-01,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-7.5099e+00,  2.4063e+00,  2.6080e-01,  0.0000e+00,  0.0000e+00,
          0.0000e+00, -5.6933e+00,  1.5522e+00,  7.4712e-01,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 1.3217e+01,  2.3134e+00,  5.6554e-01,  0.0000e+00,  0.0000e+00,
          0.0000e+00, -3.8643e+00,  2.4177e+00,  9.4448e-03,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 5.2593e+00,  1.1512e+00,  1.9182e-02,  0.0000e+00,  0.0000e+00,
          0.0000e+00, -1.2664e+01,  2.2688e+00,  3.4699e-01,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 4.0670e+00,  1.2292e+00,  5.6723e-01,  0.0000e+00,  0.0000e+00,
      

#### Model Definition

In [26]:
problem = QPNode(Q_np, q_np, A_eq_np)
qp_layer = DeclarativeLayer(problem)

model = TrajNet(qp_layer, P)
model = model.double()
model = model.to(device)

In [27]:
model(traj_inp, b_inp)

torch.Size([20, 22])
torch.Size([20, 20])
torch.Size([12, 20]) torch.Size([12, 20])


tensor([[-4.3574e+00, -3.6912e+00, -3.0554e+00, -2.4600e+00, -1.9157e+00,
         -1.4313e+00, -1.0101e+00, -6.5229e-01, -3.5986e-01, -1.4431e-01,
         -2.5318e-02, -3.3423e-04,  1.8945e+00,  1.7886e+00,  1.6641e+00,
          1.5329e+00,  1.4023e+00,  1.2749e+00,  1.1506e+00,  1.0320e+00,
          9.2797e-01,  8.5193e-01,  8.1397e-01,  8.0722e-01],
        [-3.4292e+00, -2.9873e+00, -2.5483e+00, -2.1241e+00, -1.7260e+00,
         -1.3625e+00, -1.0377e+00, -7.5411e-01, -5.1721e-01, -3.4088e-01,
         -2.4385e-01, -2.2371e-01,  2.4916e+00,  2.1214e+00,  1.6959e+00,
          1.2469e+00,  7.9796e-01,  3.6164e-01, -5.6453e-02, -4.4723e-01,
         -7.8825e-01, -1.0427e+00, -1.1771e+00, -1.2031e+00],
        [-2.2852e+00, -2.0018e+00, -1.7752e+00, -1.5942e+00, -1.4528e+00,
         -1.3495e+00, -1.2828e+00, -1.2471e+00, -1.2310e+00, -1.2216e+00,
         -1.2136e+00, -1.2108e+00, -1.1027e+00, -6.9019e-01, -3.1212e-01,
          3.2602e-02,  3.4101e-01,  6.0819e-01,  8.3121e-01,  

#### Training

In [16]:
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

In [17]:
epoch_train_loss = []
num_epochs = 50

for epoch in range(num_epochs):
    train_loss = []
    for batch_num, data in enumerate(train_loader):
        traj_inp, traj_out, b_inp = data
        traj_inp = traj_inp.to(device)
        traj_out = traj_out.to(device)
        b_inp = b_inp.to(device)

        out = model(traj_inp, b_inp)
        loss = criterion(out, traj_out)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss.append(loss.item())
        if batch_num % 10 == 0:
            print("Epoch: {}, Batch: {}, Loss: {}".format(epoch, batch_num, loss.item()))
    
    mean_loss = np.mean(train_loss)
    epoch_train_loss.append(mean_loss)
    print("Epoch: {}, Mean Loss: {}".format(epoch, mean_loss))
    print("-"*100)

torch.Size([12, 20]) torch.Size([12, 20])
Epoch: 0, Batch: 0, Loss: 45.91413982231994
torch.Size([12, 20]) torch.Size([12, 20])
torch.Size([12, 20]) torch.Size([12, 20])


RuntimeError: 

#### Testing Code

In [None]:
model.eval()

In [None]:
traj_inp[0].size()

In [None]:
traj_out[0].size()

In [None]:
out[0].size()

In [None]:
traj_inp[0]

In [None]:
traj_inp[0][::2]

In [None]:
traj_inp[0][1::2]

In [None]:
def plot_traj(i, traj_inp, traj_out, traj_pred):
    traj_inp = traj_inp.numpy()
    traj_out = traj_out.numpy()
    traj_pred = traj_pred.numpy()
    
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    ax.scatter(traj_inp[::2], traj_inp[1::2], label='Inp traj')
    ax.scatter(traj_out[:12], traj_out[12:], label='GT')
    ax.scatter(traj_pred[:12], traj_pred[12:], label='Pred')
    ax.legend()
    ax.set_xlim([-20, 20])
    ax.set_ylim([-20, 20])
    plt.savefig('./results/{}.png'.format(i))
    plt.close()

In [None]:
with torch.no_grad():
    cnt = 0
    test_loss = []
    for batch_num, data in enumerate(test_loader):
        traj_inp, traj_out, b_inp = data
        traj_inp = traj_inp.to(device)
        traj_out = traj_out.to(device)
        b_inp = b_inp.to(device)

        out = model(traj_inp, b_inp)
        loss = criterion(out, traj_out)
        
        test_loss.append(loss.item())
        print("Batch: {}, Loss: {}".format(batch_num, loss.item()))
        
        for i in range(traj_inp.size()[0]):
            plot_traj(cnt, traj_inp[i], traj_out[i], out[i])
            cnt += 1

mean_loss = np.mean(test_loss)
print("Epoch Mean Test Loss: {}".format(mean_loss))

In [None]:
test_out = model(traj_inp, b_inp)

In [None]:
test_out.size(), traj_out.size()

In [None]:
test_out[14]

In [None]:
traj_out[14]

In [None]:
torch.mean((test_out[14] - traj_out[14]) ** 2)

In [None]:
criterion(test_out, traj_out)