In [None]:
#TEM nanopore detection with Anti-overfitting measures (without all the bells and whistles)
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import cv2
import tkinter as tk
from tkinter import filedialog
from tqdm import tqdm
from datetime import datetime
import matplotlib.pyplot as plt

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.2):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout_rate),  # Add dropout for regularization
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, dropout_rate=0.2):
        super().__init__()
        
        # Encoder
        self.conv1 = DoubleConv(in_channels, 64, dropout_rate)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128, dropout_rate)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256, dropout_rate)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512, dropout_rate)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512, 1024, dropout_rate)
        
        # Decoder
        self.up6 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.conv6 = DoubleConv(1024, 512, dropout_rate)
        self.up7 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv7 = DoubleConv(512, 256, dropout_rate)
        self.up8 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv8 = DoubleConv(256, 128, dropout_rate)
        self.up9 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv9 = DoubleConv(128, 64, dropout_rate)
        
        self.conv10 = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        conv1 = self.conv1(x)
        pool1 = self.pool1(conv1)
        
        conv2 = self.conv2(pool1)
        pool2 = self.pool2(conv2)
        
        conv3 = self.conv3(pool2)
        pool3 = self.pool3(conv3)
        
        conv4 = self.conv4(pool3)
        pool4 = self.pool4(conv4)
        
        conv5 = self.conv5(pool4)
        
        # Decoder
        up6 = self.up6(conv5)
        merge6 = torch.cat([up6, conv4], dim=1)
        conv6 = self.conv6(merge6)
        
        up7 = self.up7(conv6)
        merge7 = torch.cat([up7, conv3], dim=1)
        conv7 = self.conv7(merge7)
        
        up8 = self.up8(conv7)
        merge8 = torch.cat([up8, conv2], dim=1)
        conv8 = self.conv8(merge8)
        
        up9 = self.up9(conv8)
        merge9 = torch.cat([up9, conv1], dim=1)
        conv9 = self.conv9(merge9)
        
        conv10 = self.conv10(conv9)
        return conv10

def get_original_and_augmented_groups(data_dir):
    """Group files by original images and their augmentations"""
    all_files = sorted([f for f in os.listdir(data_dir) if f.endswith('.tif')])
    
    # Group by original image
    original_groups = {}
    for file in all_files:
        if "_aug_" in file:
            original_name = file.split("_aug_")[0]
            aug_num = int(file.split("_aug_")[1].split('.')[0])
            
            if original_name not in original_groups:
                original_groups[original_name] = []
            original_groups[original_name].append((file, aug_num))
    
    # Sort augmentations within each group
    for orig in original_groups:
        original_groups[orig].sort(key=lambda x: x[1])
    
    return original_groups

class NanoporeDataset(Dataset):
    def __init__(self, data_dir, original_images, is_validation=False):
        self.data_dir = data_dir
        self.is_validation = is_validation
        
        # Get all TIF files and their corresponding masks
        all_files = sorted(os.listdir(data_dir))
        self.image_files = []
        
        for file in all_files:
            if file.endswith('.tif'):
                # Check if file is from our selected original images
                original_name = file.split("_aug_")[0] if "_aug_" in file else file.replace('.tif', '')
                
                if original_name in original_images:
                    mask_file = file.replace('.tif', '_mask.png')
                    
                    # Only include if mask exists
                    if mask_file in all_files:
                        self.image_files.append(file)
        
        print(f"{'Validation' if is_validation else 'Training'} set contains {len(self.image_files)} images")

    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        try:
            # Get image and mask paths
            img_path = os.path.join(self.data_dir, self.image_files[idx])
            mask_path = os.path.join(self.data_dir, self.image_files[idx].replace('.tif', '_mask.png'))
            
            # Read images
            image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            
            if image is None or mask is None:
                raise ValueError(f"Failed to read image or mask: {self.image_files[idx]}")
            
            # Resize to 1024x1024
            image = cv2.resize(image, (1024, 1024), interpolation=cv2.INTER_AREA)
            mask = cv2.resize(mask, (1024, 1024), interpolation=cv2.INTER_NEAREST)
            
            # Normalize
            image = image.astype(np.float32) / 255.0
            mask = mask.astype(np.float32) / 255.0
            
            # Convert to tensors
            image = torch.from_numpy(image).unsqueeze(0)  # Add channel dimension
            mask = torch.from_numpy(mask).unsqueeze(0)
            
            return image, mask
        
        except Exception as e:
            print(f"Error loading sample {idx}: {str(e)}")
            # Return a placeholder to avoid crashing
            # In practice, could use a random image from the dataset
            placeholder = torch.zeros((1, 1024, 1024), dtype=torch.float32)
            return placeholder, placeholder

def check_overfitting(train_loss, val_loss, threshold=0.3):
    """Check if model is overfitting based on train/val loss gap"""
    if val_loss > 0 and train_loss < val_loss * (1 - threshold):
        return True
    return False

