In [1]:
import numpy as np
import glob
import scipy.io as sio
import torch
from torch import nn
import csv
import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import yaml
from evaluate import error
import time
import re
import random
import sys

from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce

from torchvision.transforms import Resize
from syn_DI_dataset import make_dataset, make_dataloader

from RGB_benchmark.rgb_ResNet18.RGB_ResNet import *
from depth_benchmark.depth_ResNet18 import *
from mmwave_benchmark.mmwave_point_transformer import *
from lidar_benchmark.lidar_point_transformer import *

In [2]:
class rgb_feature_extractor(nn.Module):
    def __init__(self, rgb_model):
        super(rgb_feature_extractor, self).__init__()
        self.part = nn.Sequential(*list(rgb_model.children())[:-2])
    def forward(self, x):
        x = self.part(x).view(x.size(0), 512, -1)
        x = x.permute(0, 2, 1)
        return x

In [3]:
class depth_feature_extractor(nn.Module):
    def __init__(self, depth_model):
        super(depth_feature_extractor, self).__init__()
        self.part = nn.Sequential(*list(depth_model.children())[:-2])
    def forward(self, x):
        x = self.part(x).view(x.size(0), 512, -1)
        x = x.permute(0, 2, 1)
        return x

In [4]:
class mmwave_feature_extractor(nn.Module):
    def __init__(self, mmwave_model):
        super(mmwave_feature_extractor, self).__init__()
        self.part = nn.Sequential(*list(mmwave_model.children())[:-1])
    def forward(self, x):
        x, _ = self.part(x)
        return x

In [5]:
class lidar_feature_extractor(nn.Module):
    def __init__(self, lidar_model):
        super(lidar_feature_extractor, self).__init__()
        # self.model = lidar_model
        npoints, nblocks, nneighbor, n_c, d_points = 1024, 5, 16, 51, 3
        self.fc1 = lidar_model.backbone.fc1
        self.transformer1 = lidar_model.backbone.transformer1
        self.transition_downs = nn.ModuleList()
        self.transformers = nn.ModuleList()
        for i in range(nblocks - 4):
            channel = 32 * 2 ** (i + 1)
            self.transition_downs.append(lidar_model.backbone.transition_downs[i])
            self.transformers.append(lidar_model.backbone.transformers[i])
        self.nblocks = nblocks
    
    def forward(self, x):
        xyz = x[..., :3]
        points = self.transformer1(xyz, self.fc1(x))[0]

        xyz_and_feats = [(xyz, points)]
        for i in range(self.nblocks - 4):
            xyz, points = self.transition_downs[i](xyz, points)
            points = self.transformers[i](xyz, points)[0]
            xyz_and_feats.append((xyz, points))
        points = points.view(points.size(0), -1, 512)
        return points

In [6]:
class csi_feature_extractor(nn.Module):
    def __init__(self, model):
        super(csi_feature_extractor, self).__init__()
        self.part = nn.Sequential(
            model.encoder_conv1,
            model.encoder_bn1,
            model.encoder_relu,
            model.encoder_layer1,
            model.encoder_layer2,
            model.encoder_layer3,
            model.encoder_layer4, 
            # torch.nn.AvgPool2d((1, 4))
        )
    def forward(self, x):
        x = x.unsqueeze(1)
        x = torch.transpose(x, 2, 3) #16,2,114,3,32
        x = torch.flatten(x, 3, 4)# 16,2,114,96
        torch_resize = Resize([136,32])
        x = torch_resize(x)
        x = self.part(x).view(x.size(0), 512, -1)
        x = x.permute(0, 2, 1)
        return x

In [7]:
class feature_extrator(nn.Module):
    def __init__(self):
        super(feature_extrator, self).__init__()
        
        rgb_model = RGB_ResNet18()
        rgb_model.load_state_dict(torch.load('./RGB_benchmark/rgb_ResNet18/RGB_Resnet18_copy.pt'))
        rgb_extractor = rgb_feature_extractor(rgb_model)
        rgb_extractor.eval()

        depth_model = Depth_ResNet18()
        depth_model.load_state_dict(torch.load('depth_benchmark/depth_Resnet18.pt'))
        depth_extractor = depth_feature_extractor(depth_model)
        depth_extractor.eval()
        
        mmwave_model = mmwave_PointTransformerReg()
        mmwave_model.load_state_dict(torch.load('mmwave_benchmark/mmwave_all_random.pt'))
        mmwave_extractor = mmwave_feature_extractor(mmwave_model)
        mmwave_extractor.eval()

        lidar_model = lidar_PointTransformerReg()
        lidar_model.load_state_dict(torch.load('lidar_benchmark/lidar_all_random.pt'))
        lidar_extractor = lidar_feature_extractor(lidar_model)
        lidar_extractor.eval()

        sys.path.insert(0, './CSI_benchmark')
        csi_model = torch.load('CSI_benchmark/protocol3_random_1.pkl')
        csi_extractor = csi_feature_extractor(csi_model)
        csi_extractor.eval()

        self.rgb_extractor = rgb_extractor
        self.depth_extractor = depth_extractor
        self.mmwave_extractor = mmwave_extractor
        self.lidar_extractor = lidar_extractor
        self.csi_extractor = csi_extractor
        
    def forward(self, rgb_data, depth_data, mmwave_data, lidar_data, csi_data):
        rgb_feature = self.rgb_extractor(rgb_data)
        depth_feature = self.depth_extractor(depth_data)
        mmwave_feature = self.mmwave_extractor(mmwave_data)
        lidar_feature = self.lidar_extractor(lidar_data)
        csi_feature = self.csi_extractor(csi_data)
        return rgb_feature, depth_feature, mmwave_feature, lidar_feature, csi_feature

