# 모듈 불러오기

In [31]:
# !conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia

In [32]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random
import re

import scipy
from scipy.ndimage import gaussian_filter1d
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.manifold import TSNE

import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

import os 
import glob
import itertools
from copy import deepcopy
from sklearn.model_selection import KFold

import datetime

from dataloader import *

from model_distribution import *

# Train

In [33]:
device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')

conv1d_dim1 = 32
conv1d_dim2 = 64
conv1d_dim3 = 128
dense_dim = 256

learning_rate = 0.001
n_epochs = 2000

criterion_distribution = nn.GaussianNLLLoss()

In [34]:
# 경로 입력 및 아이디 추출

file_path = "./data/total/"
data_path = glob.glob(file_path + '*')
name = []
for file_name in data_path:
    folder_name = os.path.split(file_name)[1][:7]
    name += [folder_name]
    
id_name = np.unique(name)

In [35]:
id_name

array(['IF00017', 'IF00024', 'IF00034', 'IF00041', 'IF01020', 'IF01045',
       'IF01047', 'IF02035', 'IF03014', 'IF03027', 'IF03039', 'IF94031',
       'IF99008', 'IF99013', 'IF99030', 'IF99032', 'IM01006', 'IM01029',
       'IM02040', 'IM03011', 'IM03048', 'IM96018', 'IM96033', 'IM97015',
       'IM98009', 'IM98019', 'IM98026', 'IM98036', 'IM98042', 'IM98049',
       'IM98050', 'IM99007', 'IM99010', 'IM99012', 'IM99021', 'IM99025',
       'IM99037'], dtype='<U7')

In [36]:
test_id = np.array(['IF03014', 'IF00041', 'IM02040', 'IM98049'])
# test_id = np.array(['IF03014', 'IF00041', 'IM02040', 'IM98042'])

In [37]:
import os
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["MKL_SERVICE_FORCE_INTEL"] = "1"

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import numpy as np
import random
from copy import deepcopy
import datetime

# seed 고정
random_seed = 77
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)

