In [1]:
from torch.utils.data import DataLoader
from resnet import ResNet
from sklearn.model_selection import train_test_split, StratifiedKFold
import numpy as np
from custom_dataset import FinalDataset
import torch
from torch import nn

In [2]:
file_path= "Dataset.hdf5"
dataset = FinalDataset(file_path)

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer, device):
    size = len(dataloader.dataset.indices)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        pred = model(X)
        weights = torch.ones_like(y)
        weights[y == 1] = 50.0
        
        loss = loss_fn(pred, y) 
        loss = (loss * weights).mean()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch % 100 == 0:
            print(f"loss: {loss.item():>7f} [{batch * len(X):>5d}/{size:>5d}]")

def test_loop(dataloader, model, loss_fn, device):
    size = len(dataloader.dataset.indices)
    num_batches = len(dataloader)
    model.eval()
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            all_preds.append(pred.cpu())
            all_targets.append(y.cpu())
    
    all_preds = torch.cat(all_preds)
    all_targets = torch.cat(all_targets)
    
    # Try different thresholds
    for threshold in [0.1, 0.2, 0.3, 0.4, 0.5]:
        predicted = (all_preds > threshold).float()
        correct = (predicted == all_targets).sum().item()
        
        tp = ((predicted == 1) & (all_targets == 1)).sum().item()
        fp = ((predicted == 1) & (all_targets == 0)).sum().item()
        fn = ((predicted == 0) & (all_targets == 1)).sum().item()
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        print(f"Threshold {threshold}: Accuracy={100*correct/size:>0.1f}%, "
              f"Precision={100*precision:.4f}%, Recall={100*recall:.4f}%, F1={f1:.4f}")


In [None]:
# Load dataset labels
labels = np.array(dataset.labels)
file_path = "Dataset.hdf5"  # Define file path

skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")
"""
    BCE with Logits applica la sigmoide. Non dovrei usare BCELoss?
"""
loss_fn = nn.BCELoss(reduction="none")
model = ResNet(12, 24).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
epochs = 5

# Stratified K-Fold training and validation
for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(labels)), labels)):
    print(f"Fold {fold+1}")
    train_subset = FinalDataset(file_path, train_idx)
    val_subset = FinalDataset(file_path, val_idx)
    train_loader = DataLoader(train_subset, batch_size=256, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=256, shuffle=False)
    
    for epoch in range(epochs):
        print(f"Epoch {epoch+1} - Fold {fold+1}\n-------------------------------")
        train_loop(train_loader, model, loss_fn, optimizer, device)
        test_loop(val_loader, model, loss_fn, device)

print("Done!")
torch.save(model.state_dict(), "model.pth")

Using cuda device
Fold 1
Epoch 1 - Fold 1
-------------------------------
loss: 0.808981 [    0/329565]
loss: 1.202881 [25600/329565]
loss: 1.836979 [51200/329565]
loss: 2.486901 [76800/329565]
loss: 2.012782 [102400/329565]
loss: 1.135640 [128000/329565]
loss: 1.204883 [153600/329565]
loss: 1.161057 [179200/329565]
loss: 1.813601 [204800/329565]
loss: 1.474037 [230400/329565]
loss: 1.006143 [256000/329565]
loss: 1.158084 [281600/329565]
loss: 0.819912 [307200/329565]
Threshold 0.1: Accuracy=2.1%, Precision=0.0205, Recall=1.0000, F1=0.0402
Threshold 0.2: Accuracy=5.1%, Precision=0.0211, Recall=0.9960, F1=0.0413
Threshold 0.3: Accuracy=13.2%, Precision=0.0229, Recall=0.9907, F1=0.0447
Threshold 0.4: Accuracy=22.6%, Precision=0.0253, Recall=0.9800, F1=0.0494
Threshold 0.5: Accuracy=32.6%, Precision=0.0284, Recall=0.9614, F1=0.0553
Epoch 2 - Fold 1
-------------------------------
loss: 1.088853 [    0/329565]
loss: 1.440027 [25600/329565]
loss: 1.230439 [51200/329565]
loss: 1.036208 [7680