# Transformer based autoencoder motion data reconstruction

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import pickle
from tqdm.notebook import tqdm

In [None]:
print(torch.__version__)
print(torch.cuda.is_available())

In [None]:
DATASET_PATH = 'HumanAct12Poses/humanact12poses.pkl'
if not os.path.exists(DATASET_PATH):
    print(f"{DATASET_PATH}에 데이터셋 없음")
else:
    print(f"{DATASET_PATH}에 데이터셋 있음")

In [None]:
data = None
try:
    with open(DATASET_PATH, 'rb') as f:
        data = pickle.load(f, encoding='latin1')
    print("HumanAct12 데이터셋 로드됨")
except:
    print(f"HumanAct12 데이터셋 로드 실패")

In [None]:
if data is not None:
    print(f"데이터셋 키: {data.keys()}")

    seq_num = 0
    for key, seq in data.items():
        print(f"Key: {key}, 시퀀스 개수: {len(seq)}")
        seq_num += len(seq)

        if len(seq) > 0:
            first_seq = seq[0]
            print(f"시퀀스 타입: {type(first_seq)}")
            if isinstance(first_seq, np.ndarray):
                print(f"  첫 번째 시퀀스 형태 (프레임 수, 조인트 수, 좌표 차원): {first_seq.shape}")
    print(f"총 시퀀스 개수: {seq_num}")
else:
    print("데이터셋이 없습니다.")

In [None]:
humanact12_coarse_action_enumerator = {
    1: "warm_up",
    2: "walk",
    3: "run",
    4: "jump",
    5: "drink",
    6: "lift_dumbbell",
    7: "sit",
    8: "eat",
    9: "turn steering wheel",
    10: "phone",
    11: "boxing",
    12: "throw",
}

print(max(item.shape[0] for item in data['joints3D']))

In [None]:
# SMPL 24 조인트 연결
H_CONNECTIONS = [
    (0, 1), (0, 2), (0, 3),      # Pelvis to Hip (L/R) & Spine1
    (1, 4), (2, 5),              # Hips to Knees
    (3, 6),                      # Spine1 to Spine2
    (4, 7), (5, 8),              # Knees to Ankle
    (6, 9),                      # Spine2 to Spine3
    (7, 10), (8, 11),            # Ankle to Foot (Toe)
    (9, 12),                     # Spine3 to Neck
    (12, 15),                    # Neck to Head
    (9, 13), (9, 14),            # Spine3 to Collars (L/R Shoulder base)
    (13, 16), (14, 17),          # Collars to Shoulders (Arm root)
    (16, 18), (17, 19),          # Shoulders to Elbows
    (18, 20), (19, 21),          # Elbows to Wrists
    (20, 22), (21, 23)           # Wrists to Hands
]

# H_CONNECTIONS_2 = [
#     (0, 1), (0, 2), (0, 3),      # Pelvis to Hip (L/R) & Spine1
#     (1, 4), (2, 5),              # Hips to Knees
#     (3, 6),                      # Spine1 to Spine2
#     (4, 7), (5, 8),              # Knees to Ankle
#     (6, 9),                      # Spine2 to Spine3
#     (7, 10), (8, 11),            # Ankle to Foot (Toe)
#     (9, 12),                     # Spine3 to Neck
#     (12, 15),                    # Neck to Head
#     (13, 14),
#     (13, 16), (14, 17),          # Collars to Shoulders (Arm root)
#     (16, 18), (17, 19),          # Shoulders to Elbows
#     (18, 20), (19, 21),          # Elbows to Wrists
#     (20, 22), (21, 23)           # Wrists to Hands
# ]

