In [1]:
import numpy as np
import glob
import scipy.io as sio
import torch
from torch import nn
import csv
import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import yaml
from evaluate import error
from mmfi_mmwave import make_dataset, make_dataloader

### Loda Data

In [2]:
dir = 'd:\\Data\\MMFi_Dataset\\E01\\S02\\A01\\depth\\frame007.png'
_mod, _frame = os.path.split(dir)
print(_mod, _frame)

d:\Data\MMFi_Dataset\E01\S02\A01\depth frame007.png


In [2]:
dataset_root = 'd:\Data\My_MMFi_Data\MMFi_Dataset'
# xian zai shi yong de shi Radar_Fused
with open('config_copy.yaml', 'r') as fd:
    config = yaml.load(fd, Loader=yaml.FullLoader)

train_dataset, val_dataset = make_dataset(dataset_root, config)

S02 ['A01', 'A02', 'A03', 'A04', 'A06', 'A08', 'A09', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A18', 'A19', 'A20', 'A21', 'A22', 'A23', 'A24', 'A26']
S03 ['A01', 'A02', 'A03', 'A04', 'A05', 'A06', 'A07', 'A08', 'A09', 'A10', 'A11', 'A12', 'A15', 'A16', 'A17', 'A18', 'A21', 'A22', 'A23', 'A24', 'A25', 'A26', 'A27']
S05 ['A01', 'A02', 'A03', 'A04', 'A05', 'A06', 'A07', 'A09', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A19', 'A20', 'A21', 'A24', 'A25']
S06 ['A01', 'A03', 'A04', 'A06', 'A07', 'A08', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15', 'A18', 'A21', 'A22', 'A23', 'A24', 'A25', 'A26', 'A27']
S08 ['A01', 'A02', 'A04', 'A05', 'A06', 'A07', 'A08', 'A09', 'A10', 'A11', 'A13', 'A14', 'A17', 'A19', 'A20', 'A21', 'A22', 'A23', 'A24', 'A25', 'A26', 'A27']
S09 ['A01', 'A07', 'A08', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15', 'A17', 'A18', 'A20', 'A21', 'A24', 'A25', 'A27']
S11 ['A01', 'A02', 'A03', 'A05', 'A06', 'A08', 'A10', 'A11', 'A12', 'A13', 'A16', 'A17', 'A18', 'A19', 'A21', 'A23',

In [3]:
def collate_fn_padd(batch):
    '''
    Padds batch of variable length

    note: it converts things ToTensor manually here since the ToTensor transform
    assume it takes in images rather than arbitrary tensors.
    '''
    
    input_modalities = list(batch[0].keys())[-1]
    ## get kpts
    kpts = []
    [kpts.append(np.array(t['output'])) for t in batch]
    kpts = torch.FloatTensor(np.array(kpts))
    ## get sequence lengths
    lengths = torch.tensor([t[input_modalities].shape[0] for t in batch ])
    ## padd
    batch = [torch.Tensor(t[input_modalities]) for t in batch ]
    batch = torch.nn.utils.rnn.pad_sequence(batch)
    ## compute mask
    batch = batch.permute(1,0,2)
    mask = (batch != 0)

    return batch, kpts, lengths, mask

In [4]:
rng_generator = torch.manual_seed(config['init_rand_seed'])
train_loader = make_dataloader(train_dataset, is_training=True, generator=rng_generator, **config['loader'], collate_fn = collate_fn_padd)
val_loader = make_dataloader(val_dataset, is_training=False, generator=rng_generator, **config['loader'], collate_fn = collate_fn_padd)

In [5]:
for i, data in enumerate(train_loader):
    # rgb_data, depth_data, lidar_data, mmwave_data, wifi_data, kpts, lengths = data
    # print(rgb_data[0].shape, depth_data[0].shape, lidar_data[0].shape, mmwave_data[0].shape, wifi_data[0].shape,kpts.shape, lengths.shape)
    # print(rgb_data.shape, depth_data.shape, lidar_data.shape, mmwave_data.shape, wifi_data.shape, kpts.shape, lengths.shape)
    data, label,_,_ = data
    print(data.shape, label.shape)
    break

torch.Size([32, 66, 5]) torch.Size([32, 17, 3])


### Model

In [12]:
from mmwave_point_transformer_TD import *

model = mmwave_PointTransformerReg()

In [13]:
a = torch.randn(32, 43, 5)
b = model(a)
print(b.shape)

torch.Size([32, 17, 3])


### training 

In [8]:
def test(model, tensor_loader, criterion1, criterion2, device):
    model.eval()
    test_mpjpe = 0
    test_pampjpe = 0
    test_mse = 0
    for data in tqdm(tensor_loader):
        inputs, labels, _, _ = data
        inputs = inputs.to(device)
        labels.to(device)
        labels = labels.type(torch.FloatTensor)
        outputs = model(inputs)
        outputs = outputs.type(torch.FloatTensor)
        outputs.to(device)
        test_mse += criterion1(outputs,labels).item() * inputs.size(0)

        outputs = outputs.detach().numpy()
        labels = labels.detach().numpy()
        
        mpjpe, pampjpe = criterion2(outputs,labels)
        test_mpjpe += mpjpe.item() * inputs.size(0)
        test_pampjpe += pampjpe.item() * inputs.size(0)
    test_mpjpe = test_mpjpe/len(tensor_loader.dataset)
    test_pampjpe = test_pampjpe/len(tensor_loader.dataset)
    test_mse = test_mse/len(tensor_loader.dataset)
    print("mse: {:.8f}, mpjpe: {:.8f}, pampjpe: {:.8f}".format(float(test_mse), float(test_mpjpe),float(test_pampjpe)))
    return test_mpjpe

In [14]:
def train(model, train_loader, test_loader, num_epochs, learning_rate, train_criterion, test_criterion, device):
    optimizer = torch.optim.AdamW(model.parameters(), lr = learning_rate)
    # optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[20,40],gamma=0.1)
    parameter_dir = './mmwave_benchmark/mmwave_all_random_TD(1).pt'
    best_test_mpjpe = 100
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        epoch_accuracy = 0
        for data in tqdm(train_loader):
            inputs, labels, length, _ = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            labels = labels.type(torch.FloatTensor)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            outputs = outputs.to(device)
            outputs = outputs.type(torch.FloatTensor)
            loss = train_criterion(outputs,labels)
            loss.backward()
            # print(length)
            # print("loss is ", loss.item())
            optimizer.step()
            
            epoch_loss += loss.item() * inputs.size(0)
            # print("epoch loss is ", epoch_loss)
        epoch_loss = epoch_loss/len(train_loader.dataset)
        print('Epoch: {}, Loss: {:.8f}'.format(epoch, epoch_loss))
        if (epoch+1) % 5 == 0:
            test_mpjpe = test(
                model=model,
                tensor_loader=test_loader,
                criterion1 = train_criterion,
                criterion2 = test_criterion,
                device= device
            )
            if test_mpjpe <= best_test_mpjpe:
                print(f"best test mpjpe is:{test_mpjpe}")
                best_test_mpjpe = test_mpjpe
                torch.save(model.state_dict(), parameter_dir)
        # scheduler.step()
    return

In [2]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device) 

cuda:0


In [16]:
train_criterion = nn.MSELoss()
test_criterion = error
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.load_state_dict(torch.load('./RGB_benchmark/rgb_ResNet18/RGB_Resnet18.pt'))
model.to(device)
train(
    model=model, 
    train_loader= train_loader,
    test_loader= val_loader,    
    num_epochs= 50,
    learning_rate=1e-3,
    train_criterion = train_criterion,
    test_criterion = test_criterion,
    device=device 
    )

  0%|          | 0/6495 [00:00<?, ?it/s]

100%|██████████| 6495/6495 [07:02<00:00, 15.39it/s]  


Epoch: 0, Loss: 0.01591376


100%|██████████| 6495/6495 [03:21<00:00, 32.23it/s] 


Epoch: 1, Loss: 0.00859070


100%|██████████| 6495/6495 [03:11<00:00, 33.94it/s]


Epoch: 2, Loss: 0.00708109


100%|██████████| 6495/6495 [03:11<00:00, 33.90it/s]


Epoch: 3, Loss: 0.00634870


100%|██████████| 6495/6495 [03:11<00:00, 33.90it/s]


Epoch: 4, Loss: 0.00577423


100%|██████████| 1702/1702 [00:48<00:00, 34.89it/s]


mse: 0.00622417, mpjpe: 0.11472560, pampjpe: 0.06042530
best test mpjpe is:0.11472559812785546


100%|██████████| 6495/6495 [03:07<00:00, 34.56it/s]


Epoch: 5, Loss: 0.00535929


100%|██████████| 6495/6495 [03:08<00:00, 34.43it/s]


Epoch: 6, Loss: 0.00504343


100%|██████████| 6495/6495 [03:08<00:00, 34.54it/s]


Epoch: 7, Loss: 0.00480117


100%|██████████| 6495/6495 [03:08<00:00, 34.40it/s]


Epoch: 8, Loss: 0.00461437


100%|██████████| 6495/6495 [03:07<00:00, 34.55it/s]


Epoch: 9, Loss: 0.00445668


100%|██████████| 1702/1702 [00:40<00:00, 42.43it/s]


mse: 0.00585786, mpjpe: 0.11190700, pampjpe: 0.05565229
best test mpjpe is:0.11190700365207754


100%|██████████| 6495/6495 [03:07<00:00, 34.63it/s]


Epoch: 10, Loss: 0.00432717


100%|██████████| 6495/6495 [03:08<00:00, 34.47it/s]


Epoch: 11, Loss: 0.00422894


100%|██████████| 6495/6495 [03:07<00:00, 34.57it/s]


Epoch: 12, Loss: 0.00409982


100%|██████████| 6495/6495 [03:08<00:00, 34.49it/s]


Epoch: 13, Loss: 0.00404086


100%|██████████| 6495/6495 [03:07<00:00, 34.58it/s]


Epoch: 14, Loss: 0.00395506


100%|██████████| 1702/1702 [00:39<00:00, 42.70it/s]


mse: 0.00570056, mpjpe: 0.10915165, pampjpe: 0.05564439
best test mpjpe is:0.10915165436045457


100%|██████████| 6495/6495 [03:09<00:00, 34.21it/s]


Epoch: 15, Loss: 0.00387668


100%|██████████| 6495/6495 [03:11<00:00, 34.00it/s]


Epoch: 16, Loss: 0.00383079


100%|██████████| 6495/6495 [03:37<00:00, 29.81it/s]


Epoch: 17, Loss: 0.00377115


100%|██████████| 6495/6495 [07:09<00:00, 15.13it/s]  


Epoch: 18, Loss: 0.00373067


100%|██████████| 6495/6495 [03:36<00:00, 29.94it/s]


Epoch: 19, Loss: 0.00369465


100%|██████████| 1702/1702 [00:39<00:00, 42.84it/s]


mse: 0.00562897, mpjpe: 0.10991267, pampjpe: 0.05388967


100%|██████████| 6495/6495 [03:08<00:00, 34.44it/s]


Epoch: 20, Loss: 0.00363856


100%|██████████| 6495/6495 [04:03<00:00, 26.62it/s]


Epoch: 21, Loss: 0.00358320


100%|██████████| 6495/6495 [03:13<00:00, 33.48it/s]


Epoch: 22, Loss: 0.00356208


100%|██████████| 6495/6495 [03:08<00:00, 34.43it/s]


Epoch: 23, Loss: 0.00352962


100%|██████████| 6495/6495 [03:09<00:00, 34.35it/s]


Epoch: 24, Loss: 0.00348399


100%|██████████| 1702/1702 [00:48<00:00, 35.16it/s]


mse: 0.00556113, mpjpe: 0.10797129, pampjpe: 0.05513898
best test mpjpe is:0.10797129311865866


100%|██████████| 6495/6495 [03:08<00:00, 34.46it/s]


Epoch: 25, Loss: 0.00347442


100%|██████████| 6495/6495 [03:21<00:00, 32.31it/s]


Epoch: 26, Loss: 0.00344428


100%|██████████| 6495/6495 [03:25<00:00, 31.62it/s]


Epoch: 27, Loss: 0.00341463


100%|██████████| 6495/6495 [03:25<00:00, 31.62it/s]


Epoch: 28, Loss: 0.00339544


100%|██████████| 6495/6495 [03:25<00:00, 31.66it/s]


Epoch: 29, Loss: 0.00336100


100%|██████████| 1702/1702 [00:42<00:00, 40.14it/s]


mse: 0.00551999, mpjpe: 0.10735389, pampjpe: 0.05376087
best test mpjpe is:0.10735388970091125


100%|██████████| 6495/6495 [03:25<00:00, 31.63it/s]


Epoch: 30, Loss: 0.00332704


100%|██████████| 6495/6495 [03:24<00:00, 31.78it/s]


Epoch: 31, Loss: 0.00332864


100%|██████████| 6495/6495 [03:26<00:00, 31.50it/s]


Epoch: 32, Loss: 0.00329157


100%|██████████| 6495/6495 [03:25<00:00, 31.57it/s]


Epoch: 33, Loss: 0.00329050


100%|██████████| 6495/6495 [03:26<00:00, 31.51it/s]


Epoch: 34, Loss: 0.00329816


100%|██████████| 1702/1702 [00:42<00:00, 40.22it/s]


mse: 0.00569757, mpjpe: 0.10970360, pampjpe: 0.05436388


100%|██████████| 6495/6495 [03:26<00:00, 31.51it/s]


Epoch: 35, Loss: 0.00324641


100%|██████████| 6495/6495 [03:25<00:00, 31.53it/s]


Epoch: 36, Loss: 0.00324567


100%|██████████| 6495/6495 [05:04<00:00, 21.30it/s]


Epoch: 37, Loss: 0.00323496


100%|██████████| 6495/6495 [06:45<00:00, 16.02it/s]  


Epoch: 38, Loss: 0.00320127


100%|██████████| 6495/6495 [03:23<00:00, 31.88it/s]


Epoch: 39, Loss: 0.00320334


100%|██████████| 1702/1702 [00:43<00:00, 39.48it/s]


mse: 0.00678205, mpjpe: 0.11931880, pampjpe: 0.06032864


100%|██████████| 6495/6495 [03:22<00:00, 32.04it/s]


Epoch: 40, Loss: 0.00319916


100%|██████████| 6495/6495 [03:24<00:00, 31.72it/s]


Epoch: 41, Loss: 0.00317225


100%|██████████| 6495/6495 [03:26<00:00, 31.39it/s]


Epoch: 42, Loss: 0.00315993


100%|██████████| 6495/6495 [03:26<00:00, 31.41it/s]


Epoch: 43, Loss: 0.00316245


100%|██████████| 6495/6495 [03:28<00:00, 31.12it/s]


Epoch: 44, Loss: 0.00317033


100%|██████████| 1702/1702 [00:43<00:00, 39.57it/s]


mse: 0.00542607, mpjpe: 0.10680229, pampjpe: 0.05372254
best test mpjpe is:0.10680228873008475


100%|██████████| 6495/6495 [03:25<00:00, 31.54it/s]


Epoch: 45, Loss: 0.00315111


100%|██████████| 6495/6495 [03:26<00:00, 31.51it/s]


Epoch: 46, Loss: 0.00313706


100%|██████████| 6495/6495 [03:25<00:00, 31.64it/s]


Epoch: 47, Loss: 0.00311078


100%|██████████| 6495/6495 [03:27<00:00, 31.35it/s]


Epoch: 48, Loss: 0.00310118


100%|██████████| 6495/6495 [03:27<00:00, 31.31it/s]


Epoch: 49, Loss: 0.03812058


100%|██████████| 1702/1702 [00:42<00:00, 40.13it/s]

mse: 0.01813097, mpjpe: 0.19555041, pampjpe: 0.10229196





In [17]:
parameter_dir = './RGB_benchmark/rgb_ResNet18/RGB_Resnet18.pt'
torch.save(model.state_dict(), parameter_dir)