In [1]:
from PIL import Image
import argparse
import matplotlib.pyplot as plt
import sys
import random
import numpy as np
import os

import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary

import bqplot.scales
import ipyvolume as ipv
import ipywidgets as widgets


torch.cuda.is_available()


True

In [2]:
sys.path.append("/workspace/HKU-OccNet/")
from utils import SemanticKITTIDataset
from utils import visualize_labeled_array3d
from utils import plot_tensor2d

KITTI_DIR = "/workspace/Dataset/dataset"



train_set = SemanticKITTIDataset(root_dir=KITTI_DIR, mode='train', 
                                 sequences=['00'], split_ratio=0.3)
print(len(train_set))


1362


In [3]:
#left_img, right_img, vox_labels = train_set.get_data(666)
# plt.figure(1)
# plot_tensor2d(left_img)
# plt.figure(2)
# plot_tensor2d(right_img)
# plt.show()
#visualize_labeled_array3d(vox_labels.numpy().astype(np.uint16), size = 0.5, marker = 'box')

In [4]:
#left_img.shape

In [5]:
train_dataloader = DataLoader(train_set, batch_size=1, shuffle=True)

In [6]:
num_classes = len(train_set.class_names)
class_weights = train_set.class_weights

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 
device

device(type='cuda', index=0)

In [8]:

img_width, img_height = 1241, 376

In [9]:
from monoscene.monoscene import MonoScene
from utils.monoscene_utils import *

pretrained_weight_path = '/workspace/PretrainedWeights/monoscene_kitti.ckpt'


monoscene_pt = MonoScene.load_from_checkpoint(
        pretrained_weight_path,
        dataset="kitti",
        n_classes=20,
        feature = 64,
        project_scale = 2,
        full_scene_size = (256, 256, 32),
)




class MonoScene(nn.Module):
    def __init__(self, MonoScene_pretrained, calib, img_width, img_height):
        super(MonoScene, self).__init__()

        self.monoscene_pt = MonoScene_pretrained
        self.batch_dict = get_projections(img_width, img_height, calib)
        for key in self.batch_dict:
            self.batch_dict[key] = self.batch_dict[key].unsqueeze(0)
            self.batch_dict[key] = self.batch_dict[key].to(device)
        
    
    # input_tensor 'x' should be batched stereo image tensor with
    # shape: N x 2 x C x H x W, where N is the batch size, 
    # 2 for left and right images, 
    # and C,H,W are the dimensions of RGB images 
    def forward(self, x):
        #print(x.shape)

        self.batch_dict["img"] = x.to(device)

        x = self.monoscene_pt(self.batch_dict)
        return x



print("Monoscene parameter:", sum(p.numel() for p in monoscene_pt.parameters() if p.requires_grad))
for p in monoscene_pt.parameters():
     p.requires_grad = False