In [None]:
def visualize_motion(motion_sequence_3d_input: np.ndarray, connections: list, title: str = "Motion Visualization"):
    # motion_sequence_3d_input: (frame_length, 24, 3)
    motion_sequence_3d = motion_sequence_3d_input.copy()

    if motion_sequence_3d.ndim != 3 or motion_sequence_3d.shape[2] != 3:
        raise ValueError(f"입력 motion_sequence_3d는 (N_frames, N_joints, 3) 형태여야 합니다. 현재 형태: {motion_sequence_3d.shape}")

    transformed_coords = np.zeros_like(motion_sequence_3d)
    transformed_coords[:, :, 0] = motion_sequence_3d[:, :, 0]
    transformed_coords[:, :, 1] = motion_sequence_3d[:, :, 2]
    transformed_coords[:, :, 2] = -motion_sequence_3d[:, :, 1]
    
    motion_sequence_3d = transformed_coords

    num_frames = motion_sequence_3d.shape[0]

    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111, projection='3d')

    min_coords = motion_sequence_3d.min(axis=(0, 1))
    max_coords = motion_sequence_3d.max(axis=(0, 1))
    
    max_range = np.array([max_coords - min_coords]).max()
    mid_x = (min_coords[0] + max_coords[0]) * 0.5
    mid_y = (min_coords[1] + max_coords[1]) * 0.5
    mid_z = (min_coords[2] + max_coords[2]) * 0.5
    
    # 프레임 업데이트 함수
    def update(frame):
        ax.cla()
        
        ax.set_xlim(mid_x - max_range / 2, mid_x + max_range / 2)
        ax.set_ylim(mid_y - max_range / 2, mid_y + max_range / 2)
        ax.set_zlim(mid_z - max_range / 2, mid_z + max_range / 2)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.set_title(f"{title} (Frame: {frame}/{num_frames})")
        ax.view_init(elev=10, azim=-90)

        current_frame_joints = motion_sequence_3d[frame]

        # 조인트 그리기
        ax.scatter(current_frame_joints[:, 0],
                   current_frame_joints[:, 1],
                   current_frame_joints[:, 2],
                   c='r', marker='o', s=10)

        # 뼈대(연결) 그리기
        for connection in connections:
            joint_from = current_frame_joints[connection[0]]
            joint_to = current_frame_joints[connection[1]]
            ax.plot([joint_from[0], joint_to[0]],
                    [joint_from[1], joint_to[1]],
                    [joint_from[2], joint_to[2]],
                    color='blue', linewidth=2)

    ani = animation.FuncAnimation(fig, update, frames=num_frames, interval=50, blit=False, repeat=False)
    
    print("애니메이션 생성 중...")
    try:
        html_video = HTML(ani.to_jshtml())
        print("애니메이션 생성 완료.")
        return html_video
    except Exception as e:
        print(f"애니메이션 생성 또는 표시 중 오류 발생: {e}")
        plt.close(fig)
        return None

# if 'data' in locals() and data is not None and 'joints3D' in data and len(data['joints3D']) > 0:
#     sample_motion_sequence_3d = data['joints3D'][50]

#     print(f"시각화할 샘플 모션 시퀀스 형태: {sample_motion_sequence_3d.shape}")

#     motion_video = visualize_motion(sample_motion_sequence_3d, H_CONNECTIONS, title="Original Motion Sample")

#     if motion_video:
#         display(motion_video)
#     else:
#         print("애니메이션을 표시할 수 없습니다.")
        
# else:
#     print("시각화할 모션 데이터가 없습니다.")

