In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import argparse
from dataloader import RGBDepthPano

from image_encoders import RGBEncoder, DepthEncoder
from TRM_net import BinaryDistPredictor_TRM, TRM_predict

from eval import waypoint_eval

import os
import glob
import utils
import random
from utils import nms
from utils import print_progress
from tensorboardX import SummaryWriter

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
class Args:
    def __init__(self):
        self.EXP_ID = 'test_ipynb'
        self.TRAINEVAL = 'train'
        self.VIS = 0
        self.ANGLES = 120
        self.NUM_IMGS = 12
        self.NUM_CLASSES = 12
        self.MAX_NUM_CANDIDATES = 5
        self.PREDICTOR_NET = 'TRM'
        self.EPOCH = 300
        self.BATCH_SIZE = 8
        self.LEARNING_RATE = 1e-6
        self.WEIGHT = 0
        self.TRM_LAYER = 2
        self.TRM_NEIGHBOR = 1
        self.HEATMAP_OFFSET = 5
        self.HIDDEN_DIM = 768

args = Args()

In [None]:
def setup(args):
    """
    Set random seeds and create experiment directories
    Args:
        args: Command line arguments
    """
    torch.manual_seed(0)  # Set PyTorch random seed
    random.seed(0)  # Set Python random seed
    exp_log_path = './checkpoints/%s/'%(args.EXP_ID)  # Experiment log path
    os.makedirs(exp_log_path, exist_ok=True)  # Create experiment directory
    exp_log_path = './checkpoints/%s/snap/'%(args.EXP_ID)  # Model snapshot path
    os.makedirs(exp_log_path, exist_ok=True)  # Create model snapshot directory

## RGBEncoder 
It uses Torchvision pre-trained Resnet50. It removes last two layers (fully connected and pooling layers), keep only feature extraction part. 
* Input size: [batchsize, num_imgs, 3, 224, 224]
* Output size: [batchsize*num_imgs, 2048, 7, 7]

In [None]:
rgb_encoder = RGBEncoder(resnet_pretrain=True, trainable=False).to(device)

## DepthEncoder
It is based on habitat_baselines.rl.ddppo.policy resnet, specificly resnet50. It first goes through the backbone net where you can specific the output channels after first conv, parameters of Group Norm and input image size. Then it goes through a compression layer to change the number of channels so that the output flatten size is what you want, which is after_compression_flat_size.
* Input size: [batchsize, num_imgs, 256, 256, 1]
* Output size: [batchsize*num_imgs, 128, 4, 4]

In [None]:
depth_encoder = DepthEncoder(resnet_pretrain=True, trainable=False).to(device)

## Transformer Architescture
### VisPosEmbeddings
Input size: [batch_size, num_images, hidden_size]  
embeddings = vis_embeddings + position_embeddings  
Output size: [batch_size, num_images, hidden_size]  
### CaptionBertAttention
Attention layer in each block.  
Input size: [batch_size, num_images, hidden_size]  
Attention mask: [1, 1, num_images, num_images]  
#### utils.get_attention_mask
It receives number of neighbors allowed to attend to on each side specified in args.TRM_NEIGHBOR. And return a [1, 1, num_images, num_images] mask tensor, 1 indicates allowed attention, 0 indicates forbidden attention. 

In CaptionBertAttention, it first goes through CaptionBertSelfAttention, a self-attention layer which adds attention mask to the attention score matrix and outputs the context embeddings. Then it goes through a feedforward net which enables residue connection. 

Output size: [batch_size, num_images, hidden_size]

### CaptionBertEncoder
It consists of mutiple blocks of CaptionBertAttention specified in args.TRM_LAYER.

### BertImgModel
It simply calls CaptionBertEncoder.

### WaypointBert
It consists of a BertImgModel layer and a dropout layer.



## BinaryDistPredictor_TRM
### Compression Layer
* Receives input from RGBEncoder and DepthEncoder [batchsize * num_imgs, 2048, 7, 7], [batchsize * num_imgs, 128, 4, 4]
* Flatten the input to 2048x7x7 and 128x4x4 respectively and compression to hiddem dim.
* Reshape to (batchsize, num_imgs, hidden_dim)
### Merge Layer
* Concatenate the two outputs to (batchsize, num_imgs, 2*hidden_dim) and then linear project to (batchsize, num_imgs, hidden_dim) with ReLU.
### Transformer Layer
* Get attention mask specified in args.TRM_NEIGHBOR
* Go through WaypointBert layer.
### Classifier Layer
* For each image in each sample in each batch, after Transformer Layer, its size is hidden_dim.
* Linear layer from hidden_dim to hidden_dim with ReLU.
* Linear layer hidden_dim to n_classes*(num_angles/num_imgs). For example, n_classes=12 is the distance index from current node to max radius. 0.25 to 3.0 is 12 * 0.25. num_angles=120 represents dividing 360 degrees to 120 * 3 degrees. num_imgs=12 is the number of images in the panorama. In this case, each image corresonds to num_angles/num_imgs=10 sectors of 3 degree. So this is a 10(3 degree) * 12(0.25m) heatmap centered at this image. In each sample, there are 12 images. So in total it's a 120(3 degree) * 12(0.25m) heatmap, where each image is responsible for its 10(3 degree) * 12(0.25m) local heatmap.
* Output size is (batchsize, num_imgs, n_classes*(num_angles/num_imgs)). After reshape, it becomes (batchsize, num_angles, n_classes) which is the 120(3 degree) * 12(0.25m) heatmap. Each point in the heatmap correspond to a independent vis_logit which captures the probability of being a watpoint.