def train_ddp(rank, world_size):
    # 분산 학습 초기화
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    # 데이터셋 준비 및 K-Fold
    kfold = KFold(n_splits=10, shuffle=True, random_state=random_seed)
    id_name_trnval = np.setdiff1d(id_name, test_id)
    best_MAE_fold = 0

    for fold, (train_idx, valid_idx) in enumerate(kfold.split(id_name_trnval)):
        train_id = id_name_trnval[train_idx]
        valid_id = id_name_trnval[valid_idx]

        print('Train ID : {}\n Valid ID {}'.format(train_id, valid_id))
        
        train_dataset_R = Gait_Dataset_Salted(file_path, train_id, right=True)
        train_dataset_L = Gait_Dataset_Salted(file_path, train_id, right=False)
        valid_dataset_R = Gait_Dataset_Salted(file_path, valid_id, right=True)
        valid_dataset_L = Gait_Dataset_Salted(file_path, valid_id, right=False)

        train_dataset = torch.utils.data.ConcatDataset([train_dataset_R, train_dataset_L])
        valid_dataset = torch.utils.data.ConcatDataset([valid_dataset_R, valid_dataset_L])

        # DistributedSampler로 데이터 로더 생성
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
        valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset, num_replicas=world_size, rank=rank, shuffle=False)

        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, sampler=train_sampler)
        val_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=128, sampler=valid_sampler)

        print('Fold {} Dataloader Load Complete'.format(fold+1))

        # 모델 정의 및 DDP로 감싸기
        model = Encoder(conv1d_dim1, conv1d_dim2, conv1d_dim3, dense_dim).to(rank)
        model = DDP(model, device_ids=[rank])
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

        # Early Stopping을 위한 변수
        best = 10000
        converge_cnt = 0
        best_MAE = 1000
        best_epoch = 0

        # 학습 시작 시간 기록
        start_time = datetime.datetime.now()
        print(f"Training started at: {start_time}")
        
        # Training loop
        for epoch in range(n_epochs):
            tot_trn_loss = 0.0

            model.train()
            train_sampler.set_epoch(epoch)  # Shuffle each epoch

            for i, data in enumerate(train_loader):
                inputs_acc, inputs_gyr, inputs_prs, stride_length, mu, sigma, folder_id = data
                inputs_acc, inputs_gyr, inputs_prs, stride_length, mu, sigma = inputs_acc.float(), inputs_gyr.float(), inputs_prs.float(), stride_length.float(), mu.float(), sigma.float()
                inputs_acc, inputs_gyr, inputs_prs = inputs_acc.to(rank), inputs_gyr.to(rank), inputs_prs.to(rank)
                mu, sigma = mu.reshape(-1, 1).to(rank), sigma.reshape(-1, 1).to(rank)
                stride_length = stride_length.reshape(-1, 1).to(rank)

                outputs, var = model(inputs_acc, inputs_gyr, inputs_prs)
                loss = criterion_distribution(mu, outputs, var)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                tot_trn_loss += loss.item()

            # Evaluation Mode
            model.eval()
            tot_val_loss = 0
            tot_val_MAE = 0

            with torch.no_grad():
                for i, data in enumerate(val_loader):
                    inputs_acc, inputs_gyr, inputs_prs, stride_length, mu, sigma, folder_id = data
                    inputs_acc, inputs_gyr, inputs_prs, stride_length, mu, sigma = inputs_acc.float(), inputs_gyr.float(), inputs_prs.float(), stride_length.float(), mu.float(), sigma.float()
                    inputs_acc, inputs_gyr, inputs_prs = inputs_acc.to(rank), inputs_gyr.to(rank), inputs_prs.to(rank)
                    mu, sigma = mu.reshape(-1, 1).to(rank), sigma.reshape(-1, 1).to(rank)
                    stride_length = stride_length.reshape(-1, 1).to(rank)

                    outputs, var = model(inputs_acc, inputs_gyr, inputs_prs)
                    loss = criterion_distribution(mu, outputs, var)
                    tot_val_loss += loss.item()
                    tot_val_MAE += torch.sum(torch.abs(outputs - stride_length)) / len(stride_length)

            trn_loss = tot_trn_loss / len(train_loader)
            val_loss = tot_val_loss / len(val_loader)
            MAE = tot_val_MAE / len(val_loader)

            # Early Stopping
            if val_loss < best:
                best = np.mean(val_loss)
                best_MAE = MAE
                best_epoch = epoch+1
                if rank == 0:  # Only save model from rank 0 process
                    torch.save(deepcopy(model.state_dict()), f'./model/L2/L2_fold{fold+1}.pth')
                converge_cnt = 0
            else:
                converge_cnt += 1

            if converge_cnt > 50:
                print(f'Early stopping: Fold {fold+1}, Epoch {best_epoch}, Valid Loss {best:.3f}, MAE {best_MAE:.3f}')
                best_MAE_fold += best_MAE
                break

        # 학습 종료 시간 및 경과 시간 계산
        end_time = datetime.datetime.now()
        elapsed_time = end_time - start_time
        print(f"Training ended at: {end_time}")
        print(f"Total training time: {elapsed_time}")

    dist.destroy_process_group()

if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    mp.spawn(train_ddp, args=(world_size,), nprocs=world_size)



Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/shinsjn/miniconda3/envs/Jupyter_test/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/home/shinsjn/miniconda3/envs/Jupyter_test/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'train_ddp' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/shinsjn/miniconda3/envs/Jupyter_test/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/home/shinsjn/miniconda3/envs/Jupyter_test/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'train_ddp' on <module '__main__' (built-in)>
W1104 15:04:46.750000 140126867067520 torch/multiproce

ProcessExitedException: process 1 terminated with exit code 1

# Visualize

In [8]:
device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')

test_file_path = "./data/total/"

test_dataset_R = Gait_Dataset_Salted(test_file_path, test_id, right=True)
test_dataset_L = Gait_Dataset_Salted(test_file_path, test_id, right=False)

test_dataset = torch.utils.data.ConcatDataset([test_dataset_R, test_dataset_L])

test_loader = torch.utils.data.DataLoader(test_dataset,
                                       batch_size=256,
                                       shuffle=False,
                                       worker_init_fn=np.random.seed(42))

In [9]:
### Scatter Plot 

conv1d_dim1 = 32
conv1d_dim2 = 64
conv1d_dim3 = 128
dense_dim = 256

stride_length_list = pd.DataFrame()
sigma_list = pd.DataFrame()
tot_val_MAE = 0
tot_val_MAPE = 0

