In [1]:
import torch
from torch.utils.data import DataLoader, random_split
from cnn.classifier import Model
import yaml
import json
import numpy as np
from torch.utils.data import Dataset


# 설정
BATCH_SIZE = 32
LEARNING_RATE = 0.0001
NUM_EPOCHS = 100
VALIDATION_SPLIT = 0.2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 데이터 준비
with open('/root/25th-conference-fakebusters/model/fakecatcher/config.yaml', 'r') as file:
    config = yaml.safe_load(file)
    fps_standard = config.get('fps_standard', 30)  # 기본값 30
    time_interval = config.get('seg_time_interval', 1)  # 기본값 1
    w = fps_standard * time_interval

with open('/root/25th-conference-fakebusters/model/fakecatcher/ppg_map_results_updated.json', 'r') as file:
    data = json.load(file)

# Extract PPG Maps and Labels
ppg_maps = []
labels = []
for item in data:
    ppg_maps.append(np.array(item['ppg_map']))  # (64, w) 형태의 배열
    labels.append(item['label'])               # 0 or 1

# Convert to numpy arrays
ppg_maps = np.array(ppg_maps)
labels = np.array(labels)     

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        """
        Args:
            data (np.array): Shape (2831, 64, 90)
            labels (np.array): Shape (2831,)
        """
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        """
        Args:
            idx (int): Index
        Returns:
            tuple: (data_sample, label)
        """
        return self.data[idx], self.labels[idx]

dataset = CustomDataset(ppg_maps, labels)
print(f"Dataset size: {len(dataset)}")
print(f"Sample PPG map shape: {dataset[0][0].shape}")
print(f"Sample label: {dataset[0][1]}")
train_size = int((1 - VALIDATION_SPLIT) * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# 모델 준비
model_name = 'baseCNN'  # "ResNet34", "EfficientNetB3"로 변경 가능
model = Model.get_model(model_name, w)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 학습
Model.train(model, train_loader, val_loader, criterion, optimizer, device, epochs=NUM_EPOCHS)

# # 테스트
# test_loss, test_accuracy = Model.test(model, val_loader, criterion, device)
# print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")


Dataset size: 2831
Sample PPG map shape: torch.Size([64, 90])
Sample label: 1
Epoch 1/100, Loss: 0.8494
Validation Loss: 0.6777, Accuracy: 0.5626
Epoch 2/100, Loss: 0.6932
Validation Loss: 0.6775, Accuracy: 0.5855
Epoch 3/100, Loss: 0.6830
Validation Loss: 0.6846, Accuracy: 0.5467
Epoch 4/100, Loss: 0.6825
Validation Loss: 0.6772, Accuracy: 0.5414
Epoch 5/100, Loss: 0.6742
Validation Loss: 0.6923, Accuracy: 0.5344
Epoch 6/100, Loss: 0.6782
Validation Loss: 0.6690, Accuracy: 0.5873
Epoch 7/100, Loss: 0.6693
Validation Loss: 0.6670, Accuracy: 0.5855
Epoch 8/100, Loss: 0.6680
Validation Loss: 0.6686, Accuracy: 0.5873
Epoch 9/100, Loss: 0.6669
Validation Loss: 0.6690, Accuracy: 0.5802
Epoch 10/100, Loss: 0.6585
Validation Loss: 0.6666, Accuracy: 0.5996
Epoch 11/100, Loss: 0.6509
Validation Loss: 0.6671, Accuracy: 0.5944
Epoch 12/100, Loss: 0.6513
Validation Loss: 0.6630, Accuracy: 0.6049
Epoch 13/100, Loss: 0.6454
Validation Loss: 0.6577, Accuracy: 0.6243
Epoch 14/100, Loss: 0.6354
Validat

In [4]:
import matplotlib.pyplot as plt

# Assuming Model.train() returns the training and validation loss history
train_losses, val_losses = Model.train(model, train_loader, val_loader, criterion, optimizer, device, epochs=NUM_EPOCHS)

# Plotting the loss changes
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss Changes Over Epochs')
plt.legend()
plt.show()

Epoch 1/100, Loss: 0.0267
Validation Loss: 1.6056, Accuracy: 0.6631
Epoch 2/100, Loss: 0.0395
Validation Loss: 1.3771, Accuracy: 0.6614
Epoch 3/100, Loss: 0.0266
Validation Loss: 1.8077, Accuracy: 0.6543
Epoch 4/100, Loss: 0.0284
Validation Loss: 1.8008, Accuracy: 0.6561
Epoch 5/100, Loss: 0.0248
Validation Loss: 1.6707, Accuracy: 0.6649
Epoch 6/100, Loss: 0.0267
Validation Loss: 1.6138, Accuracy: 0.6490
Epoch 7/100, Loss: 0.0316
Validation Loss: 1.6858, Accuracy: 0.6684
Epoch 8/100, Loss: 0.0340
Validation Loss: 1.6908, Accuracy: 0.6631
Epoch 9/100, Loss: 0.0299
Validation Loss: 1.7709, Accuracy: 0.6878
Epoch 10/100, Loss: 0.0278
Validation Loss: 1.6210, Accuracy: 0.6684
Epoch 11/100, Loss: 0.0243
Validation Loss: 1.6626, Accuracy: 0.6614
Epoch 12/100, Loss: 0.0306
Validation Loss: 1.7488, Accuracy: 0.6684
Epoch 13/100, Loss: 0.0210
Validation Loss: 1.7285, Accuracy: 0.6667
Epoch 14/100, Loss: 0.0336
Validation Loss: 1.9280, Accuracy: 0.6473
Epoch 15/100, Loss: 0.0259
Validation Loss:

TypeError: cannot unpack non-iterable NoneType object