In [None]:
# 하이퍼파라미터 설정
class HParams:
    def __init__(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.batch_size = 32
        self.num_epochs = 100
        self.learning_rate = 1e-4
        self.hidden_dim = 256  # BERT 모델의 hidden dimension
        self.num_layers = 8    # BERT 인코더 레이어 수
        self.num_heads = 8     # BERT 어텐션 헤드 수
        self.dropout_rate = 0.05
        self.mask_ratio = 0.15
        self.frame_length = 64
        self.input_dim = 72
        self.output_dim = 72
        self.train_split_ratio = 0.7
        self.val_split_ratio = 0.2
        self.test_split_ratio = 0.1

hparams = HParams()

In [None]:
class HumanAct12Dataset(Dataset):
    def __init__(self, data: dict, frame_length: int):
        self.data = data
        self.frame_length = frame_length
        self.sequences = []
        self.np_sequences = []
        self.scaler = StandardScaler()
        self.normalized_sequences = []

        if 'joints3D' in data:
            for i in range(len(data['joints3D'])):
                if data['joints3D'][i].shape[-1] == 3:
                    # sequence: (프레임수, 72)
                    sequence: np.ndarray = data['joints3D'][i].reshape(-1, 72)
                    current_length = sequence.shape[0]
                    
                    if current_length >= frame_length:
                        # 시퀀스가 frame_length보다 길면 트리밍
                        trimmed_sequence = sequence[:frame_length, :]
                        self.sequences.append(trimmed_sequence)
                    else:
                        # 시퀀스가 frame_length보다 짧으면 패딩
                        padded_sequence = np.zeros((frame_length, 72))
                        padded_sequence[:current_length, :] = sequence
                        
                        # 나머지 부분을 마지막 프레임으로 패딩
                        if current_length > 0:
                            last_frame = sequence[-1, :]
                            padded_sequence[current_length:, :] = last_frame
                        
                        self.sequences.append(padded_sequence)
                else:
                    print("joints3D shape이 다름")
            
            #정규화
            # (num_total_sequences, frame_length, input_dim)
            self.np_sequences = np.array(self.sequences)

            # (num_total_sequences * frame_length, input_dim)
            self.scaler.fit(self.np_sequences.reshape(-1, hparams.input_dim))
            
            # (num_total_sequences, frame_length, input_dim)
            self.normalized_sequences = self.scaler.transform(self.np_sequences.reshape(-1, hparams.input_dim)).reshape(self.np_sequences.shape)

    def __len__(self):
        return len(self.normalized_sequences)

    # 인덱싱 할 때 필요함. tensor로 변환
    def __getitem__(self, idx):
        sequence = self.normalized_sequences[idx]
        return torch.tensor(sequence, dtype=torch.float32)

full_dataset = HumanAct12Dataset(data, hparams.frame_length)

total_size = len(full_dataset)
train_size = int(hparams.train_split_ratio * total_size)
val_size = int(hparams.val_split_ratio * total_size)
test_size = total_size - train_size - val_size # int(hparams.test_split_ratio * total_size)

train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

train_dataloader = DataLoader(train_dataset, batch_size=hparams.batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=hparams.batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=hparams.batch_size, shuffle=False)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

if total_size > 0:
    sample_data = full_dataset[0]
    print(f"Sample data shape: {sample_data.shape}")


In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout_rate=0.1, max_len=64):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout_rate) # overfitting 방지

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0) # (1, max_len, d_model)
        self.register_buffer('pe', pe) # 버퍼에 저장해서 pe는 학습 안함

    def forward(self, x: torch.Tensor):
        # x: (batch_size, seq_len, d_model)
        x = x + self.pe[:, :x.size(1), :] # x.size(1)이 seq_len. 관례적으로 x의 0번째에 seq_len을 준다는데 왜 굳이 그러는지 모르겠음
        return self.dropout(x)

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout_rate):
        super(MultiHeadSelfAttention, self).__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.d_model = d_model

        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.fc_out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, src_mask: torch.Tensor = None):
        # query, key, value: (batch_size, seq_len, d_model)
        batch_size = query.shape[0]

        # multi head로 나누기
        query = self.wq(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        key = self.wk(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        value = self.wv(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # query, key, value: (batch_size, num_heads, seq_len, d_k)
        # scores는 집중 정도를 나타내는 어텐션 점수
        scores: torch.Tensor = (query @ key.transpose(2, 3)) / np.sqrt(self.d_k)
        
        # 패딩한 부분은 학습하지 않도록 하려고 했는데 create_mask에서 패딩된 부분을 마스킹해버리면 loss값이 이상해짐 -> 실패
        if src_mask is not None:
            # src_mask: (batch_size, 1, 1, seq_len) or (batch_size, 1, seq_len, seq_len)
            scores = scores.masked_fill(src_mask == 0, -1e9) # softmax 계산 시 0으로 나누기 방지

        attention_weights = torch.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        output: torch.Tensor = attention_weights @ value

        # multi head 결과 concat하고 fc_out 통과
        # contiguous 해줘야 view 할때 메모리 연속성 보장 가능 (transpose할때 메모리 저장 방식 바뀔 수 있다고 함)
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.fc_out(output)
        # output: (batch_size, seq_len, d_model)
        return output

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout_rate):
        super(FeedForward, self).__init__()
        self.sequential = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            # nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(d_ff, d_model),
        )

    def forward(self, x):
        return self.sequential(x)

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout_rate):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadSelfAttention(d_model, num_heads, dropout_rate)
        self.feed_forward = FeedForward(d_model, d_ff, dropout_rate)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)

    def forward(self, x, src_mask=None):
        # self attention 레이어 후 normalization
        # norm 쓰면 vanishing이나 exploding 방지된다고 함
        attn_output = self.self_attn(x, x, x, src_mask=src_mask)
        x = self.norm1(x + self.dropout1(attn_output)) # residual connection 쓰는 이유: 이전 레이어 정보 손실 방지 + 깊은 학습

        # feed forward 레이어 후 normalization
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout2(ff_output))
        return x