def train_nanopore_detector(data_dir, output_dir, 
                          batch_size=1, 
                          epochs=30,
                          learning_rate=5e-5,
                          patience=3,
                          weight_decay=1e-4):
    """Train nanopore detector with anti-overfitting measures"""
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Get image groups
    original_groups = get_original_and_augmented_groups(data_dir)
    original_images = list(original_groups.keys())
    print(f"Found {len(original_images)} original images with augmentations")
    
    # Split into train/val by original images
    np.random.seed(42)  # For reproducibility
    val_size = max(2, int(len(original_images) * 0.15))  # 15% for validation
    val_originals = np.random.choice(original_images, size=val_size, replace=False)
    train_originals = [img for img in original_images if img not in val_originals]
    
    print(f"Using {len(train_originals)} originals for training, {len(val_originals)} for validation")
    print(f"Validation originals: {val_originals}")
    
    # Create datasets
    train_dataset = NanoporeDataset(data_dir, train_originals, is_validation=False)
    val_dataset = NanoporeDataset(data_dir, val_originals, is_validation=True)
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    # Create model with dropout
    model = UNet(in_channels=1, out_channels=1, dropout_rate=0.2).to(device)
    
    # Loss and optimizer with weight decay
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2, verbose=True
    )
    
    # Setup for training
    best_val_loss = float('inf')
    early_stop_counter = 0
    train_losses = []
    val_losses = []
    
    # Create directory for visualizations
    vis_dir = os.path.join(output_dir, 'visualizations')
    os.makedirs(vis_dir, exist_ok=True)
    
    print("\nStarting training with anti-overfitting measures...")
    print(f"Early stopping patience: {patience}")
    print(f"Learning rate: {learning_rate}")
    print(f"Weight decay: {weight_decay}")
    
    try:
        for epoch in range(epochs):
            # Training phase
            model.train()
            train_loss = 0
            
            with tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}') as pbar:
                for batch_idx, (images, masks) in enumerate(pbar):
                    images, masks = images.to(device), masks.to(device)
                    
                    # Forward pass
                    optimizer.zero_grad()
                    outputs = model(images)
                    loss = criterion(outputs, masks)
                    
                    # Backward pass
                    loss.backward()
                    optimizer.step()
                    
                    # Update metrics
                    train_loss += loss.item()
                    pbar.set_postfix({'loss': loss.item()})
                    
                    # Periodically clear cache
                    if batch_idx % 10 == 0 and torch.cuda.is_available():
                        torch.cuda.empty_cache()
            
            avg_train_loss = train_loss / len(train_loader)
            train_losses.append(avg_train_loss)
            
            # Validation phase
            model.eval()
            val_loss = 0
            
            with torch.no_grad():
                for images, masks in val_loader:
                    images, masks = images.to(device), masks.to(device)
                    outputs = model(images)
                    loss = criterion(outputs, masks)
                    val_loss += loss.item()
                    
                    # Save some validation predictions for visual inspection
                    if epoch % 5 == 0 and val_loss == 0:  # First batch of every 5th epoch
                        output_sigmoid = torch.sigmoid(outputs)
                        for i in range(min(2, len(images))):
                            # Create visualization
                            fig, axes = plt.subplots(1, 3, figsize=(15, 5))
                            
                            # Original image
                            axes[0].imshow(images[i].cpu().numpy()[0], cmap='gray')
                            axes[0].set_title('TEM Image')
                            axes[0].axis('off')
                            
                            # Ground truth mask
                            axes[1].imshow(masks[i].cpu().numpy()[0], cmap='gray')
                            axes[1].set_title('True Mask')
                            axes[1].axis('off')
                            
                            # Predicted mask
                            axes[2].imshow(output_sigmoid[i].cpu().numpy()[0], cmap='gray')
                            axes[2].set_title('Predicted Mask')
                            axes[2].axis('off')
                            
                            plt.tight_layout()
                            plt.savefig(os.path.join(vis_dir, f'epoch_{epoch}_sample_{i}.png'))
                            plt.close()
            
            avg_val_loss = val_loss / len(val_loader)
            val_losses.append(avg_val_loss)
            
            # Check for overfitting
            is_overfitting = check_overfitting(avg_train_loss, avg_val_loss)
            
            # Print epoch summary
            print(f'Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}', end='')
            if is_overfitting:
                print(' - WARNING: Potential overfitting detected!')
            else:
                print('')
            
            # Update learning rate
            scheduler.step(avg_val_loss)
            
            # Save checkpoint every epoch
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss
            }, os.path.join(output_dir, f'checkpoint_epoch_{epoch}.pth'))
            
            # Save best model
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'train_loss': avg_train_loss,
                    'val_loss': avg_val_loss
                }, os.path.join(output_dir, 'best_model.pth'))
                print(f'Saved new best model with validation loss: {best_val_loss:.4f}')
                early_stop_counter = 0
            else:
                early_stop_counter += 1
                print(f'Validation loss did not improve. Counter: {early_stop_counter}/{patience}')
                
                if early_stop_counter >= patience:
                    print(f'Early stopping triggered after {epoch+1} epochs')
                    break
            
            # Clear GPU cache after each epoch
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        
        # Plot train/val loss curves
        plt.figure(figsize=(10, 6))
        plt.plot(train_losses, label='Training Loss')
        plt.plot(val_losses, label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.title('Training and Validation Loss')
        plt.savefig(os.path.join(output_dir, 'loss_curves.png'))
        plt.close()
        
        return model
    
    except Exception as e:
        print(f"Training error: {str(e)}")
        print("Saving emergency checkpoint...")
        torch.save(model.state_dict(), os.path.join(output_dir, 'emergency_save.pth'))
        raise e

def main():
    # Setup UI for folder selection
    root = tk.Tk()
    root.withdraw()
    data_dir = filedialog.askdirectory(title="Select Directory with TEM Images and Masks")
    
    if not data_dir:
        print("No directory selected. Exiting.")
        return
    
    # Create output directory with timestamp
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    output_dir = os.path.join(data_dir, f'nanopore_model_{timestamp}')
    os.makedirs(output_dir, exist_ok=True)
    
    print("\nStarting nanopore detection model training...")
    print(f"Model checkpoints will be saved to: {output_dir}")
    
    try:
        model = train_nanopore_detector(data_dir, output_dir)
        
        # Save final model
        final_model_path = os.path.join(output_dir, 'final_model.pth')
        torch.save(model.state_dict(), final_model_path)
        print(f"\nTraining complete. Final model saved to: {final_model_path}")
        
    except Exception as e:
        print(f"Training failed with error: {e}")
        if torch.cuda.is_available():
            print(f"GPU Memory at failure: {torch.cuda.memory_allocated()/1e9:.2f} GB")

if __name__ == '__main__':
    main()


Starting nanopore detection model training...
Model checkpoints will be saved to: D:/Rajat/augmented_images\nanopore_model_20250306_151818
Using device: cuda
Found 20 original images with augmentations
Using 17 originals for training, 3 for validation
Validation originals: ['Image_10' 'Image_7' 'Image_5']
Training set contains 272 images
Validation set contains 48 images





Starting training with anti-overfitting measures...
Early stopping patience: 3
Learning rate: 5e-05
Weight decay: 0.0001


Epoch 1/30: 100%|██████████| 272/272 [14:40<00:00,  3.24s/it, loss=0.336]


Epoch 1: Train Loss: 0.4445, Val Loss: 0.3981
Saved new best model with validation loss: 0.3981


Epoch 2/30: 100%|██████████| 272/272 [12:23<00:00,  2.73s/it, loss=0.244]


Epoch 2: Train Loss: 0.2887, Val Loss: 0.2369
Saved new best model with validation loss: 0.2369


Epoch 3/30: 100%|██████████| 272/272 [12:19<00:00,  2.72s/it, loss=0.227]


Epoch 3: Train Loss: 0.2272, Val Loss: 0.2164
Saved new best model with validation loss: 0.2164


Epoch 4/30: 100%|██████████| 272/272 [14:06<00:00,  3.11s/it, loss=0.151]


Epoch 4: Train Loss: 0.1841, Val Loss: 0.1442
Saved new best model with validation loss: 0.1442


Epoch 5/30: 100%|██████████| 272/272 [12:36<00:00,  2.78s/it, loss=0.155]


Epoch 5: Train Loss: 0.1528, Val Loss: 0.1345
Saved new best model with validation loss: 0.1345


Epoch 6/30: 100%|██████████| 272/272 [13:11<00:00,  2.91s/it, loss=0.109] 


Validation loss did not improve. Counter: 1/3


Epoch 7/30: 100%|██████████| 272/272 [12:31<00:00,  2.76s/it, loss=0.134] 


Epoch 7: Train Loss: 0.1108, Val Loss: 0.0948
Saved new best model with validation loss: 0.0948


Epoch 8/30: 100%|██████████| 272/272 [19:04<00:00,  4.21s/it, loss=0.12]  


Epoch 8: Train Loss: 0.0948, Val Loss: 0.1216
Validation loss did not improve. Counter: 1/3


Epoch 9/30: 100%|██████████| 272/272 [16:56<00:00,  3.74s/it, loss=0.069] 


Epoch 9: Train Loss: 0.0836, Val Loss: 0.0666
Saved new best model with validation loss: 0.0666


Epoch 10/30: 100%|██████████| 272/272 [17:36<00:00,  3.89s/it, loss=0.0624]


Epoch 10: Train Loss: 0.0748, Val Loss: 0.0603
Saved new best model with validation loss: 0.0603


Epoch 11/30: 100%|██████████| 272/272 [30:16<00:00,  6.68s/it, loss=0.0471]


Epoch 11: Train Loss: 0.0675, Val Loss: 0.0509
Saved new best model with validation loss: 0.0509


Epoch 12/30: 100%|██████████| 272/272 [14:28<00:00,  3.19s/it, loss=0.0513]


Epoch 12: Train Loss: 0.0617, Val Loss: 0.0456
Saved new best model with validation loss: 0.0456


Epoch 13/30: 100%|██████████| 272/272 [13:40<00:00,  3.02s/it, loss=0.0576]


Epoch 13: Train Loss: 0.0563, Val Loss: 0.0461
Validation loss did not improve. Counter: 1/3


Epoch 14/30: 100%|██████████| 272/272 [13:59<00:00,  3.09s/it, loss=0.0291]


Epoch 14: Train Loss: 0.0517, Val Loss: 0.0479
Validation loss did not improve. Counter: 2/3


Epoch 15/30: 100%|██████████| 272/272 [37:11<00:00,  8.20s/it, loss=0.0486]  


Validation loss did not improve. Counter: 3/3
Early stopping triggered after 15 epochs

Training complete. Final model saved to: D:/Rajat/augmented_images\nanopore_model_20250306_151818\final_model.pth


In [None]:
#Final model with accuracy values!/usr/bin/env python3
"""
TEM nanopore detector training script with verbose metrics and plotting.
Copy-paste this file and run. Uses a GUI folder picker to choose the dataset directory.

Expected dataset layout: For each image .tif there is a corresponding mask named
    <image_name>_mask.png
Augmented files with suffix "_aug_<n>.tif" are grouped by original image name.
"""

import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import cv2
import tkinter as tk
from tkinter import filedialog
from tqdm import tqdm
from datetime import datetime
import matplotlib.pyplot as plt
import csv

# sklearn metrics
try:
    from sklearn.metrics import (
        roc_auc_score,
        average_precision_score,
        precision_recall_curve,
        roc_curve,
        auc,
        f1_score,
        jaccard_score,
        precision_score,
        recall_score,
    )
except Exception as e:
    raise ImportError("scikit-learn is required for metrics. Install it with `pip install scikit-learn`.") from e

# ---------------------------
# Model: UNet + DoubleConv
# ---------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.2):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout_rate),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, dropout_rate=0.2):
        super().__init__()
        self.conv1 = DoubleConv(in_channels, 64, dropout_rate)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128, dropout_rate)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256, dropout_rate)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512, dropout_rate)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512, 1024, dropout_rate)

        self.up6 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.conv6 = DoubleConv(1024, 512, dropout_rate)
        self.up7 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv7 = DoubleConv(512, 256, dropout_rate)
        self.up8 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv8 = DoubleConv(256, 128, dropout_rate)
        self.up9 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv9 = DoubleConv(128, 64, dropout_rate)

        self.conv10 = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        conv1 = self.conv1(x)
        pool1 = self.pool1(conv1)

        conv2 = self.conv2(pool1)
        pool2 = self.pool2(conv2)

        conv3 = self.conv3(pool2)
        pool3 = self.pool3(conv3)

        conv4 = self.conv4(pool3)
        pool4 = self.pool4(conv4)

        conv5 = self.conv5(pool4)

        up6 = self.up6(conv5)
        merge6 = torch.cat([up6, conv4], dim=1)
        conv6 = self.conv6(merge6)

        up7 = self.up7(conv6)
        merge7 = torch.cat([up7, conv3], dim=1)
        conv7 = self.conv7(merge7)

        up8 = self.up8(conv7)
        merge8 = torch.cat([up8, conv2], dim=1)
        conv8 = self.conv8(merge8)

        up9 = self.up9(conv8)
        merge9 = torch.cat([up9, conv1], dim=1)
        conv9 = self.conv9(merge9)

        conv10 = self.conv10(conv9)
        return conv10