Lightning automatically upgraded your loaded checkpoint from v1.1.3 to v1.9.5. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file ../../PretrainedWeights/monoscene_kitti.ckpt`


n_relations 4
Loading base model ()...

Using cache found in /root/.cache/torch/hub/rwightman_gen-efficientnet-pytorch_master


Done.
Removing last two layers (global_pool & classifier).
Building Encoder-Decoder model..Done.
Monoscene parameter: 149555444


In [10]:
calib = read_calib("/workspace/HKU-OccNet/calib.txt")

#get_projections(img_width, img_height, calib)

In [11]:
from utils import load_STTR_model
from utils import NestedTensor, batched_index_select


pretrained_weight_path = '/workspace/PretrainedWeights/kitti_finetuned_model.pth.tar'

sttr_pt = load_STTR_model(pretrained_weight_path)
for param in sttr_pt.parameters():
    param.requires_grad = False


print("Number of parameters (in millions):", sum(p.numel() for p in sttr_pt.parameters()) / 1_000_000, 'M')


class STTR_InputAdapterLayer(nn.Module):
    def __init__(self, downsample=3):
        super(STTR_InputAdapterLayer, self).__init__()
        self.downsample = downsample

    def forward(self, input_tensor):
        input_tensor = input_tensor
        bs, _, _, h, w = input_tensor.shape  # Extract batch size, height, and width

        # Extract left and right images from the input tensor
        left_imgs = input_tensor[:, 0, :, :, :].squeeze(1)
        right_imgs = input_tensor[:, 1, :, :, :].squeeze(1)

        col_offset = int(self.downsample / 2)
        row_offset = int(self.downsample / 2)
        sampled_cols = torch.arange(col_offset, w, self.downsample)[None,].expand(bs, -1).cuda()
        sampled_rows = torch.arange(row_offset, h, self.downsample)[None,].expand(bs, -1).cuda()
        

        # Create NestedTensor for the batch
        nested_tensor = NestedTensor(left_imgs, right_imgs,  
                                    sampled_cols=sampled_cols, sampled_rows=sampled_rows)

        return nested_tensor

class STTR(nn.Module):
    def __init__(self, STTR_pretrained):
        super(STTR, self).__init__()
        self.sttr_adapter_layer = STTR_InputAdapterLayer(downsample=3)
        self.sttr_pt = STTR_pretrained
        
    
    # input_tensor 'x' should be batched stereo image tensor with
    # shape: N x 2 x C x H x W, where N is the batch size, 
    # 2 for left and right images, 
    # and C,H,W are the dimensions of RGB images 
    def forward(self, x): 
        x = self.sttr_adapter_layer(x)
        # bs, _, h, w = x.left.size()
        # feat = self.sttr_pt.backbone(x)
        # tokens = self.sttr_pt.tokenizer(feat)
        # pos_enc = self.sttr_pt.pos_encoder(x)
        # # separate left and right
        # feat_left = tokens[:bs]
        # feat_right = tokens[bs:]  # NxCxHxW
        # # downsample
        # if x.sampled_cols is not None:
        #     feat_left = batched_index_select(feat_left, 3, x.sampled_cols)
        #     feat_right = batched_index_select(feat_right, 3, x.sampled_cols)
        # if x.sampled_rows is not None:
        #     feat_left = batched_index_select(feat_left, 2, x.sampled_rows)
        #     feat_right = batched_index_select(feat_right, 2, x.sampled_rows)
        # attn_weight = self.sttr_pt.transformer(feat_left, feat_right, pos_enc)
        # output = self.sttr_pt.regression_head(attn_weight, x)
        output = self.sttr_pt(x)
        disp_map = output['disp_pred'][0]
        occ_map = output['occ_pred'][0] > 0.5
        disp_map[occ_map] = 0.0
        
        return disp_map




Pre-trained model successfully loaded.
Number of parameters (in millions): 2.513811 M


In [12]:
# import torch
# import torch.nn as nn

# class StereoToVoxelNet(nn.Module):
#     def __init__(self, input_channels=3, H=376, W=1241):
#         super(StereoToVoxelNet, self).__init__()
        
#         self.conv1_stereo = nn.Conv2d(input_channels * 2, 32, kernel_size=3, stride=1, padding=1)
#         self.conv2_stereo = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
#         self.conv3_stereo = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)

#         # MLP layers for depth disparity map
#         self.mlp1 = nn.Linear(H * W, 1024)
#         self.mlp2 = nn.Linear(1024, 2048)

#         self.adaptive_pool_stereo = nn.AdaptiveAvgPool2d((64, 64))

#         # Upsample layer to upscale stereo features to match size before concatenation
#         self.upsample = nn.Upsample(size=(128, 128), mode='bilinear', align_corners=False)

#         # Final layers
#         self.final_conv = nn.Conv2d(128 * 128 + 2048, 128, kernel_size=3, stride=1, padding=1)
#         self.reshape = nn.Unflatten(1, (128, 128, 128, 16))

#     def forward(self, input_tensor, depth_disp):
#         N, _, C, H, W = input_tensor.shape
#         input_tensor = input_tensor.view(N, -1, H, W)

#         # Process stereo images
#         x_stereo = nn.ReLU()(self.conv1_stereo(input_tensor))
#         x_stereo = nn.ReLU()(self.conv2_stereo(x_stereo))
#         x_stereo = nn.ReLU()(self.conv3_stereo(x_stereo))
#         x_stereo = self.adaptive_pool_stereo(x_stereo)

#         # Process depth disparity map
#         #print(depth_disp.shape)
#         depth_flat = depth_disp.view(N, -1)  # Flatten the depth map
#         #print(depth_flat.shape)
#         depth_features = nn.ReLU()(self.mlp1(depth_flat))
#         depth_features = nn.ReLU()(self.mlp2(depth_features))

#         # Combine features from stereo and depth
#         x_combined = torch.cat([x_stereo.flatten(1), depth_features], dim=1)

#         # Final processing
#         print(x_combined.shape) #(2, 526336)
#         x = nn.ReLU()(self.final_conv(x_combined.view(N, -1, 128, 128)))
        
#         x = self.reshape(x)

#         return x


In [13]:
from utils import sem_scal_loss, geo_scal_loss, CE_ssc_loss
from utils import Header


In [14]:
sttr = STTR(sttr_pt)
monoscene = MonoScene(monoscene_pt, calib, img_width, img_height)

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


In [15]:
# img shape: 3, 376, 1241 
# voxel shape: 256, 256, 32



In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ImageDepthEncoder(nn.Module):
    def __init__(self):
        super(ImageDepthEncoder, self).__init__()
        # Image encoder (e.g., a simple CNN)
        self.image_encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        # Depth map encoder (e.g., another CNN)
        self.depth_encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

    def forward(self, image, depth_map):
        image_features = self.image_encoder(image)
        depth_map = depth_map.unsqueeze(1)  # Add channel dimension
        depth_features = self.depth_encoder(depth_map)
        # Concatenate features along channel dimension
        combined_features = torch.cat([image_features, depth_features], dim=1)
        return combined_features

class DepthProjectVoxel(nn.Module):
    def __init__(self):
        super(DepthProjectVoxel, self).__init__()
        self.encoder = ImageDepthEncoder()
        # 3D Convolution layers for voxel projection
        self.conv3d_layers = nn.Sequential(
            nn.Conv3d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv3d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv3d(64, 1, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()  # Assuming occupancy probabilities in [0, 1]
        )

    def forward(self, image, depth_map):
        combined_features = self.encoder(image, depth_map)
        #print(combined_features.shape) #printed torch.Size([1, 256, 94, 310])
        
        # Reshape and expand dimensions to fit 3D convolution
        combined_features = combined_features.view(-1, 256, 1, 47, 155) 
        voxel_probabilities = self.conv3d_layers(combined_features)
        print(voxel_probabilities.shape)
        
        voxel_probabilities = F.interpolate(voxel_probabilities, size=(256, 256, 32), mode='trilinear', align_corners=True)

        voxel_probabilities = voxel_probabilities.squeeze(0).squeeze(0)  
       
        voxel_probabilities = nn.AdaptiveAvgPool3d((256, 256, 32))(voxel_probabilities)

        return voxel_probabilities



In [17]:
class CrossAttentionModule(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(CrossAttentionModule, self).__init__()
        # Define layers for cross-attention
        self.query_conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)
        self.key_conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)
        self.value_conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, query, key, value):
        batch_size = query.shape[0]

        # Transform query, key, value
        query = self.query_conv(query)
        key = self.key_conv(key)
        value = self.value_conv(value)

        # Reshape for matmul
        query = query.view(batch_size, -1, query.shape[2]*query.shape[3]*query.shape[4])
        key = key.view(batch_size, -1, key.shape[2]*key.shape[3]*key.shape[4]).permute(0, 2, 1)
        value = value.view(batch_size, -1, value.shape[2]*value.shape[3]*value.shape[4])

        # Attention mechanism
        attention = torch.bmm(query, key)
        attention = self.softmax(attention)
        out = torch.bmm(value, attention.permute(0, 2, 1))

        # Reshape back to original size
        out = out.view(batch_size, -1, query.shape[2], query.shape[3], query.shape[4])
        return out

class VoxelCrossAttn(nn.Module):
    def __init__(self, num_classes):
        super(VoxelCrossAttn, self).__init__()
        self.monoscene = ...  # Your existing MonoScene model
        self.cross_attention = CrossAttentionModule(num_classes, num_classes)
        self.additional_layers = ...  # Additional layers as needed

    def forward(self, pred_1h, vox_proposal):
        mono_pred_1h = pred_1h

        # Cross-attention
        vox_proposal_expanded = vox_proposal.unsqueeze(0).expand_as(mono_pred_1h)
        attn_output = self.cross_attention(vox_proposal_expanded, mono_pred_1h, mono_pred_1h)

        # Additional layers for processing
        output = self.additional_layers(attn_output)

        return output

In [18]:
import torch
import torch.nn as nn

class STFB_Occ(nn.Module):
    def __init__(self, num_classes, sttr, monoscene):
        super(STFB_Occ, self).__init__()
        self.sttr = sttr
        self.monoscene = monoscene
        
        self.depth_prj_voxel = DepthProjectVoxel()
        

        # Initial convolution layers
        self.conv1 = nn.Conv2d(4, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)

        # FPN Layers
        self.toplayer = nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)  # Reduce channels
        self.latlayer1 = nn.Conv2d(128, 256, kernel_size=1, stride=1, padding=0)
        self.latlayer2 = nn.Conv2d(64, 256, kernel_size=1, stride=1, padding=0)

        self.smooth1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.smooth2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        
        self.channel_reducer = nn.Conv2d(in_channels=768, out_channels=3, kernel_size=1)
        
        self.cross_attn = VoxelCrossAttn(num_classes)

    def _upsample_add(self, x, y):
        _, _, H, W = y.size()
        return nn.functional.interpolate(x, size=(H, W), mode='bilinear', align_corners=False) + y

    def forward(self, x):
        # Existing STTR depth estimation
        x_depth = self.sttr(x)
        x_depth = x_depth.unsqueeze(0)
        

        left_image = x[:, 0, :, :, :].squeeze(1)  # Shape: [N, C, H, W] (N always 1)
        
        print(left_image.shape, x_depth.shape)
        vox_proposal = self.depth_prj_voxel(left_image, x_depth)
        print(vox_proposal.shape)
        
        
        x = torch.cat((left_image, x_depth.unsqueeze(0)), dim=1)  # Shape: [N, C+1, H, W]

        # Convolution layers
        c1 = self.conv1(x)
        c2 = self.conv2(c1)
        c3 = self.conv3(c2)

        # Top-down pathway
        p3 = self.toplayer(c3)
        p2 = self._upsample_add(p3, self.latlayer1(c2))
        p1 = self._upsample_add(p2, self.latlayer2(c1))

        # Smoothing
        p2 = self.smooth1(p2)
        p1 = self.smooth2(p1)

        # Final monoscene processing
        fused_features = torch.cat([p1, p2, p3], dim=1)
        reduced_features = self.channel_reducer(fused_features)
        
        #print(fused_features.shape)
        mono_pred_1h = self.monoscene(reduced_features) 
        # mono_pred_1h has shape [batch_size, num_classes, 255, 255, 32]
        #print(mono_pred_1h.shape) #[1, 20, 256, 256, 32]
        # vox_propasal has shape = [255, 255, 32]
        
        vox_pred_1h = self.cross_attn(mono_pred_1h, vox_proposal)
        #return mono_pred_1h
        return vox_pred_1h


In [19]:
model = STFB_Occ(num_classes=num_classes, 
                 sttr=sttr, 
                 monoscene=monoscene)
print("Number of parameters (in millions):", sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000, 'M')


Number of parameters (in millions): 2.92856 M


In [20]:
#model.sttr()

In [21]:
# for i, (image2, image3, voxel_labels) in tqdm(enumerate(train_dataloader),total = len(train_dataloader)):
#     if i == 0:
#         print(image2.shape)
#         print(image3.shape)
#         print(voxel_labels.shape)
#         inputs = torch.stack((image2, image3), dim=1).to(device)
        
#         print(model.sttr(inputs).shape)
#         break

In [22]:
from tqdm import tqdm
from torch.optim import Adam
import torch.nn.functional as F
model = model.to(device)
optimizer = Adam(model.parameters(), lr=0.01)
num_epochs = 3
best_loss = np.inf
for epoch in range(num_epochs):
    train_loss = 0.0
    for i, (image2, image3, voxel_labels) in tqdm(enumerate(train_dataloader),total = len(train_dataloader)):
        inputs = torch.stack((image2, image3), dim=1).to(device)
        voxel_labels = voxel_labels.to(device)
        
        voxel_pred_1h = model(inputs)
        
        loss = sem_scal_loss(voxel_pred_1h, voxel_labels)
        loss += geo_scal_loss(voxel_pred_1h, voxel_labels)
        class_weights = class_weights.float().to(device)
        loss += CE_ssc_loss(voxel_pred_1h, voxel_labels, class_weights)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        #print(loss.item())  
    valid_loss = 0.0
    with torch.no_grad():
        for (image2, image3, voxel_labels) in tqdm(val_dataloader):
            inputs = torch.stack((image2, image3), dim=1).to(device)
            voxel_labels = voxel_labels.to(device)
            
            voxel_pred_1h = model(inputs)
            
            loss = sem_scal_loss(voxel_pred_1h, voxel_labels)
            loss += geo_scal_loss(voxel_pred_1h, voxel_labels)
            class_weights = class_weights.float().to(device)
            loss += CE_ssc_loss(voxel_pred_1h, voxel_labels, class_weights)
            
            valid_loss += loss.item()
            
    print(f'Epoch {epoch + 1}: Training loss: {train_loss / len(train_dataloader)}, Validation loss: {valid_loss / len(val_dataloader)}')
    if (train_loss / len(train_dataloader)) < best_loss:
        torch.save(model.state_dict(), 'STFBOcc.pth')
        best_loss = (train_loss / len(train_dataloader))

  dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)


torch.Size([1, 3, 376, 1241]) torch.Size([1, 376, 1241])
torch.Size([4, 1, 1, 47, 155])
torch.Size([4, 1, 256, 256, 32])


  projected_pix // scale_2d,
  projected_pix // scale_2d,
  0%|          | 0/1362 [00:02<?, ?it/s]
