In [1]:
import numpy as np
import glob
import scipy.io as sio
import torch
from torch import nn
import random
import csv
import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import yaml
import time
import re
from XRF55_Dataset import XRF55_Datase 
from baseline_model import single_model, dual_model, triple_model

In [2]:
train_dataset = XRF55_Datase(root_dir="D:\Data\XRF55\XRF_dataset", scene='all', is_train=True)
test_dataset = XRF55_Datase(root_dir="D:\Data\XRF55\XRF_dataset", scene='all', is_train=False)

In [3]:
def collate_fn_padd(batch):
    '''
    Padds batch of variable length

    dict_keys(['modality', 'scene', 'subject', 'action', 'idx', 'output', 
    'input_rgb', 'input_depth', 'input_lidar', 'input_mmwave'])
    '''
    ## get sequence lengths
    # for t in batch:
    #     print(t.keys())
    #     print(a)
    # #     # print(t[0].shape,t[1].shape)
    # kpts = []
    # [kpts.append(np.array(t['output'])) for t in batch]
    # kpts = torch.FloatTensor(np.array(kpts))
    labels = []
    [labels.append(float(t[3])) for t in batch]
    labels = torch.FloatTensor(labels)

    # mmwave
    mmwave_data = np.array([(t[2]) for t in batch ])
    mmwave_data = torch.FloatTensor(mmwave_data)

    # wifi-csi
    wifi_data = np.array([(t[0]) for t in batch ])
    wifi_data = torch.FloatTensor(wifi_data)

    # rfid
    rfid_data = np.array([(t[1]) for t in batch ])
    rfid_data = torch.FloatTensor(rfid_data)

    return mmwave_data, wifi_data, rfid_data, labels

In [4]:
parentdir = 'C:/Users/Chen_Xinyan/Desktop/Modality_Invariant/XRF55'
os.chdir(parentdir)


with open('baseline/config_all.yaml', 'r') as fd:
    config = yaml.load(fd, Loader=yaml.FullLoader)

In [5]:
def test(model, tensor_loader, criterion, device, modality_list):
    model.eval()
    test_acc = 0
    test_loss = 0
    # random.seed(config['modality_existances']['val_random_seed'])
    random.seed(3407)
    for data in tqdm(tensor_loader):
        start_time = time.time()
        mmwave_data, wifi_data, rfid_data, labels = data
        # t1 = time.time()
        # load_time = t1 - start_time 
        if len(modality_list) == 1:
            if modality_list[0] == 'mmwave':
                input_1 = mmwave_data.to(device)
                exist_list = [True, False, False]
            elif modality_list[0] == 'wifi':
                input_1 = wifi_data.to(device)
                exist_list = [False, True, False]
            elif modality_list[0] == 'rfid':
                input_1 = rfid_data.to(device)
                exist_list = [False, False, True]
            
            outputs = model(input_1, exist_list)
        elif len(modality_list) == 2:
            if modality_list[0] == 'mmwave' and modality_list[1] == 'wifi':
                input_1 = mmwave_data.to(device)
                input_2 = wifi_data.to(device)
                exist_list = [True, True, False]
            elif modality_list[0] == 'mmwave' and modality_list[1] == 'rfid':
                input_1 = mmwave_data.to(device)
                input_2 = rfid_data.to(device)
                exist_list = [True, False, True]
            elif modality_list[0] == 'wifi' and modality_list[1] == 'rfid':
                input_1 = wifi_data.to(device)
                input_2 = rfid_data.to(device)
                exist_list = [False, True, True]
            outputs = model(input_1, input_2, exist_list)
        elif len(modality_list) == 3:
            input_1 = mmwave_data.to(device)
            input_2 = wifi_data.to(device)
            input_3 = rfid_data.to(device)
            exist_list = [True, True, True]
            outputs = model(input_1, input_2, input_3, exist_list)
        

        labels.to(device)
        labels = labels.type(torch.LongTensor)

        
        outputs = outputs.type(torch.FloatTensor)
        outputs.to(device)
        # t2 = time.time()
        # forward_time = t2 - t1
        loss = criterion(outputs,labels)
        predict_y = torch.argmax(outputs,dim=1).to(device)
        accuracy = (predict_y == labels.to(device)).sum().item() / labels.size(0)
        test_acc += accuracy
        test_loss += loss.item() * labels.size(0)
        # t3 = time.time()
        # record_time = t3 - t2
        # print('load_time: ', load_time)
        # print('forward_time: ', forward_time)
        # print('record_time: ', record_time)
    test_acc = test_acc/len(tensor_loader)
    test_loss = test_loss/len(tensor_loader.dataset)
    print("validation accuracy:{:.4f}, loss:{:.5f}".format(float(test_acc),float(test_loss)))
    return test_acc