# ---------------------------
# Utilities and Dataset
# ---------------------------
def get_original_and_augmented_groups(data_dir):
    """Group files by original images and their augmentations"""
    # consider .tif images only
    all_files = sorted([f for f in os.listdir(data_dir) if f.lower().endswith('.tif')])
    original_groups = {}
    for file in all_files:
        if "_aug_" in file:
            original_name = file.split("_aug_")[0]
            try:
                aug_num = int(file.split("_aug_")[1].split('.')[0])
            except Exception:
                aug_num = 0
            original_groups.setdefault(original_name, []).append((file, aug_num))
        else:
            # also include originals even if no aug
            original_groups.setdefault(file.replace('.tif',''), [])
    # sort augmentation lists
    for k in list(original_groups.keys()):
        original_groups[k] = sorted(original_groups[k], key=lambda x: x[1])
    return original_groups

class NanoporeDataset(Dataset):
    def __init__(self, data_dir, original_images, is_validation=False, target_size=(1024,1024)):
        self.data_dir = data_dir
        self.is_validation = is_validation
        self.target_size = target_size

        all_files = sorted(os.listdir(data_dir))
        image_files = []
        for file in all_files:
            if file.lower().endswith('.tif'):
                original_name = file.split("_aug_")[0] if "_aug_" in file else file.replace('.tif','')
                if original_name in original_images:
                    mask_file = file.replace('.tif', '_mask.png')
                    if mask_file in all_files:
                        image_files.append(file)
        self.image_files = image_files
        print(f"{'Validation' if is_validation else 'Training'} set contains {len(self.image_files)} images")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        try:
            img_name = self.image_files[idx]
            img_path = os.path.join(self.data_dir, img_name)
            mask_path = os.path.join(self.data_dir, img_name.replace('.tif', '_mask.png'))

            image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

            if image is None or mask is None:
                raise ValueError(f"Could not read {img_name} or mask")

            image = cv2.resize(image, self.target_size, interpolation=cv2.INTER_AREA)
            mask = cv2.resize(mask, self.target_size, interpolation=cv2.INTER_NEAREST)

            image = image.astype(np.float32) / 255.0
            mask = (mask.astype(np.float32) / 255.0)  # may be 0/255 or 0/1

            image = torch.from_numpy(image).unsqueeze(0)
            mask = torch.from_numpy(mask).unsqueeze(0)

            return image, mask
        except Exception as e:
            print(f"Error loading sample {idx}: {e}")
            placeholder = torch.zeros((1, 1024, 1024), dtype=torch.float32)
            return placeholder, placeholder

