In [1]:
import numpy as np
import glob
import scipy.io as sio
import torch
from torch import nn
import csv
import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import yaml
from evaluate import error
from mmfi import make_dataset, make_dataloader

### Loda Data

In [2]:
dir = 'd:\\Data\\MMFi_Dataset\\E01\\S02\\A01\\depth\\frame007.png'
_mod, _frame = os.path.split(dir)
print(_mod, _frame)

d:\Data\MMFi_Dataset\E01\S02\A01\depth frame007.png


In [2]:
dataset_root = 'd:\Data\My_MMFi_Data\MMFi_Dataset'
# xian zai shi yong de shi Radar_Fused
with open('config.yaml', 'r') as fd:
    config = yaml.load(fd, Loader=yaml.FullLoader)

train_dataset, val_dataset = make_dataset(dataset_root, config)

S02 ['A01', 'A02', 'A03', 'A04', 'A06', 'A08', 'A09', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A18', 'A19', 'A20', 'A21', 'A22', 'A23', 'A24', 'A26']
S03 ['A01', 'A02', 'A03', 'A04', 'A05', 'A06', 'A07', 'A08', 'A09', 'A10', 'A11', 'A12', 'A15', 'A16', 'A17', 'A18', 'A21', 'A22', 'A23', 'A24', 'A25', 'A26', 'A27']
S05 ['A01', 'A02', 'A03', 'A04', 'A05', 'A06', 'A07', 'A09', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A19', 'A20', 'A21', 'A24', 'A25']
S06 ['A01', 'A03', 'A04', 'A06', 'A07', 'A08', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15', 'A18', 'A21', 'A22', 'A23', 'A24', 'A25', 'A26', 'A27']
S08 ['A01', 'A02', 'A04', 'A05', 'A06', 'A07', 'A08', 'A09', 'A10', 'A11', 'A13', 'A14', 'A17', 'A19', 'A20', 'A21', 'A22', 'A23', 'A24', 'A25', 'A26', 'A27']
S09 ['A01', 'A07', 'A08', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15', 'A17', 'A18', 'A20', 'A21', 'A24', 'A25', 'A27']
S11 ['A01', 'A02', 'A03', 'A05', 'A06', 'A08', 'A10', 'A11', 'A12', 'A13', 'A16', 'A17', 'A18', 'A19', 'A21', 'A23',

In [3]:
def collate_fn_padd(batch):
    '''
    Padds batch of variable length

    note: it converts things ToTensor manually here since the ToTensor transform
    assume it takes in images rather than arbitrary tensors.

    dict_keys(['modality', 'scene', 'subject', 'action', 'idx', 'output', 
    'input_rgb', 'input_depth', 'input_lidar', 'input_mmwave'])
    '''
    ## get sequence lengths
    # for t in batch:
    #     print(t.keys())
    #     print(a)
    # #     # print(t[0].shape,t[1].shape)
    kpts = []
    [kpts.append(np.array(t['output'])) for t in batch]
    kpts = torch.FloatTensor(np.array(kpts))

    lengths = torch.tensor([t['input_depth'].shape[0] for t in batch ])

    # # rgb
    # rgb_data = np.array([(t['input_rgb']) for t in batch ])
    # rgb_data = torch.FloatTensor(rgb_data).permute(0,3,1,2)

    # depth
    depth_data = np.array([(t['input_depth']) for t in batch ])
    depth_data = torch.FloatTensor(depth_data).permute(0,3,1,2)

    # # mmwave
    # ## padd
    # mmwave_data = [torch.Tensor(t['input_mmwave']) for t in batch ]
    # mmwave_data = torch.nn.utils.rnn.pad_sequence(mmwave_data)
    # ## compute mask
    # mmwave_data = mmwave_data.permute(1,0,2)

    # # lidar
    # ## padd
    # lidar_data = [torch.Tensor(t['input_lidar']) for t in batch ]
    # lidar_data = torch.nn.utils.rnn.pad_sequence(lidar_data)
    # ## compute mask
    # lidar_data = lidar_data.permute(1,0,2)

    # # wifi-csi
    # wifi_data = np.array([(t['input_wifi-csi']) for t in batch ])
    # wifi_data = torch.FloatTensor(wifi_data)

    # return rgb_data, depth_data, lidar_data, mmwave_data, wifi_data, kpts, lengths
    return depth_data, kpts

In [4]:
rng_generator = torch.manual_seed(config['init_rand_seed'])
train_loader = make_dataloader(train_dataset, is_training=True, generator=rng_generator, **config['loader'], collate_fn = collate_fn_padd)
val_loader = make_dataloader(val_dataset, is_training=False, generator=rng_generator, **config['loader'], collate_fn = collate_fn_padd)

In [14]:
print(len(train_loader), len(val_loader))
for i, data in enumerate(train_loader):
    # rgb_data, depth_data, lidar_data, mmwave_data, wifi_data, kpts, lengths = data
    # print(rgb_data[0].shape, depth_data[0].shape, lidar_data[0].shape, mmwave_data[0].shape, wifi_data[0].shape,kpts.shape, lengths.shape)
    # print(rgb_data.shape, depth_data.shape, lidar_data.shape, mmwave_data.shape, wifi_data.shape, kpts.shape, lengths.shape)
    depth_data, label = data
    print(depth_data.shape, label.shape)
    break

8019 2005
torch.Size([32, 3, 480, 640]) torch.Size([32, 17, 3])


### Model

In [12]:
from depth_benchmark.depth_ResNet18 import *

model = Depth_ResNet18()

# x = torch.randn(32, 100, 5, 3)
# all_domain_kpts, all_domain_outputs = model(x)
# print(all_domain_kpts[0].shape, all_domain_outputs[1][4].shape)

### training 

In [15]:
def test(model, tensor_loader, criterion1, criterion2, device):
    model.eval()
    test_mpjpe = 0
    test_pampjpe = 0
    test_mse = 0
    for i, data in enumerate(tensor_loader):
        depth_data, kpts = data
        inputs = depth_data.to(device)
        kpts.to(device)
        labels = kpts.type(torch.FloatTensor)
        outputs = model(inputs)
        outputs = outputs.type(torch.FloatTensor)
        outputs.to(device)
        test_mse += criterion1(outputs,labels).item() * inputs.size(0)

        outputs = outputs.detach().numpy()
        labels = labels.detach().numpy()
        
        mpjpe, pampjpe = criterion2(outputs,labels)
        test_mpjpe += mpjpe.item() * inputs.size(0)
        test_pampjpe += pampjpe.item() * inputs.size(0)
        if i > 10:
            break
    test_mpjpe = test_mpjpe/len(tensor_loader.dataset)
    test_pampjpe = test_pampjpe/len(tensor_loader.dataset)
    test_mse = test_mse/len(tensor_loader.dataset)
    print("mse: {:.8f}, mpjpe: {:.8f}, pampjpe: {:.8f}".format(float(test_mse), float(test_mpjpe),float(test_pampjpe)))
    return test_mpjpe

In [16]:
def train(model, train_loader, test_loader, num_epochs, learning_rate, train_criterion, test_criterion, device):
    optimizer = torch.optim.AdamW(model.parameters(), lr = learning_rate)
    # optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[20,40],gamma=0.1)
    parameter_dir = './depth_benchmark/depth_Resnet18.pt'
    best_test_mpjpe = 100
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        epoch_accuracy = 0
        for i, data in enumerate(train_loader):
            depth_data, kpts = data
            inputs = depth_data.to(device)
            labels = kpts.to(device)
            labels = labels.type(torch.FloatTensor)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            outputs = outputs.to(device)
            outputs = outputs.type(torch.FloatTensor)
            loss = train_criterion(outputs,labels)
            loss.backward()
            # print(length)
            # print("loss is ", loss.item())
            optimizer.step()
            
            epoch_loss += loss.item() * inputs.size(0)
            # print("epoch loss is ", epoch_loss)
            if i > 10:
                break   
        epoch_loss = epoch_loss/len(train_loader.dataset)
        print('Epoch: {}, Loss: {:.8f}'.format(epoch, epoch_loss))
        if (epoch+1) % 5 == 0:
            test_mpjpe = test(
                model=model,
                tensor_loader=test_loader,
                criterion1 = train_criterion,
                criterion2 = test_criterion,
                device= device
            )
            if test_mpjpe <= best_test_mpjpe:
                print(f"best test mpjpe is:{test_mpjpe}")
                best_test_mpjpe = test_mpjpe
                torch.save(model.state_dict(), parameter_dir)
        # scheduler.step()
    return

In [17]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device) 

cuda:0


In [18]:
train_criterion = nn.MSELoss()
test_criterion = error
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.load_state_dict(torch.load('./pre-train_weights/lidar_p1_random.pt'))
model.to(device)
train(
    model=model, 
    train_loader= train_loader,
    test_loader= val_loader,    
    num_epochs= 50,
    learning_rate=1e-3,
    train_criterion = train_criterion,
    test_criterion = test_criterion,
    device=device 
    )

Epoch: 0, Loss: 0.00190598
Epoch: 1, Loss: 0.00015393
Epoch: 2, Loss: 0.00007392
Epoch: 3, Loss: 0.00005242
Epoch: 4, Loss: 0.00004551
mse: 0.00013087, mpjpe: 0.00143339, pampjpe: 0.00070077
best test mpjpe is:0.0014333886121451847
Epoch: 5, Loss: 0.00004022


KeyboardInterrupt: 