In [1]:
import numpy as np
import glob
import scipy.io as sio
import torch
from torch import nn
import csv
import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import yaml
from XRF55_Dataset import XRF55_Datase 

### Loda Data

In [2]:
train_dataset = XRF55_Datase(root_dir="D:\Data\XRF55\XRF_dataset", scene='all', is_train=True)
test_dataset = XRF55_Datase(root_dir="D:\Data\XRF55\XRF_dataset", scene='all', is_train=False)

In [3]:
def collate_fn_padd(batch):
    '''
    Padds batch of variable length

    note: it converts things ToTensor manually here since the ToTensor transform
    assume it takes in images rather than arbitrary tensors.

    dict_keys(['modality', 'scene', 'subject', 'action', 'idx', 'output', 
    'input_rgb', 'input_depth', 'input_lidar', 'input_mmwave'])
    '''
    # ## get sequence lengths
    # for t in batch:
    #     print(t[0],t[1].shape,t[2].shape,t[3])
    #     print(a)
    # # #     # print(t[0].shape,t[1].shape)
    # # kpts = []
    # # [kpts.append(np.array(t['output'])) for t in batch]
    # # kpts = torch.FloatTensor(np.array(kpts))

    # # lengths = torch.tensor([t['input_rgb'].shape[0] for t in batch ])
    # all_actions = {'A01': 0., 'A02': 1., 'A03': 2., 'A04': 3., 'A05': 4., 
    #             'A06': 5., 'A07': 6., 'A08': 7., 'A09': 8., 'A10': 9.,
    #             'A11': 10., 'A12': 11., 'A13': 12., 'A14': 13., 'A15': 14., 
    #             'A16': 15., 'A17': 16., 'A18': 17., 'A19': 18., 'A20': 19., 
    #             'A21': 20., 'A22': 21., 'A23': 22., 'A24': 23., 'A25': 24., 
    #             'A26': 25., 'A27': 26.}
    
    labels = []
    [labels.append(float(t[3])) for t in batch]
    labels = torch.FloatTensor(labels)

    # wifi-csi
    rfid_data = np.array([(t[1]) for t in batch ])
    rfid_data = torch.FloatTensor(rfid_data)

    return rfid_data, labels

In [4]:
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn_padd)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn_padd)

In [5]:
for data in train_dataloader:
    rfid_data, labels = data
    print(rfid_data.shape)
    print(labels)
    break

torch.Size([16, 23, 148])
tensor([20.,  6., 45., 54.,  4., 46., 53., 19., 42., 11., 28., 45., 53., 43.,
        46., 52.])


### Model

In [23]:
import sys
path = os.getcwd()
os.chdir('c:/Users/Chen_Xinyan/Desktop/Modality_Invariant/XRF55')

# from backbone_models.RFID.BiLSTM import *
from backbone_models.RFID.ResNet import *
# from HAR_CSI_benchmark.models.mynetwork import *
# wifi_data = torch.randn(32, 3, 114, 10)
# model = rfid_BiLSTM(num_classes=55)
model = resnet18()

# out = model(wifi_data)
# print(out.shape)

In [24]:
print(model)