In [14]:
class linear_projector(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(linear_projector, self).__init__()
        self.linear_projection = nn.Linear(input_dim, output_dim)
    def forward(self, rgb_feature, depth_feature, mmwave_feature, lidar_feature, csi_feature, modality_list):
        feature_list = [rgb_feature, depth_feature, mmwave_feature, lidar_feature, csi_feature]
        if sum (modality_list) == 0:
            print('WARNING: modality_list is empty!')
            feature = torch.zeros_like(lidar_feature, device=torch.device('cuda'))
            feature = self.linear_projection(feature)
        else:
            real_feature_list = []
            for i in range(len(modality_list)):
                if modality_list[i] == True:
                    real_feature_list.append(feature_list[i])
                else:
                    continue
            feature = torch.cat(real_feature_list, dim=1)
            feature = self.linear_projection(feature)
        return feature

In [13]:
# modality_list = [True, True, True, True, True]
modality_list = [False,False,False,False,False]
if sum(modality_list) == 0:
    print('WARNING: modality_list is empty!')




In [15]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size = 256, num_heads = 4, dropout = 0.0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.qkv = nn.Linear(emb_size, emb_size*3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
    
    def forward(self, x, mask = None):
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
        
        scaling = self.emb_size ** (1/2)
        att = F.softmax(energy, dim=-1) / scaling
        att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out

class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        
    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x

class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size, expansion = 4, drop_p = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )
        
class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size = 256,
                 drop_p = 0.5,
                 forward_expansion = 4,
                 forward_drop_p = 0.,
                 ** kwargs):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))
        
class TransformerEncoder(nn.Sequential):
    def __init__(self, depth = 1, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])
        
class regression_Head(nn.Sequential):
    def __init__(self, emb_size, num_classes):
        super().__init__(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size), 
            nn.Linear(emb_size, num_classes))
        
class ViT(nn.Sequential):
    def __init__(self,
                emb_size = 256,
                depth = 1,
                *,
                num_classes = 17*3,
                **kwargs):
        super().__init__(
            TransformerEncoder(depth, emb_size=emb_size, **kwargs),
            regression_Head(emb_size, num_classes)
        )

In [16]:
class modality_invariant_model(nn.Module):
    def __init__(self):
        super(modality_invariant_model, self).__init__()
        self.feature_extractor = feature_extrator()
        self.linear_projector = linear_projector(512, 256)
        self.vit = ViT()
    def forward(self, rgb_data, depth_data, mmwave_data, lidar_data, csi_data, modality_list):
        rgb_feature, depth_feature, mmwave_feature, lidar_feature, csi_feature = self.feature_extractor(rgb_data, depth_data, mmwave_data, lidar_data, csi_data)
        feature = self.linear_projector(rgb_feature, depth_feature, mmwave_feature, lidar_feature, csi_feature, modality_list)
        out = self.vit(feature)
        out = out.view(-1, 17, 3)
        return out

In [17]:
rgb_data = torch.rand(32, 3, 480, 640).cuda()
depth_data = torch.rand(32, 3, 480, 640).cuda()
mmwave_data = torch.rand(32, 31, 5).cuda()
lidar_data = torch.rand(32, 1467, 3).cuda()
csi_data = torch.rand(32, 3, 114, 10).cuda()
# modality_list = random.choices(
#     [True, False],
#     k= 5,
#     weights=[80, 20]
# )
# modality_list = [True, True, True, True, True]
modality_list = [False,False,False,False,False]
model = modality_invariant_model().cuda()
out = model(rgb_data, depth_data, mmwave_data, lidar_data, csi_data, modality_list)
print(out.shape)

torch.Size([32, 17, 3])




In [13]:
print(model)

modality_invariant_model(
  (feature_extractor): feature_extrator(
    (rgb_extractor): rgb_feature_extractor(
      (part): Sequential(
        (0): Sequential(
          (0): Conv2d(3, 3, kernel_size=(14, 14), stride=(2, 2))
          (1): ReLU()
          (2): Conv2d(3, 3, kernel_size=(5, 56), stride=(1, 1))
          (3): ReLU()
          (4): Conv2d(3, 3, kernel_size=(5, 23), stride=(1, 1))
          (5): ReLU()
          (6): Conv2d(3, 16, kernel_size=(3, 14), stride=(1, 1))
        )
        (1): Conv2d(16, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU()
        (4): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (5): Sequential(
          (0): Block(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (batch_norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1