for fold_idx, model_name in enumerate(glob.glob('./model/L2/L2_fold' + '*')):
    print(model_name)
    model = Encoder(conv1d_dim1, conv1d_dim2, conv1d_dim3, dense_dim).to(device)
    model.load_state_dict(torch.load(model_name))
    model.eval()
    pred_list = []
    for i, data in enumerate(test_loader):
        inputs_acc, inputs_gyr, inputs_prs, stride_length, mu, sigma, folder_id = data
        inputs_acc, inputs_gyr, inputs_prs, stride_length = inputs_acc.float(), inputs_gyr.float(), inputs_prs.float(), stride_length.float() 
        inputs_acc, inputs_gyr, inputs_prs = inputs_acc.to(device), inputs_gyr.to(device), inputs_prs.to(device)

        stride_length = stride_length.reshape(-1, 1)
        stride_length = stride_length.to(device)

        outputs = model(inputs_acc, inputs_gyr, inputs_prs)
        stride_length_list.loc[:, fold_idx] = outputs[0].reshape(-1).cpu().detach().numpy()
        sigma_list.loc[:, fold_idx] = outputs[1].reshape(-1).cpu().detach().numpy()
        
pred = stride_length_list.mean(axis=1)
stride_length = stride_length.reshape(-1).cpu().detach().numpy()
# stride_length = stride_length.reshape(-1)

mu = mu.cpu().detach().numpy()
# mu = mu

# print(pred)
# print(stride_length_list)
MAE = np.sum(np.abs(pred - stride_length)) / len(stride_length)
MAPE = 100 - (np.mean(np.abs(pred - stride_length) / stride_length) * 100)
RMSE = np.sqrt(np.mean((stride_length - pred)**2))
RSE = RMSE / np.sqrt(np.sum((stride_length - mu)**2))


print('MAE : {:.4f}, MAPE : {:.2f}%, RMSE : {:.4f}, RSE : {:.4f}'.format(MAE, MAPE, RMSE, RSE))
# print('MAE : {:.4f}'.format(MAE))

plt.figure(figsize=(10, 10))
plt.scatter(stride_length, pred)
plt.xlim([80, 200])
plt.xlabel('stride_length')
plt.ylim([80, 200])
plt.ylabel('pred')

xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, linestyle='--', color='k', lw=3, scalex=False, scaley=False)

plt.show()

./model/L2/L2_fold8.pth


  model.load_state_dict(torch.load(model_name))


./model/L2/L2_fold7.pth


  model.load_state_dict(torch.load(model_name))


./model/L2/L2_fold9.pth


  model.load_state_dict(torch.load(model_name))


./model/L2/L2_fold5.pth


  model.load_state_dict(torch.load(model_name))


./model/L2/L2_fold1.pth


  model.load_state_dict(torch.load(model_name))


