In [10]:
from habitat.sims.habitat_simulator.actions import HabitatSimActions

from habitat_dataset import get_dataset, HabitatDataset

from pathlib import Path
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
import pandas as pd
import cv2

ACTIONS = {v: k for k, v in HabitatSimActions._known_actions.items()}

transform_ = transforms.ToPILImage()
dataset = HabitatDataset('data/depth/train/0002')

i = 20
rgb, _, _, action, meta = dataset[i]

In [11]:
import torch.nn as nn
from resnet import ResnetBase

def spatial_softmax_base():
    return nn.Sequential(
            nn.BatchNorm2d(512),
            nn.ConvTranspose2d(512, 256, 3, 2, 1, 1),
            nn.ReLU(True),
            nn.BatchNorm2d(256),
            nn.ConvTranspose2d(256, 128, 3, 2, 1, 1),
            nn.ReLU(True),
            nn.BatchNorm2d(128),
            nn.ConvTranspose2d(128, 64, 3, 2, 1, 1),
            nn.ReLU(True)
    )

# Direct Imitation `[v2.x]`
Should at least work for not crashing/getting stuck... Might need to do DAgger or something to capture non-perfect situations

In [12]:
class DirectImitation(ResnetBase): # v2.x
    def __init__(self, resnet_model='resnet34', **resnet_kwargs):
        resnet_kwargs['input_channel'] = resnet_kwargs.get('input_channel', 3)

        super().__init__(resnet_model, **resnet_kwargs)

        self.normalize = nn.BatchNorm2d(resnet_kwargs['input_channel'])
        self.deconv = spatial_softmax_base()
        self.extract = nn.Sequential(
                nn.BatchNorm2d(64),
                nn.Conv2d(64, 5, 1, 1, 0))#,
                #common.SpatialSoftmax(temperature))
      
        self.fc1 = nn.Linear(64, 1)
        self.fc2 = nn.Linear(64, 1)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        rgb = x[0]

        rgb = self.normalize(rgb)
        rgb = self.conv(rgb)
        rgb = self.deconv(rgb)

        return self.softmax(self.fc2(self.fc1(self.extract(rgb)).squeeze()).squeeze())

In [18]:
net = DirectImitation()

net((rgb.unsqueeze(dim=0),)).detach()

tensor([0.1701, 0.1881, 0.2680, 0.2013, 0.1725])

# Conditional Imitation `[v3.x]`

In [None]:
class ConditionalImitation(ResnetBase): # 3.x
    def __init__(self, resnet_model='resnet34', **resnet_kwargs):
        resnet_kwargs['input_channel'] = resnet_kwargs.get('input_channel', 3)

        super().__init__(resnet_model, **resnet_kwargs)

        self.normalize = nn.BatchNorm2d(resnet_kwargs['input_channel'])
        self.deconv = spatial_softmax_base()
        self.extract = nn.Sequential(
                nn.BatchNorm2d(64),
                nn.Conv2d(64, 5, 1, 1, 0))

        self.fc1 = nn.Linear(64, 1)
        self.fc2 = nn.Linear(64, 2)
        self.fc3 = nn.Linear(10, 5)

        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        rgb, meta = x

        rgb = self.normalize(rgb)
        rgb = self.conv(rgb)
        rgb = self.deconv(rgb)
        rgb = self.extract(rgb)
        rgb = self.fc1(rgb).squeeze()
        rgb = self.fc2(rgb)
        x = self.fc3(rgb.view((-1, 10)) + meta)

        return self.softmax(x)

In [67]:
net = ConditionalImitation()

net(rgb.unsqueeze(dim=0), meta).detach()

tensor([[7.0419e-02, 5.4644e-04, 2.0197e-02, 2.8781e-02, 8.8006e-01]])

---

In [10]:
for rgb, _, _, action, _ in dataset:
    img = np.array(transform_(rgb))

    cv2.putText(img, 'Predicted: {}'.format(ACTIONS[(net(rgb.unsqueeze(dim=0)).detach()).argmax().item()]),
        (10, 20),
        cv2.FONT_HERSHEY_SIMPLEX,
        0.5,
        (255,255,255),
        2)

    cv2.putText(img, 'Actual:    {}'.format(ACTIONS[action.argmax().item()]),
        (10, 40),
        cv2.FONT_HERSHEY_SIMPLEX,
        0.5,
        (255,255,255),
        2)

    cv2.imshow('rgb', img)
    cv2.waitKey(1)