In [None]:
class MotionBERT(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_heads, dropout_rate, frame_length):
        super(MotionBERT, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.frame_length = frame_length

        self.embedding = nn.Linear(input_dim, hidden_dim)
        self.positional_encoding = PositionalEncoding(hidden_dim, dropout_rate, max_len=frame_length)
        
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(hidden_dim, num_heads, hidden_dim * 4, dropout_rate) for _ in range(num_layers)
        ])
        
        self.output_layer = nn.Linear(hidden_dim, input_dim)

    def forward(self, x, src_mask=None):
        # x: (batch_size, frame_length, input_dim)

        x = self.embedding(x) # (batch_size, frame_length, hidden_dim)
        x = self.positional_encoding(x) # (batch_size, frame_length, hidden_dim)
        
        for layer in self.encoder_layers:
            x = layer(x, src_mask)
        
        output = self.output_layer(x) # (batch_size, frame_length, input_dim)
        return output

# 모델 인스턴스 생성
model = MotionBERT(
    input_dim=hparams.input_dim,
    hidden_dim=hparams.hidden_dim,
    num_layers=hparams.num_layers,
    num_heads=hparams.num_heads,
    dropout_rate=hparams.dropout_rate,
    frame_length=hparams.frame_length
).to(hparams.device)

print(model)