RuntimeError: Error(s) in loading state_dict for Encoder:
	Missing key(s) in state_dict: "conv1d_acc.0.weight", "conv1d_acc.0.bias", "conv1d_acc.1.weight", "conv1d_acc.1.bias", "conv1d_acc.1.running_mean", "conv1d_acc.1.running_var", "conv1d_acc.3.weight", "conv1d_acc.3.bias", "conv1d_acc.4.weight", "conv1d_acc.4.bias", "conv1d_acc.4.running_mean", "conv1d_acc.4.running_var", "conv1d_acc.6.weight", "conv1d_acc.6.bias", "conv1d_acc.7.weight", "conv1d_acc.7.bias", "conv1d_acc.7.running_mean", "conv1d_acc.7.running_var", "conv1d_gyr.0.weight", "conv1d_gyr.0.bias", "conv1d_gyr.1.weight", "conv1d_gyr.1.bias", "conv1d_gyr.1.running_mean", "conv1d_gyr.1.running_var", "conv1d_gyr.3.weight", "conv1d_gyr.3.bias", "conv1d_gyr.4.weight", "conv1d_gyr.4.bias", "conv1d_gyr.4.running_mean", "conv1d_gyr.4.running_var", "conv1d_gyr.6.weight", "conv1d_gyr.6.bias", "conv1d_gyr.7.weight", "conv1d_gyr.7.bias", "conv1d_gyr.7.running_mean", "conv1d_gyr.7.running_var", "conv1d_prs.0.weight", "conv1d_prs.0.bias", "conv1d_prs.1.weight", "conv1d_prs.1.bias", "conv1d_prs.1.running_mean", "conv1d_prs.1.running_var", "conv1d_prs.3.weight", "conv1d_prs.3.bias", "conv1d_prs.4.weight", "conv1d_prs.4.bias", "conv1d_prs.4.running_mean", "conv1d_prs.4.running_var", "conv1d_prs.6.weight", "conv1d_prs.6.bias", "conv1d_prs.7.weight", "conv1d_prs.7.bias", "conv1d_prs.7.running_mean", "conv1d_prs.7.running_var", "dense_mean.0.weight", "dense_mean.0.bias", "dense_mean.2.weight", "dense_mean.2.bias", "dense_mean.4.weight", "dense_mean.4.bias", "dense_var.0.weight", "dense_var.0.bias", "dense_var.2.weight", "dense_var.2.bias", "dense_var.4.weight", "dense_var.4.bias". 
	Unexpected key(s) in state_dict: "module.conv1d_acc.0.weight", "module.conv1d_acc.0.bias", "module.conv1d_acc.1.weight", "module.conv1d_acc.1.bias", "module.conv1d_acc.1.running_mean", "module.conv1d_acc.1.running_var", "module.conv1d_acc.1.num_batches_tracked", "module.conv1d_acc.3.weight", "module.conv1d_acc.3.bias", "module.conv1d_acc.4.weight", "module.conv1d_acc.4.bias", "module.conv1d_acc.4.running_mean", "module.conv1d_acc.4.running_var", "module.conv1d_acc.4.num_batches_tracked", "module.conv1d_acc.6.weight", "module.conv1d_acc.6.bias", "module.conv1d_acc.7.weight", "module.conv1d_acc.7.bias", "module.conv1d_acc.7.running_mean", "module.conv1d_acc.7.running_var", "module.conv1d_acc.7.num_batches_tracked", "module.conv1d_gyr.0.weight", "module.conv1d_gyr.0.bias", "module.conv1d_gyr.1.weight", "module.conv1d_gyr.1.bias", "module.conv1d_gyr.1.running_mean", "module.conv1d_gyr.1.running_var", "module.conv1d_gyr.1.num_batches_tracked", "module.conv1d_gyr.3.weight", "module.conv1d_gyr.3.bias", "module.conv1d_gyr.4.weight", "module.conv1d_gyr.4.bias", "module.conv1d_gyr.4.running_mean", "module.conv1d_gyr.4.running_var", "module.conv1d_gyr.4.num_batches_tracked", "module.conv1d_gyr.6.weight", "module.conv1d_gyr.6.bias", "module.conv1d_gyr.7.weight", "module.conv1d_gyr.7.bias", "module.conv1d_gyr.7.running_mean", "module.conv1d_gyr.7.running_var", "module.conv1d_gyr.7.num_batches_tracked", "module.conv1d_prs.0.weight", "module.conv1d_prs.0.bias", "module.conv1d_prs.1.weight", "module.conv1d_prs.1.bias", "module.conv1d_prs.1.running_mean", "module.conv1d_prs.1.running_var", "module.conv1d_prs.1.num_batches_tracked", "module.conv1d_prs.3.weight", "module.conv1d_prs.3.bias", "module.conv1d_prs.4.weight", "module.conv1d_prs.4.bias", "module.conv1d_prs.4.running_mean", "module.conv1d_prs.4.running_var", "module.conv1d_prs.4.num_batches_tracked", "module.conv1d_prs.6.weight", "module.conv1d_prs.6.bias", "module.conv1d_prs.7.weight", "module.conv1d_prs.7.bias", "module.conv1d_prs.7.running_mean", "module.conv1d_prs.7.running_var", "module.conv1d_prs.7.num_batches_tracked", "module.dense_mean.0.weight", "module.dense_mean.0.bias", "module.dense_mean.2.weight", "module.dense_mean.2.bias", "module.dense_mean.4.weight", "module.dense_mean.4.bias", "module.dense_var.0.weight", "module.dense_var.0.bias", "module.dense_var.2.weight", "module.dense_var.2.bias", "module.dense_var.4.weight", "module.dense_var.4.bias". 