In [1]:
import numpy as np
import glob
import scipy.io as sio
import torch
from torch import nn
import csv
import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import yaml
from evaluate import error
from mmfi import make_dataset, make_dataloader

### Loda Data

In [2]:
dir = 'd:\\Data\\MMFi_Dataset\\E01\\S02\\A01\\depth\\frame007.png'
_mod, _frame = os.path.split(dir)
print(_mod, _frame)

d:\Data\MMFi_Dataset\E01\S02\A01\depth frame007.png


In [2]:
dataset_root = 'd:\Data\My_MMFi_Data\MMFi_Dataset'
# xian zai shi yong de shi Radar_Fused
with open('config.yaml', 'r') as fd:
    config = yaml.load(fd, Loader=yaml.FullLoader)

train_dataset, val_dataset = make_dataset(dataset_root, config)

S02 ['A01', 'A02', 'A03', 'A04', 'A06', 'A08', 'A09', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A18', 'A19', 'A20', 'A21', 'A22', 'A23', 'A24', 'A26']
S03 ['A01', 'A02', 'A03', 'A04', 'A05', 'A06', 'A07', 'A08', 'A09', 'A10', 'A11', 'A12', 'A15', 'A16', 'A17', 'A18', 'A21', 'A22', 'A23', 'A24', 'A25', 'A26', 'A27']
S05 ['A01', 'A02', 'A03', 'A04', 'A05', 'A06', 'A07', 'A09', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A19', 'A20', 'A21', 'A24', 'A25']
S06 ['A01', 'A03', 'A04', 'A06', 'A07', 'A08', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15', 'A18', 'A21', 'A22', 'A23', 'A24', 'A25', 'A26', 'A27']
S08 ['A01', 'A02', 'A04', 'A05', 'A06', 'A07', 'A08', 'A09', 'A10', 'A11', 'A13', 'A14', 'A17', 'A19', 'A20', 'A21', 'A22', 'A23', 'A24', 'A25', 'A26', 'A27']
S09 ['A01', 'A07', 'A08', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15', 'A17', 'A18', 'A20', 'A21', 'A24', 'A25', 'A27']
S11 ['A01', 'A02', 'A03', 'A05', 'A06', 'A08', 'A10', 'A11', 'A12', 'A13', 'A16', 'A17', 'A18', 'A19', 'A21', 'A23',

In [3]:
def collate_fn_padd(batch):
    '''
    Padds batch of variable length

    note: it converts things ToTensor manually here since the ToTensor transform
    assume it takes in images rather than arbitrary tensors.

    dict_keys(['modality', 'scene', 'subject', 'action', 'idx', 'output', 
    'input_rgb', 'input_depth', 'input_lidar', 'input_mmwave'])
    '''
    ## get sequence lengths
    # for t in batch:
    #     print(t.keys())
    #     print(a)
    # #     # print(t[0].shape,t[1].shape)
    kpts = []
    [kpts.append(np.array(t['output'])) for t in batch]
    kpts = torch.FloatTensor(np.array(kpts))

    lengths = torch.tensor([t['input_rgb'].shape[0] for t in batch ])

    # rgb
    rgb_data = np.array([(t['input_rgb']) for t in batch ])
    rgb_data = torch.FloatTensor(rgb_data).permute(0,3,1,2)

    # # depth
    # depth_data = np.array([(t['input_depth']) for t in batch ])
    # depth_data = torch.FloatTensor(depth_data).permute(0,3,1,2)

    # # mmwave
    # ## padd
    # mmwave_data = [torch.Tensor(t['input_mmwave']) for t in batch ]
    # mmwave_data = torch.nn.utils.rnn.pad_sequence(mmwave_data)
    # ## compute mask
    # mmwave_data = mmwave_data.permute(1,0,2)

    # # lidar
    # ## padd
    # lidar_data = [torch.Tensor(t['input_lidar']) for t in batch ]
    # lidar_data = torch.nn.utils.rnn.pad_sequence(lidar_data)
    # ## compute mask
    # lidar_data = lidar_data.permute(1,0,2)

    # # wifi-csi
    # wifi_data = np.array([(t['input_wifi-csi']) for t in batch ])
    # wifi_data = torch.FloatTensor(wifi_data)

    # return rgb_data, depth_data, lidar_data, mmwave_data, wifi_data, kpts, lengths
    return rgb_data, kpts

In [4]:
rng_generator = torch.manual_seed(config['init_rand_seed'])
train_loader = make_dataloader(train_dataset, is_training=True, generator=rng_generator, **config['loader'], collate_fn = collate_fn_padd)
val_loader = make_dataloader(val_dataset, is_training=False, generator=rng_generator, **config['loader'], collate_fn = collate_fn_padd)

In [6]:
for i, data in enumerate(train_loader):
    # rgb_data, depth_data, lidar_data, mmwave_data, wifi_data, kpts, lengths = data
    # print(rgb_data[0].shape, depth_data[0].shape, lidar_data[0].shape, mmwave_data[0].shape, wifi_data[0].shape,kpts.shape, lengths.shape)
    # print(rgb_data.shape, depth_data.shape, lidar_data.shape, mmwave_data.shape, wifi_data.shape, kpts.shape, lengths.shape)
    rgb_data, label = data
    print(rgb_data.shape, label.shape)
    break

torch.Size([32, 3, 480, 640]) torch.Size([32, 17, 3])


### Model

In [19]:
from RGB_benchmark.rgb_ResNet18.RGB_ResNet import *

model = RGB_ResNet50()

### training 

In [20]:
def test(model, tensor_loader, criterion1, criterion2, device):
    model.eval()
    test_mpjpe = 0
    test_pampjpe = 0
    test_mse = 0
    for data in tqdm(tensor_loader):
        depth_data, kpts = data
        inputs = depth_data.to(device)
        kpts.to(device)
        labels = kpts.type(torch.FloatTensor)
        outputs = model(inputs)
        outputs = outputs.type(torch.FloatTensor)
        outputs.to(device)
        test_mse += criterion1(outputs,labels).item() * inputs.size(0)

        outputs = outputs.detach().numpy()
        labels = labels.detach().numpy()
        
        mpjpe, pampjpe = criterion2(outputs,labels)
        test_mpjpe += mpjpe.item() * inputs.size(0)
        test_pampjpe += pampjpe.item() * inputs.size(0)
    test_mpjpe = test_mpjpe/len(tensor_loader.dataset)
    test_pampjpe = test_pampjpe/len(tensor_loader.dataset)
    test_mse = test_mse/len(tensor_loader.dataset)
    print("mse: {:.8f}, mpjpe: {:.8f}, pampjpe: {:.8f}".format(float(test_mse), float(test_mpjpe),float(test_pampjpe)))
    return test_mpjpe

In [21]:
def train(model, train_loader, test_loader, num_epochs, learning_rate, train_criterion, test_criterion, device):
    optimizer = torch.optim.AdamW(model.parameters(), lr = learning_rate)
    # optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[20,40],gamma=0.1)
    parameter_dir = './RGB_benchmark/rgb_ResNet18/RGB_Resnet50.pt'
    best_test_mpjpe = 100
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        epoch_accuracy = 0
        for data in tqdm(train_loader):
            depth_data, kpts = data
            inputs = depth_data.to(device)
            labels = kpts.to(device)
            labels = labels.type(torch.FloatTensor)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            outputs = outputs.to(device)
            outputs = outputs.type(torch.FloatTensor)
            loss = train_criterion(outputs,labels)
            loss.backward()
            # print(length)
            # print("loss is ", loss.item())
            optimizer.step()
            
            epoch_loss += loss.item() * inputs.size(0)
            # print("epoch loss is ", epoch_loss)
        epoch_loss = epoch_loss/len(train_loader.dataset)
        print('Epoch: {}, Loss: {:.8f}'.format(epoch, epoch_loss))
        if (epoch+1) % 5 == 0:
            test_mpjpe = test(
                model=model,
                tensor_loader=test_loader,
                criterion1 = train_criterion,
                criterion2 = test_criterion,
                device= device
            )
            if test_mpjpe <= best_test_mpjpe:
                print(f"best test mpjpe is:{test_mpjpe}")
                best_test_mpjpe = test_mpjpe
                torch.save(model.state_dict(), parameter_dir)
        # scheduler.step()
    return

In [22]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device) 

cuda:0


In [23]:
train_criterion = nn.MSELoss()
test_criterion = error
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.load_state_dict(torch.load('./RGB_benchmark/rgb_ResNet18/RGB_Resnet18.pt'))
model.to(device)
train(
    model=model, 
    train_loader= train_loader,
    test_loader= val_loader,    
    num_epochs= 50,
    learning_rate=1e-3,
    train_criterion = train_criterion,
    test_criterion = test_criterion,
    device=device 
    )

100%|██████████| 8019/8019 [4:58:13<00:00,  2.23s/it]   


Epoch: 0, Loss: 0.00964679


100%|██████████| 8019/8019 [4:39:23<00:00,  2.09s/it]   


Epoch: 1, Loss: 0.00269378


100%|██████████| 8019/8019 [4:36:00<00:00,  2.07s/it]   


Epoch: 2, Loss: 0.00172173


100%|██████████| 8019/8019 [4:32:29<00:00,  2.04s/it]   


Epoch: 3, Loss: 0.00130108


100%|██████████| 8019/8019 [4:32:10<00:00,  2.04s/it]   


Epoch: 4, Loss: 0.00118844


100%|██████████| 2005/2005 [48:01<00:00,  1.44s/it] 


mse: 0.04068223, mpjpe: 0.30138793, pampjpe: 0.13811177
best test mpjpe is:0.3013879332790382


100%|██████████| 8019/8019 [5:23:17<00:00,  2.42s/it]   


Epoch: 5, Loss: 0.00098241


100%|██████████| 8019/8019 [4:34:22<00:00,  2.05s/it]  


Epoch: 6, Loss: 0.00101449


100%|██████████| 8019/8019 [4:34:06<00:00,  2.05s/it]   


Epoch: 7, Loss: 0.00185997


100%|██████████| 8019/8019 [4:38:56<00:00,  2.09s/it]   


Epoch: 8, Loss: 0.00100563


100%|██████████| 8019/8019 [4:19:49<00:00,  1.94s/it]   


Epoch: 9, Loss: 0.00088958


100%|██████████| 2005/2005 [1:15:36<00:00,  2.26s/it]


mse: 0.06352119, mpjpe: 0.40035685, pampjpe: 0.14821081


100%|██████████| 8019/8019 [5:28:05<00:00,  2.45s/it]   


Epoch: 10, Loss: 0.00083302


100%|██████████| 8019/8019 [4:48:02<00:00,  2.16s/it]   


Epoch: 11, Loss: 0.00080005


100%|██████████| 8019/8019 [4:49:54<00:00,  2.17s/it]   


Epoch: 12, Loss: 0.00074672


100%|██████████| 8019/8019 [4:45:50<00:00,  2.14s/it]   


Epoch: 13, Loss: 0.00071830


100%|██████████| 8019/8019 [4:41:29<00:00,  2.11s/it]   


Epoch: 14, Loss: 0.00069410


100%|██████████| 2005/2005 [1:17:54<00:00,  2.33s/it]


mse: 0.05702450, mpjpe: 0.36301359, pampjpe: 0.16099268


100%|██████████| 8019/8019 [5:35:19<00:00,  2.51s/it]   


Epoch: 15, Loss: 0.00068262


100%|██████████| 8019/8019 [4:45:36<00:00,  2.14s/it]   


Epoch: 16, Loss: 0.00066559


100%|██████████| 8019/8019 [4:51:10<00:00,  2.18s/it]   


Epoch: 17, Loss: 0.00065702


100%|██████████| 8019/8019 [4:52:31<00:00,  2.19s/it]   


Epoch: 18, Loss: 0.00064002


100%|██████████| 8019/8019 [4:37:35<00:00,  2.08s/it]  


Epoch: 19, Loss: 0.00062739


100%|██████████| 2005/2005 [1:10:24<00:00,  2.11s/it]


mse: 0.45956184, mpjpe: 1.15213124, pampjpe: 0.19156851


100%|██████████| 8019/8019 [5:25:17<00:00,  2.43s/it]   


Epoch: 20, Loss: 0.00062520


100%|██████████| 8019/8019 [4:36:08<00:00,  2.07s/it]   


Epoch: 21, Loss: 0.00061900


100%|██████████| 8019/8019 [4:36:20<00:00,  2.07s/it]  


Epoch: 22, Loss: 0.00061508


100%|██████████| 8019/8019 [4:35:25<00:00,  2.06s/it]   


Epoch: 23, Loss: 0.00060704


100%|██████████| 8019/8019 [3:45:21<00:00,  1.69s/it]  


Epoch: 24, Loss: 0.00060335


100%|██████████| 2005/2005 [28:04<00:00,  1.19it/s]


mse: 0.43183633, mpjpe: 1.10675341, pampjpe: 0.20749692


100%|██████████| 8019/8019 [3:01:00<00:00,  1.35s/it]  


Epoch: 25, Loss: 0.00060003


100%|██████████| 8019/8019 [2:49:01<00:00,  1.26s/it]  


Epoch: 26, Loss: 0.00058681


100%|██████████| 8019/8019 [2:52:39<00:00,  1.29s/it]  


Epoch: 27, Loss: 0.00059806


100%|██████████| 8019/8019 [2:54:22<00:00,  1.30s/it]  


Epoch: 28, Loss: 0.00059475


100%|██████████| 8019/8019 [2:56:02<00:00,  1.32s/it]  


Epoch: 29, Loss: 0.00058091


100%|██████████| 2005/2005 [27:51<00:00,  1.20it/s]


mse: 0.05066622, mpjpe: 0.34670687, pampjpe: 0.15144829


100%|██████████| 8019/8019 [3:13:40<00:00,  1.45s/it]  


Epoch: 30, Loss: 0.00057391


100%|██████████| 8019/8019 [2:57:08<00:00,  1.33s/it]  


Epoch: 31, Loss: 0.00056215


100%|██████████| 8019/8019 [3:01:39<00:00,  1.36s/it]  


Epoch: 32, Loss: 0.00057507


100%|██████████| 8019/8019 [3:03:50<00:00,  1.38s/it]  


Epoch: 33, Loss: 0.00057088


100%|██████████| 8019/8019 [3:03:08<00:00,  1.37s/it]  


Epoch: 34, Loss: 0.00055848


100%|██████████| 2005/2005 [28:45<00:00,  1.16it/s]


mse: 0.07316065, mpjpe: 0.42770590, pampjpe: 0.15525877


100%|██████████| 8019/8019 [3:21:11<00:00,  1.51s/it]  


Epoch: 35, Loss: 0.00055530


100%|██████████| 8019/8019 [3:06:08<00:00,  1.39s/it]  


Epoch: 36, Loss: 0.00055550


100%|██████████| 8019/8019 [3:08:25<00:00,  1.41s/it]  


Epoch: 37, Loss: 0.00055336


100%|██████████| 8019/8019 [3:11:01<00:00,  1.43s/it]  


Epoch: 38, Loss: 0.00054867


100%|██████████| 8019/8019 [3:13:12<00:00,  1.45s/it]  


Epoch: 39, Loss: 0.00056164


100%|██████████| 2005/2005 [29:27<00:00,  1.13it/s]


mse: 0.20342006, mpjpe: 0.65906470, pampjpe: 0.25772225


100%|██████████| 8019/8019 [3:36:10<00:00,  1.62s/it]  


Epoch: 40, Loss: 0.00056649


100%|██████████| 8019/8019 [3:36:18<00:00,  1.62s/it]  


Epoch: 41, Loss: 0.00059600


  3%|▎         | 207/8019 [04:56<3:20:16,  1.54s/it]

In [17]:
parameter_dir = './RGB_benchmark/rgb_ResNet18/RGB_Resnet18.pt'
torch.save(model.state_dict(), parameter_dir)