In [None]:
def create_mask(input_tensor: torch.Tensor, mask_ratio: float, mask_type: str):
    # input_tensor: (batch_size, frame_length, input_dim)
    batch_size, frame_length, _ = input_tensor.shape
    
    masked_input = input_tensor.clone()
    target = input_tensor.clone()
    mask_indices = torch.zeros_like(input_tensor, dtype=torch.bool)

    if mask_type == 'frame':
        # 프레임 마스킹: 무작위 프레임 마스킹
        num_frames_to_mask = int(frame_length * mask_ratio)
        if num_frames_to_mask == 0 and mask_ratio > 0:
            num_frames_to_mask = 1

        for b in range(batch_size):
            masked_frames_idx = torch.randperm(frame_length)[:num_frames_to_mask]
            for idx in masked_frames_idx:
                rand_val = torch.rand(1).item()
                if rand_val < 0.8: # 80%는 0.0으로 마스킹
                    masked_input[b, idx, :] = 0.0 
                # elif rand_val < 0.9: # 10%는 원래 값 유지
                #     masked_input[b, idx, :] = input_tensor[b, idx, :]
                else: # 10%는 랜덤 프레임으로 대체 -> 마스킹 토큰 없으니까 그냥 20%를 랜덤 프레임으로 대체
                    random_frame_idx = torch.randint(0, frame_length, (1,)).item()
                    masked_input[b, idx, :] = input_tensor[b, random_frame_idx, :]
            mask_indices[b, masked_frames_idx, :] = True

    elif mask_type == 'joint':
        # 관절 마스킹: 특정 관절의 모든 프레임 마스킹
        num_individual_joints_to_mask = int(24 * mask_ratio)
        if num_individual_joints_to_mask == 0 and mask_ratio > 0:
            num_individual_joints_to_mask = 1

        for b in range(batch_size):
            # 마스킹할 조인트 인덱스
            masked_joint_indices = torch.randperm(24)[:num_individual_joints_to_mask]
            
            for joint_idx_base in masked_joint_indices:
                feature_start_idx = joint_idx_base * 3
                feature_end_idx = feature_start_idx + 3
                masked_input[b, :, feature_start_idx:feature_end_idx] = 0.0
                mask_indices[b, :, feature_start_idx:feature_end_idx] = True
            
    elif mask_type == 'mixed':
        # 혼합 마스킹: frame + joint masking
        num_frames_to_mask = int(frame_length * mask_ratio)//2 # 각각 반절씩
        if num_frames_to_mask == 0 and mask_ratio > 0:
            num_frames_to_mask = 1

        for b in range(batch_size):
            masked_frames_idx = torch.randperm(frame_length)[:num_frames_to_mask]
            for idx in masked_frames_idx:
                rand_val = torch.rand(1).item()
                if rand_val < 0.8: # 80%는 0.0으로 마스킹
                    masked_input[b, idx, :] = 0.0 
                # elif rand_val < 0.9: # 10%는 원래 값 유지
                #     masked_input[b, idx, :] = input_tensor[b, idx, :] # 원래 값을 유지 (이 경우 마스킹된 것으로 처리되지만, 실제 값은 그대로)
                else: # 10%는 랜덤 프레임으로 대체
                    random_frame_idx = torch.randint(0, frame_length, (1,)).item()
                    masked_input[b, idx, :] = input_tensor[b, random_frame_idx, :]
            mask_indices[b, masked_frames_idx, :] = True

        num_individual_joints_to_mask = int(24 * mask_ratio)//2
        if num_individual_joints_to_mask == 0 and mask_ratio > 0:
            num_individual_joints_to_mask = 1

        for b in range(batch_size):
            # 마스킹할 조인트 인덱스
            masked_joint_indices = torch.randperm(24)[:num_individual_joints_to_mask]

            for joint_idx_base in masked_joint_indices:
                feature_start_idx = joint_idx_base * 3
                feature_end_idx = feature_start_idx + 3
                masked_input[b, :, feature_start_idx:feature_end_idx] = 0.0
                mask_indices[b, :, feature_start_idx:feature_end_idx] = True
    else:
        raise ValueError("Invalid mask_type. Choose from 'frame', 'joint', 'mixed'.")

    # 마스킹 되지 않은 부분은 0 -> 나중에 손실 계산할 때 무시하도록
    target = target * mask_indices.float()
    
    
    return masked_input, target, mask_indices

# # 마스킹 전략 테스트
# dummy_input = torch.randn(hparams.batch_size, hparams.frame_length, hparams.input_dim)

# # 프레임 마스킹
# masked_input_frame, target_frame, mask_indices_frame = create_mask(dummy_input, hparams.mask_ratio, 'frame')
# print(f"Frame Masking - Masked input shape: {masked_input_frame.shape}, Target shape: {target_frame.shape}")
# print(f"Frame Masking - Number of masked elements: {mask_indices_frame.sum().item()}")