# ---------------------------
# Metric helpers
# ---------------------------
def dice_coeff(preds_bool, targets_bool, eps=1e-7):
    preds = preds_bool.astype(np.uint8)
    targets = targets_bool.astype(np.uint8)
    intersection = (preds & targets).sum()
    return (2. * intersection + eps) / (preds.sum() + targets.sum() + eps)

def iou_score_np(preds_bool, targets_bool, eps=1e-7):
    preds = preds_bool.astype(np.uint8)
    targets = targets_bool.astype(np.uint8)
    intersection = (preds & targets).sum()
    union = (preds | targets).sum()
    return (intersection + eps) / (union + eps)

def check_overfitting(train_loss, val_loss, threshold=0.3):
    """Check if model is overfitting based on train/val loss gap"""
    if val_loss > 0 and train_loss < val_loss * (1 - threshold):
        return True
    return False

# ---------------------------
# Training function (fixed: detach before numpy conversions)
# ---------------------------
def train_nanopore_detector(data_dir, output_dir,
                            batch_size=1,
                            epochs=30,
                            learning_rate=5e-5,
                            patience=3,
                            weight_decay=1e-4,
                            prob_threshold=0.5,
                            resize_to=(1024,1024)):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    original_groups = get_original_and_augmented_groups(data_dir)
    original_images = list(original_groups.keys())
    print(f"Found {len(original_images)} original images with augmentations")

    np.random.seed(42)
    val_size = max(2, int(len(original_images) * 0.15))
    val_originals = np.random.choice(original_images, size=val_size, replace=False).tolist()
    train_originals = [img for img in original_images if img not in val_originals]

    print(f"Using {len(train_originals)} originals for training, {len(val_originals)} for validation")
    print(f"Validation originals: {val_originals}")

    train_dataset = NanoporeDataset(data_dir, train_originals, is_validation=False, target_size=resize_to)
    val_dataset = NanoporeDataset(data_dir, val_originals, is_validation=True, target_size=resize_to)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0,
                              pin_memory=True if torch.cuda.is_available() else False)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0,
                            pin_memory=True if torch.cuda.is_available() else False)

    model = UNet(in_channels=1, out_channels=1, dropout_rate=0.2).to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)

    best_val_loss = float('inf')
    early_stop_counter = 0

    history = {
        'train_loss': [],
        'train_accuracy': [],
        'val_loss': [],
        'val_accuracy': [],
        'val_precision': [],
        'val_recall': [],
        'val_f1': [],
        'val_iou': [],
        'val_dice': [],
        'val_pr_auc': [],
        'val_roc_auc': []
    }

    vis_dir = os.path.join(output_dir, 'visualizations')
    os.makedirs(vis_dir, exist_ok=True)

    print("\nStarting training with verbose metrics...")
    print(f"Early stopping patience: {patience}")
    print(f"Learning rate: {learning_rate}")
    print(f"Weight decay: {weight_decay}")

    try:
        for epoch in range(epochs):
            # -----------------------
            # Training (compute train loss + train accuracy)
            # -----------------------
            model.train()
            running_train_loss = 0.0
            train_probs_list = []
            train_targets_list = []

            with tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [train]") as pbar:
                for images, masks in pbar:
                    images, masks = images.to(device), masks.to(device)
                    optimizer.zero_grad()
                    outputs = model(images)
                    loss = criterion(outputs, masks)
                    loss.backward()
                    optimizer.step()

                    running_train_loss += loss.item()
                    # collect probabilities and targets for train accuracy
                    probs = torch.sigmoid(outputs).detach().cpu().numpy()  # <-- detach here
                    targets = masks.detach().cpu().numpy()
                    train_probs_list.append(probs.reshape(-1))
                    train_targets_list.append(targets.reshape(-1))

                    pbar.set_postfix({'loss': f"{loss.item():.4f}"})
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()

            avg_train_loss = running_train_loss / max(1, len(train_loader))
            history['train_loss'].append(avg_train_loss)

            if len(train_probs_list) > 0 and len(train_targets_list) > 0:
                all_train_probs = np.concatenate(train_probs_list)
                all_train_targets = np.concatenate(train_targets_list)
                train_bin_preds = (all_train_probs >= prob_threshold).astype(np.uint8)
                train_bin_targets = (all_train_targets >= 0.5).astype(np.uint8)
                train_accuracy = (train_bin_preds == train_bin_targets).mean()
            else:
                train_accuracy = float('nan')

            history['train_accuracy'].append(float(train_accuracy))

            # -----------------------
            # Validation
            # -----------------------
            model.eval()
            running_val_loss = 0.0
            val_probs_list = []
            val_targets_list = []
            vis_saved = False

            with torch.no_grad():
                for batch_idx, (images, masks) in enumerate(val_loader):
                    images, masks = images.to(device), masks.to(device)
                    outputs = model(images)
                    loss = criterion(outputs, masks)
                    running_val_loss += loss.item()

                    probs = torch.sigmoid(outputs).detach().cpu().numpy()  # <-- detach here
                    targets = masks.detach().cpu().numpy()
                    val_probs_list.append(probs.reshape(-1))
                    val_targets_list.append(targets.reshape(-1))

                    if (epoch % 5 == 0) and (not vis_saved):
                        for i in range(min(2, images.size(0))):
                            img_np = images[i].detach().cpu().numpy()[0]
                            gt_np = targets[i][0]
                            pred_np = probs[i][0]
                            fig, axes = plt.subplots(1,3,figsize=(12,4))
                            axes[0].imshow(img_np, cmap='gray'); axes[0].set_title('TEM'); axes[0].axis('off')
                            axes[1].imshow(gt_np, cmap='gray'); axes[1].set_title('Mask'); axes[1].axis('off')
                            axes[2].imshow(pred_np, cmap='gray'); axes[2].set_title('Pred Prob'); axes[2].axis('off')
                            plt.tight_layout()
                            plt.savefig(os.path.join(vis_dir, f'epoch_{epoch}_val_sample_{i}.png'))
                            plt.close()
                        vis_saved = True

            avg_val_loss = running_val_loss / max(1, len(val_loader))
            history['val_loss'].append(avg_val_loss)

            if len(val_probs_list) > 0 and len(val_targets_list) > 0:
                all_val_probs = np.concatenate(val_probs_list)
                all_val_targets = np.concatenate(val_targets_list)
                val_bin_preds = (all_val_probs >= prob_threshold).astype(np.uint8)
                val_bin_targets = (all_val_targets >= 0.5).astype(np.uint8)

                val_accuracy = (val_bin_preds == val_bin_targets).mean()
                prec = precision_score(val_bin_targets, val_bin_preds, zero_division=0)
                rec = recall_score(val_bin_targets, val_bin_preds, zero_division=0)
                f1 = f1_score(val_bin_targets, val_bin_preds, zero_division=0)
                try:
                    iou = jaccard_score(val_bin_targets, val_bin_preds, zero_division=0)
                except Exception:
                    iou = iou_score_np(val_bin_preds.astype(bool), val_bin_targets.astype(bool))
                dice = dice_coeff(val_bin_preds.astype(bool), val_bin_targets.astype(bool))
                try:
                    pr_auc = average_precision_score(val_bin_targets, all_val_probs)
                except Exception:
                    pr_auc = float('nan')
                has_pos = val_bin_targets.sum() > 0
                has_neg = (val_bin_targets.size - val_bin_targets.sum()) > 0
                try:
                    roc_auc = roc_auc_score(val_bin_targets, all_val_probs) if (has_pos and has_neg) else float('nan')
                except Exception:
                    roc_auc = float('nan')
            else:
                val_accuracy = prec = rec = f1 = iou = dice = pr_auc = roc_auc = float('nan')
                all_val_probs = np.array([])
                all_val_targets = np.array([])

            history['val_accuracy'].append(float(val_accuracy))
            history['val_precision'].append(float(prec))
            history['val_recall'].append(float(rec))
            history['val_f1'].append(float(f1))
            history['val_iou'].append(float(iou))
            history['val_dice'].append(float(dice))
            history['val_pr_auc'].append(float(pr_auc) if not np.isnan(pr_auc) else None)
            history['val_roc_auc'].append(float(roc_auc) if not np.isnan(roc_auc) else None)

            # -----------------------
            # Print epoch summary (both train & val loss + train & val accuracy)
            # -----------------------
            print(
                f"Epoch {epoch+1}: "
                f"Train Loss {avg_train_loss:.4f}, Train Acc {train_accuracy:.4f} | "
                f"Val Loss {avg_val_loss:.4f}, Val Acc {val_accuracy:.4f} | "
                f"Val Prec {prec:.4f}, Rec {rec:.4f}, F1 {f1:.4f}, IoU {iou:.4f}, Dice {dice:.4f} | "
                f"PR-AUC {pr_auc if not np.isnan(pr_auc) else 'NA'}, ROC-AUC {roc_auc if not np.isnan(roc_auc) else 'NA'}"
            )

            # Overfitting check
            if check_overfitting(avg_train_loss, avg_val_loss):
                print(" - WARNING: Potential overfitting detected!")

            # Save checkpoint
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'train_accuracy': train_accuracy,
                'val_loss': avg_val_loss,
                'val_accuracy': val_accuracy
            }, os.path.join(output_dir, f'checkpoint_epoch_{epoch}.pth'))

            # Save best
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'train_loss': avg_train_loss,
                    'train_accuracy': train_accuracy,
                    'val_loss': avg_val_loss,
                    'val_accuracy': val_accuracy
                }, os.path.join(output_dir, 'best_model.pth'))
                print(f"Saved new best model (val loss {best_val_loss:.4f})")
                early_stop_counter = 0
            else:
                early_stop_counter += 1
                print(f"Validation loss did not improve. Counter: {early_stop_counter}/{patience}")
                if early_stop_counter >= patience:
                    print(f"Early stopping triggered after {epoch+1} epochs")
                    break

            scheduler.step(avg_val_loss)
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        # Save final model and metrics CSV (includes train_accuracy now)
        final_model_path = os.path.join(output_dir, 'final_model.pth')
        torch.save(model.state_dict(), final_model_path)
        print(f"Training finished. Final model saved to: {final_model_path}")

        csv_path = os.path.join(output_dir, 'metrics.csv')
        with open(csv_path, 'w', newline='') as csvfile:
            writer = csv.writer(csvfile)
            header = ['epoch', 'train_loss', 'train_accuracy', 'val_loss', 'val_accuracy', 'val_precision', 'val_recall',
                      'val_f1', 'val_iou', 'val_dice', 'val_pr_auc', 'val_roc_auc']
            writer.writerow(header)
            for i in range(len(history['train_loss'])):
                row = [
                    i+1,
                    history['train_loss'][i],
                    history['train_accuracy'][i] if i < len(history['train_accuracy']) else '',
                    history['val_loss'][i] if i < len(history['val_loss']) else '',
                    history['val_accuracy'][i] if i < len(history['val_accuracy']) else '',
                    history['val_precision'][i] if i < len(history['val_precision']) else '',
                    history['val_recall'][i] if i < len(history['val_recall']) else '',
                    history['val_f1'][i] if i < len(history['val_f1']) else '',
                    history['val_iou'][i] if i < len(history['val_iou']) else '',
                    history['val_dice'][i] if i < len(history['val_dice']) else '',
                    history['val_pr_auc'][i] if i < len(history['val_pr_auc']) else '',
                    history['val_roc_auc'][i] if i < len(history['val_roc_auc']) else ''
                ]
                writer.writerow(row)
        print(f"Saved metrics CSV to: {csv_path}")

        # Plot loss curves (train + val)
        plt.figure(figsize=(8,6))
        plt.plot(history['train_loss'], label='Train Loss')
        plt.plot(history['val_loss'], label='Val Loss')
        plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.title('Loss curves')
        plt.savefig(os.path.join(output_dir, 'loss_curves.png'))
        plt.close()

        # Plot accuracy curves (train + val)
        plt.figure(figsize=(8,6))
        plt.plot(history['train_accuracy'], label='Train Acc')
        plt.plot(history['val_accuracy'], label='Val Acc')
        plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.legend(); plt.title('Accuracy curves')
        plt.savefig(os.path.join(output_dir, 'accuracy_curves.png'))
        plt.close()

        # Plot other validation metric trends (same as before)
        epochs_range = range(1, len(history['val_loss']) + 1)
        plt.figure(figsize=(10,6))
        plt.plot(epochs_range, history['val_precision'], label='Val Precision')
        plt.plot(epochs_range, history['val_recall'], label='Val Recall')
        plt.plot(epochs_range, history['val_f1'], label='Val F1')
        plt.plot(epochs_range, history['val_iou'], label='Val IoU')
        plt.plot(epochs_range, history['val_dice'], label='Val Dice')
        plt.xlabel('Epoch'); plt.ylabel('Score'); plt.legend(); plt.title('Validation metrics over epochs')
        plt.savefig(os.path.join(output_dir, 'val_metrics_trends.png'))
        plt.close()

        # AUC trends
        plt.figure(figsize=(8,6))
        plt.plot(epochs_range, [x if x is not None else np.nan for x in history['val_pr_auc']], label='Val PR-AUC')
        plt.plot(epochs_range, [x if x is not None else np.nan for x in history['val_roc_auc']], label='Val ROC-AUC')
        plt.xlabel('Epoch'); plt.ylabel('AUC'); plt.legend(); plt.title('AUC trends')
        plt.savefig(os.path.join(output_dir, 'val_auc_trends.png'))
        plt.close()

        # Final ROC and PR curves using last validation arrays (if available)
        try:
            if 'all_val_probs' in locals() and all_val_probs.size > 0 and all_val_targets.size > 0:
                if np.unique(all_val_targets).size > 1:
                    fpr, tpr, _ = roc_curve(all_val_targets, all_val_probs)
                    roc_auc_val = roc_auc_score(all_val_targets, all_val_probs)
                    plt.figure(figsize=(6,6))
                    plt.plot(fpr, tpr, label=f'ROC AUC={roc_auc_val:.4f}')
                    plt.plot([0,1],[0,1], linestyle='--')
                    plt.xlabel('FPR'); plt.ylabel('TPR'); plt.title('ROC Curve (last val epoch)')
                    plt.legend()
                    plt.savefig(os.path.join(output_dir, 'roc_curve_last_epoch.png'))
                    plt.close()
                precision_vals, recall_vals, _ = precision_recall_curve(all_val_targets, all_val_probs)
                pr_auc_val = auc(recall_vals, precision_vals)
                plt.figure(figsize=(6,6))
                plt.plot(recall_vals, precision_vals, label=f'PR AUC={pr_auc_val:.4f}')
                plt.xlabel('Recall'); plt.ylabel('Precision'); plt.title('Precision-Recall (last val epoch)')
                plt.legend()
                plt.savefig(os.path.join(output_dir, 'pr_curve_last_epoch.png'))
                plt.close()
        except Exception:
            pass

        return model

    except Exception as e:
        print(f"Training error: {e}")
        # emergency save
        try:
            torch.save(model.state_dict(), os.path.join(output_dir, 'emergency_save.pth'))
            print("Saved emergency checkpoint.")
        except Exception:
            pass
        raise

