In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class MeshDataset(Dataset):
    def __init__(self, mesh_data, pairs=None):
        """
        Args:
            mesh_data (np.ndarray): (859, 6890, 3) shape의 numpy 배열.
            pairs (list of tuple): 각 tuple이 (i, j) 인덱스 pair. 기본값은 연속된 frame pair.
        """
        self.mesh_data = mesh_data
        if pairs is None:
            # 기본적으로 연속된 frame pair (0,1), (1,2), ..., (857,858)를 사용
            self.pairs = [(i, i+1) for i in range(mesh_data.shape[0]-1)]
        else:
            self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        i, j = self.pairs[idx]
        frame_i = self.mesh_data[i]  # shape: (6890, 3)
        frame_j = self.mesh_data[j]  # shape: (6890, 3)
        # torch.Tensor로 변환 (필요한 경우)
        return torch.from_numpy(frame_i).float(), torch.from_numpy(frame_j).float()

# sample_data 읽기
sample_data = np.load("../../../results/5 Effective Boxing Combos To Drill In_chunk2/mesh.npy")

# Dataset 생성 (원하는 pair 방식을 사용)
dataset = MeshDataset(sample_data)

# DataLoader 생성 (예: batch_size 4, shuffling)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# DataLoader 사용 예제: 한 배치에서 각 pair를 가져오기
for batch in dataloader:
    frames_i, frames_j = batch
    print("Frame_i batch shape:", frames_i.shape)  # (batch_size, 6890, 3)
    print("Frame_j batch shape:", frames_j.shape)  # (batch_size, 6890, 3)
    break


Frame_i batch shape: torch.Size([4, 6890, 3])
Frame_j batch shape: torch.Size([4, 6890, 3])


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

class ConvEncoder(nn.Module):
    def __init__(self, latent_dim=64):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels=3, out_channels=32, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1)
        # conv3 출력: (batch, 128, 861) → flatten하면 128*861
        self.fc_mu = nn.Linear(128 * 861, latent_dim)
        self.fc_logvar = nn.Linear(128 * 861, latent_dim)
        
    def forward(self, x):
        # x: (batch, 6890, 3) → (batch, 3, 6890)
        x = x.permute(0, 2, 1)
        x = F.relu(self.conv1(x))  # (batch, 32, 3445)
        x = F.relu(self.conv2(x))  # (batch, 64, 1722)
        x = F.relu(self.conv3(x))  # (batch, 128, 861)
        x = x.view(x.size(0), -1)  # (batch, 128*861)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

class ConvDecoder(nn.Module):
    def __init__(self, latent_dim=64):
        super().__init__()
        # latent_dim → 128*861
        self.fc = nn.Linear(latent_dim, 128 * 861)
        # deconv1: (batch, 128, 861) → (batch, 64, 1722)
        self.deconv1 = nn.ConvTranspose1d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1)
        # deconv2: (batch, 64, 1722) → (batch, 32, 3445) with output_padding=1
        self.deconv2 = nn.ConvTranspose1d(in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=1, output_padding=1)
        # deconv3: (batch, 32, 3445) → (batch, 3, 6890)
        self.deconv3 = nn.ConvTranspose1d(in_channels=32, out_channels=3, kernel_size=4, stride=2, padding=1)
        
    def forward(self, z):
        # z: (batch, latent_dim)
        x = self.fc(z)  # (batch, 128*861)
        x = x.view(x.size(0), 128, 861)  # (batch, 128, 861)
        x = F.relu(self.deconv1(x))       # (batch, 64, 1722)
        x = F.relu(self.deconv2(x))       # (batch, 32, 3445)
        x = self.deconv3(x)               # (batch, 3, 6890)
        x = x.permute(0, 2, 1)            # (batch, 6890, 3)
        return x

class ConvVAE(nn.Module):
    def __init__(self, latent_dim=64):
        super().__init__()
        self.encoder = ConvEncoder(latent_dim=latent_dim)
        self.decoder = ConvDecoder(latent_dim=latent_dim)
        
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = reparameterize(mu, logvar)
        x_hat = self.decoder(z)
        return x_hat, mu, logvar


In [9]:
def vae_loss(recon_x, x, mu, logvar):
    # 재구성 손실 (MSE)와 KL divergence 계산 (배치당 평균)
    recon_loss = F.mse_loss(recon_x, x, reduction='sum') / x.size(0)
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)
    return recon_loss + kl_div, recon_loss, kl_div


In [10]:
import torch.optim as optim

latent_dim = 64
lr = 1e-3
epochs = 10

vae = ConvVAE(latent_dim=latent_dim)
optimizer = optim.Adam(vae.parameters(), lr=lr)

for epoch in range(epochs):
    vae.train()
    total_loss = 0.0
    for frames_i, frames_j in dataloader:
        # 예시로 frames_i 사용 (shape: (batch, 6890, 3))
        optimizer.zero_grad()
        x_hat, mu, logvar = vae(frames_i)
        loss, recon_loss, kl_div = vae_loss(x_hat, frames_i, mu, logvar)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * frames_i.size(0)
    avg_loss = total_loss / len(dataset)
    print(f"Epoch [{epoch+1}/{epochs}] - Loss: {avg_loss:.4f}")


Epoch [1/10] - Loss: nan
Epoch [2/10] - Loss: nan
Epoch [3/10] - Loss: nan
Epoch [4/10] - Loss: nan
Epoch [5/10] - Loss: nan
Epoch [6/10] - Loss: nan
Epoch [7/10] - Loss: nan
Epoch [8/10] - Loss: nan
Epoch [9/10] - Loss: nan
Epoch [10/10] - Loss: nan
