In [1]:
import numpy as np
import glob
import scipy.io as sio
import torch
from torch import nn
import random
import csv
import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import yaml
from evaluate import error
import time
import re
from syn_DI_dataset import make_dataset, make_dataloader

In [2]:
dataset_root = 'd:\Data\My_MMFi_Data\MMFi_Dataset'
with open('config.yaml', 'r') as fd:
    config = yaml.load(fd, Loader=yaml.FullLoader)

train_dataset, val_dataset = make_dataset(dataset_root, config)

['rgb', 'depth', 'lidar', 'mmwave', 'wifi-csi']
S02 ['A01', 'A02', 'A03', 'A04', 'A06', 'A08', 'A09', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A18', 'A19', 'A20', 'A21', 'A22', 'A23', 'A24', 'A26']
S03 ['A01', 'A02', 'A03', 'A04', 'A05', 'A06', 'A07', 'A08', 'A09', 'A10', 'A11', 'A12', 'A15', 'A16', 'A17', 'A18', 'A21', 'A22', 'A23', 'A24', 'A25', 'A26', 'A27']
S05 ['A01', 'A02', 'A03', 'A04', 'A05', 'A06', 'A07', 'A09', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A19', 'A20', 'A21', 'A24', 'A25']
S06 ['A01', 'A03', 'A04', 'A06', 'A07', 'A08', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15', 'A18', 'A21', 'A22', 'A23', 'A24', 'A25', 'A26', 'A27']
S08 ['A01', 'A02', 'A04', 'A05', 'A06', 'A07', 'A08', 'A09', 'A10', 'A11', 'A13', 'A14', 'A17', 'A19', 'A20', 'A21', 'A22', 'A23', 'A24', 'A25', 'A26', 'A27']
S09 ['A01', 'A07', 'A08', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15', 'A17', 'A18', 'A20', 'A21', 'A24', 'A25', 'A27']
S11 ['A01', 'A02', 'A03', 'A05', 'A06', 'A08', 'A10', 'A11', 'A12', 

In [3]:
def collate_fn_padd(batch):
    '''
    Padds batch of variable length

    note: it converts things ToTensor manually here since the ToTensor transform
    assume it takes in images rather than arbitrary tensors.

    dict_keys(['modality', 'scene', 'subject', 'action', 'idx', 'output', 
    'input_rgb', 'input_depth', 'input_lidar', 'input_mmwave'])
    '''
    ## get sequence lengths
    # for t in batch:
    #     print(t.keys())
    #     print(a)
    # #     # print(t[0].shape,t[1].shape)
    kpts = []
    [kpts.append(np.array(t['output'])) for t in batch]
    kpts = torch.FloatTensor(np.array(kpts))

    # lengths = torch.tensor([t['input_mmwave'].shape[0] for t in batch ])

    # rgb
    rgb_data = np.array([(t['input_rgb']) for t in batch ])
    rgb_data = torch.FloatTensor(rgb_data).permute(0,3,1,2)

    # depth
    depth_data = np.array([(t['input_depth']) for t in batch ])
    depth_data = torch.FloatTensor(depth_data).permute(0,3,1,2)

    # mmwave
    ## padd
    mmwave_data = [torch.Tensor(t['input_mmwave']) for t in batch ]
    mmwave_data = torch.nn.utils.rnn.pad_sequence(mmwave_data)
    ## compute mask
    mmwave_data = mmwave_data.permute(1,0,2)

    # lidar
    ## padd
    lidar_data = [torch.Tensor(t['input_lidar']) for t in batch ]
    lidar_data = torch.nn.utils.rnn.pad_sequence(lidar_data)
    ## compute mask
    lidar_data = lidar_data.permute(1,0,2)

    # wifi-csi
    wifi_data = np.array([(t['input_wifi-csi']) for t in batch ])
    wifi_data = torch.FloatTensor(wifi_data)
    "要改"
    exist_list = [True, False, False, False, False]

    return rgb_data, depth_data, lidar_data, mmwave_data, wifi_data, kpts, exist_list
    # return rgb_data, kpts, exist_list

In [4]:
rng_generator = torch.manual_seed(config['init_rand_seed'])
train_loader = make_dataloader(train_dataset, is_training=True, generator=rng_generator, **config['loader'], collate_fn = collate_fn_padd)
val_loader = make_dataloader(val_dataset, is_training=False, generator=rng_generator, **config['loader'], collate_fn = collate_fn_padd)

In [5]:
from baseline.baseline_model import single_model

# model = model(['RGB'])
# "modality_list: list of string, choose from ['RGB', 'Depth', 'mmWave', 'Lidar', 'Wifi']"
rgb_model = single_model(['rgb'])
# rgb_weights = torch.load('./baseline/baseline_weights/Single/RGB.pt')
# print(rgb_weights.keys())
# print(rgb_model)
# print(rgb_model.state_dict().keys())
rgb_model.load_state_dict(torch.load('./baseline/baseline_weights/Single/RGB.pt'))
rgb_model.cuda()

depth_model = single_model(['depth'])
depth_model.load_state_dict(torch.load('./baseline/baseline_weights/Single/Depth.pt'))
depth_model.cuda()

mmwave_model = single_model(['mmwave'])
mmwave_model.load_state_dict(torch.load('./baseline/baseline_weights/Single/mmWave.pt'))
mmwave_model.cuda()

lidar_model = single_model(['lidar'])
lidar_model.load_state_dict(torch.load('./baseline/baseline_weights/Single/Lidar.pt'))
lidar_model.cuda()

wifi_model = single_model(['wifi-csi'])
wifi_model.load_state_dict(torch.load('./baseline/baseline_weights/Single/Wifi.pt'))
wifi_model.cuda()

single_model(
  (feature_extractor): single_feature_extrator(
    (csi_extractor): csi_feature_extractor(
      (part): Sequential(
        (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Sequential(
          (0): BasicBlock(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu): ReLU(inplace=True)
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): BasicBlock(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn1): BatchNorm

In [6]:
def get_result(rgb_model,depth_model,mmwave_model,lidar_model,wifi_model, tensor_loader, device):
    rgb_model.eval()
    depth_model.eval()
    mmwave_model.eval()
    lidar_model.eval()
    wifi_model.eval()
    # test_mpjpe = 0
    # test_pampjpe = 0
    # test_mse = 0
    # random.seed(config['modality_existances']['val_random_seed'])
    for i, data in tqdm(enumerate(tensor_loader)):
        # start_time = time.time()
        rgb_data, depth_data, lidar_data, mmwave_data, wifi_data, kpts, exist_list = data
        rgb_data = rgb_data.to(device)
        depth_data = depth_data.to(device)
        lidar_data = lidar_data.to(device)
        mmwave_data = mmwave_data.to(device)
        wifi_data = wifi_data.to(device)
        kpts.to(device)
        labels = kpts.type(torch.FloatTensor)
        # outputs = model(input_1, exist_list)
        # outputs = outputs.type(torch.FloatTensor)
        rgb_model.to('cuda')
        rgb_outputs = rgb_model(rgb_data, [True, False, False, False, False])
        rgb_model.to('cpu')
        del rgb_data
        depth_model.to('cuda')
        depth_outputs = depth_model(depth_data, [False, True, False, False, False])
        depth_model.to('cpu')
        del depth_data
        mmwave_model.to('cuda')
        mmwave_outputs = mmwave_model(mmwave_data, [False, False, True, False, False])
        mmwave_model.to('cpu')
        del mmwave_data
        lidar_model.to('cuda')
        lidar_outputs = lidar_model(lidar_data, [False, False, False, True, False])
        lidar_model.to('cpu')
        del lidar_data
        wifi_model.to('cuda')
        wifi_outputs = wifi_model(wifi_data, [False, False, False, False, True])
        wifi_model.to('cpu')
        del wifi_data
        
        # rgb_outputs = rgb_outputs.to(device)
        rgb_outputs = rgb_outputs.detach().cpu().numpy()
        depth_outputs = depth_outputs.detach().cpu().numpy()
        mmwave_outputs = mmwave_outputs.detach().cpu().numpy()
        lidar_outputs = lidar_outputs.detach().cpu().numpy()
        wifi_outputs = wifi_outputs.detach().cpu().numpy()
        labels = labels.detach().cpu().numpy()
        if i == 0:
            rgb_result = rgb_outputs
            depth_result = depth_outputs
            mmwave_result = mmwave_outputs
            lidar_result = lidar_outputs
            wifi_result = wifi_outputs
            all_label = labels
        else:
            rgb_result = np.vstack((rgb_result, rgb_outputs))
            depth_result = np.vstack((depth_result, depth_outputs))
            mmwave_result = np.vstack((mmwave_result, mmwave_outputs))
            lidar_result = np.vstack((lidar_result, lidar_outputs))
            wifi_result = np.vstack((wifi_result, wifi_outputs))
            all_label = np.vstack((all_label, labels))
    np.save('./baseline_results/rgb_result.npy', rgb_result)
    np.save('./baseline_results/depth_result.npy', depth_result)
    np.save('./baseline_results/mmwave_result.npy', mmwave_result)
    np.save('./baseline_results/lidar_result.npy', lidar_result)
    np.save('./baseline_results/wifi_result.npy', wifi_result)
    np.save('./baseline_results/all_label.npy', all_label)

    return

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
get_result(rgb_model,depth_model,mmwave_model,lidar_model,wifi_model, val_loader, device)

  lidar_data = [torch.Tensor(t['input_lidar']) for t in batch ]
3403it [1:09:01,  1.22s/it]


In [8]:
rgb_result=np.load('./baseline_results/rgb_result.npy')
depth_result=np.load('./baseline_results/depth_result.npy')
mmwave_result=np.load('./baseline_results/mmwave_result.npy')
lidar_result=np.load('./baseline_results/lidar_result.npy')
wifi_result=np.load('./baseline_results/wifi_result.npy')
all_label=np.load('./baseline_results/all_label.npy')

In [9]:
mpjpe, pampjpe = error(rgb_result,all_label)
print('RGB mpjpe:',mpjpe)
print('RGB pampjpe:',pampjpe)

RGB mpjpe: 0.1209835
RGB pampjpe: 0.06802574693815214


In [10]:
mpjpe, pampjpe = error(depth_result,all_label)
print('Depth mpjpe:',mpjpe)
print('Depth pampjpe:',pampjpe)

Depth mpjpe: 0.102404244
Depth pampjpe: 0.05273534376296619


In [11]:
mpjpe, pampjpe = error(lidar_result,all_label)
print('Lidar mpjpe:',mpjpe)
print('Lidar pampjpe:',pampjpe)

Lidar mpjpe: 0.16146122
Lidar pampjpe: 0.10347314387242829


In [12]:
mpjpe, pampjpe = error(mmwave_result,all_label)
print('mmWave mpjpe:',mpjpe)
print('mmWave pampjpe:',pampjpe)

mmWave mpjpe: 0.14125669
mmWave pampjpe: 0.07241973172052109


In [13]:
mpjpe, pampjpe = error(wifi_result,all_label)
print('Wifi mpjpe:',mpjpe)
print('Wifi pampjpe:',pampjpe)

Wifi mpjpe: 0.22709544
Wifi pampjpe: 0.10794727373808227


In [15]:
ID_result = (rgb_result + depth_result)/2
mpjpe, pampjpe = error(ID_result,all_label)
print('ID mpjpe:',mpjpe)
print('ID pampjpe:',pampjpe)

ID mpjpe: 0.0967305
ID pampjpe: 0.053927209294340514


In [16]:
IL_result = (rgb_result + lidar_result)/2
mpjpe, pampjpe = error(IL_result,all_label)
print('IL mpjpe:',mpjpe)
print('IL pampjpe:',pampjpe)

IL mpjpe: 0.12204741
IL pampjpe: 0.07560428465871315


In [17]:
IR_result = (rgb_result + mmwave_result)/2
mpjpe, pampjpe = error(IR_result,all_label)
print('IR mpjpe:',mpjpe)
print('IR pampjpe:',pampjpe)

IR mpjpe: 0.10099864
IR pampjpe: 0.057282707213582695


In [18]:
IW_result = (rgb_result + wifi_result)/2
mpjpe, pampjpe = error(IW_result,all_label)
print('IW mpjpe:',mpjpe)
print('IW pampjpe:',pampjpe)

IW mpjpe: 0.1464063
IW pampjpe: 0.07734011737127457


In [19]:
DL_result = (depth_result + lidar_result)/2
mpjpe, pampjpe = error(DL_result,all_label)
print('DL mpjpe:',mpjpe)
print('DL pampjpe:',pampjpe)

DL mpjpe: 0.11172355
DL pampjpe: 0.06875045324114044


In [20]:
DR_result = (depth_result + mmwave_result)/2
mpjpe, pampjpe = error(DR_result,all_label)
print('DR mpjpe:',mpjpe)
print('DR pampjpe:',pampjpe)

DR mpjpe: 0.09408594
DR pampjpe: 0.05178810151339959


In [21]:
DW_result = (depth_result + wifi_result)/2
mpjpe, pampjpe = error(DW_result,all_label)
print('DW mpjpe:',mpjpe)
print('DW pampjpe:',pampjpe)

DW mpjpe: 0.14171347
DW pampjpe: 0.07138707609416639


In [22]:
RL_result = (mmwave_result + lidar_result)/2
mpjpe, pampjpe = error(RL_result,all_label)
print('RL mpjpe:',mpjpe)
print('RL pampjpe:',pampjpe)

RL mpjpe: 0.11641539
RL pampjpe: 0.07146439485644991


In [23]:
RW_result = (mmwave_result + wifi_result)/2
mpjpe, pampjpe = error(RW_result,all_label)
print('RW mpjpe:',mpjpe)
print('RW pampjpe:',pampjpe)

RW mpjpe: 0.1441773
RW pampjpe: 0.07215036383181073


In [24]:
LW_result = (lidar_result + wifi_result)/2
mpjpe, pampjpe = error(LW_result,all_label)
print('LW mpjpe:',mpjpe)
print('LW pampjpe:',pampjpe)

LW mpjpe: 0.16706093
LW pampjpe: 0.10072880605367486


In [25]:
IDL_result = (rgb_result + depth_result + lidar_result)/3
mpjpe, pampjpe = error(IDL_result,all_label)
print('IDL mpjpe:',mpjpe)
print('IDL pampjpe:',pampjpe)

IDL mpjpe: 0.102884695
IDL pampjpe: 0.062288294354467075


In [26]:
IDR_result = (rgb_result + depth_result + mmwave_result)/3
mpjpe, pampjpe = error(IDR_result,all_label)
print('IDR mpjpe:',mpjpe)
print('IDR pampjpe:',pampjpe)

IDR mpjpe: 0.086676255
IDR pampjpe: 0.05006913068905228


In [28]:
IDW_result = (rgb_result + depth_result + wifi_result)/3
mpjpe, pampjpe = error(IDW_result,all_label)
print('IDW mpjpe:',mpjpe)
print('IDW pampjpe:',pampjpe)

IDW mpjpe: 0.11859307
IDW pampjpe: 0.06366522570369035


In [29]:
ILR_result = (rgb_result + lidar_result + mmwave_result)/3
mpjpe, pampjpe = error(ILR_result,all_label)
print('ILR mpjpe:',mpjpe)
print('ILR pampjpe:',pampjpe)

ILR mpjpe: 0.10136549
ILR pampjpe: 0.06270019567711774


In [30]:
ILW_result = (rgb_result + lidar_result + wifi_result)/3
mpjpe, pampjpe = error(ILW_result,all_label)
print('ILW mpjpe:',mpjpe)
print('ILW pampjpe:',pampjpe)

ILW mpjpe: 0.13493872
ILW pampjpe: 0.08199176980513029


In [31]:
IRW_result = (rgb_result + mmwave_result + wifi_result)/3
mpjpe, pampjpe = error(IRW_result,all_label)
print('IRW mpjpe:',mpjpe)
print('IRW pampjpe:',pampjpe)

IRW mpjpe: 0.11651134
IRW pampjpe: 0.06309395703861613


In [32]:
DLR_result = (depth_result + lidar_result + mmwave_result)/3
mpjpe, pampjpe = error(DLR_result,all_label)
print('DLR mpjpe:',mpjpe)
print('DLR pampjpe:',pampjpe)

DLR mpjpe: 0.095291585
DLR pampjpe: 0.05856401709935985


In [34]:
DLW_result = (depth_result + lidar_result + wifi_result)/3
mpjpe, pampjpe = error(DLW_result,all_label)
print('DLW mpjpe:',mpjpe)
print('DLW pampjpe:',pampjpe)

DLW mpjpe: 0.13074219
DLW pampjpe: 0.07806211223969958


In [35]:
DRW_result = (depth_result + mmwave_result + wifi_result)/3
mpjpe, pampjpe = error(DRW_result,all_label)
print('DRW mpjpe:',mpjpe)
print('DRW pampjpe:',pampjpe)

DRW mpjpe: 0.11318378
DRW pampjpe: 0.059399814542748605


In [36]:
LRW_result = (lidar_result + mmwave_result + wifi_result)/3
mpjpe, pampjpe = error(LRW_result,all_label)
print('LRW mpjpe:',mpjpe)
print('LRW pampjpe:',pampjpe)

LRW mpjpe: 0.12963524
LRW pampjpe: 0.0780878697796786


In [37]:
IDLR_result = (rgb_result + depth_result + lidar_result + mmwave_result)/4
mpjpe, pampjpe = error(IDLR_result,all_label)
print('IDLR mpjpe:',mpjpe)
print('IDLR pampjpe:',pampjpe)


IDLR mpjpe: 0.09071056
IDLR pampjpe: 0.05577281046866885


In [38]:
IDLW_result = (rgb_result + depth_result + lidar_result + wifi_result)/4
mpjpe, pampjpe = error(IDLW_result,all_label)
print('IDLW mpjpe:',mpjpe)
print('IDLW pampjpe:',pampjpe)

IDLW mpjpe: 0.11672417
IDLW pampjpe: 0.07002651349511252


In [39]:
IDRW_result = (rgb_result + depth_result + mmwave_result + wifi_result)/4
mpjpe, pampjpe = error(IDRW_result,all_label)
print('IDRW mpjpe:',mpjpe)
print('IDRW pampjpe:',pampjpe)

IDRW mpjpe: 0.10192484
IDRW pampjpe: 0.05629642005526464


In [40]:
ILRW_result = (rgb_result + lidar_result + mmwave_result + wifi_result)/4
mpjpe, pampjpe = error(ILRW_result,all_label)
print('ILRW mpjpe:',mpjpe)
print('ILRW pampjpe:',pampjpe)

ILRW mpjpe: 0.11402801
ILRW pampjpe: 0.06952264365507982


In [41]:
DLRW_result = (depth_result + lidar_result + mmwave_result + wifi_result)/4
mpjpe, pampjpe = error(DLRW_result,all_label)
print('DLRW mpjpe:',mpjpe)
print('DLRW pampjpe:',pampjpe)

DLRW mpjpe: 0.11081125
DLRW pampjpe: 0.06666690455020988


In [42]:
IDLRW_result = (rgb_result + depth_result + lidar_result + mmwave_result + wifi_result)/5
mpjpe, pampjpe = error(IDLRW_result,all_label)
print('IDLRW mpjpe:',mpjpe)
print('IDLRW pampjpe:',pampjpe)

IDLRW mpjpe: 0.10297314
IDLRW pampjpe: 0.062427426992494484