# ---------------------------
# main()
# ---------------------------
def main():
    root = tk.Tk()
    root.withdraw()
    data_dir = filedialog.askdirectory(title="Select Directory with TEM Images and Masks")
    if not data_dir:
        print("No directory selected. Exiting.")
        return

    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    output_dir = os.path.join(data_dir, f'nanopore_model_{timestamp}')
    os.makedirs(output_dir, exist_ok=True)

    print("\nStarting nanopore detection model training...")
    print(f"Model checkpoints and outputs will be saved to: {output_dir}")

    try:
        model = train_nanopore_detector(data_dir, output_dir,
                                       batch_size=1,
                                       epochs=30,
                                       learning_rate=5e-5,
                                       patience=3,
                                       weight_decay=1e-4,
                                       prob_threshold=0.5,
                                       resize_to=(1024,1024))
        print("\nTraining complete.")
    except Exception as e:
        print(f"Training failed: {e}")
        if torch.cuda.is_available():
            try:
                print(f"GPU memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB")
            except Exception:
                pass

if __name__ == '__main__':
    main()



Starting nanopore detection model training...
Model checkpoints and outputs will be saved to: D:/Rajat/Nanoporosity/albumented_images\nanopore_model_20251112_172446
Using device: cuda
Found 20 original images with augmentations
Using 17 originals for training, 3 for validation
Validation originals: ['Image_10', 'Image_7', 'Image_5']
Training set contains 255 images
Validation set contains 45 images