In [6]:
def train(model, train_loader, test_loader, modality_list, num_epochs, learning_rate, criterion, device):
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    if len(modality_list) == 1:
        parameter_dir = os.path.join('./baseline/baseline_weights/Single/', modality_list[0] + '.pt')
    elif len(modality_list) == 2:
        parameter_dir = os.path.join('./baseline/baseline_weights/Dual/', modality_list[0] + '_' + modality_list[1] + '.pt')
    elif len(modality_list) == 3:
        parameter_dir = os.path.join('./baseline/baseline_weights/Triple/', modality_list[0] + '_' + modality_list[1] + '_' + modality_list[2] + '.pt')
    print('parameter_dir: ', parameter_dir)
    # if os.path.exists(parameter_dir):
    #     pass
    # else:
    #     print(os.getcwd())
    #     print(a)
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        epoch_accuracy = 0
        random.seed(epoch)
        # num_iter = 400
        for i, data in enumerate(tqdm(train_loader)):
            # if i < num_iter:
            mmwave_data, wifi_data, rfid_data, labels = data
            if len(modality_list) == 1:
                if modality_list[0] == 'mmwave':
                    input_1 = mmwave_data.to(device)
                    exist_list = [True, False, False]
                elif modality_list[0] == 'wifi':
                    input_1 = wifi_data.to(device)
                    exist_list = [False, True, False]
                elif modality_list[0] == 'rfid':
                    input_1 = rfid_data.to(device)
                    exist_list = [False, False, True]
                
                outputs = model(input_1, exist_list)
            elif len(modality_list) == 2:
                if modality_list[0] == 'mmwave' and modality_list[1] == 'wifi':
                    input_1 = mmwave_data.to(device)
                    input_2 = wifi_data.to(device)
                    exist_list = [True, True, False]
                elif modality_list[0] == 'mmwave' and modality_list[1] == 'rfid':
                    input_1 = mmwave_data.to(device)
                    input_2 = rfid_data.to(device)
                    exist_list = [True, False, True]
                elif modality_list[0] == 'wifi' and modality_list[1] == 'rfid':
                    input_1 = wifi_data.to(device)
                    input_2 = rfid_data.to(device)
                    exist_list = [False, True, True]
                outputs = model(input_1, input_2, exist_list)
            elif len(modality_list) == 3:
                input_1 = mmwave_data.to(device)
                input_2 = wifi_data.to(device)
                input_3 = rfid_data.to(device)
                exist_list = [True, True, True]
                outputs = model(input_1, input_2, input_3, exist_list)
            
            labels.to(device)
            labels = labels.type(torch.LongTensor)
            
            optimizer.zero_grad()

            outputs = outputs.type(torch.FloatTensor)
            outputs.to(device)
            loss = criterion(outputs,labels)
            if loss == float('nan'):
                print('nan')
                print(outputs)
                print(labels)
                
            loss.backward()
            # print(length)
            # print("loss is ", loss.item())
            optimizer.step()
            
            epoch_loss += loss.item()
            predict_y = torch.argmax(outputs,dim=1).to(device)
            epoch_accuracy += (predict_y == labels.to(device)).sum().item() / labels.size(0)
            # print('accuracy is ', epoch_accuracy)
            # print('loss is ', epoch_loss)
            # else:
            #     break
            # print("epoch loss is ", epoch_loss)
        # epoch_loss = epoch_loss/len(train_loader.dataset)
        epoch_loss = epoch_loss/len(train_loader)
        epoch_accuracy = epoch_accuracy/len(train_loader)
        print('Epoch:{}, Accuracy:{:.4f},Loss:{:.9f}'.format(epoch+1, float(epoch_accuracy),float(epoch_loss)))
        if (epoch+1) % 5 == 0:
            test_acc = test(
                model=model,
                tensor_loader=test_loader,
                criterion = criterion,
                device= device,
                modality_list = modality_list
            )
    torch.save(model.state_dict(), parameter_dir)
    return

In [7]:
for i in range(len(config['modality_list'])):
    modality_list = config['modality_list'][i]
    print('modality_list: ', modality_list)
    train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn_padd)
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn_padd)
    # model = model(['RGB'])
    if len(modality_list) == 1:
        model = single_model(modality_list)
    elif len(modality_list) == 2:
        model = dual_model(modality_list)
    elif len(modality_list) == 3:
        model = triple_model(modality_list)
    else:
        print('error')
    # elif len(config['modality']) == 5:
    #     model = Five_model(config['modality'])
    model.cuda()

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    criterion = nn.CrossEntropyLoss()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train(
        model=model,
        train_loader= train_dataloader,
        test_loader= test_dataloader,
        modality_list = modality_list,  
        num_epochs= 40,
        learning_rate=1e-4,
        criterion=criterion,
        device=device
            )

modality_list:  ['mmwave', 'wifi', 'rfid']
cuda:0
parameter_dir:  ./baseline/baseline_weights/Triple/mmwave_wifi_rfid.pt


  return self._call_impl(*args, **kwargs)
100%|██████████| 963/963 [40:10<00:00,  2.50s/it]


Epoch:1, Accuracy:0.5052,Loss:2.178683336


100%|██████████| 963/963 [34:35<00:00,  2.15s/it]


Epoch:2, Accuracy:0.9373,Loss:0.584495883


100%|██████████| 963/963 [34:26<00:00,  2.15s/it]


Epoch:3, Accuracy:0.9892,Loss:0.195967060


100%|██████████| 963/963 [34:35<00:00,  2.16s/it]


