In [1]:
import numpy as np
import glob
import scipy.io as sio
import torch
from torch import nn
import random
import csv
import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import yaml
import time
import re
from XRF55_Dataset import *

### Loda Data

In [2]:
train_dataset = XRF55_Datase(root_dir="D:\Data\XRF55\XRF_dataset", scene='all', is_train=True)
test_dataset = XRF55_Datase(root_dir="D:\Data\XRF55\XRF_dataset", scene='all', is_train=False)

In [3]:
def generate_none_empth_modality_list():
    modality_list = random.choices(
        [True, False],
        k= 2,
        weights=[50, 50]
    )
    rfid_ = random.choices(
        [True, False],
        k= 1,
        weights=[80, 20]
    )
    modality_list.append(rfid_[0])
    # print(modality_list)
    if sum(modality_list) == 0:
        modality_list = generate_none_empth_modality_list()
        return modality_list
    else:
        return modality_list

def collate_fn_padd(batch):
    '''
    Padds batch of variable length

    dict_keys(['modality', 'scene', 'subject', 'action', 'idx', 'output', 
    'input_rgb', 'input_depth', 'input_rfid', 'input_mmwave'])
    '''
    ## get sequence lengths
    # for t in batch:
    #     print(t.keys())
    #     print(a)
    # #     # print(t[0].shape,t[1].shape)
    # kpts = []
    # [kpts.append(np.array(t['output'])) for t in batch]
    # kpts = torch.FloatTensor(np.array(kpts))
    labels = []
    [labels.append(float(t[3])) for t in batch]
    labels = torch.FloatTensor(labels)

    # mmwave
    mmwave_data = np.array([(t[2]) for t in batch ])
    mmwave_data = torch.FloatTensor(mmwave_data)

    # wifi-csi
    wifi_data = np.array([(t[0]) for t in batch ])
    wifi_data = torch.FloatTensor(wifi_data)

    # rfid
    rfid_data = np.array([(t[1]) for t in batch ])
    rfid_data = torch.FloatTensor(rfid_data)
    
    modality_list = generate_none_empth_modality_list()

    return mmwave_data, wifi_data, rfid_data, labels, modality_list

In [4]:
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn_padd)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn_padd)

In [5]:
for data in train_dataloader:
    mmwave_data, wifi_data, rfid_data, labels, modality_list = data
    print(mmwave_data.shape, wifi_data.shape, rfid_data.shape, labels.shape, modality_list)
    break

torch.Size([16, 17, 256, 128]) torch.Size([16, 270, 1000]) torch.Size([16, 23, 148]) torch.Size([16]) [True, True, True]


### Model

In [7]:
import torch
from MI_model_5_3input import modality_invariant_model
# from MI_model_6 import modality_invariant_model
# from MI_model import modality_invariant_model

model = modality_invariant_model(model_depth=2, num_classes=55)
model.cuda()
# model.cuda()

# rgb_data = torch.randn(16, 3, 480, 640).cuda()
# depth_data = torch.randn(16, 3, 480, 640).cuda()
# lidar_data = torch.randn(16, 1432, 3).cuda()
# mmwave_data = torch.randn(16, 67, 5).cuda()
# wifi_data = torch.randn(16, 3, 114, 10).cuda()
# # modality_list = [False, False, False, False, False]
# # modality_list = [True, True, True, True, True]


# out = model(rgb_data, depth_data,  mmwave_data, lidar_data, wifi_data, modality_list)

# print(out)

