In [1]:
import numpy as np
import glob
import scipy.io as sio
import torch
from torch import nn
import random
import csv
import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import yaml
from evaluate import error
import time
import re
from syn_DI_dataset import make_dataset, make_dataloader

In [2]:
torch.manual_seed(3407)

<torch._C.Generator at 0x13eee223ef0>

### Loda Data

In [3]:
dir = 'd:\\Data\\MMFi_Dataset\\E01\\S02\\A01\\depth\\frame007.png'
_mod, _frame = os.path.split(dir)
print(_mod, _frame)

d:\Data\MMFi_Dataset\E01\S02\A01\depth frame007.png


In [3]:
dataset_root = 'd:\Data\My_MMFi_Data\MMFi_Dataset'
# xian zai shi yong de shi Radar_Fused
with open('config.yaml', 'r') as fd:
    config = yaml.load(fd, Loader=yaml.FullLoader)

train_dataset, val_dataset = make_dataset(dataset_root, config)

['rgb', 'depth', 'lidar', 'mmwave', 'wifi-csi']
S02 ['A01', 'A02', 'A03', 'A04', 'A06', 'A08', 'A09', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A18', 'A19', 'A20', 'A21', 'A22', 'A23', 'A24', 'A26']
S03 ['A01', 'A02', 'A03', 'A04', 'A05', 'A06', 'A07', 'A08', 'A09', 'A10', 'A11', 'A12', 'A15', 'A16', 'A17', 'A18', 'A21', 'A22', 'A23', 'A24', 'A25', 'A26', 'A27']
S05 ['A01', 'A02', 'A03', 'A04', 'A05', 'A06', 'A07', 'A09', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A19', 'A20', 'A21', 'A24', 'A25']
S06 ['A01', 'A03', 'A04', 'A06', 'A07', 'A08', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15', 'A18', 'A21', 'A22', 'A23', 'A24', 'A25', 'A26', 'A27']
S08 ['A01', 'A02', 'A04', 'A05', 'A06', 'A07', 'A08', 'A09', 'A10', 'A11', 'A13', 'A14', 'A17', 'A19', 'A20', 'A21', 'A22', 'A23', 'A24', 'A25', 'A26', 'A27']
S09 ['A01', 'A07', 'A08', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15', 'A17', 'A18', 'A20', 'A21', 'A24', 'A25', 'A27']
S11 ['A01', 'A02', 'A03', 'A05', 'A06', 'A08', 'A10', 'A11', 'A12', 

In [4]:
def generate_none_empth_modality_list():
    modality_list = random.choices(
        [True, False],
        k= 3,
        weights=[50, 50]
    )
    lidar_ = random.choices(
        [True, False],
        k= 1,
        weights=[70, 30]
    )
    wifi_ = random.choices(
        [True, False],
        k= 1,
        weights=[70, 30]
    )
    final_list = modality_list + lidar_ + wifi_
    if sum(final_list) == 0:
        final_list = generate_none_empth_modality_list()
        return final_list
    else:
        return final_list


def collate_fn_padd(batch):
    '''
    Padds batch of variable length

    note: it converts things ToTensor manually here since the ToTensor transform
    assume it takes in images rather than arbitrary tensors.

    dict_keys(['modality', 'scene', 'subject', 'action', 'idx', 'output', 
    'input_rgb', 'input_depth', 'input_lidar', 'input_mmwave'])
    '''
    ## get sequence lengths
    # for t in batch:
    #     print(t.keys())
    #     print(a)
    # #     # print(t[0].shape,t[1].shape)
    kpts = []
    [kpts.append(np.array(t['output'])) for t in batch]
    kpts = torch.FloatTensor(np.array(kpts))

    lengths = torch.tensor([t['input_mmwave'].shape[0] for t in batch ])

    # rgb
    rgb_data = np.array([(t['input_rgb']) for t in batch ])
    rgb_data = torch.FloatTensor(rgb_data).permute(0,3,1,2)

    # depth
    depth_data = np.array([(t['input_depth']) for t in batch ])
    depth_data = torch.FloatTensor(depth_data).permute(0,3,1,2)

    # mmwave
    ## padd
    mmwave_data = [torch.Tensor(t['input_mmwave']) for t in batch ]
    mmwave_data = torch.nn.utils.rnn.pad_sequence(mmwave_data)
    ## compute mask
    mmwave_data = mmwave_data.permute(1,0,2)

    # lidar
    ## padd
    lidar_data = [torch.Tensor(t['input_lidar']) for t in batch ]
    lidar_data = torch.nn.utils.rnn.pad_sequence(lidar_data)
    ## compute mask
    lidar_data = lidar_data.permute(1,0,2)

    # wifi-csi
    wifi_data = np.array([(t['input_wifi-csi']) for t in batch ])
    wifi_data = torch.FloatTensor(wifi_data)
    
    # modality_list = generate_none_empth_modality_list()

    # return rgb_data, depth_data, lidar_data, mmwave_data, wifi_data, kpts, modality_list
    return rgb_data, depth_data, mmwave_data, lidar_data, wifi_data, kpts

In [5]:
rng_generator = torch.manual_seed(config['init_rand_seed'])
train_loader = make_dataloader(train_dataset, is_training=True, generator=rng_generator, **config['loader'], collate_fn = collate_fn_padd)
val_loader = make_dataloader(val_dataset, is_training=False, generator=rng_generator, **config['loader'], collate_fn = collate_fn_padd)

#### check random seed

In [17]:
print(len(train_loader.dataset)/16)

12991.5


In [6]:
avg_time = 0
for epoch in range(3):
    random.seed(config['modality_existances']['train_random_seed'])
    i= 0
    for data in val_loader:
        # start_time = time.time()
        # rgb_data, depth_data, lidar_data, mmwave_data, wifi_data, kpts, modality_list = data
        rgb_data, depth_data, mmwave_data, lidar_data, wifi_data, kpts = data
        # epoch_time = time.time() - start_time
        # print(rgb_data[0].shape, depth_data[0].shape, lidar_data[0].shape, mmwave_data[0].shape, wifi_data[0].shape,kpts.shape, lengths.shape)
        # print(rgb_data.shape, depth_data.shape, lidar_data.shape, mmwave_data.shape, wifi_data.shape, kpts.shape)
        print(rgb_data.shape, depth_data.shape, mmwave_data.shape, lidar_data.shape, wifi_data.shape, kpts.shape)
        # print('epoch_time: ', epoch_time)
        # avg_time += epoch_time
        # print(rgb_data, depth_data, lidar_data, mmwave_data, wifi_data, kpts, modality_list)
        # print(modality_list)
        i += 1
        if i > 5:
            print('............................................................................................')
            break
    # print('tot_time: ', avg_time)
    # print('avg_time: ', avg_time / (i+1))

  lidar_data = [torch.Tensor(t['input_lidar']) for t in batch ]


torch.Size([16, 3, 480, 640]) torch.Size([16, 3, 480, 640]) torch.Size([16, 25, 5]) torch.Size([16, 1135, 3]) torch.Size([16, 3, 114, 10]) torch.Size([16, 17, 3])
torch.Size([16, 3, 480, 640]) torch.Size([16, 3, 480, 640]) torch.Size([16, 31, 5]) torch.Size([16, 1144, 3]) torch.Size([16, 3, 114, 10]) torch.Size([16, 17, 3])
torch.Size([16, 3, 480, 640]) torch.Size([16, 3, 480, 640]) torch.Size([16, 23, 5]) torch.Size([16, 1147, 3]) torch.Size([16, 3, 114, 10]) torch.Size([16, 17, 3])
torch.Size([16, 3, 480, 640]) torch.Size([16, 3, 480, 640]) torch.Size([16, 32, 5]) torch.Size([16, 1141, 3]) torch.Size([16, 3, 114, 10]) torch.Size([16, 17, 3])
torch.Size([16, 3, 480, 640]) torch.Size([16, 3, 480, 640]) torch.Size([16, 31, 5]) torch.Size([16, 1138, 3]) torch.Size([16, 3, 114, 10]) torch.Size([16, 17, 3])
torch.Size([16, 3, 480, 640]) torch.Size([16, 3, 480, 640]) torch.Size([16, 27, 5]) torch.Size([16, 1142, 3]) torch.Size([16, 3, 114, 10]) torch.Size([16, 17, 3])
......................

In [13]:
print(wifi_data[0])
for i in range(len(wifi_data)):
    print([np.isinf(wifi_data[i]).any(), np.isnan(wifi_data[i]).any()])

tensor([[[0.6910, 0.6351, 0.6291,  ..., 0.6200, 0.6925, 0.6778],
         [0.7105, 0.6586, 0.6445,  ..., 0.6356, 0.7155, 0.7018],
         [0.7295, 0.6837, 0.6835,  ..., 0.6633, 0.7371, 0.7330],
         ...,
         [0.6692, 0.6904, 0.6697,  ..., 0.6625, 0.7364, 0.7282],
         [0.6935, 0.6720, 0.6527,  ..., 0.6496, 0.7161, 0.6973],
         [0.7022, 0.6164, 0.6220,  ..., 0.6028, 0.6862, 0.6751]],

        [[0.7283, 0.6402, 0.6356,  ..., 0.6021, 0.6480, 0.6688],
         [0.7375, 0.6516, 0.6239,  ..., 0.6088, 0.6386, 0.6688],
         [0.7589, 0.6125, 0.6246,  ..., 0.6051, 0.6626, 0.6679],
         ...,
         [0.8889, 0.9083, 0.9153,  ..., 0.8729, 0.9158, 0.9366],
         [0.8806, 0.9071, 0.9057,  ..., 0.8696, 0.9087, 0.9247],
         [0.8710, 0.9004, 0.9010,  ..., 0.8649, 0.9038, 0.9193]],

        [[0.7061, 0.7954, 0.7334,  ..., 0.7507, 0.7676, 0.7970],
         [0.7375, 0.8103, 0.7451,  ..., 0.7716, 0.7855, 0.8094],
         [0.7479, 0.8268, 0.7675,  ..., 0.7906, 0.8038, 0.

### calculate modality time

In [None]:
file1 = open('./time.txt', 'r')
Lines = file1.readlines()


### Model

In [7]:
from meta_transformer_model import *

model = meta_transformer()
model_dict = model.state_dict()
# print(model_dict['feature_extractor.rgb_extractor.part.0.0.weight'])
# print(model_dict['encoder.encoder.4.attn.qkv.weight'])
pt_weights = torch.load('metaTransformer_IDRLW.pt')
# print(pt_weights['feature_extractor.rgb_extractor.part.0.0.weight'])
# print(pt_weights['encoder.encoder.4.attn.qkv.weight'])
trained_encoder = {k:v for k, v in pt_weights.items() if k.startswith('encoder')}
model_dict.update(trained_encoder) 
model.load_state_dict(model_dict)
# print(model.state_dict()['feature_extractor.rgb_extractor.part.0.0.weight'])
# print(model.state_dict()['encoder.encoder.4.attn.qkv.weight'])
model.cuda()

# rgb = torch.randn((16, 3, 480, 640)).cuda()
# depth = torch.randn((16, 3, 480, 640)).cuda()
# mmwave = torch.randn((16, 48, 5)).cuda()
# lidar = torch.randn((16, 970, 3)).cuda()
# wifi = torch.randn((16, 3, 114, 10)).cuda()
# # modality_list = ['rgb', 'depth', 'mmwave', 'lidar', 'wifi']
# modality_list = ['rgb', 'depth', 'lidar']
# out = model(rgb,depth,mmwave,lidar,wifi,modality_list)
# print(out.shape)

meta_transformer(
  (feature_extractor): feature_extrator(
    (rgb_extractor): rgb_feature_extractor(
      (part): Sequential(
        (0): Sequential(
          (0): Conv2d(3, 3, kernel_size=(14, 14), stride=(2, 2))
          (1): ReLU()
          (2): Conv2d(3, 3, kernel_size=(5, 56), stride=(1, 1))
          (3): ReLU()
          (4): Conv2d(3, 3, kernel_size=(5, 23), stride=(1, 1))
          (5): ReLU()
          (6): Conv2d(3, 16, kernel_size=(3, 14), stride=(1, 1))
        )
        (1): Conv2d(16, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU()
        (4): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (5): Sequential(
          (0): Block(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (batch_norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine

In [5]:
'model with pos encoding'
import torch
from MI_model_5 import modality_invariant_model

model = modality_invariant_model()
model.cuda()

# rgb_data = torch.randn(16, 3, 480, 640).cuda()
# depth_data = torch.randn(16, 3, 480, 640).cuda()
# lidar_data = torch.randn(16, 1432, 3).cuda()
# mmwave_data = torch.randn(16, 67, 5).cuda()
# wifi_data = torch.randn(16, 3, 114, 10).cuda()
# # modality_list = [False, True, True, True, False]
# modality_list = [True, True, True, True, True]
# # # # modality_list = [False, False, False, False, False]

# out = model(rgb_data, depth_data,  mmwave_data, lidar_data, wifi_data, modality_list)

# print(out.shape)
# print(model)

modality_invariant_model(
  (feature_extractor): feature_extrator(
    (rgb_extractor): rgb_feature_extractor(
      (part): Sequential(
        (0): Sequential(
          (0): Conv2d(3, 3, kernel_size=(14, 14), stride=(2, 2))
          (1): ReLU()
          (2): Conv2d(3, 3, kernel_size=(5, 56), stride=(1, 1))
          (3): ReLU()
          (4): Conv2d(3, 3, kernel_size=(5, 23), stride=(1, 1))
          (5): ReLU()
          (6): Conv2d(3, 16, kernel_size=(3, 14), stride=(1, 1))
        )
        (1): Conv2d(16, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU()
        (4): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (5): Sequential(
          (0): Block(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (batch_norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1

In [8]:
pytorch_total_params = sum(p.numel() for p in model.parameters())
print(pytorch_total_params)
print(model)

67568275
modality_invariant_model(
  (feature_extractor): feature_extrator(
    (rgb_extractor): rgb_feature_extractor(
      (part): Sequential(
        (0): Sequential(
          (0): Conv2d(3, 3, kernel_size=(14, 14), stride=(2, 2))
          (1): ReLU()
          (2): Conv2d(3, 3, kernel_size=(5, 56), stride=(1, 1))
          (3): ReLU()
          (4): Conv2d(3, 3, kernel_size=(5, 23), stride=(1, 1))
          (5): ReLU()
          (6): Conv2d(3, 16, kernel_size=(3, 14), stride=(1, 1))
        )
        (1): Conv2d(16, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU()
        (4): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (5): Sequential(
          (0): Block(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (batch_norm1): BatchNorm2d(64, eps=1e-05, mom

### training 

In [8]:
def test(model, tensor_loader, criterion1, criterion2, device):
    model.eval()
    test_mpjpe = 0
    test_pampjpe = 0
    test_mse = 0
    random.seed(config['modality_existances']['val_random_seed'])
    for data in tqdm(tensor_loader):
        # rgb_data, depth_data, lidar_data, mmwave_data, wifi_data, kpts, modality_list = data
        # rgb_data, depth_data, mmwave_data, kpts = data
        rgb_data, depth_data, mmwave_data, lidar_data, wifi_data, kpts = data
        # t1 = time.time()
        # load_time = t1 - start_time 
        rgb_data = rgb_data.to(device)
        depth_data = depth_data.to(device)
        lidar_data = lidar_data.to(device)
        mmwave_data = mmwave_data.to(device)
        wifi_data = wifi_data.to(device)
        modality_list = ['rgb', 'depth', 'lidar']
        kpts.to(device)
        labels = kpts.type(torch.FloatTensor)
        outputs = model(rgb_data, depth_data, mmwave_data, lidar_data, wifi_data, modality_list)
        outputs = outputs.type(torch.FloatTensor)
        outputs.to(device)
        # t2 = time.time()
        # forward_time = t2 - t1
        test_mse += criterion1(outputs,labels).item() * rgb_data.size(0)

        outputs = outputs.detach().numpy()
        labels = labels.detach().numpy()
        
        mpjpe, pampjpe = criterion2(outputs,labels)
        test_mpjpe += mpjpe.item() * rgb_data.size(0)
        test_pampjpe += pampjpe.item() * rgb_data.size(0)
        # t3 = time.time()
        # record_time = t3 - t2
        # print('load_time: ', load_time)
        # print('forward_time: ', forward_time)
        # print('record_time: ', record_time)
    test_mpjpe = test_mpjpe/len(tensor_loader.dataset)
    test_pampjpe = test_pampjpe/len(tensor_loader.dataset)
    test_mse = test_mse/len(tensor_loader.dataset)
    print("mse: {:.8f}, mpjpe: {:.8f}, pampjpe: {:.8f}".format(float(test_mse), float(test_mpjpe),float(test_pampjpe)))
    return test_mpjpe

In [9]:
train_criterion = nn.MSELoss()
test_criterion = error
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
test(model, val_loader, train_criterion, test_criterion, device)

  0%|          | 9/3403 [00:05<36:54,  1.53it/s]


KeyboardInterrupt: 

In [9]:
def train(model, train_loader, test_loader, num_epochs, learning_rate, train_criterion, test_criterion, device):
    optimizer = torch.optim.AdamW(
        [
                {'params': model.encoder.parameters()},
                {'params': model.regression_head.parameters()}
            ],
        lr = learning_rate
    )
    # optimizer = torch.optim.SGD([
    #             {'params': model.linear_projector.parameters()},
    #             # {'params': model.vit.parameters()}
    #             {'params': model.MIT.parameters()}
    #         ], lr = learning_rate)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[10,20,30,40],gamma=0.5)
    parameter_dir = './metaTransformer_IDL_FTALL.pt'
    best_test_mpjpe = 100
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        epoch_accuracy = 0
        random.seed(epoch)
        num_iter = 400
        for i, data in enumerate(tqdm(train_loader)):
            if i < num_iter:
                # rgb_data, depth_data, lidar_data, mmwave_data, wifi_data, kpts, modality_list = data
                rgb_data, depth_data, mmwave_data, lidar_data, wifi_data, kpts = data
                # print('data')
                # print(rgb_data, depth_data, lidar_data, mmwave_data, wifi_data)
                # print('end')
                rgb_data = rgb_data.to(device)
                depth_data = depth_data.to(device)
                lidar_data = lidar_data.to(device)
                mmwave_data = mmwave_data.to(device)
                wifi_data = wifi_data.to(device)
                modality_list = ['rgb', 'depth', 'lidar']
                labels = kpts.to(device)
                labels = labels.type(torch.FloatTensor)
                
                optimizer.zero_grad()
                # outputs = model(rgb_data, depth_data,  mmwave_data, lidar_data,wifi_data, modality_list)
                outputs = model(rgb_data, depth_data, mmwave_data, lidar_data, wifi_data, modality_list)
                # print(outputs)
                outputs = outputs.to(device)
                outputs = outputs.type(torch.FloatTensor)
                loss = train_criterion(outputs,labels)
                if loss == float('nan'):
                    print('nan')
                    print(outputs)
                    print(labels)
                    
                loss.backward()
                # print(length)
                # print("loss is ", loss.item())
                optimizer.step()
                
                epoch_loss += loss.item() * rgb_data.size(0)
            else:
                break
            # print("epoch loss is ", epoch_loss)
        # epoch_loss = epoch_loss/len(train_loader.dataset)
        epoch_loss = epoch_loss/(rgb_data.size(0)*num_iter)
        print('Epoch: {}, Loss: {:.8f}'.format(epoch, epoch_loss))
        if (epoch+1) % 3 == 0:
            test_mpjpe = test(
                model=model,
                tensor_loader=test_loader,
                criterion1 = train_criterion,
                criterion2 = test_criterion,
                device= device
            )
            if test_mpjpe <= best_test_mpjpe:
                print(f"best test mpjpe is:{test_mpjpe}")
                best_test_mpjpe = test_mpjpe
                torch.save(model.state_dict(), parameter_dir)
        # scheduler.step()
    torch.save(model.state_dict(), parameter_dir)
    return

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device) 

cuda:0


In [11]:
train_criterion = nn.MSELoss()
test_criterion = error
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.load_state_dict(torch.load('./MI_10thJun.pt'))
model.to(device)
train(
    model=model,
    train_loader= train_loader,
     test_loader= val_loader,
    num_epochs= 60,
    learning_rate=1e-3,
    train_criterion = train_criterion, 
    test_criterion = test_criterion,
    device=device 
        )

  3%|▎         | 400/12991 [40:32<21:16:05,  6.08s/it]


Epoch: 0, Loss: 0.13771381


  3%|▎         | 400/12991 [39:12<20:34:13,  5.88s/it]


Epoch: 1, Loss: 0.00781199


  3%|▎         | 400/12991 [36:04<18:55:48,  5.41s/it]


Epoch: 2, Loss: 0.00513433


100%|██████████| 3403/3403 [1:52:44<00:00,  1.99s/it]  


mse: 0.00637186, mpjpe: 0.11660850, pampjpe: 0.08776047
best test mpjpe is:0.11660849739751611


  3%|▎         | 400/12991 [47:38<24:59:48,  7.15s/it]


Epoch: 3, Loss: 0.00331775


  3%|▎         | 400/12991 [40:31<21:15:33,  6.08s/it]


Epoch: 4, Loss: 0.00328973


  3%|▎         | 400/12991 [37:05<19:27:30,  5.56s/it]


Epoch: 5, Loss: 0.00280841


100%|██████████| 3403/3403 [1:49:23<00:00,  1.93s/it]  


mse: 0.00520741, mpjpe: 0.10559226, pampjpe: 0.07522064
best test mpjpe is:0.10559226044157874


  3%|▎         | 400/12991 [45:27<23:51:04,  6.82s/it]


Epoch: 6, Loss: 0.00232198


  3%|▎         | 400/12991 [29:40<15:34:11,  4.45s/it]


Epoch: 7, Loss: 0.00218082


  3%|▎         | 400/12991 [25:20<13:17:44,  3.80s/it]


Epoch: 8, Loss: 0.00214124


100%|██████████| 3403/3403 [1:03:14<00:00,  1.12s/it]


mse: 0.00453831, mpjpe: 0.09624872, pampjpe: 0.06806127
best test mpjpe is:0.0962487248712358


  3%|▎         | 400/12991 [38:19<20:06:10,  5.75s/it]


Epoch: 9, Loss: 0.00192308


  3%|▎         | 400/12991 [30:01<15:45:08,  4.50s/it]


Epoch: 10, Loss: 102.36965096


  3%|▎         | 400/12991 [26:09<13:43:14,  3.92s/it]


Epoch: 11, Loss: 0.05166870


  4%|▎         | 126/3403 [02:33<1:06:44,  1.22s/it]


KeyboardInterrupt: 

In [10]:
parameter_dir = './MI_2ndApr.pt'
torch.save(model.state_dict(), parameter_dir)

In [13]:
print(list(model.feature_extractor.csi_extractor.parameters()))
class csi_feature_extractor(nn.Module):
    def __init__(self, model):
        super(csi_feature_extractor, self).__init__()
        self.part = nn.Sequential(
            model.encoder_conv1,
            model.encoder_bn1,
            model.encoder_relu,
            model.encoder_layer1,
            model.encoder_layer2,
            model.encoder_layer3,
            model.encoder_layer4, 
            # torch.nn.AvgPool2d((1, 4))
        )
    def forward(self, x):
        x = x.unsqueeze(1)
        x = torch.transpose(x, 2, 3) #16,2,114,3,32
        x = torch.flatten(x, 3, 4)# 16,2,114,96
        torch_resize = Resize([136,32])
        x = torch_resize(x)
        x = self.part(x).view(x.size(0), 512, -1)
        x = x.permute(0, 2, 1)
        return x
import sys
sys.path.insert(0, './CSI_benchmark')
csi_model = torch.load('CSI_benchmark/protocol3_random_1.pkl')
csi_extractor = csi_feature_extractor(csi_model)
csi_extractor.eval()
print("original model:")
print(list(csi_extractor.parameters()))

[Parameter containing:
tensor([[[[ 0.1442, -0.0400,  0.0629],
          [-0.1392,  0.0949,  0.2094],
          [ 0.0652, -0.2481, -0.3184]]],


        [[[ 0.0702, -0.2592, -0.1461],
          [ 0.0290, -0.2670, -0.1601],
          [-0.0984, -0.1065,  0.1317]]],


        [[[-0.2569, -0.1761, -0.2460],
          [-0.0025,  0.0818, -0.1916],
          [ 0.0658,  0.0344, -0.1465]]],


        [[[ 0.1116, -0.2394, -0.2183],
          [-0.0552, -0.1350, -0.1518],
          [ 0.0485, -0.0301, -0.2722]]],


        [[[ 0.1532, -0.1014, -0.0618],
          [ 0.3229, -0.0714,  0.1710],
          [-0.1570,  0.2686,  0.3306]]],


        [[[ 0.1332, -0.3275, -0.0421],
          [-0.2583, -0.3117, -0.2423],
          [-0.1054,  0.2280,  0.0535]]],


        [[[-0.1630, -0.1043,  0.2208],
          [ 0.1549, -0.2130,  0.0888],
          [-0.1122, -0.2864,  0.3296]]],


        [[[ 0.2309, -0.1619, -0.2679],
          [-0.2516,  0.0649,  0.2708],
          [ 0.0679, -0.0229, -0.0502]]],


        [

: 