In [1]:
import numpy as np
import glob
import scipy.io as sio
import torch
from torch import nn
import csv
import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import yaml
from XRF55_Dataset import XRF55_Datase 

### Loda Data

In [2]:
train_dataset = XRF55_Datase(root_dir="D:\Data\XRF55\XRF_dataset", scene='all', is_train=True)
test_dataset = XRF55_Datase(root_dir="D:\Data\XRF55\XRF_dataset", scene='all', is_train=False)

In [3]:
def collate_fn_padd(batch):
    '''
    Padds batch of variable length

    note: it converts things ToTensor manually here since the ToTensor transform
    assume it takes in images rather than arbitrary tensors.

    dict_keys(['modality', 'scene', 'subject', 'action', 'idx', 'output', 
    'input_rgb', 'input_depth', 'input_lidar', 'input_mmwave'])
    '''
    # ## get sequence lengths
    # for t in batch:
    #     print(t[0],t[1].shape,t[2].shape,t[3])
    #     print(a)
    # # #     # print(t[0].shape,t[1].shape)
    # # kpts = []
    # # [kpts.append(np.array(t['output'])) for t in batch]
    # # kpts = torch.FloatTensor(np.array(kpts))

    # # lengths = torch.tensor([t['input_rgb'].shape[0] for t in batch ])
    # all_actions = {'A01': 0., 'A02': 1., 'A03': 2., 'A04': 3., 'A05': 4., 
    #             'A06': 5., 'A07': 6., 'A08': 7., 'A09': 8., 'A10': 9.,
    #             'A11': 10., 'A12': 11., 'A13': 12., 'A14': 13., 'A15': 14., 
    #             'A16': 15., 'A17': 16., 'A18': 17., 'A19': 18., 'A20': 19., 
    #             'A21': 20., 'A22': 21., 'A23': 22., 'A24': 23., 'A25': 24., 
    #             'A26': 25., 'A27': 26.}
    
    labels = []
    [labels.append(float(t[3])) for t in batch]
    labels = torch.FloatTensor(labels)

    # wifi-csi
    wifi_data = np.array([(t[0]) for t in batch ])
    wifi_data = torch.FloatTensor(wifi_data)

    return wifi_data, labels

In [5]:
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn_padd)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn_padd)

In [6]:
for data in train_dataloader:
    wifi_data, labels = data
    print(wifi_data.shape)
    print(labels)
    break

torch.Size([32, 270, 1000])
tensor([32., 32., 46., 22., 52.,  6., 27., 42., 14., 47., 11., 32., 14.,  3.,
        13.,  1., 31., 22., 17., 39.,  6., 52.,  5., 51.,  4., 24., 21.,  4.,
        40., 41.,  9.,  1.])


In [7]:
print(len(train_dataloader.dataset))

15400


### Model

In [15]:
import sys
path = os.getcwd()
os.chdir('c:/Users/Chen_Xinyan/Desktop/Modality_Invariant/XRF55')

from backbone_models.WIFI.BiLSTM import *
from backbone_models.WIFI.ResNet import *
# from HAR_CSI_benchmark.models.mynetwork import *
# wifi_data = torch.randn(32, 3, 114, 10)
# model = Wifi_BiLSTM(num_classes=55)
model = resnet18()

# out = model(wifi_data)
# print(out.shape)

In [16]:
print(model)