Epoch:4, Accuracy:0.9947,Loss:0.098679305


100%|██████████| 963/963 [35:38<00:00,  2.22s/it]


Epoch:5, Accuracy:0.9942,Loss:0.068858866


100%|██████████| 207/207 [12:16<00:00,  3.56s/it]


validation accuracy:0.7269, loss:0.98056


100%|██████████| 963/963 [39:05<00:00,  2.44s/it]


Epoch:6, Accuracy:0.9953,Loss:0.051013995


100%|██████████| 963/963 [32:34<00:00,  2.03s/it]


Epoch:7, Accuracy:0.9958,Loss:0.036280660


100%|██████████| 963/963 [31:38<00:00,  1.97s/it]


Epoch:8, Accuracy:0.9912,Loss:0.047437116


100%|██████████| 963/963 [31:21<00:00,  1.95s/it]


Epoch:9, Accuracy:0.9955,Loss:0.030392494


100%|██████████| 963/963 [31:02<00:00,  1.93s/it]


Epoch:10, Accuracy:0.9969,Loss:0.021960566


100%|██████████| 207/207 [12:08<00:00,  3.52s/it]


validation accuracy:0.7154, loss:1.06125


100%|██████████| 963/963 [34:15<00:00,  2.13s/it]


Epoch:11, Accuracy:0.9928,Loss:0.033733775


100%|██████████| 963/963 [30:37<00:00,  1.91s/it]


Epoch:12, Accuracy:0.9950,Loss:0.026464850


100%|██████████| 963/963 [30:38<00:00,  1.91s/it]


Epoch:13, Accuracy:0.9953,Loss:0.023433190


100%|██████████| 963/963 [30:09<00:00,  1.88s/it]


Epoch:14, Accuracy:0.9946,Loss:0.025639917


100%|██████████| 963/963 [30:07<00:00,  1.88s/it]


Epoch:15, Accuracy:0.9962,Loss:0.019163014


100%|██████████| 207/207 [11:39<00:00,  3.38s/it]


validation accuracy:0.7103, loss:1.17378


100%|██████████| 963/963 [34:31<00:00,  2.15s/it]


Epoch:16, Accuracy:0.9956,Loss:0.020576927


100%|██████████| 963/963 [31:25<00:00,  1.96s/it]


Epoch:17, Accuracy:0.9946,Loss:0.022596903


100%|██████████| 963/963 [31:09<00:00,  1.94s/it]


Epoch:18, Accuracy:0.9970,Loss:0.016235587


100%|██████████| 963/963 [31:11<00:00,  1.94s/it]


Epoch:19, Accuracy:0.9973,Loss:0.013976546


100%|██████████| 963/963 [31:30<00:00,  1.96s/it]


Epoch:20, Accuracy:0.9957,Loss:0.018250064


100%|██████████| 207/207 [11:45<00:00,  3.41s/it]


validation accuracy:0.6873, loss:1.33500


100%|██████████| 963/963 [35:05<00:00,  2.19s/it]


Epoch:21, Accuracy:0.9958,Loss:0.020972978


100%|██████████| 963/963 [31:14<00:00,  1.95s/it]


Epoch:22, Accuracy:0.9988,Loss:0.008422615


100%|██████████| 963/963 [31:12<00:00,  1.94s/it]


Epoch:23, Accuracy:0.9962,Loss:0.017426381


100%|██████████| 963/963 [32:06<00:00,  2.00s/it]


Epoch:24, Accuracy:0.9959,Loss:0.017645429


100%|██████████| 963/963 [31:06<00:00,  1.94s/it]


Epoch:25, Accuracy:0.9972,Loss:0.013973855


100%|██████████| 207/207 [11:48<00:00,  3.42s/it]


validation accuracy:0.7037, loss:1.30206


100%|██████████| 963/963 [34:37<00:00,  2.16s/it]


Epoch:26, Accuracy:0.9981,Loss:0.009210131


100%|██████████| 963/963 [32:56<00:00,  2.05s/it]


Epoch:27, Accuracy:0.9947,Loss:0.021163086


100%|██████████| 963/963 [30:54<00:00,  1.93s/it]


Epoch:28, Accuracy:0.9969,Loss:0.014535989


100%|██████████| 963/963 [31:21<00:00,  1.95s/it]


Epoch:29, Accuracy:0.9965,Loss:0.013962429


100%|██████████| 963/963 [31:08<00:00,  1.94s/it]


Epoch:30, Accuracy:0.9974,Loss:0.011318158


100%|██████████| 207/207 [11:56<00:00,  3.46s/it]


validation accuracy:0.7061, loss:1.31877


  0%|          | 0/963 [00:01<?, ?it/s]


KeyboardInterrupt: 

In [8]:
torch.save(model.state_dict(), './baseline/baseline_weights/Triple/mmwave_wifi_rfid.pt')

In [12]:
test_acc = test(
                model=model,
                tensor_loader=test_dataloader,
                criterion = criterion,
                device= device,
                modality_list = modality_list
            )

  1%|▏         | 3/207 [00:18<20:59,  6.17s/it]


KeyboardInterrupt: 