Starting training with verbose metrics...
Early stopping patience: 3
Learning rate: 5e-05
Weight decay: 0.0001


Epoch 1/30 [train]: 100%|██████████| 255/255 [23:06<00:00,  5.44s/it, loss=0.3197]


Epoch 1: Train Loss 0.4157, Train Acc 0.9478 | Val Loss 0.9736, Val Acc 0.9526 | Val Prec 0.0984, Rec 0.6872, F1 0.1721, IoU 0.0942, Dice 0.1721 | PR-AUC 0.12096339361937276, ROC-AUC 0.9180450571939965
Saved new best model (val loss 0.9736)


Epoch 2/30 [train]: 100%|██████████| 255/255 [21:36<00:00,  5.09s/it, loss=0.2320]


Epoch 2: Train Loss 0.2762, Train Acc 0.9821 | Val Loss 0.2579, Val Acc 0.9897 | Val Prec 0.3639, Rec 0.5829, F1 0.4480, IoU 0.2887, Dice 0.4480 | PR-AUC 0.33087785811500287, ROC-AUC 0.9430279594422356
Saved new best model (val loss 0.2579)


Epoch 3/30 [train]: 100%|██████████| 255/255 [21:10<00:00,  4.98s/it, loss=0.1980]


Epoch 3: Train Loss 0.2207, Train Acc 0.9831 | Val Loss 0.1778, Val Acc 0.9936 | Val Prec 0.5563, Rec 0.5326, F1 0.5442, IoU 0.3738, Dice 0.5442 | PR-AUC 0.5228139623833468, ROC-AUC 0.958017663409531
Saved new best model (val loss 0.1778)


