In [1]:
import numpy as np
import glob
import scipy.io as sio
import torch
from torch import nn
import csv
import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import yaml
from evaluate import error
from HAR_mmwave_benchmark.mmfi_mmwave import make_dataset, make_dataloader

### Loda Data

In [2]:
dataset_root = 'd:\Data\My_MMFi_Data\MMFi_Dataset'
# xian zai shi yong de shi Radar_Fused
with open('config_copy.yaml', 'r') as fd:
    config = yaml.load(fd, Loader=yaml.FullLoader)

train_dataset, val_dataset = make_dataset(dataset_root, config)

S02 ['A01', 'A02', 'A03', 'A04', 'A06', 'A08', 'A09', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A18', 'A19', 'A20', 'A21', 'A22', 'A23', 'A24', 'A26']
S03 ['A01', 'A02', 'A03', 'A04', 'A05', 'A06', 'A07', 'A08', 'A09', 'A10', 'A11', 'A12', 'A15', 'A16', 'A17', 'A18', 'A21', 'A22', 'A23', 'A24', 'A25', 'A26', 'A27']
S05 ['A01', 'A02', 'A03', 'A04', 'A05', 'A06', 'A07', 'A09', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A19', 'A20', 'A21', 'A24', 'A25']
S06 ['A01', 'A03', 'A04', 'A06', 'A07', 'A08', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15', 'A18', 'A21', 'A22', 'A23', 'A24', 'A25', 'A26', 'A27']
S08 ['A01', 'A02', 'A04', 'A05', 'A06', 'A07', 'A08', 'A09', 'A10', 'A11', 'A13', 'A14', 'A17', 'A19', 'A20', 'A21', 'A22', 'A23', 'A24', 'A25', 'A26', 'A27']
S09 ['A01', 'A07', 'A08', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15', 'A17', 'A18', 'A20', 'A21', 'A24', 'A25', 'A27']
S11 ['A01', 'A02', 'A03', 'A05', 'A06', 'A08', 'A10', 'A11', 'A12', 'A13', 'A16', 'A17', 'A18', 'A19', 'A21', 'A23',

In [3]:
def collate_fn_padd(batch):
    '''
    Padds batch of variable length

    note: it converts things ToTensor manually here since the ToTensor transform
    assume it takes in images rather than arbitrary tensors.

    dict_keys(['modality', 'scene', 'subject', 'action', 'idx', 'output', 
    'input_rgb', 'input_depth', 'input_lidar', 'input_mmwave'])
    '''
    ## get sequence lengths
    # for t in batch:
    #     print(t.keys())
    #     print(a)
    # #     # print(t[0].shape,t[1].shape)
    # kpts = []
    # [kpts.append(np.array(t['output'])) for t in batch]
    # kpts = torch.FloatTensor(np.array(kpts))

    # lengths = torch.tensor([t['input_rgb'].shape[0] for t in batch ])
    all_actions = {'A01': 0., 'A02': 1., 'A03': 2., 'A04': 3., 'A05': 4., 
                'A06': 5., 'A07': 6., 'A08': 7., 'A09': 8., 'A10': 9.,
                'A11': 10., 'A12': 11., 'A13': 12., 'A14': 13., 'A15': 14., 
                'A16': 15., 'A17': 16., 'A18': 17., 'A19': 18., 'A20': 19., 
                'A21': 20., 'A22': 21., 'A23': 22., 'A24': 23., 'A25': 24., 
                'A26': 25., 'A27': 26.}
    
    labels = []
    [labels.append(all_actions[t['action']]) for t in batch]
    labels = torch.FloatTensor(labels)

    # rgb
    # rgb_data = np.array([(t['input_rgb']) for t in batch ])
    # rgb_data = torch.FloatTensor(rgb_data).permute(0,3,1,2)

    # # depth
    # depth_data = np.array([(t['input_depth']) for t in batch ])
    # depth_data = torch.FloatTensor(depth_data).permute(0,3,1,2)

    # # mmwave
    ## padd
    mmwave_data = [torch.Tensor(t['input_mmwave']) for t in batch ]
    mmwave_data = torch.nn.utils.rnn.pad_sequence(mmwave_data)
    ## compute mask
    mmwave_data = mmwave_data.permute(1,0,2)

    # # lidar
    # ## padd
    # lidar_data = [torch.Tensor(t['input_lidar']) for t in batch ]
    # lidar_data = torch.nn.utils.rnn.pad_sequence(lidar_data)
    # ## compute mask
    # lidar_data = lidar_data.permute(1,0,2)

    # # wifi-csi
    # wifi_data = np.array([(t['input_wifi-csi']) for t in batch ])
    # wifi_data = torch.FloatTensor(wifi_data)

    # return rgb_data, depth_data, lidar_data, mmwave_data, wifi_data, kpts, lengths
    return mmwave_data, labels

In [4]:
rng_generator = torch.manual_seed(config['init_rand_seed'])
train_loader = make_dataloader(train_dataset, is_training=True, generator=rng_generator, **config['loader'], collate_fn = collate_fn_padd)
val_loader = make_dataloader(val_dataset, is_training=False, generator=rng_generator, **config['loader'], collate_fn = collate_fn_padd)

In [5]:
for i, data in enumerate(train_loader):
    # rgb_data, depth_data, lidar_data, mmwave_data, wifi_data, kpts, lengths = data
    # print(rgb_data[0].shape, depth_data[0].shape, lidar_data[0].shape, mmwave_data[0].shape, wifi_data[0].shape,kpts.shape, lengths.shape)
    # print(rgb_data.shape, depth_data.shape, lidar_data.shape, mmwave_data.shape, wifi_data.shape, kpts.shape, lengths.shape)
    mmwave_data, label = data
    print(mmwave_data.shape, label.shape)
    break

torch.Size([32, 66, 5]) torch.Size([32])


### Model

In [6]:
from HAR_mmwave_benchmark.mmwave_point_transformer_TD import *
# mmwave_data = torch.randn(32, 100, 5)
model = mmwave_PointTransformerReg()

# out = model(mmwave_data)
# print(out.shape)

### training 

In [7]:
def test(model, tensor_loader, criterion, device):
    model.eval()
    test_acc = 0
    test_loss = 0
    for data in tqdm(tensor_loader):
        rgb_data, labels = data
        inputs = rgb_data.to(device)
        labels.to(device)
        labels = labels.type(torch.LongTensor)
        outputs = model(inputs)
        outputs = outputs.type(torch.FloatTensor)
        outputs.to(device)
        loss = criterion(outputs,labels)
        predict_y = torch.argmax(outputs,dim=1).to(device)
        accuracy = (predict_y == labels.to(device)).sum().item()
        test_acc += accuracy
        test_loss += loss.item() * labels.size(0)
        outputs = outputs.detach().numpy()
        labels = labels.detach().numpy()
    test_acc = test_acc/len(tensor_loader.dataset)
    test_loss = test_loss/len(tensor_loader.dataset)
    print("validation accuracy:{:.4f}, loss:{:.5f}".format(float(test_acc),float(test_loss)))
    return test_acc

In [8]:
def train(model, train_loader, test_loader, num_epochs, learning_rate, criterion, device):
    optimizer = torch.optim.AdamW(model.parameters(), lr = learning_rate)
    # optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[20,40],gamma=0.1)
    parameter_dir = './HAR_mmwave_benchmark/mmwave_all_random_TD.pt'
    best_test_acc = 0
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        epoch_accuracy = 0
        # num_iter = 2000
        for i, data in enumerate(tqdm(train_loader)):
            # if i < num_iter:
            rgb_data, labels = data
            inputs = rgb_data.to(device)
            labels.to(device)
            labels = labels.type(torch.LongTensor)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            outputs = outputs.type(torch.FloatTensor)
            outputs.to(device)
            loss = criterion(outputs,labels)
            loss.backward()
            # print(length)
            # print("loss is ", loss.item())
            optimizer.step()
            
            epoch_loss += loss.item() * labels.size(0)
            predict_y = torch.argmax(outputs,dim=1).to(device)
            epoch_accuracy += (predict_y == labels.to(device)).sum().item()
            outputs = outputs.detach().numpy()
            labels = labels.detach().numpy()
            # else:
            #     break
            # print("epoch loss is ", epoch_loss)
        epoch_loss = epoch_loss/len(train_loader.dataset)
        epoch_accuracy = epoch_accuracy/len(train_loader.dataset)
        print('Epoch:{}, Accuracy:{:.4f},Loss:{:.9f}'.format(epoch+1, float(epoch_accuracy),float(epoch_loss)))
        if (epoch+1) % 5 == 0:
            test_acc = test(
                model=model,
                tensor_loader=test_loader,
                criterion = criterion,
                device= device
            )
            if test_acc >= best_test_acc:
                print(f"best test acuracy is:{test_acc}")
                best_test_acc = test_acc
                torch.save(model.state_dict(), parameter_dir)
        # scheduler.step()
    return

In [9]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device) 

cuda:0


In [10]:
criteria = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.load_state_dict(torch.load('./HAR_mmwave_benchmark/mmwave_all_random_TD.pt'))
model.to(device)
train(
    model=model, 
    train_loader= train_loader,
    test_loader= val_loader,    
    num_epochs= 50,
    learning_rate=1e-4,
    criterion = criteria,
    device=device 
    )

100%|██████████| 6495/6495 [2:36:00<00:00,  1.44s/it]  


Epoch:1, Accuracy:0.6266,Loss:1.148837399


100%|██████████| 6495/6495 [1:07:09<00:00,  1.61it/s]


Epoch:2, Accuracy:0.8005,Loss:0.608032959


100%|██████████| 6495/6495 [05:40<00:00, 19.07it/s]  


Epoch:3, Accuracy:0.8461,Loss:0.464686784


100%|██████████| 6495/6495 [03:39<00:00, 29.58it/s] 


Epoch:4, Accuracy:0.8739,Loss:0.380783196


100%|██████████| 6495/6495 [03:26<00:00, 31.38it/s]


Epoch:5, Accuracy:0.8919,Loss:0.326804005


100%|██████████| 1702/1702 [15:30<00:00,  1.83it/s] 


validation accuracy:0.8129, loss:0.61255
best test acuracy is:0.8128892399830985


100%|██████████| 6495/6495 [05:25<00:00, 19.94it/s]  


Epoch:6, Accuracy:0.9042,Loss:0.287193558


100%|██████████| 6495/6495 [03:38<00:00, 29.75it/s]


Epoch:7, Accuracy:0.9134,Loss:0.258674495


100%|██████████| 6495/6495 [03:34<00:00, 30.29it/s]


Epoch:8, Accuracy:0.9213,Loss:0.234686705


100%|██████████| 6495/6495 [03:37<00:00, 29.85it/s]


Epoch:9, Accuracy:0.9281,Loss:0.216477968


100%|██████████| 6495/6495 [03:34<00:00, 30.32it/s] 


Epoch:10, Accuracy:0.9334,Loss:0.198688185


100%|██████████| 1702/1702 [00:31<00:00, 53.64it/s]


validation accuracy:0.8315, loss:0.61042
best test acuracy is:0.831462531919975


100%|██████████| 6495/6495 [03:40<00:00, 29.40it/s] 


Epoch:11, Accuracy:0.9374,Loss:0.184232015


100%|██████████| 6495/6495 [03:32<00:00, 30.57it/s]


Epoch:12, Accuracy:0.9417,Loss:0.172277568


100%|██████████| 6495/6495 [03:49<00:00, 28.30it/s]


Epoch:13, Accuracy:0.9453,Loss:0.162062073


100%|██████████| 6495/6495 [04:41<00:00, 23.07it/s]  


Epoch:14, Accuracy:0.9484,Loss:0.152734391


100%|██████████| 6495/6495 [05:13<00:00, 20.73it/s]  


Epoch:15, Accuracy:0.9513,Loss:0.142914341


100%|██████████| 1702/1702 [00:31<00:00, 53.34it/s]


validation accuracy:0.8490, loss:0.55514
best test acuracy is:0.8490070361729098


100%|██████████| 6495/6495 [04:14<00:00, 25.51it/s]  


Epoch:16, Accuracy:0.9540,Loss:0.136463212


100%|██████████| 6495/6495 [05:57<00:00, 18.17it/s]  


Epoch:17, Accuracy:0.9563,Loss:0.129490132


100%|██████████| 6495/6495 [03:58<00:00, 27.22it/s]


Epoch:18, Accuracy:0.9579,Loss:0.124508866


100%|██████████| 6495/6495 [06:03<00:00, 17.89it/s]  


Epoch:19, Accuracy:0.9594,Loss:0.118706104


100%|██████████| 6495/6495 [04:27<00:00, 24.30it/s]


Epoch:20, Accuracy:0.9611,Loss:0.113222602


100%|██████████| 1702/1702 [00:41<00:00, 40.60it/s]


validation accuracy:0.8518, loss:0.59767
best test acuracy is:0.8517810886778241


100%|██████████| 6495/6495 [03:17<00:00, 32.86it/s]


Epoch:21, Accuracy:0.9625,Loss:0.108592562


100%|██████████| 6495/6495 [03:20<00:00, 32.46it/s]


Epoch:22, Accuracy:0.9646,Loss:0.104016474


100%|██████████| 6495/6495 [03:20<00:00, 32.41it/s]


Epoch:23, Accuracy:0.9658,Loss:0.099294642


100%|██████████| 6495/6495 [03:18<00:00, 32.66it/s]


Epoch:24, Accuracy:0.9670,Loss:0.095960492


100%|██████████| 6495/6495 [03:17<00:00, 32.81it/s]


Epoch:25, Accuracy:0.9683,Loss:0.092517472


100%|██████████| 1702/1702 [00:31<00:00, 53.76it/s]


validation accuracy:0.8597, loss:0.58974
best test acuracy is:0.8597358220197306


100%|██████████| 6495/6495 [03:17<00:00, 32.90it/s]


Epoch:26, Accuracy:0.9688,Loss:0.090419449


100%|██████████| 6495/6495 [03:18<00:00, 32.78it/s]


Epoch:27, Accuracy:0.9702,Loss:0.086292725


100%|██████████| 6495/6495 [03:18<00:00, 32.79it/s]


Epoch:28, Accuracy:0.9711,Loss:0.083570060


100%|██████████| 6495/6495 [03:18<00:00, 32.75it/s]


Epoch:29, Accuracy:0.9715,Loss:0.082344831


100%|██████████| 6495/6495 [03:17<00:00, 32.86it/s]


Epoch:30, Accuracy:0.9725,Loss:0.079613929


100%|██████████| 1702/1702 [00:31<00:00, 53.23it/s]


validation accuracy:0.8510, loss:0.67029


100%|██████████| 6495/6495 [03:17<00:00, 32.82it/s]


Epoch:31, Accuracy:0.9741,Loss:0.075596412


100%|██████████| 6495/6495 [03:18<00:00, 32.75it/s]


Epoch:32, Accuracy:0.9740,Loss:0.074904160


100%|██████████| 6495/6495 [03:17<00:00, 32.84it/s]


Epoch:33, Accuracy:0.9751,Loss:0.072422173


100%|██████████| 6495/6495 [03:55<00:00, 27.57it/s]  


Epoch:34, Accuracy:0.9758,Loss:0.069814361


100%|██████████| 6495/6495 [05:27<00:00, 19.86it/s]  


Epoch:35, Accuracy:0.9764,Loss:0.068709844


100%|██████████| 1702/1702 [00:41<00:00, 40.56it/s]


validation accuracy:0.8499, loss:0.66513


100%|██████████| 6495/6495 [04:06<00:00, 26.39it/s]  


Epoch:36, Accuracy:0.9773,Loss:0.065848449


100%|██████████| 6495/6495 [03:18<00:00, 32.76it/s]


Epoch:37, Accuracy:0.9772,Loss:0.065767680


100%|██████████| 6495/6495 [03:18<00:00, 32.74it/s]


Epoch:38, Accuracy:0.9780,Loss:0.064120363


100%|██████████| 6495/6495 [03:18<00:00, 32.79it/s]


Epoch:39, Accuracy:0.9787,Loss:0.061755427


100%|██████████| 6495/6495 [03:18<00:00, 32.77it/s]


Epoch:40, Accuracy:0.9790,Loss:0.061031962


100%|██████████| 1702/1702 [00:32<00:00, 52.97it/s]


validation accuracy:0.8578, loss:0.64905


100%|██████████| 6495/6495 [03:17<00:00, 32.92it/s]


Epoch:41, Accuracy:0.9795,Loss:0.059994157


100%|██████████| 6495/6495 [03:18<00:00, 32.69it/s]


Epoch:42, Accuracy:0.9800,Loss:0.058120550


100%|██████████| 6495/6495 [03:17<00:00, 32.91it/s]


Epoch:43, Accuracy:0.9806,Loss:0.055285146


100%|██████████| 6495/6495 [03:18<00:00, 32.78it/s]


Epoch:44, Accuracy:0.9806,Loss:0.055886112


100%|██████████| 6495/6495 [03:17<00:00, 32.85it/s]


Epoch:45, Accuracy:0.9813,Loss:0.054941986


100%|██████████| 1702/1702 [00:31<00:00, 53.72it/s]


validation accuracy:0.8591, loss:0.68453


100%|██████████| 6495/6495 [03:17<00:00, 32.86it/s]


Epoch:46, Accuracy:0.9815,Loss:0.053724266


100%|██████████| 6495/6495 [03:18<00:00, 32.76it/s]


Epoch:47, Accuracy:0.9820,Loss:0.052154583


100%|██████████| 6495/6495 [03:17<00:00, 32.83it/s]


Epoch:48, Accuracy:0.9821,Loss:0.051604681


100%|██████████| 6495/6495 [03:18<00:00, 32.71it/s]


Epoch:49, Accuracy:0.9825,Loss:0.050721025


100%|██████████| 6495/6495 [03:17<00:00, 32.83it/s]


Epoch:50, Accuracy:0.9827,Loss:0.050256905


100%|██████████| 1702/1702 [00:31<00:00, 53.19it/s]

validation accuracy:0.8568, loss:0.69956





In [8]:
criteria = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load('./HAR_lidar_benchmark/lidar_all_random.pt'))
model.to(device)
test(
        model=model,
        tensor_loader=val_loader,
        criterion = criteria,
        device= device
    )

100%|██████████| 4010/4010 [10:04<00:00,  6.63it/s]

validation accuracy:0.7845, loss:0.75000





0.7845429604688864

In [11]:
parameter_dir = './HAR_lidar_benchmark/lidar_all_random.pt'
torch.save(model.state_dict(), parameter_dir)