# # 관절 마스킹
# masked_input_joint, target_joint, mask_indices_joint = create_mask(dummy_input, hparams.mask_ratio, 'joint')
# print(f"Joint Masking - Masked input shape: {masked_input_joint.shape}, Target shape: {target_joint.shape}")
# print(f"Joint Masking - Number of masked elements: {mask_indices_joint.sum().item()}")

# # 혼합 마스킹
# masked_input_mixed, target_mixed, mask_indices_mixed = create_mask(dummy_input, hparams.mask_ratio, 'mixed')
# print(f"Mixed Masking - Masked input shape: {masked_input_mixed.shape}, Target shape: {target_mixed.shape}")
# print(f"Mixed Masking - Number of masked elements: {mask_indices_mixed.sum().item()}")

In [None]:
# 학습 루프
def train_model(model: nn.Module, dataloader: DataLoader, optimizer: optim.Optimizer, criterion: nn.Module, hparams: HParams, mask_type: str, epoch_info: str):
    model.train()
    total_loss = 0
    with tqdm(total=len(dataloader), desc=f"{epoch_info} ({mask_type} Masking, Train)") as pbar:
        for batch_idx, data in enumerate(dataloader):
            data = data.to(hparams.device) # (batch_size, frame_length, input_dim)

            # 마스킹 적용
            masked_data, target_data, mask_indices = create_mask(data, hparams.mask_ratio, mask_type)
            
            optimizer.zero_grad()
            
            # 모델 예측
            output = model(masked_data) # (batch_size, frame_length, output_dim)
            
            # 마스킹된 부분에 대해서만 손실 계산
            loss = criterion(output, target_data) * mask_indices.float()

            loss = loss.sum() / mask_indices.sum() # 마스킹된 원소의 개수로 나누어 평균 손실 계산
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            pbar.set_postfix(loss=loss.item())
            pbar.update(1)
            
    return total_loss / len(dataloader)

# 평가 루프
def evaluate_model(model: nn.Module, dataloader: DataLoader, criterion: nn.Module, hparams: HParams, mask_type: str, epoch_info: str):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        with tqdm(total=len(dataloader), desc=f"{epoch_info} ({mask_type} Masking, Val)") as pbar:
            for batch_idx, data in enumerate(dataloader):
                data = data.to(hparams.device)
                
                masked_data, target_data, mask_indices = create_mask(data, hparams.mask_ratio, mask_type)
                
                output = model(masked_data)
                
                loss = criterion(output, target_data) * mask_indices.float()
                
                loss = loss.sum() / mask_indices.sum()
                
                total_loss += loss.item()
                pbar.set_postfix(loss=loss.item())
                pbar.update(1)
                
    return total_loss / len(dataloader)

# 마스킹 방법별 학습
mask_types = ['frame', 'joint', 'mixed']
train_losses = {mt: [] for mt in mask_types}
val_losses = {mt: [] for mt in mask_types}

