In [1]:
import numpy as np
import glob
import scipy.io as sio
import torch
from torch import nn
import random
import csv
import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import yaml
import time
import re
from XRF55_Dataset import XRF55_Datase 

In [2]:
train_dataset = XRF55_Datase(root_dir="D:\Data\XRF55\XRF_dataset", scene='all', is_train=True)
test_dataset = XRF55_Datase(root_dir="D:\Data\XRF55\XRF_dataset", scene='all', is_train=False)

In [9]:
def collate_fn_padd(batch):
    '''
    Padds batch of variable length

    dict_keys(['modality', 'scene', 'subject', 'action', 'idx', 'output', 
    'input_rgb', 'input_depth', 'input_lidar', 'input_mmwave'])
    '''
    ## get sequence lengths
    # for t in batch:
    #     print(t.keys())
    #     print(a)
    # #     # print(t[0].shape,t[1].shape)
    # kpts = []
    # [kpts.append(np.array(t['output'])) for t in batch]
    # kpts = torch.FloatTensor(np.array(kpts))
    labels = []
    [labels.append(float(t[3])) for t in batch]
    labels = torch.FloatTensor(labels)

    # mmwave
    mmwave_data = np.array([(t[2]) for t in batch ])
    mmwave_data = torch.FloatTensor(mmwave_data)

    # wifi-csi
    wifi_data = np.array([(t[0]) for t in batch ])
    wifi_data = torch.FloatTensor(wifi_data)

    # rfid
    rfid_data = np.array([(t[1]) for t in batch ])
    rfid_data = torch.FloatTensor(rfid_data)
    
    modality_list = [True, True, True]

    return mmwave_data, wifi_data, rfid_data, labels, modality_list

In [10]:
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn_padd)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn_padd)

In [11]:
from baseline.baseline_model import single_model

# model = model(['RGB'])
# "modality_list: list of string, choose from ['RGB', 'Depth', 'mmWave', 'Lidar', 'Wifi']"
mmwave_model = single_model(['mmwave'])
# rgb_weights = torch.load('./baseline/baseline_weights/Single/RGB.pt')
# print(rgb_weights.keys())
# print(rgb_model)
# print(rgb_model.state_dict().keys())
mmwave_model.load_state_dict(torch.load('./baseline/baseline_weights/Single/mmwave.pt'))
mmwave_model.cuda()

wifi_model = single_model(['wifi'])
wifi_model.load_state_dict(torch.load('./baseline/baseline_weights/Single/wifi.pt'))
wifi_model.cuda()

rfid_model = single_model(['rfid'])
rfid_model.load_state_dict(torch.load('./baseline/baseline_weights/Single/rfid.pt'))
rfid_model.cuda()