Epoch 4/30 [train]: 100%|██████████| 255/255 [20:47<00:00,  4.89s/it, loss=0.1502]


Epoch 4: Train Loss 0.1839, Train Acc 0.9833 | Val Loss 0.1529, Val Acc 0.9932 | Val Prec 0.5232, Rec 0.5621, F1 0.5420, IoU 0.3717, Dice 0.5420 | PR-AUC 0.528280430683488, ROC-AUC 0.9563666998425551
Saved new best model (val loss 0.1529)


Epoch 5/30 [train]: 100%|██████████| 255/255 [20:52<00:00,  4.91s/it, loss=0.1780]


Epoch 5: Train Loss 0.1559, Train Acc 0.9836 | Val Loss 0.1285, Val Acc 0.9908 | Val Prec 0.4130, Rec 0.6587, F1 0.5077, IoU 0.3402, Dice 0.5077 | PR-AUC 0.5155693045012066, ROC-AUC 0.9683345574530089
Saved new best model (val loss 0.1285)


Epoch 6/30 [train]: 100%|██████████| 255/255 [19:49<00:00,  4.67s/it, loss=0.1232]


Epoch 6: Train Loss 0.1341, Train Acc 0.9838 | Val Loss 0.1087, Val Acc 0.9938 | Val Prec 0.5716, Rec 0.5212, F1 0.5452, IoU 0.3748, Dice 0.5452 | PR-AUC 0.5281021567477429, ROC-AUC 0.9605293913031249
Saved new best model (val loss 0.1087)


Epoch 7/30 [train]: 100%|██████████| 255/255 [19:57<00:00,  4.70s/it, loss=0.0929]


Epoch 7: Train Loss 0.1168, Train Acc 0.9841 | Val Loss 0.1515, Val Acc 0.9887 | Val Prec 0.3575, Rec 0.7166, F1 0.4770, IoU 0.3132, Dice 0.4770 | PR-AUC 0.40967357514639, ROC-AUC 0.963206403842076
Validation loss did not improve. Counter: 1/3


Epoch 8/30 [train]: 100%|██████████| 255/255 [20:08<00:00,  4.74s/it, loss=0.1348]


Epoch 8: Train Loss 0.1036, Train Acc 0.9842 | Val Loss 0.0817, Val Acc 0.9903 | Val Prec 0.3974, Rec 0.6800, F1 0.5016, IoU 0.3348, Dice 0.5016 | PR-AUC 0.5058390360169867, ROC-AUC 0.97523099837765
Saved new best model (val loss 0.0817)


Epoch 9/30 [train]: 100%|██████████| 255/255 [26:51<00:00,  6.32s/it, loss=0.0755]


Epoch 9: Train Loss 0.0922, Train Acc 0.9843 | Val Loss 0.0711, Val Acc 0.9924 | Val Prec 0.4785, Rec 0.6621, F1 0.5555, IoU 0.3846, Dice 0.5555 | PR-AUC 0.5756530978431165, ROC-AUC 0.9762750694087294
Saved new best model (val loss 0.0711)


Epoch 10/30 [train]: 100%|██████████| 255/255 [23:27<00:00,  5.52s/it, loss=0.0720]


Epoch 10: Train Loss 0.0824, Train Acc 0.9847 | Val Loss 0.0632, Val Acc 0.9924 | Val Prec 0.4774, Rec 0.6773, F1 0.5600, IoU 0.3889, Dice 0.5600 | PR-AUC 0.5880964991629962, ROC-AUC 0.978554399206711
Saved new best model (val loss 0.0632)


Epoch 11/30 [train]: 100%|██████████| 255/255 [22:02<00:00,  5.19s/it, loss=0.0549]