modality_invariant_model(
  (feature_extractor): feature_extrator(
    (mmwave_extractor): mmwave_feature_extractor(
      (part): Sequential(
        (0): Conv2d(17, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (4): Sequential(
          (0): BasicBlock(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu): ReLU()
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): BasicBlock(
            (conv1): Conv2d(64, 64, ker

### training 

In [8]:
def multi_test(model, tensor_loader, criterion, device):
    model.eval()
    mmwave_test_loss = 0
    mmwave_test_accuracy = 0

    wifi_test_loss = 0
    wifi_test_accuracy = 0
    
    rfid_test_loss = 0
    rfid_test_accuracy = 0

    mmwave_wifi_test_loss = 0
    mmwave_wifi_test_accuracy = 0

    mmwave_rfid_test_loss = 0
    mmwave_rfid_test_accuracy = 0

    wifi_rfid_test_loss = 0
    wifi_rfid_test_accuracy = 0

    mmwave_wifi_rfid_test_loss = 0
    mmwave_wifi_rfid_test_accuracy = 0

    # random.seed(config['modality_existances']['val_random_seed'])
    for data in tqdm(tensor_loader):
        start_time = time.time()
        mmwave_data, wifi_data, rfid_data, labels, modality_list = data
        # t1 = time.time()
        # load_time = t1 - start_time 
        mmwave_data = mmwave_data.to(device)
        wifi_data = wifi_data.to(device)
        rfid_data = rfid_data.to(device)
        labels.to(device)
        labels = labels.type(torch.LongTensor)


        ' SINGLE MODALITY '
        ### mmwave
        mmwave_modality_list = [True, False, False]
        mmwave_outputs = model(mmwave_data, wifi_data,  rfid_data, mmwave_modality_list)
        mmwave_outputs = mmwave_outputs.type(torch.FloatTensor)
        mmwave_outputs.to(device)
        mmwave_test_loss += criterion(mmwave_outputs,labels).item() * mmwave_data.size(0)
        mmwave_predict_y = torch.argmax(mmwave_outputs,dim=1).to(device)
        mmwave_test_accuracy += (mmwave_predict_y == labels.to(device)).sum().item() / labels.size(0)
        mmwave_outputs = mmwave_outputs.detach().cpu()
        mmwave_predict_y = mmwave_predict_y.detach().cpu()
        ### wifi-cis
        wifi_modality_list = [False, True, False]
        wifi_outputs = model(mmwave_data, wifi_data,  rfid_data, wifi_modality_list)
        wifi_outputs = wifi_outputs.type(torch.FloatTensor)
        wifi_outputs.to(device)
        wifi_test_loss += criterion(wifi_outputs,labels).item() * wifi_data.size(0)
        wifi_predict_y = torch.argmax(wifi_outputs,dim=1).to(device)
        wifi_test_accuracy += (wifi_predict_y == labels.to(device)).sum().item() / labels.size(0)
        wifi_outputs = wifi_outputs.detach().cpu()
        wifi_predict_y = wifi_predict_y.detach().cpu()
        ### rfid
        rfid_modality_list = [False, False,True]
        rfid_outputs = model(mmwave_data, wifi_data,  rfid_data, rfid_modality_list)
        rfid_outputs = rfid_outputs.type(torch.FloatTensor)
        rfid_outputs.to(device)
        rfid_test_loss += criterion(rfid_outputs,labels).item() * rfid_data.size(0)
        rfid_predict_y = torch.argmax(rfid_outputs,dim=1).to(device)
        rfid_test_accuracy += (rfid_predict_y == labels.to(device)).sum().item() / labels.size(0)
        rfid_outputs = rfid_outputs.detach().cpu()
        rfid_predict_y = rfid_predict_y.detach().cpu()
        
        'Dual modality'
        ### mmwave + wifi
        mmwave_wifi_modality_list = [True, True, False]
        mmwave_wifi_outputs = model(mmwave_data, wifi_data,  rfid_data, mmwave_wifi_modality_list)
        mmwave_wifi_outputs = mmwave_wifi_outputs.type(torch.FloatTensor)
        mmwave_wifi_outputs.to(device)
        mmwave_wifi_test_loss += criterion(mmwave_wifi_outputs,labels).item() * mmwave_data.size(0)
        mmwave_wifi_predict_y = torch.argmax(mmwave_wifi_outputs,dim=1).to(device)
        mmwave_wifi_test_accuracy += (mmwave_wifi_predict_y == labels.to(device)).sum().item() / labels.size(0)
        mmwave_wifi_outputs = mmwave_wifi_outputs.detach().cpu()
        mmwave_wifi_predict_y = mmwave_wifi_predict_y.detach().cpu()

        ### mmwave + rfid
        mmwave_rfid_modality_list = [True, False, True]
        mmwave_rfid_outputs = model(mmwave_data, wifi_data,  rfid_data, mmwave_rfid_modality_list)
        mmwave_rfid_outputs = mmwave_rfid_outputs.type(torch.FloatTensor)
        mmwave_rfid_outputs.to(device)
        mmwave_rfid_test_loss += criterion(mmwave_rfid_outputs,labels).item() * mmwave_data.size(0)
        mmwave_rfid_predict_y = torch.argmax(mmwave_rfid_outputs,dim=1).to(device)
        mmwave_rfid_test_accuracy += (mmwave_rfid_predict_y == labels.to(device)).sum().item() / labels.size(0)
        mmwave_rfid_outputs = mmwave_rfid_outputs.detach().cpu()
        mmwave_rfid_predict_y = mmwave_rfid_predict_y.detach().cpu()

        ### wifi + rfid
        wifi_rfid_modality_list = [False, True, True]
        wifi_rfid_outputs = model(mmwave_data, wifi_data,  rfid_data, wifi_rfid_modality_list)
        wifi_rfid_outputs = wifi_rfid_outputs.type(torch.FloatTensor)
        wifi_rfid_outputs.to(device)
        wifi_rfid_test_loss += criterion(wifi_rfid_outputs,labels).item() * wifi_data.size(0)
        wifi_rfid_predict_y = torch.argmax(wifi_rfid_outputs,dim=1).to(device)
        wifi_rfid_test_accuracy += (wifi_rfid_predict_y == labels.to(device)).sum().item() / labels.size(0)
        wifi_rfid_outputs = wifi_rfid_outputs.detach().cpu()
        wifi_rfid_predict_y = wifi_rfid_predict_y.detach().cpu()

        'Three modality'
        ### mmwave + wifi + rfid
        mmwave_wifi_rfid_modality_list = [True, True, True]
        mmwave_wifi_rfid_outputs = model(mmwave_data, wifi_data,  rfid_data, mmwave_wifi_rfid_modality_list)
        mmwave_wifi_rfid_outputs = mmwave_wifi_rfid_outputs.type(torch.FloatTensor)
        mmwave_wifi_rfid_outputs.to(device)
        mmwave_wifi_rfid_test_loss += criterion(mmwave_wifi_rfid_outputs,labels).item() * mmwave_data.size(0)
        mmwave_wifi_rfid_predict_y = torch.argmax(mmwave_wifi_rfid_outputs,dim=1).to(device)
        mmwave_wifi_rfid_test_accuracy += (mmwave_wifi_rfid_predict_y == labels.to(device)).sum().item() / labels.size(0)
        mmwave_wifi_rfid_outputs = mmwave_wifi_rfid_outputs.detach().cpu()
        mmwave_wifi_rfid_predict_y = mmwave_wifi_rfid_predict_y.detach().cpu()


    'single modality'
    ### mmwave
    mmwave_test_loss = mmwave_test_loss/len(tensor_loader.dataset)
    mmwave_test_accuracy = mmwave_test_accuracy/len(tensor_loader)
    print("modality: {}, Cross Entropy Loss: {:.8f}, Accuracy: {:.8f}".format('mmWave',float(mmwave_test_loss), float(mmwave_test_accuracy)))
    ### wifi
    wifi_test_loss = wifi_test_loss/len(tensor_loader.dataset)
    wifi_test_accuracy = wifi_test_accuracy/len(tensor_loader)
    print("modality: {}, Cross Entropy Loss: {:.8f}, Accuracy: {:.8f}\n".format('WiFi-CSI',float(wifi_test_loss), float(wifi_test_accuracy)))
    ### rfid
    rfid_test_loss = rfid_test_loss/len(tensor_loader.dataset)
    rfid_test_accuracy = rfid_test_accuracy/len(tensor_loader)
    print("modality: {}, Cross Entropy Loss: {:.8f}, Accuracy: {:.8f}\n".format('RFID',float(rfid_test_loss), float(rfid_test_accuracy)))

    'dual modality'
    ### mmwave + wifi
    mmwave_wifi_test_loss = mmwave_wifi_test_loss/len(tensor_loader.dataset)
    mmwave_wifi_test_accuracy = mmwave_wifi_test_accuracy/len(tensor_loader)
    print("modality: {}, Cross Entropy Loss: {:.8f}, Accuracy: {:.8f}".format('mmWave+WiFi-CSI',float(mmwave_wifi_test_loss), float(mmwave_wifi_test_accuracy)))
    
    ### mmwave + rfid
    mmwave_rfid_test_loss = mmwave_rfid_test_loss/len(tensor_loader.dataset)
    mmwave_rfid_test_accuracy = mmwave_rfid_test_accuracy/len(tensor_loader)
    print("modality: {}, Cross Entropy Loss: {:.8f}, Accuracy: {:.8f}".format('mmWave+RFID',float(mmwave_rfid_test_loss), float(mmwave_rfid_test_accuracy)))

    ### wifi + rfid
    wifi_rfid_test_loss = wifi_rfid_test_loss/len(tensor_loader.dataset)
    wifi_rfid_test_accuracy = wifi_rfid_test_accuracy/len(tensor_loader)
    print("modality: {}, Cross Entropy Loss: {:.8f}, Accuracy: {:.8f}".format('WiFi-CSI+RFID',float(wifi_rfid_test_loss), float(wifi_rfid_test_accuracy)))

    'three modality'
    ### mmwave + wifi + rfid
    mmwave_wifi_rfid_test_loss = mmwave_wifi_rfid_test_loss/len(tensor_loader.dataset)
    mmwave_wifi_rfid_test_accuracy = mmwave_wifi_rfid_test_accuracy/len(tensor_loader)
    print("modality: {}, Cross Entropy Loss: {:.8f}, Accuracy: {:.8f}".format('mmWave+WiFi-CSI+RFID',float(mmwave_wifi_rfid_test_loss), float(mmwave_wifi_rfid_test_accuracy)))
    return

In [9]:
criterion = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.load_state_dict(torch.load('./HAR_17thNov_iterD2.pt'))

multi_test(model, test_dataloader, criterion, device)

  return self._call_impl(*args, **kwargs)
100%|██████████| 207/207 [10:26<00:00,  3.03s/it]

modality: mmWave, Cross Entropy Loss: 0.83845632, Accuracy: 0.80646135
modality: WiFi-CSI, Cross Entropy Loss: 2.69654207, Accuracy: 0.24184783

modality: RFID, Cross Entropy Loss: 3.20865457, Accuracy: 0.38753019

modality: mmWave+WiFi-CSI, Cross Entropy Loss: 0.62377319, Accuracy: 0.84073068
modality: mmWave+RFID, Cross Entropy Loss: 0.62020258, Accuracy: 0.84148551
modality: WiFi-CSI+RFID, Cross Entropy Loss: 2.32784379, Accuracy: 0.48007246
modality: mmWave+WiFi-CSI+RFID, Cross Entropy Loss: 0.50496502, Accuracy: 0.86337560





In [10]:
print("modality: {}, Cross Entropy Loss: {:.8f}, Accuracy: {:.8f}\n".format('RFID',float(rfid_test_loss), float(rfid_test_accuracy)))

NameError: name 'rfid_test_loss' is not defined