In [6]:
predictor = BinaryDistPredictor_TRM(args=args, hidden_dim=args.HIDDEN_DIM, n_classes=args.NUM_CLASSES).to(device)

## Ground Truth Dict
* navigability_dict[scan_id][node] contains information about a node in this scene(scan_id)
* target: A target map 120 * 12 . For the neighbors of this node, if its distance from the node <3.25m and >0.25m, its location in the target map is filled by 1. In the same angle, if there is more than one effective neighbor waypoints, retain the furthest waypoint so that it guarantees in each angle there is at most one waypoint. If there is no effective waypoints in each angle, delete this node. Then it uses gaussian filter to smooth the target map. It sets those locations in obstacles to be 0 according to the obstacle map. If all the values in target map are lower than some threshold, delete this node.
* obstacle: A obstacle map 120 * 12. 1 indicates obstacles and 0 indicates open spaces.
* weight: A weight map 120 * 12. Not used.
* source_pos: The pose of the current node in the simulator.
* target_pose: The poses of the effective neigbor waypoints in the simulator.

In [None]:
''' Load navigability data (ground truth waypoints, obstacles, and weights) '''
nav_dict_path = './training_data/%s_*_mp3d_waypoint_twm0.2_obstacle_first_withpos.json'%(args.ANGLES)
navigability_dict = utils.load_gt_navigability(nav_dict_path)

# Randomly select a scan_id and node to print
print('navigability_dict.keys(): ', list(navigability_dict.keys()))
scan_id = list(navigability_dict.keys())[0]
print('navigability_dict[scan_id].keys(): ', list(navigability_dict[scan_id].keys()))
node = list(navigability_dict[scan_id].keys())[0]
print('navigability_dict[scan_id][node].keys(): ', navigability_dict[scan_id][node].keys())



## Generate training images
In ./training_data/rgbd_fov90/{split}/{scan}/{scan}_{node}_mp3d_imgs.pkl   
It has 12 rgb and depth images for each node in each scan(scene)

In [9]:
''' Create data loaders for RGB and depth images '''
train_img_dir = './training_data/rgbd_fov90/train/*/*.pkl'  # Training image directory
traindataloader = RGBDepthPano(args, train_img_dir, navigability_dict)  # Training data loader
eval_img_dir = './training_data/rgbd_fov90/val_unseen/*/*.pkl'  # Evaluation image directory
evaldataloader = RGBDepthPano(args, eval_img_dir, navigability_dict)  # Evaluation data loader
trainloader = torch.utils.data.DataLoader(traindataloader, 
        batch_size=args.BATCH_SIZE, shuffle=True, num_workers=4)  # Training batch data loader
evalloader = torch.utils.data.DataLoader(evaldataloader, 
        batch_size=args.BATCH_SIZE, shuffle=False, num_workers=4) 


In [None]:

# Get a single batch from the dataloader
dataiter = iter(trainloader)
batch = next(dataiter)

# Select the first sample from the batch
sample = {k: v[0] if isinstance(v, torch.Tensor) else v[0] for k, v in batch.items()}
# Print sample information
print(f"Sample ID: {sample['sample_id']}")
print(f"Scan ID: {sample['scan_id']}")
print(f"Waypoint ID: {sample['waypoint_id']}")
print('Number of RGB images: ', len(sample['rgb']))
print(f"RGB shape: {sample['rgb'][0].shape}")
print(f"Depth shape: {sample['depth'][0].shape}")


In [12]:
criterion_mse = torch.nn.MSELoss(reduction='none')
params = list(predictor.parameters())
optimizer = torch.optim.AdamW(params, lr=args.LEARNING_RATE)

In [None]:
for i, batch in enumerate(trainloader):
    data = batch
    break

scan_ids = data['scan_id']  # Scene IDs
waypoint_ids = data['waypoint_id']  # Waypoint IDs
rgb_imgs = data['rgb'].to(device)  # RGB images
depth_imgs = data['depth'].to(device)  # Depth images

print('input rgb shape: ', rgb_imgs.shape) #[B, N, C, H, W]
print('input depth shape: ', depth_imgs.shape) #[B, N, H, W, 1]

rgb_feats = rgb_encoder(rgb_imgs)      
depth_feats = depth_encoder(depth_imgs)  

print('output rgb_feats shape: ', rgb_feats.shape) #[B*N, 2048]
print('output depth_feats shape: ', depth_feats.shape) #[B*N, 128, 4, 4]
    

In [None]:
target, obstacle, weight, _, _ = utils.get_gt_nav_map(
                    args.ANGLES, navigability_dict, scan_ids, waypoint_ids)
target = target.to(device)
obstacle = obstacle.to(device)
weight = weight.to(device)
print('target shape: ', target.shape) #(B, angles, 12)
print('obstacle shape: ', obstacle.shape) #(B, angles, 12)
print('weight shape: ', weight.shape) #(B, angles, 12)



## Use MSE between predicted heat map and target map to define loss

In [None]:
vis_logits = TRM_predict('train', args, predictor, rgb_feats, depth_feats)
loss_vis = criterion_mse(vis_logits, target)