In [4]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [None]:
# Load Preprocessed .npy Files

base_path = os.path.join(os.path.dirname(__file__), 'data') if '__file__' in globals() else 'data/'
audio_train = np.load(os.path.join(base_path, "audio_filtered_train.npy"), allow_pickle=True).item()
audio_val = np.load(os.path.join(base_path, "audio_filtered_val.npy"), allow_pickle=True).item()
audio_test = np.load(os.path.join(base_path, "audio_filtered_test.npy"), allow_pickle=True).item()

rgb_train = np.load(os.path.join(base_path, "rgb_filtered_train.npy"), allow_pickle=True).item()
rgb_val = np.load(os.path.join(base_path, "rgb_filtered_val.npy"), allow_pickle=True).item()
rgb_test = np.load(os.path.join(base_path, "rgb_filtered_test.npy"), allow_pickle=True).item()

flow_train = np.load(os.path.join(base_path, "flow_filtered_train.npy"), allow_pickle=True).item()
flow_val = np.load(os.path.join(base_path, "flow_filtered_val.npy"), allow_pickle=True).item()
flow_test = np.load(os.path.join(base_path, "flow_filtered_test.npy"), allow_pickle=True).item()


In [None]:
# Dataset Definition

class VideoToAudioDataset(Dataset):
    def __init__(self, rgb_dict, flow_dict, audio_dict):
        self.keys = list(audio_dict.keys())
        self.rgb_dict = rgb_dict
        self.flow_dict = flow_dict
        self.audio_dict = audio_dict

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        key = self.keys[idx]
        rgb_feat = self.rgb_dict[key]      # (18, 768)
        flow_feat = self.flow_dict[key]    # (18, 768)
        x = np.concatenate([rgb_feat, flow_feat], axis=-1)  # (18, 1536)
        y = self.audio_dict[key]           # (18, 128)
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)


In [8]:
# 데이터 불러오기 

train_loader = DataLoader(VideoToAudioDataset(rgb_train, flow_train, audio_train), batch_size=4, shuffle=True)
val_loader = DataLoader(VideoToAudioDataset(rgb_val, flow_val, audio_val), batch_size=4)
test_loader = DataLoader(VideoToAudioDataset(rgb_test, flow_test, audio_test), batch_size=4)


In [None]:
# LSTM-based generator 

class AudioFeatureGenerator(nn.Module):
    def __init__(self, input_dim=1536, hidden_dim=512, num_layers=2, output_dim=218):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, dropout=0.3)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        out, _ = self.lstm(x)  # (B, 18, H)
        return self.fc(out)    # (B, 18, 218)

In [12]:
# Training & Validation Functions

def sequence_mse_loss(pred, target):
    return ((pred - target) ** 2).mean()

def train(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        pred = model(x)  # (B, 18, 218)
        loss = criterion(pred, y)  # element-wise comparison
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
    return total_loss / len(loader.dataset)

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            pred = model(x)
            loss = criterion(pred, y)
            total_loss += loss.item() * x.size(0)
    return total_loss / len(loader.dataset)

In [13]:
# Run Training

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AudioFeatureGenerator().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = sequence_mse_loss

for epoch in range(1, 11):
    train_loss = train(model, train_loader, optimizer, criterion, device)
    val_loss = validate(model, val_loader, criterion, device)
    print(f"[Epoch {epoch}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")


RuntimeError: The size of tensor a (218) must match the size of tensor b (128) at non-singleton dimension 2