In [1]:
from sray.utils.imgs_info import build_imgs_info, imgs_info_to_torch
from sray.dataset.database import ScannetDatabase
from sray.utils.imgs_info import build_imgs_info
import numpy as np
from mmengine.config import Config
from mmseg.models import build_segmentor
from sray.network.tpvformer10 import *
from sray.network.tpvformer10.tpv_head import TPVFormerHead,CustomPositionalEncoding
import torch

In [2]:
dataset_name = 'scannet/scene0188_00/black_320'
dataset = ScannetDatabase(dataset_name)

In [3]:
ref_img_info = build_imgs_info(dataset,[10,20,30,40,50,60,70,80])

In [4]:
list(ref_img_info.keys())

['imgs_mmseg',
 'seg_logits',
 'pred_sem_seg',
 'mlvl_feats',
 'imgs',
 'poses',
 'Ks',
 'depth_range',
 'masks',
 'labels',
 'depth']

In [5]:
ref_img_info['poses'].shape

(8, 3, 4)

In [6]:
def build_img_metas(ref_img_info,img_H = 280,img_W = 320):

    img_metas=[]
    d = {
        'img_shape' : [[img_H, img_W]],
    }
    img_metas = []
    for pose,k in zip(ref_img_info['poses'],ref_img_info['Ks']):
        lidar2cam_rt = np.eye(4)
        lidar2cam_rt[:3, :4] = pose[:3,:4]
        intrinsic = np.eye(4)
        intrinsic[:k.shape[0], :k.shape[1]] = k
        lidar2img = intrinsic  @ lidar2cam_rt
        ret = d.copy()
        ret['lidar2img'] = lidar2img
        img_metas.append(ret)
    return img_metas

In [7]:
img_metas = build_img_metas(ref_img_info)
ref_img_info['img_metas']  =img_metas

In [8]:
len(ref_img_info['img_metas'])

8

In [9]:
ref_img_info = imgs_info_to_torch(ref_img_info)

In [10]:
def model_builder(model_config):
    model = build_segmentor(model_config)
    model.init_weights()
    return model

In [11]:
cfg = Config.fromfile('sray/network/tpvformer10/tpv_config.py')

In [12]:
cfg.model

{'type': 'TPVFormer',
 'use_grid_mask': True,
 'tpv_aggregator': {'type': 'TPVAggregator',
  'tpv_h': 160,
  'tpv_w': 160,
  'tpv_z': 64,
  'nbr_classes': 17,
  'in_dims': 64,
  'hidden_dims': 128,
  'out_dims': 64,
  'scale_h': 1,
  'scale_w': 1,
  'scale_z': 1},
 'img_backbone': {'type': 'ResNet',
  'depth': 101,
  'num_stages': 4,
  'out_indices': (1, 2, 3),
  'frozen_stages': 1,
  'norm_cfg': {'type': 'BN2d', 'requires_grad': False},
  'norm_eval': True,
  'style': 'caffe',
  'dcn': {'type': 'DCNv2', 'deform_groups': 1, 'fallback_on_stride': False},
  'stage_with_dcn': (False, False, True, True)},
 'img_neck': {'type': 'FPN',
  'in_channels': [512, 1024, 2048],
  'out_channels': 64,
  'start_level': 0,
  'add_extra_convs': 'on_output',
  'num_outs': 4,
  'relu_before_extra_convs': True},
 'tpv_head': {'type': 'TPVFormerHead',
  'tpv_h': 160,
  'tpv_w': 160,
  'tpv_z': 64,
  'pc_range': [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0],
  'num_feature_levels': 4,
  'num_cams': 8,
  'embed_dims'

In [13]:
my_model = model_builder(cfg.model).cuda()

2023-10-23 04:58:40,289 - mmcv - INFO - 
tpv_head.level_embeds - torch.Size([4, 64]): 
Initialized by user-defined `init_weights` in TPVFormerHead  
 
2023-10-23 04:58:40,290 - mmcv - INFO - 
tpv_head.cams_embeds - torch.Size([8, 64]): 
Initialized by user-defined `init_weights` in TPVFormerHead  
 
2023-10-23 04:58:40,290 - mmcv - INFO - 
tpv_head.positional_encoding.h_embed.weight - torch.Size([160, 24]): 
Initialized by user-defined `init_weights` in TPVFormerHead  
 
2023-10-23 04:58:40,291 - mmcv - INFO - 
tpv_head.positional_encoding.w_embed.weight - torch.Size([160, 24]): 
Initialized by user-defined `init_weights` in TPVFormerHead  
 
2023-10-23 04:58:40,291 - mmcv - INFO - 
tpv_head.positional_encoding.z_embed.weight - torch.Size([64, 16]): 
Initialized by user-defined `init_weights` in TPVFormerHead  
 
2023-10-23 04:58:40,291 - mmcv - INFO - 
tpv_head.encoder.layers.0.attentions.0.output_proj.0.weight - torch.Size([64, 64]): 
Initialized by user-defined `init_weights` in TPV

In [14]:
tpv_hw,tpv_zh,tpv_wz = my_model(img_metas=ref_img_info['img_metas'],img=ref_img_info['imgs'][None])

In [16]:
h = 160
w = 160
z = 64
tpv_hw.permute(0,2,1).reshape((1,-1,h,w))
tpv_zh.permute(0,2,1).reshape((1,-1,z,h))
tpv_wz.permute(0,2,1).reshape((1,-1,w,z))


tensor([[[[-1.1519, -1.4062, -1.8597,  ..., -0.0708,  0.1429,  0.4717],
          [-2.0356, -1.9693, -1.1195,  ...,  0.5641, -0.3887,  0.9794],
          [-2.7295, -0.7180, -1.6982,  ..., -0.0984, -0.4568,  0.3929],
          ...,
          [ 0.7365,  0.1642,  0.7138,  ...,  0.2032,  2.1435,  1.3940],
          [ 0.6291,  1.1392, -0.3433,  ..., -1.0481, -0.6113,  0.8371],
          [-0.7565, -1.8308, -0.7792,  ..., -0.3452,  0.1050,  0.3485]],

         [[ 1.5901,  1.1290,  1.0374,  ..., -0.5502, -0.9338, -1.0303],
          [ 0.6859,  1.6061,  0.4410,  ...,  0.6342, -1.4760, -2.1971],
          [ 0.5807,  1.3325,  1.1537,  ..., -0.7519, -2.4986, -2.1104],
          ...,
          [ 1.1667,  0.4242, -0.0357,  ..., -0.9584,  0.2288,  0.1296],
          [ 0.3384, -0.1551,  0.3613,  ..., -0.6977,  0.1690, -0.9875],
          [-0.4996, -0.2300, -0.1627,  ..., -1.0490, -0.3526, -0.6052]],

         [[ 0.1712, -0.6607, -0.0332,  ..., -1.1934, -0.9618, -0.7588],
          [-0.3532, -1.0614, -

In [None]:
coord = torch.tensor([[1,2,3]])

In [None]:
coord[...,[0,1]],coord[...,[1,2]],coord[...,[2,0]]