for mask_type in mask_types:
    print(f"\n--- Training with {mask_type} Masking ---")
    
    # 각 마스킹 타입별로 모델을 새로 초기화
    model = MotionBERT(
        input_dim=hparams.input_dim,
        hidden_dim=hparams.hidden_dim,
        num_layers=hparams.num_layers,
        num_heads=hparams.num_heads,
        dropout_rate=hparams.dropout_rate,
        frame_length=hparams.frame_length
    ).to(hparams.device)
    
    # 모델이 새로 초기화될 때마다 옵티마이저도 새로 생성
    optimizer = optim.AdamW(model.parameters(), lr=hparams.learning_rate, weight_decay=0.01)
    criterion = nn.MSELoss(reduction='none') # reduction='none'으로 각 원소별 손실 계산

    for epoch in range(hparams.num_epochs):
        epoch_info = f"Epoch {epoch+1}/{hparams.num_epochs}"
        
        # train_model 함수 호출
        avg_train_loss = train_model(model, train_dataloader, optimizer, criterion, hparams, mask_type, epoch_info)
        train_losses[mask_type].append(avg_train_loss)
        
        # evaluate_model 함수 호출
        avg_val_loss = evaluate_model(model, val_dataloader, criterion, hparams, mask_type, epoch_info)
        val_losses[mask_type].append(avg_val_loss)

        print(f"{epoch_info} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

    # 각 마스킹 타입별로 모델 저장
    torch.save(model.state_dict(), f'motion_bert_model_{mask_type}2.pth')
    print(f"Model saved: motion_bert_model_{mask_type}2.pth")

print("\n--- Training Complete ---")

In [None]:
# 학습 손실 시각화
plt.figure(figsize=(12, 6))
for mask_type in mask_types:
    print(f"{mask_type} train_losses: {train_losses[mask_type][-1]}")
    print(f"{mask_type} val_losses: {val_losses[mask_type][-1]}")
    plt.plot(train_losses[mask_type], label=f'{mask_type} Train Loss')
    plt.plot(val_losses[mask_type], label=f'{mask_type} Val Loss', linestyle='--')

plt.title('Training and Validation Loss over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss (MSE)')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
def denormalize_motion(normalized_data_tensor: torch.Tensor, scaler: StandardScaler, input_dim: int, frame_length: int):
    # normalized_data_tensor: (frame_length, input_dim) or (batch_size, frame_length, input_dim)
    normalized_data_np = normalized_data_tensor.cpu().numpy()

    # Batch 차원 제거
    if normalized_data_np.ndim == 3:
        normalized_data_np = normalized_data_np[0]
    
    # (frame_length, input_dim)
    denormalized_data_flat = scaler.inverse_transform(normalized_data_np.reshape(-1, input_dim))
    
    # (frame_length, 24, 3)
    denormalized_data = denormalized_data_flat.reshape(frame_length, 24, 3)
    
    return denormalized_data

In [None]:
model.eval()
with torch.no_grad():
    for i in range(3):
        motion: torch.Tensor = test_dataset[i]
        denormalized_motion = denormalize_motion(motion, full_dataset.scaler, hparams.input_dim, hparams.frame_length)
        display(visualize_motion(denormalized_motion, H_CONNECTIONS, title="Original Motion Sample"))
        for mask_type in mask_types:
            print(f"\n--- Visualizing {mask_type} Masking Reconstruction ---")
            
            # 해당 마스킹 타입으로 학습된 모델 로드
            try:
                model_path = f'motion_bert_model_{mask_type}2.pth'
                model.load_state_dict(torch.load(model_path, map_location=hparams.device))
                print(f"Loaded model from {model_path}")
            except FileNotFoundError:
                print(f"Model file {model_path} not found. Skipping visualization for {mask_type}.")
                continue

            # 테스트 데이터셋에서 하나의 샘플 가져오기
            if len(test_dataset) > 0:
                sample_data = motion.unsqueeze(0).to(hparams.device) # (1, frame_length, input_dim)
                
                # 마스킹 적용
                masked_sample, target_sample, mask_indices_sample = create_mask(sample_data, hparams.mask_ratio, mask_type)
                
                # 모델 예측
                reconstructed_sample = model(masked_sample)
                
                # 재구성된 샘플에서 마스킹된 부분만 복구하고, 마스킹되지 않은 부분은 원본 유지
                # masked_sample은 마스킹된 부분이 0이므로, reconstructed_sample에서 마스킹된 부분만 사용
                final_reconstructed_sample = masked_sample.clone()
                final_reconstructed_sample[mask_indices_sample] = reconstructed_sample[mask_indices_sample]

                reconstructed_sequence = denormalize_motion(final_reconstructed_sample, full_dataset.scaler, hparams.input_dim, hparams.frame_length)
                motion_video = visualize_motion(reconstructed_sequence, H_CONNECTIONS, title=f"{mask_type}, Reconstructed Motion Sample")

                if motion_video:
                    display(motion_video)
                else:
                    print("애니메이션을 표시할 수 없습니다.")
            else:
                print("No data available in the test dataset for visualization.")