ResNet(
  (conv1): Conv1d(270, 128, kernel_size=(7,), stride=(2,), padding=(3,), bias=False)
  (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv1d(128, 128, kernel_size=(7,), stride=(2,), padding=(3,), bias=False)
  (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv1d(128, 128, kernel_size=(7,), stride=(2,), padding=(3,), bias=False)
  (bn3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool1d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
      (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)


### training 

In [17]:
def test(model, tensor_loader, criterion, device):
    model.eval()
    test_acc = 0
    test_loss = 0
    for data in tqdm(tensor_loader):
        wifi_data, labels = data
        inputs = wifi_data.to(device)
        labels.to(device)
        labels = labels.type(torch.LongTensor)
        outputs = model(inputs)
        outputs = outputs.type(torch.FloatTensor)
        outputs.to(device)
        loss = criterion(outputs,labels)
        predict_y = torch.argmax(outputs,dim=1).to(device)
        accuracy = (predict_y == labels.to(device)).sum().item()
        test_acc += accuracy
        test_loss += loss.item() * labels.size(0)
        outputs = outputs.detach().numpy()
        labels = labels.detach().numpy()
    test_acc = test_acc/len(tensor_loader.dataset)
    test_loss = test_loss/len(tensor_loader.dataset)
    print("validation accuracy:{:.4f}, loss:{:.5f}".format(float(test_acc),float(test_loss)))
    return test_acc

In [18]:
def train(model, train_loader, test_loader, num_epochs, learning_rate, criterion, device):
    optimizer = torch.optim.AdamW(model.parameters(), lr = learning_rate)
    # optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[20,40],gamma=0.1)
    parameter_dir = './backbone_models/WIFI/ResNet18.pt'
    best_test_acc = 0
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        epoch_accuracy = 0.0
        num_iter = 100
        for i, data in enumerate(tqdm(train_loader)):
            # if i < num_iter:
            rgb_data, labels = data
            inputs = rgb_data.to(device)
            labels.to(device)
            labels = labels.type(torch.LongTensor)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            outputs = outputs.type(torch.FloatTensor)
            outputs.to(device)
            # print(outputs)
            # print(labels)
            loss = criterion(outputs,labels)
            # print(loss)
            # if loss == float('nan'):
            #     print('nan')
            #     print(outputs)
            #     print(labels)
            loss.backward()
            # print(length)
            # print("loss is ", loss.item())
            optimizer.step()

        
            epoch_loss += loss.item() * labels.size(0)
            predict_y = torch.argmax(outputs,dim=1).to(device)
            epoch_accuracy += (predict_y == labels.to(device)).sum().item()
            outputs = outputs.detach().numpy()
            labels = labels.detach().numpy()
            # else:
            #     break
            # print("epoch loss is ", epoch_loss)
        # epoch_loss = epoch_loss/(rgb_data.size(0)*100)
        # epoch_accuracy = epoch_accuracy/(rgb_data.size(0)*100)
        epoch_loss = epoch_loss/len(train_loader.dataset)
        epoch_accuracy = epoch_accuracy/len(train_loader.dataset)
        print('Epoch:{}, Accuracy:{:.4f},Loss:{:.9f}'.format(epoch+1, float(epoch_accuracy),float(epoch_loss)))
        if (epoch+1) % 10 == 0:
            test_acc = test(
                model=model,
                tensor_loader=test_loader,
                criterion = criterion,
                device= device
            )
            if test_acc >= best_test_acc:
                print(f"best test acuracy is:{test_acc}")
                best_test_acc = test_acc
                torch.save(model, parameter_dir)
        # scheduler.step()
    return

In [19]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device) 

cuda:0


In [20]:
criteria = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.load_state_dict(torch.load('./HAR_lidar_benchmark/lidar_all_random.pt'))
model.to(device)
train(
    model=model, 
    train_loader= train_dataloader,
    test_loader= test_dataloader,    
    num_epochs= 100,
    learning_rate=1e-3,
    criterion = criteria,
    device=device 
    )

100%|██████████| 482/482 [27:54<00:00,  3.47s/it]


Epoch:1, Accuracy:0.0388,Loss:3.900661711


100%|██████████| 482/482 [53:50<00:00,  6.70s/it]


Epoch:2, Accuracy:0.0714,Loss:3.494317501


100%|██████████| 482/482 [56:12<00:00,  7.00s/it]


Epoch:3, Accuracy:0.1051,Loss:3.204015212


100%|██████████| 482/482 [56:10<00:00,  6.99s/it]


Epoch:4, Accuracy:0.1480,Loss:2.934926641


100%|██████████| 482/482 [57:38<00:00,  7.18s/it]


Epoch:5, Accuracy:0.1925,Loss:2.671179093


100%|██████████| 482/482 [56:39<00:00,  7.05s/it]


Epoch:6, Accuracy:0.2479,Loss:2.405589844


100%|██████████| 482/482 [56:45<00:00,  7.07s/it]


Epoch:7, Accuracy:0.3023,Loss:2.207430410


100%|██████████| 482/482 [56:55<00:00,  7.09s/it]


Epoch:8, Accuracy:0.3590,Loss:1.994251876


100%|██████████| 482/482 [56:38<00:00,  7.05s/it]


Epoch:9, Accuracy:0.4018,Loss:1.830861649


100%|██████████| 482/482 [57:01<00:00,  7.10s/it]


Epoch:10, Accuracy:0.4574,Loss:1.656136164


100%|██████████| 207/207 [28:41<00:00,  8.32s/it]


validation accuracy:0.4477, loss:1.71686
best test acuracy is:0.44772727272727275


100%|██████████| 482/482 [59:31<00:00,  7.41s/it]


Epoch:11, Accuracy:0.5086,Loss:1.479271891


100%|██████████| 482/482 [57:37<00:00,  7.17s/it]


Epoch:12, Accuracy:0.5535,Loss:1.353387630


100%|██████████| 482/482 [56:38<00:00,  7.05s/it]


Epoch:13, Accuracy:0.5937,Loss:1.237335489


100%|██████████| 482/482 [57:07<00:00,  7.11s/it]


Epoch:14, Accuracy:0.6218,Loss:1.138118649


100%|██████████| 482/482 [56:57<00:00,  7.09s/it]


Epoch:15, Accuracy:0.6610,Loss:1.029347671


100%|██████████| 482/482 [57:01<00:00,  7.10s/it]


Epoch:16, Accuracy:0.6933,Loss:0.933706076


100%|██████████| 482/482 [57:08<00:00,  7.11s/it]


Epoch:17, Accuracy:0.7162,Loss:0.855265326


100%|██████████| 482/482 [58:26<00:00,  7.27s/it]


Epoch:18, Accuracy:0.7373,Loss:0.796738009


100%|██████████| 482/482 [57:08<00:00,  7.11s/it]


Epoch:19, Accuracy:0.7595,Loss:0.720339663


100%|██████████| 482/482 [57:30<00:00,  7.16s/it]


Epoch:20, Accuracy:0.7737,Loss:0.679843828


100%|██████████| 207/207 [28:29<00:00,  8.26s/it]


validation accuracy:0.6715, loss:1.10187
best test acuracy is:0.6715151515151515


100%|██████████| 482/482 [1:00:16<00:00,  7.50s/it]


Epoch:21, Accuracy:0.7864,Loss:0.634834167


100%|██████████| 482/482 [58:59<00:00,  7.34s/it]


Epoch:22, Accuracy:0.8055,Loss:0.584225200


100%|██████████| 482/482 [57:24<00:00,  7.15s/it]


Epoch:23, Accuracy:0.8224,Loss:0.527925159


100%|██████████| 482/482 [57:36<00:00,  7.17s/it]


Epoch:24, Accuracy:0.8290,Loss:0.508173848


100%|██████████| 482/482 [57:28<00:00,  7.15s/it]


Epoch:25, Accuracy:0.8406,Loss:0.475562974


100%|██████████| 482/482 [57:20<00:00,  7.14s/it]


Epoch:26, Accuracy:0.8422,Loss:0.460753017


100%|██████████| 482/482 [57:15<00:00,  7.13s/it]


Epoch:27, Accuracy:0.8582,Loss:0.421280505


100%|██████████| 482/482 [58:51<00:00,  7.33s/it]


Epoch:28, Accuracy:0.8627,Loss:0.415564882


100%|██████████| 482/482 [56:16<00:00,  7.00s/it]


Epoch:29, Accuracy:0.8757,Loss:0.373006281


100%|██████████| 482/482 [57:40<00:00,  7.18s/it]


Epoch:30, Accuracy:0.8832,Loss:0.357330561


100%|██████████| 207/207 [28:24<00:00,  8.23s/it]


validation accuracy:0.7917, loss:0.72887
best test acuracy is:0.7916666666666666


100%|██████████| 482/482 [1:00:06<00:00,  7.48s/it]


Epoch:31, Accuracy:0.8873,Loss:0.342743590


100%|██████████| 482/482 [55:53<00:00,  6.96s/it]


Epoch:32, Accuracy:0.8869,Loss:0.333631869


100%|██████████| 482/482 [56:27<00:00,  7.03s/it]


Epoch:33, Accuracy:0.8974,Loss:0.313361146


100%|██████████| 482/482 [56:12<00:00,  7.00s/it]


Epoch:34, Accuracy:0.8987,Loss:0.307381348


100%|██████████| 482/482 [56:43<00:00,  7.06s/it]


Epoch:35, Accuracy:0.9116,Loss:0.269082004


100%|██████████| 482/482 [56:19<00:00,  7.01s/it]


Epoch:36, Accuracy:0.9141,Loss:0.255614209


100%|██████████| 482/482 [56:24<00:00,  7.02s/it]


Epoch:37, Accuracy:0.9068,Loss:0.277872579


100%|██████████| 482/482 [56:37<00:00,  7.05s/it]


Epoch:38, Accuracy:0.9158,Loss:0.249638140


100%|██████████| 482/482 [56:43<00:00,  7.06s/it]


Epoch:39, Accuracy:0.9231,Loss:0.230859357


100%|██████████| 482/482 [56:54<00:00,  7.08s/it]


Epoch:40, Accuracy:0.9216,Loss:0.238769480


100%|██████████| 207/207 [28:19<00:00,  8.21s/it]


validation accuracy:0.8024, loss:0.77900
best test acuracy is:0.8024242424242424


100%|██████████| 482/482 [1:01:18<00:00,  7.63s/it]


Epoch:41, Accuracy:0.9236,Loss:0.235058944


100%|██████████| 482/482 [57:38<00:00,  7.18s/it]


Epoch:42, Accuracy:0.9364,Loss:0.195461922


100%|██████████| 482/482 [57:15<00:00,  7.13s/it]


Epoch:43, Accuracy:0.9271,Loss:0.219291970


100%|██████████| 482/482 [56:45<00:00,  7.07s/it]


Epoch:44, Accuracy:0.9353,Loss:0.201504504


100%|██████████| 482/482 [57:13<00:00,  7.12s/it]


Epoch:45, Accuracy:0.9347,Loss:0.196005564


100%|██████████| 482/482 [57:41<00:00,  7.18s/it]


Epoch:46, Accuracy:0.9338,Loss:0.203473041


100%|██████████| 482/482 [57:49<00:00,  7.20s/it]


Epoch:47, Accuracy:0.9403,Loss:0.184351529


100%|██████████| 482/482 [57:20<00:00,  7.14s/it]


Epoch:48, Accuracy:0.9407,Loss:0.182497604


100%|██████████| 482/482 [57:17<00:00,  7.13s/it]


Epoch:49, Accuracy:0.9461,Loss:0.164443531


100%|██████████| 482/482 [57:18<00:00,  7.13s/it]


Epoch:50, Accuracy:0.9451,Loss:0.167744943


100%|██████████| 207/207 [28:17<00:00,  8.20s/it]


validation accuracy:0.8098, loss:0.75774
best test acuracy is:0.8098484848484848


 29%|██▊       | 138/482 [16:59<42:21,  7.39s/it]


KeyboardInterrupt: 

In [None]:
os.chdir('c:/Users/Chen_Xinyan/Desktop/Modality_Invariant/XRF55')
parameter_dir = './backbone_models/WIFI/ResNet18.pt'
torch.save(model, parameter_dir)

In [8]:
criteria = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load('./HAR_lidar_benchmark/lidar_all_random.pt'))
model.to(device)
test(
        model=model,
        tensor_loader=val_loader,
        criterion = criteria,
        device= device
    )

100%|██████████| 4010/4010 [10:04<00:00,  6.63it/s]

validation accuracy:0.7845, loss:0.75000





0.7845429604688864

In [11]:
parameter_dir = './HAR_lidar_benchmark/lidar_all_random.pt'
torch.save(model.state_dict(), parameter_dir)