<a href="https://colab.research.google.com/github/verityw/manipulation-final-project/blob/main/PoseInterpreterNetwork.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torchvision.models
import collections
import math
import torch.nn.functional as F


from google.colab import drive
drive.mount('/content/gdrive')

SAMPLES = -1
BATCH_SIZE = 10
EPOCHS = 5
IMG_SHAPE = (480, 640, 4)
POSE_SHAPE = (6,)
VAL_PROPORTION = .1
device = "cuda"

Mounted at /content/gdrive


In [1]:
#cd to directory containing this notebook

In [None]:
DATA_DIRECTORY = os.path.join(os.getcwd(), "data")
X_DIR, Y_DIR = os.path.join(DATA_DIRECTORY, "x"), os.path.join(DATA_DIRECTORY, "y")
XVAL_DIR, YVAL_DIR = os.path.join(DATA_DIRECTORY, "xval"), os.path.join(DATA_DIRECTORY, "yval")
CHECKPOINTS = os.path.join(os.getcwd(), "checkpoints")

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, x_path, y_path, num_items):
        self.x_path = x_path
        self.y_path = y_path
        self.indices = np.arange(num_items, dtype=int)
    
    def __getitem__(self, index):
        ind = self.indices[index]
        img = np.load(os.path.join(self.x_path, str(index) + "img.npy"))
        img = torch.tensor(np.transpose(img, (2, 0, 1))) / 255.
        pose = torch.tensor(np.load(os.path.join(os.path.join(self.y_path, str(index) + "pose.npy"))))
        pos, rot = pose[:3], pose[3:]
        rot = rot * np.sign(rot[0]) # Make sure qw > 0
        return (img, [pos, rot]) # Return 3 x H x W image, (x, y, z), and (qw, qx, qy, qz)

    def __len__(self):
        return len(self.indices)

In [None]:
def conv(in_channels, out_channels, kernel_size):
    padding = (kernel_size-1) // 2
    assert 2*padding == kernel_size-1, "parameters incorrect. kernel={}, padding={}".format(kernel_size, padding)
    return nn.Sequential(
          nn.Conv2d(in_channels,out_channels,kernel_size,stride=1,padding=padding,bias=False),
          nn.BatchNorm2d(out_channels),
          nn.ReLU(inplace=True),
        )

def fc(in_channels, out_channels):
    return nn.Sequential(
        nn.Linear(in_channels, out_channels),
    )

class normalize(nn.Module):
    def __init__(self):
        super(normalize, self).__init__()
    def forward(self, x):
        norms = x.norm(dim=1, keepdim=True)
        x = x.div(norms)
        return x


class PoseEstimatorNet(nn.Module):
    def __init__(self):
        super(PoseEstimatorNet, self).__init__()
        self.resnet = torchvision.models.resnet18()
        self.initial = conv(4, 3, 1) # in_channels = 4 for RGBAlpha
        self.fc1 = fc(1000, 256)
        self.fc2_position = fc(256, 3)
        self.fc2_rotation = fc(256, 4)
        self.normalize = normalize()
        self.relu = torch.nn.ReLU() 
    def forward(self, x):
        # Convolve to be acceptable size for resnet
        x = self.initial(x)
        # ResNet18
        x = self.resnet(x)
        # Multi-layer perceptron
        x = self.fc1(x)
        x = self.relu(x)
        x_pos = self.fc2_position(x)
        x_rot = self.fc2_rotation(x)
        x_rot = self.normalize(x_rot)

        return x_pos, x_rot

In [None]:
# Losses

# L1 loss
class L1Loss(nn.Module):
    def __init__(self, alpha = 1):
        super(L1Loss, self).__init__()
        self.alpha = alpha

    def forward(self, y_pred, y_target):
        return torch.mean(torch.abs(y_pred[0] - y_target[0])) + self.alpha * torch.mean(torch.abs(y_pred[1] - y_target[1])) 

# PoseCNN loss
class PoseCNNLoss(nn.Module):
    def __init__(self, alpha = 1):
        super(PoseCNNLoss, self).__init__()
        self.alpha = alpha

    def forward(self, y_pred, y_target):
        position_loss = torch.abs(y_target[0] - y_pred[0]).mean()
        orientation_loss = (1 - (y_pred[1] * y_target[1]).sum(dim=1).pow(2)).mean()
        return position_loss + self.alpha * orientation_loss + torch.maximum(torch.tensor([0], device=device), -y_pred[0]).mean()

In [None]:
# Train the network
def iterate(mode, loader, model, optimizer, criterion, epoch):
    print("Starting epoch: " + str(epoch))
    if mode == "train":
        model.train()
    else:
        model.eval()
    running_loss = 0
    for i, [img, label] in enumerate(loader):
        img = img.to(device)
        pos, rot = label[0], label[1]
        pos, rot = pos.to(device), rot.to(device)

        optimizer.zero_grad()
        
        # Forward pass
        pos_pred, rot_pred = model(img)
        
        # Backpropagation
        loss = criterion([pos_pred, rot_pred], [pos, rot])
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss
        if i % 100 == 0:
            print("Batch: " + str(i+1))
            print("Batch Loss: " + str(loss))
            print("Average Loss:" + str(running_loss / (i + 1)))
    print("Finished epoch.")
    return running_loss / len(loader)

torch.autograd.set_detect_anomaly(True)
print("Initializing model")
model = PoseEstimatorNet()
model.to(device)
print("Initializing optimizer")
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Initializing loss criterion")
criterion = PoseCNNLoss(alpha=.1)
print("Initializing dataloaders")
train_loader = torch.utils.data.DataLoader(Dataset(X_DIR, Y_DIR, 10000), batch_size = 10)
#val_loader = torch.utils.data.DataLoader(Dataset(XVAL_DIR, YVAL_DIR, 2000), batch_size = 10)

for epoch in range(EPOCHS):
    loss = iterate("train", train_loader, model, optimizer, criterion, epoch)
    #val_loss = iterate("val", train_loader, model, optimizer, criterion, epoch)
    torch.save(model.state_dict(), os.path.join(CHECKPOINTS, str(loss) + ""))

In [None]:
Model = PoseEstimatorNet()
Model.load_state_dict(torch.load(os.path.join(CHECKPOINTS, "tensor(0.0926, device='cuda:0', dtype=torch.float64, grad_fn=<DivBackward0>)"), map_location=torch.device('cpu')))

<All keys matched successfully>

In [None]:
val_loader = torch.utils.data.DataLoader(Dataset(XVAL_DIR, YVAL_DIR, 2000), batch_size = 1)

In [None]:
img, label = next(iter(val_loader))

In [None]:
pos, rot = Model(img)

In [None]:
print(pos)
print(rot)

tensor([[0.0377, 0.0474, 0.1031]], grad_fn=<AddmmBackward>)
tensor([[-0.6211, -0.6342, -0.4091, -0.2114]], grad_fn=<DivBackward0>)


In [None]:
print(label)

[tensor([[ 0.0486, -0.0219,  0.0468]], dtype=torch.float64), tensor([[ 0.6148, -0.0229, -0.7881, -0.0180]], dtype=torch.float64)]