ResNet(
  (conv1): Conv1d(23, 128, kernel_size=(7,), stride=(2,), padding=(3,), bias=False)
  (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool1d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
      (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
      (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
      (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2):

### training 

In [25]:
def test(model, tensor_loader, criterion, device):
    model.eval()
    test_acc = 0
    test_loss = 0
    for data in tqdm(tensor_loader):
        wifi_data, labels = data
        inputs = wifi_data.to(device)
        labels.to(device)
        labels = labels.type(torch.LongTensor)
        outputs = model(inputs)
        outputs = outputs.type(torch.FloatTensor)
        outputs.to(device)
        loss = criterion(outputs,labels)
        predict_y = torch.argmax(outputs,dim=1).to(device)
        accuracy = (predict_y == labels.to(device)).sum().item()
        test_acc += accuracy
        test_loss += loss.item() * labels.size(0)
        outputs = outputs.detach().numpy()
        labels = labels.detach().numpy()
    test_acc = test_acc/len(tensor_loader.dataset)
    test_loss = test_loss/len(tensor_loader.dataset)
    print("validation accuracy:{:.4f}, loss:{:.5f}".format(float(test_acc),float(test_loss)))
    return test_acc

In [26]:
def train(model, train_loader, test_loader, num_epochs, learning_rate, criterion, device):
    optimizer = torch.optim.AdamW(model.parameters(), lr = learning_rate)
    # optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[20,40],gamma=0.1)
    parameter_dir = './backbone_models/RFID/ResNet18.pt'
    best_test_acc = 0
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        epoch_accuracy = 0.0
        # num_iter = 2000
        for i, data in enumerate(tqdm(train_loader)):
            # if i < num_iter:
            rgb_data, labels = data
            inputs = rgb_data.to(device)
            labels.to(device)
            labels = labels.type(torch.LongTensor)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            outputs = outputs.type(torch.FloatTensor)
            outputs.to(device)
            # print(outputs)
            # print(labels)
            loss = criterion(outputs,labels)
            # print(loss)
            # if loss == float('nan'):
            #     print('nan')
            #     print(outputs)
            #     print(labels)
            loss.backward()
            # print(length)
            # print("loss is ", loss.item())
            optimizer.step()

            
            epoch_loss += loss.item() * labels.size(0)
            predict_y = torch.argmax(outputs,dim=1).to(device)
            epoch_accuracy += (predict_y == labels.to(device)).sum().item()
            outputs = outputs.detach().numpy()
            labels = labels.detach().numpy()
            # else:
                # break
            # print("epoch loss is ", epoch_loss)
        # epoch_loss = epoch_loss/(rgb_data.size(0)*len(train_loader.dataset))
        # epoch_accuracy = epoch_accuracy/(rgb_data.size(0)*len(train_loader.dataset))
        epoch_loss = epoch_loss/len(train_loader.dataset)
        epoch_accuracy = epoch_accuracy/len(train_loader.dataset)
        print('Epoch:{}, Accuracy:{:.4f},Loss:{:.9f}'.format(epoch+1, float(epoch_accuracy),float(epoch_loss)))
        if (epoch+1) % 10 == 0:
            test_acc = test(
                model=model,
                tensor_loader=test_loader,
                criterion = criterion,
                device= device
            )
            if test_acc >= best_test_acc:
                print(f"best test acuracy is:{test_acc}")
                best_test_acc = test_acc
                torch.save(model, parameter_dir)
        scheduler.step()
    return

In [27]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device) 

cuda:0


In [28]:
criteria = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.load_state_dict(torch.load('./HAR_lidar_benchmark/lidar_all_random.pt'))
model.to(device)
train(
    model=model, 
    train_loader= train_dataloader,
    test_loader= test_dataloader,    
    num_epochs= 100,
    learning_rate=1e-3,
    criterion = criteria,
    device=device 
    )

  0%|          | 0/963 [00:00<?, ?it/s]

100%|██████████| 963/963 [57:22<00:00,  3.58s/it]


Epoch:1, Accuracy:0.0229,Loss:4.062083960


100%|██████████| 963/963 [55:07<00:00,  3.43s/it]


Epoch:2, Accuracy:0.0263,Loss:3.949326643


100%|██████████| 963/963 [55:40<00:00,  3.47s/it]


Epoch:3, Accuracy:0.0299,Loss:3.901183975


100%|██████████| 963/963 [56:03<00:00,  3.49s/it] 


Epoch:4, Accuracy:0.0425,Loss:3.826634088


100%|██████████| 963/963 [55:55<00:00,  3.48s/it]


Epoch:5, Accuracy:0.0594,Loss:3.712388941


100%|██████████| 963/963 [55:58<00:00,  3.49s/it]


Epoch:6, Accuracy:0.0821,Loss:3.572993868


100%|██████████| 963/963 [56:05<00:00,  3.50s/it]


Epoch:7, Accuracy:0.1023,Loss:3.438852318


100%|██████████| 963/963 [55:59<00:00,  3.49s/it]


Epoch:8, Accuracy:0.1242,Loss:3.293470246


100%|██████████| 963/963 [55:55<00:00,  3.48s/it]


Epoch:9, Accuracy:0.1523,Loss:3.144632691


100%|██████████| 963/963 [59:30<00:00,  3.71s/it]


Epoch:10, Accuracy:0.1837,Loss:3.001229865


100%|██████████| 207/207 [28:04<00:00,  8.14s/it]


validation accuracy:0.1833, loss:2.99854
best test acuracy is:0.18333333333333332


100%|██████████| 963/963 [57:16<00:00,  3.57s/it] 


Epoch:11, Accuracy:0.2184,Loss:2.845357361


100%|██████████| 963/963 [56:13<00:00,  3.50s/it]


Epoch:12, Accuracy:0.2588,Loss:2.663544587


100%|██████████| 963/963 [56:16<00:00,  3.51s/it]


Epoch:13, Accuracy:0.2992,Loss:2.504124968


100%|██████████| 963/963 [56:17<00:00,  3.51s/it]


Epoch:14, Accuracy:0.3431,Loss:2.344157313


100%|██████████| 963/963 [56:50<00:00,  3.54s/it]


Epoch:15, Accuracy:0.3841,Loss:2.167912428


100%|██████████| 963/963 [56:28<00:00,  3.52s/it]


Epoch:16, Accuracy:0.4358,Loss:1.980705260


100%|██████████| 963/963 [57:54<00:00,  3.61s/it]


Epoch:17, Accuracy:0.4819,Loss:1.789539461


100%|██████████| 963/963 [56:43<00:00,  3.53s/it]


Epoch:18, Accuracy:0.5310,Loss:1.621152976


100%|██████████| 963/963 [56:38<00:00,  3.53s/it]


Epoch:19, Accuracy:0.5814,Loss:1.425345767


100%|██████████| 963/963 [59:41<00:00,  3.72s/it]


Epoch:20, Accuracy:0.6314,Loss:1.250891137


100%|██████████| 207/207 [28:01<00:00,  8.13s/it]


validation accuracy:0.3594, loss:2.51656
best test acuracy is:0.3593939393939394


100%|██████████| 963/963 [58:22<00:00,  3.64s/it]


Epoch:21, Accuracy:0.7904,Loss:0.774375729


100%|██████████| 963/963 [57:10<00:00,  3.56s/it]


Epoch:22, Accuracy:0.8283,Loss:0.658068372


100%|██████████| 963/963 [56:45<00:00,  3.54s/it]


Epoch:23, Accuracy:0.8503,Loss:0.592905074


100%|██████████| 963/963 [57:05<00:00,  3.56s/it]


Epoch:24, Accuracy:0.8652,Loss:0.533105125


100%|██████████| 963/963 [57:05<00:00,  3.56s/it]


Epoch:25, Accuracy:0.8787,Loss:0.491784404


100%|██████████| 963/963 [56:48<00:00,  3.54s/it]


Epoch:26, Accuracy:0.8908,Loss:0.447533258


100%|██████████| 963/963 [57:32<00:00,  3.58s/it]


Epoch:27, Accuracy:0.8987,Loss:0.415589905


100%|██████████| 963/963 [56:57<00:00,  3.55s/it] 


Epoch:28, Accuracy:0.9047,Loss:0.386384400


100%|██████████| 963/963 [56:41<00:00,  3.53s/it]


Epoch:29, Accuracy:0.9131,Loss:0.355463136


100%|██████████| 963/963 [58:57<00:00,  3.67s/it]


Epoch:30, Accuracy:0.9194,Loss:0.329727458


100%|██████████| 207/207 [27:46<00:00,  8.05s/it]


validation accuracy:0.4406, loss:2.38394
best test acuracy is:0.4406060606060606


100%|██████████| 963/963 [56:04<00:00,  3.49s/it]


Epoch:31, Accuracy:0.9298,Loss:0.298450712


100%|██████████| 963/963 [55:33<00:00,  3.46s/it]


Epoch:32, Accuracy:0.9327,Loss:0.280895252


100%|██████████| 963/963 [55:46<00:00,  3.48s/it]


Epoch:33, Accuracy:0.9366,Loss:0.261874654


100%|██████████| 963/963 [56:20<00:00,  3.51s/it]


Epoch:34, Accuracy:0.9452,Loss:0.239971697


100%|██████████| 963/963 [56:11<00:00,  3.50s/it]


Epoch:35, Accuracy:0.9452,Loss:0.226166229


100%|██████████| 963/963 [56:28<00:00,  3.52s/it]


Epoch:36, Accuracy:0.9488,Loss:0.212367637


100%|██████████| 963/963 [56:24<00:00,  3.51s/it]


Epoch:37, Accuracy:0.9533,Loss:0.200216604


100%|██████████| 963/963 [56:12<00:00,  3.50s/it]


Epoch:38, Accuracy:0.9576,Loss:0.183205575


100%|██████████| 963/963 [56:05<00:00,  3.49s/it]


Epoch:39, Accuracy:0.9637,Loss:0.163215774


100%|██████████| 963/963 [59:38<00:00,  3.72s/it]


Epoch:40, Accuracy:0.9639,Loss:0.156252421


100%|██████████| 207/207 [27:53<00:00,  8.08s/it]


validation accuracy:0.4409, loss:2.68783
best test acuracy is:0.4409090909090909


100%|██████████| 963/963 [58:01<00:00,  3.62s/it]


Epoch:41, Accuracy:0.9708,Loss:0.138057392


100%|██████████| 963/963 [56:23<00:00,  3.51s/it]


Epoch:42, Accuracy:0.9718,Loss:0.131007332


100%|██████████| 963/963 [56:29<00:00,  3.52s/it]


Epoch:43, Accuracy:0.9746,Loss:0.125583159


100%|██████████| 963/963 [56:28<00:00,  3.52s/it]


Epoch:44, Accuracy:0.9742,Loss:0.125243983


100%|██████████| 963/963 [56:22<00:00,  3.51s/it]


Epoch:45, Accuracy:0.9757,Loss:0.120731865


100%|██████████| 963/963 [58:03<00:00,  3.62s/it]


Epoch:46, Accuracy:0.9738,Loss:0.120721038


100%|██████████| 963/963 [56:46<00:00,  3.54s/it]


Epoch:47, Accuracy:0.9780,Loss:0.112130272


100%|██████████| 963/963 [56:40<00:00,  3.53s/it]


Epoch:48, Accuracy:0.9749,Loss:0.118645594


100%|██████████| 963/963 [56:47<00:00,  3.54s/it]


Epoch:49, Accuracy:0.9782,Loss:0.113112458


100%|██████████| 963/963 [59:30<00:00,  3.71s/it]


Epoch:50, Accuracy:0.9779,Loss:0.113987292


100%|██████████| 207/207 [15:27<00:00,  4.48s/it]


validation accuracy:0.4380, loss:2.72519


 43%|████▎     | 416/963 [13:11<17:21,  1.90s/it]


KeyboardInterrupt: 

In [8]:
criteria = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load('./HAR_lidar_benchmark/lidar_all_random.pt'))
model.to(device)
test(
        model=model,
        tensor_loader=val_loader,
        criterion = criteria,
        device= device
    )

100%|██████████| 4010/4010 [10:04<00:00,  6.63it/s]

validation accuracy:0.7845, loss:0.75000





0.7845429604688864

In [11]:
parameter_dir = './HAR_lidar_benchmark/lidar_all_random.pt'
torch.save(model.state_dict(), parameter_dir)