Epoch 11: Train Loss 0.0749, Train Acc 0.9850 | Val Loss 0.0583, Val Acc 0.9911 | Val Prec 0.4285, Rec 0.7166, F1 0.5363, IoU 0.3664, Dice 0.5363 | PR-AUC 0.5791546339015351, ROC-AUC 0.9795814795641707
Saved new best model (val loss 0.0583)


Epoch 12/30 [train]: 100%|██████████| 255/255 [22:06<00:00,  5.20s/it, loss=0.0719]


Epoch 12: Train Loss 0.0695, Train Acc 0.9849 | Val Loss 0.0516, Val Acc 0.9904 | Val Prec 0.4053, Rec 0.7268, F1 0.5204, IoU 0.3517, Dice 0.5204 | PR-AUC 0.590789499401029, ROC-AUC 0.9777933120287408
Saved new best model (val loss 0.0516)


Epoch 13/30 [train]: 100%|██████████| 255/255 [22:24<00:00,  5.27s/it, loss=0.0512]


Epoch 13: Train Loss 0.0641, Train Acc 0.9853 | Val Loss 0.0487, Val Acc 0.9924 | Val Prec 0.4805, Rec 0.6756, F1 0.5616, IoU 0.3904, Dice 0.5616 | PR-AUC 0.5898919179163282, ROC-AUC 0.9819114193271943
Saved new best model (val loss 0.0487)


Epoch 14/30 [train]: 100%|██████████| 255/255 [22:51<00:00,  5.38s/it, loss=0.0522]


Epoch 14: Train Loss 0.0596, Train Acc 0.9856 | Val Loss 0.0398, Val Acc 0.9922 | Val Prec 0.4702, Rec 0.6882, F1 0.5587, IoU 0.3876, Dice 0.5587 | PR-AUC 0.595173944572817, ROC-AUC 0.9815957254220865
Saved new best model (val loss 0.0398)


Epoch 15/30 [train]: 100%|██████████| 255/255 [22:16<00:00,  5.24s/it, loss=0.0598]


Epoch 15: Train Loss 0.0567, Train Acc 0.9855 | Val Loss 0.0597, Val Acc 0.9834 | Val Prec 0.2761, Rec 0.8087, F1 0.4117, IoU 0.2592, Dice 0.4117 | PR-AUC 0.5593558593045749, ROC-AUC 0.9802682623392329
Validation loss did not improve. Counter: 1/3


Epoch 16/30 [train]: 100%|██████████| 255/255 [22:53<00:00,  5.39s/it, loss=0.0399]


Epoch 16: Train Loss 0.0538, Train Acc 0.9858 | Val Loss 0.0428, Val Acc 0.9895 | Val Prec 0.3790, Rec 0.7175, F1 0.4960, IoU 0.3298, Dice 0.4960 | PR-AUC 0.5179639582053689, ROC-AUC 0.9730546860839561
Validation loss did not improve. Counter: 2/3


Epoch 17/30 [train]: 100%|██████████| 255/255 [23:29<00:00,  5.53s/it, loss=0.0239]


Epoch 17: Train Loss 0.0517, Train Acc 0.9858 | Val Loss 0.0368, Val Acc 0.9916 | Val Prec 0.4448, Rec 0.7027, F1 0.5448, IoU 0.3743, Dice 0.5448 | PR-AUC 0.5766557232264454, ROC-AUC 0.9833429230059626
Saved new best model (val loss 0.0368)


Epoch 18/30 [train]: 100%|██████████| 255/255 [25:27<00:00,  5.99s/it, loss=0.0922]


Epoch 18: Train Loss 0.0493, Train Acc 0.9861 | Val Loss 0.0399, Val Acc 0.9908 | Val Prec 0.4182, Rec 0.7268, F1 0.5309, IoU 0.3614, Dice 0.5309 | PR-AUC 0.5625363529432955, ROC-AUC 0.983569873039802
Validation loss did not improve. Counter: 1/3


Epoch 19/30 [train]: 100%|██████████| 255/255 [23:17<00:00,  5.48s/it, loss=0.0451]


Epoch 19: Train Loss 0.0481, Train Acc 0.9859 | Val Loss 0.0317, Val Acc 0.9915 | Val Prec 0.4411, Rec 0.7168, F1 0.5461, IoU 0.3756, Dice 0.5461 | PR-AUC 0.6062174846372399, ROC-AUC 0.9835252111894168
Saved new best model (val loss 0.0317)


Epoch 20/30 [train]: 100%|██████████| 255/255 [22:32<00:00,  5.30s/it, loss=0.0303]


Epoch 20: Train Loss 0.0464, Train Acc 0.9861 | Val Loss 0.0391, Val Acc 0.9902 | Val Prec 0.3971, Rec 0.7082, F1 0.5089, IoU 0.3413, Dice 0.5089 | PR-AUC 0.47968833520173026, ROC-AUC 0.9793505435510342
Validation loss did not improve. Counter: 1/3


Epoch 21/30 [train]: 100%|██████████| 255/255 [23:49<00:00,  5.61s/it, loss=0.1015]


Epoch 21: Train Loss 0.0447, Train Acc 0.9864 | Val Loss 0.0529, Val Acc 0.9882 | Val Prec 0.3500, Rec 0.7518, F1 0.4776, IoU 0.3138, Dice 0.4776 | PR-AUC 0.4489661367674278, ROC-AUC 0.9783012612770149
Validation loss did not improve. Counter: 2/3


Epoch 22/30 [train]: 100%|██████████| 255/255 [23:32<00:00,  5.54s/it, loss=0.0148]


Epoch 22: Train Loss 0.0434, Train Acc 0.9865 | Val Loss 0.0333, Val Acc 0.9907 | Val Prec 0.4157, Rec 0.7366, F1 0.5315, IoU 0.3619, Dice 0.5315 | PR-AUC 0.5398227986079676, ROC-AUC 0.9818984711088902
Validation loss did not improve. Counter: 3/3
Early stopping triggered after 22 epochs
Training finished. Final model saved to: D:/Rajat/Nanoporosity/albumented_images\nanopore_model_20251112_172446\final_model.pth
Saved metrics CSV to: D:/Rajat/Nanoporosity/albumented_images\nanopore_model_20251112_172446\metrics.csv

Training complete.
