In [1]:
import os
import glob
import numpy as np
from sklearn.model_selection import StratifiedKFold
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from loader import FeatureLoader
from model import AttentionMIL

In [2]:
def phase_step(model, dataloader, optimizer, criterion, phase, device):

    if phase == 'train':
        train=True
        model.train()

    if phase == 'valid':
        train=False
        model.eval()

    with torch.set_grad_enabled(train):

        phase_loss, phase_metr = 0.0, 0.0

        for data in dataloader:

            X = data['X'].to(device)
            Y = data['Y'].to(device)

            # Forward pass
            optimizer.zero_grad()
            P, A = model(X)
            loss = criterion(P, Y.long())

            # Backward Pass
            if train:
                loss.backward()
                optimizer.step()

            with torch.no_grad():
                P = torch.argmax(P, dim=-1)
                metr = P.eq(Y).sum()

            phase_loss += loss.item()
            phase_metr += metr.item()

        phase_loss = phase_loss/len(dataloader)
        phase_metr = phase_metr/len(dataloader)

    return phase_loss, phase_metr

In [3]:
def main(config):

    
    print(f'\nFold-{config["fold"]} ...', flush=True)

    
    # Arange files and labels
    
    files = sorted(glob.glob(f'{config["data_dir"]}/*.h5'))
    labels = np.array([int(f.split('/')[-1][5]) for f in files]) - 1

    
    # K-Fold Split
    
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)
    
    train_indices, valid_indices = list(skf.split(files, labels))[config["fold"]]
        
    train_samples = [{"X": files[i], "Y":labels[i]} for i in train_indices]
    valid_samples = [{"X": files[i], "Y":labels[i]} for i in valid_indices]
    
    np.random.shuffle(train_samples)
    
    print(f'\nNumber of train files: {len(train_samples)}', flush=True)
    print(f'Number of valid_files: {len(valid_samples)}', flush=True)

    
    # Create dataset
    
    train_ds = FeatureLoader(train_samples)
    valid_ds = FeatureLoader(valid_samples)
    
    train_ds = DataLoader(train_ds, batch_size=None, shuffle=True, pin_memory=True)
    valid_ds = DataLoader(valid_ds, batch_size=None, shuffle=False, pin_memory=True)

    
    # Set device
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f'\nDevice: {device}')

    
    # Create model
    model = AttentionMIL(feature_size=1024, classes=3).to(device)
    print(f'\nModel compiled', flush=True)

    
    # Weights
    
    weights = np.array([96, 96, 48])
    weights = 1 / weights * np.sum(weights) / 3
    weights = torch.tensor(weights).float().to(device)
    
    
    # Loss, optimizer and scheduler
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), 1e-3)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=5, verbose=False)
    
    checkpoint_path = f'{config["save_dir"]}/model_f{config["fold"]}.pt'

    
    # Training
    
    print(f'\nTraining ...')

    monitor_metr = 0
    
    for epoch in range(0, config["epochs"]):

        print(f'\nEpoch {epoch:03}/{config["epochs"]:03}')

        for phase in ['train', 'valid']:
            
            dataloader = train_ds if phase == 'train' else valid_ds
            
            phase_loss, phase_metr = phase_step(model, dataloader, optimizer, criterion, phase, device)
            print(f'{phase}_loss: {phase_loss:0.4f} - {phase}_accuracy: {phase_metr:0.4f}')

        if phase_metr>monitor_metr:
            
            state = {'epoch': epoch,
                     'state_dict': model.state_dict(),
                     'optimizer': optimizer.state_dict()}
            
            torch.save(state, checkpoint_path)
            print(f'checkpoint saved: {checkpoint_path}')
            
            monitor_metr = phase_metr
            
        scheduler.step(phase_metr)
        

    print(f'\nComplete! Validation Accuracy = {monitor_metr}')

In [4]:
if __name__ == "__main__":
        
    for fold in range(5):

        config = {
            "data_dir": '/mnt/scratch/crc/data/features/features_0512',
            "save_dir": '/mnt/scratch/crc/models/models_0512',
            "epochs": 25,
            "fold": fold,
            "seed": 0,
        }

        if not os.path.exists(config["save_dir"]):
            os.makedirs(config["save_dir"])
            
        main(config)
        
        print('\n\n')


Fold-0 ...

Number of train files: 240
Number of valid_files: 60

Device: cuda

Model compiled

Training ...

Epoch 000/025
train_loss: 0.9069 - train_accuracy: 0.6417
valid_loss: 0.4044 - valid_accuracy: 0.8333
checkpoint saved: /mnt/scratch/crc/models/models_0512/model_f0.pt

Epoch 001/025
train_loss: 0.3491 - train_accuracy: 0.8667
valid_loss: 0.5030 - valid_accuracy: 0.8333

Epoch 002/025
train_loss: 0.1553 - train_accuracy: 0.9458
valid_loss: 0.4328 - valid_accuracy: 0.7667

Epoch 003/025
train_loss: 0.1565 - train_accuracy: 0.9458
valid_loss: 0.0948 - valid_accuracy: 0.9500
checkpoint saved: /mnt/scratch/crc/models/models_0512/model_f0.pt

Epoch 004/025
train_loss: 0.1209 - train_accuracy: 0.9583
valid_loss: 0.1029 - valid_accuracy: 0.9500

Epoch 005/025
train_loss: 0.0524 - train_accuracy: 0.9875
valid_loss: 0.0355 - valid_accuracy: 1.0000
checkpoint saved: /mnt/scratch/crc/models/models_0512/model_f0.pt

Epoch 006/025
train_loss: 0.1581 - train_accuracy: 0.9542
valid_loss: 0.0

valid_loss: 0.2463 - valid_accuracy: 0.9500

Epoch 021/025
train_loss: 0.0023 - train_accuracy: 1.0000
valid_loss: 0.2477 - valid_accuracy: 0.9500

Epoch 022/025
train_loss: 0.0021 - train_accuracy: 1.0000
valid_loss: 0.2478 - valid_accuracy: 0.9500

Epoch 023/025
train_loss: 0.0021 - train_accuracy: 1.0000
valid_loss: 0.2480 - valid_accuracy: 0.9500

Epoch 024/025
train_loss: 0.0021 - train_accuracy: 1.0000
valid_loss: 0.2482 - valid_accuracy: 0.9500

Complete! Validation Accuracy = 0.9666666666666667




Fold-3 ...

Number of train files: 240
Number of valid_files: 60

Device: cuda

Model compiled

Training ...

Epoch 000/025
train_loss: 0.8046 - train_accuracy: 0.6917
valid_loss: 0.2466 - valid_accuracy: 0.9167
checkpoint saved: /mnt/scratch/crc/models/models_0512/model_f3.pt

Epoch 001/025
train_loss: 0.2252 - train_accuracy: 0.8958
valid_loss: 0.2337 - valid_accuracy: 0.8833

Epoch 002/025
train_loss: 0.2180 - train_accuracy: 0.9167
valid_loss: 0.1076 - valid_accuracy: 0.9667
chec