single_model(
  (feature_extractor): single_feature_extrator(
    (rfid_extractor): rfid_feature_extractor(
      (part): Sequential(
        (0): Conv1d(23, 128, kernel_size=(7,), stride=(2,), padding=(3,), bias=False)
        (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): MaxPool1d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (4): Sequential(
          (0): BasicBlock(
            (conv1): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
            (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu): ReLU(inplace=True)
            (conv2): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
            (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): BasicBlock(
            (conv1): Conv1d(128, 128, kernel_s

In [12]:
def get_result(mmwave_model,wifi_model,rfid_model, tensor_loader, device):
    mmwave_model.eval()
    wifi_model.eval()
    rfid_model.eval()
    # test_mpjpe = 0
    # test_pampjpe = 0
    # test_mse = 0
    # random.seed(config['modality_existances']['val_random_seed'])
    for i, data in tqdm(enumerate(tensor_loader)):
        # start_time = time.time()
        mmwave_data, wifi_data, rfid_data, label, exist_list = data
        mmwave_data = mmwave_data.to(device)
        wifi_data = wifi_data.to(device)
        rfid_data = rfid_data.to(device)
        label.to(device)
        labels = label.type(torch.FloatTensor)
        # outputs = model(input_1, exist_list)
        # outputs = outputs.type(torch.FloatTensor)
        mmwave_model.to('cuda')
        mmwave_outputs = mmwave_model(mmwave_data, [True, False, False])
        mmwave_model.to('cpu')
        del mmwave_data
        wifi_model.to('cuda')
        wifi_outputs = wifi_model(wifi_data, [False, True, False])
        wifi_model.to('cpu')
        del wifi_data
        rfid_model.to('cuda')
        rfid_outputs = rfid_model(rfid_data, [False, False, True])
        rfid_model.to('cpu')
        del rfid_data
        
        # rgb_outputs = rgb_outputs.to(device)
        mmwave_outputs = mmwave_outputs.detach().cpu().numpy()
        wifi_outputs = wifi_outputs.detach().cpu().numpy()
        rfid_outputs = rfid_outputs.detach().cpu().numpy()
        labels = labels.detach().cpu().numpy()
        
        if i == 0:
            mmwave_result = mmwave_outputs
            wifi_result = wifi_outputs
            rfid_result = rfid_outputs
            all_label = labels
            print(mmwave_result.shape, wifi_result.shape, rfid_result.shape, all_label.shape)
        else:
            mmwave_result = np.vstack((mmwave_result, mmwave_outputs))
            wifi_result = np.vstack((wifi_result, wifi_outputs))
            rfid_result = np.vstack((rfid_result, rfid_outputs))
            all_label = np.hstack((all_label, labels))
            if i ==1:
                print(mmwave_result.shape, wifi_result.shape, rfid_result.shape, all_label.shape)

    np.save('./baseline_results/mmwave_result.npy', mmwave_result)
    np.save('./baseline_results/wifi_result.npy', wifi_result)
    np.save('./baseline_results/rfid_result.npy', rfid_result)
    np.save('./baseline_results/all_label.npy', all_label)

    return

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
get_result(mmwave_model,wifi_model,rfid_model, test_dataloader, device)

  return self._call_impl(*args, **kwargs)
1it [00:02,  2.75s/it]

(32, 55) (32, 55) (32, 55) (32,)


2it [00:07,  4.12s/it]

(64, 55) (64, 55) (64, 55) (64,)


207it [12:01,  3.49s/it]


In [14]:
mmwave_result=np.load('./baseline_results/mmwave_result.npy')
wifi_result=np.load('./baseline_results/wifi_result.npy')
rfid_result=np.load('./baseline_results/rfid_result.npy')
all_label=np.load('./baseline_results/all_label.npy')

In [15]:
predict_y = np.argmax(mmwave_result,axis=1)
epoch_accuracy = (predict_y == all_label).sum() / all_label.size
print('mmwave accuracy:', epoch_accuracy)

mmwave accuracy: 0.8490909090909091


In [16]:
predict_y = np.argmax(wifi_result,axis=1)
epoch_accuracy = (predict_y == all_label).sum() / all_label.size
print('wifi accuracy:', epoch_accuracy)

wifi accuracy: 0.7781818181818182


In [17]:
predict_y = np.argmax(rfid_result,axis=1)
epoch_accuracy = (predict_y == all_label).sum() / all_label.size
print('rfid accuracy:', epoch_accuracy)

rfid accuracy: 0.4216666666666667


In [19]:
RW_result = (mmwave_result+ wifi_result)/2
predict_y = np.argmax(RW_result,axis=1)
epoch_accuracy = (predict_y == all_label).sum() / all_label.size
print('RW accuracy:', epoch_accuracy)

RW accuracy: 0.943030303030303


In [21]:
RRF_result = (mmwave_result+ rfid_result)/2
predict_y = np.argmax(RRF_result,axis=1)
epoch_accuracy = (predict_y == all_label).sum() / all_label.size
print('RRF accuracy:', epoch_accuracy)

RRF accuracy: 0.8236363636363636


In [22]:
WRF_result = (wifi_result+ rfid_result)/2
predict_y = np.argmax(WRF_result,axis=1)
epoch_accuracy = (predict_y == all_label).sum() / all_label.size
print('WRF accuracy:', epoch_accuracy)

WRF accuracy: 0.7737878787878788


In [20]:
RWR_result = (mmwave_result+ wifi_result+rfid_result)/3
predict_y = np.argmax(RWR_result,axis=1)
epoch_accuracy = (predict_y == all_label).sum() / all_label.size
print('RWR accuracy:', epoch_accuracy)

RWR accuracy: 0.9224242424242424


In [15]:
ID_result = (rgb_result + depth_result)/2
mpjpe, pampjpe = error(ID_result,all_label)
print('ID mpjpe:',mpjpe)
print('ID pampjpe:',pampjpe)

ID mpjpe: 0.0967305
ID pampjpe: 0.053927209294340514
