In [1]:
import os
import cv2
import numpy as np
import pandas as pd
from glob import glob
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from albumentations import Compose, Normalize, Resize, RandomCrop, HorizontalFlip, VerticalFlip, RandomBrightnessContrast
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from tqdm import tqdm
import matplotlib.pyplot as plt

class CFG:
    seed = 42
    batch_size = 4
    lr = 1e-3
    num_epochs = 20
    input_size = (512, 512)
    model_name = "timm-efficientnet-b4"
    train_image_path = './rgb'
    train_mask_path = './label'
    test_image_path = '../Test_data'
    output_path = './predictions'
    best_model_path = './best_model.pth'
    log_dir = './logs'

def set_seed(seed=42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    print('> SEEDING DONE')

set_seed(CFG.seed)

# Custom dataset class for loading images and masks
class LeafDataset(Dataset):
    def __init__(self, image_files, mask_files, transform=None):
        self.image_files = image_files
        self.mask_files = mask_files
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img = cv2.imread(self.image_files[idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.mask_files[idx], cv2.IMREAD_GRAYSCALE)

        if self.transform:
            augmented = self.transform(image=img, mask=mask)
            img = augmented['image']
            mask = augmented['mask']
        
        mask = (mask > 0).float()  # Convert mask to binary (foreground vs. background)

        return img, mask

# Define transformations
transform = Compose([
    Resize(CFG.input_size[0], CFG.input_size[1]),
    Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

# Get image and mask files
image_files = sorted(glob(os.path.join(CFG.train_image_path, '*.png')))
mask_files = sorted(glob(os.path.join(CFG.train_mask_path, '*.png')))

# Split into training and validation sets
train_img_files, val_img_files, train_mask_files, val_mask_files = train_test_split(image_files, mask_files, test_size=0.2, random_state=42)

# Create datasets and dataloaders
train_dataset = LeafDataset(train_img_files, train_mask_files, transform=transform)
val_dataset = LeafDataset(val_img_files, val_mask_files, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=CFG.batch_size, shuffle=False)

# Load model
model = smp.UnetPlusPlus(
    encoder_name=CFG.model_name, 
    encoder_weights='imagenet', 
    in_channels=3, 
    classes=1, 
    activation=None
)
model = model.to('cuda')

# Define loss function and optimizer
class CombinedLoss(nn.Module):
    def __init__(self):
        super(CombinedLoss, self).__init__()
        self.dice_loss = smp.losses.DiceLoss(mode='binary')
        self.bce_loss = nn.BCEWithLogitsLoss()

    def forward(self, inputs, targets):
        dice = self.dice_loss(inputs, targets)
        bce = self.bce_loss(inputs, targets)
        return dice + bce

criterion = CombinedLoss()
optimizer = optim.Adam(model.parameters(), lr=CFG.lr)

# Function to calculate Dice coefficient
def dice_coefficient(y_true, y_pred, smooth=1):
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = torch.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (torch.sum(y_true_f) + torch.sum(y_pred_f) + smooth)

# Function to calculate Symmetric Best Dice (SBD)
def symmetric_best_dice(y_true, y_pred):
    return (dice_coefficient(y_true, y_pred) + dice_coefficient(y_pred, y_true)) / 2

# Training function
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    best_val_loss = float('inf')
    best_sbd = 0

    os.makedirs(CFG.log_dir, exist_ok=True)
    log_file_path = os.path.join(CFG.log_dir, f'{CFG.model_name}_training_log.csv')

    with open(log_file_path, 'w') as f:
        f.write('epoch,train_loss,val_loss,mean_sbd\n')

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0

        for images, masks in tqdm(train_loader):
            images = images.to('cuda')
            masks = masks.to('cuda').unsqueeze(1)  # Add channel dimension

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)

        train_loss = train_loss / len(train_loader.dataset)

        model.eval()
        val_loss = 0.0
        sbd_scores = []

        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to('cuda')
                masks = masks.to('cuda').unsqueeze(1)  # Add channel dimension

                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item() * images.size(0)

                preds = torch.sigmoid(outputs).cpu().numpy()
                preds = (preds > 0.5).astype(np.uint8)
                masks = masks.cpu().numpy()

                for true, pred in zip(masks, preds):
                    sbd_scores.append(symmetric_best_dice(torch.tensor(true).float(), torch.tensor(pred).float()).item())

        val_loss = val_loss / len(val_loader.dataset)
        mean_sbd = np.mean(sbd_scores)

        print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Mean SBD: {mean_sbd:.4f}')

        # Save the best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_sbd = mean_sbd
            torch.save(model.state_dict(), CFG.best_model_path)

        # Save logs
        with open(log_file_path, 'a') as f:
            f.write(f'{epoch+1},{train_loss},{val_loss},{mean_sbd}\n')

    print(f'Best Validation Loss: {best_val_loss}, Best Mean SBD: {best_sbd}')

    # Save the best mean SBD score
    with open(os.path.join(CFG.log_dir, f'{CFG.model_name}_best_mean_sbd.txt'), 'w') as f:
        f.write(f'{best_sbd:.4f}\n')

# Set seed for reproducibility
set_seed(CFG.seed)

# Train the model
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=CFG.num_epochs)

# Load the best model for predictions
model.load_state_dict(torch.load(CFG.best_model_path))

# Function to predict and save results for the test set
def predict_and_save_results(model, test_image_path, output_path):
    test_image_files = sorted(glob(os.path.join(test_image_path, '*.png')))
    transform = Compose([
        Resize(CFG.input_size[0], CFG.input_size[1]),
        Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])

    os.makedirs(output_path, exist_ok=True)
    for img_file in tqdm(test_image_files):
        img = cv2.imread(img_file)
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        augmented = transform(image=img_rgb)
        img_transformed = augmented['image'].unsqueeze(0).to('cuda')

        with torch.no_grad():
            output = model(img_transformed)
            preds = torch.sigmoid(output).cpu().numpy()
            preds = (preds > 0.5).astype(np.uint8)
            preds = cv2.resize(preds[0, 0], (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST)

        output_file = os.path.join(output_path, os.path.basename(img_file).replace('.png', '_result.png'))
        cv2.imwrite(output_file, preds * 255)

# Predict and save results for the test set
predict_and_save_results(model, CFG.test_image_path, CFG.output_path)

print(f"Predicted results saved to {CFG.output_path}")


> SEEDING DONE
> SEEDING DONE


100%|███████████████████████████████████████████| 40/40 [00:09<00:00,  4.33it/s]


Epoch 1/20, Train Loss: 1.3380, Val Loss: 1.0703, Mean SBD: 0.6417


100%|███████████████████████████████████████████| 40/40 [00:08<00:00,  4.77it/s]


Epoch 2/20, Train Loss: 0.8477, Val Loss: 0.7085, Mean SBD: 0.7689


100%|███████████████████████████████████████████| 40/40 [00:08<00:00,  4.72it/s]


Epoch 3/20, Train Loss: 0.4323, Val Loss: 0.3114, Mean SBD: 0.8348


100%|███████████████████████████████████████████| 40/40 [00:08<00:00,  4.78it/s]


Epoch 4/20, Train Loss: 0.2280, Val Loss: 0.2036, Mean SBD: 0.8414


100%|███████████████████████████████████████████| 40/40 [00:08<00:00,  4.78it/s]


Epoch 5/20, Train Loss: 0.1862, Val Loss: 0.1847, Mean SBD: 0.8483


100%|███████████████████████████████████████████| 40/40 [00:08<00:00,  4.77it/s]


Epoch 6/20, Train Loss: 0.1695, Val Loss: 0.1686, Mean SBD: 0.8576


100%|███████████████████████████████████████████| 40/40 [00:08<00:00,  4.80it/s]


Epoch 7/20, Train Loss: 0.1571, Val Loss: 0.1635, Mean SBD: 0.8571


100%|███████████████████████████████████████████| 40/40 [00:08<00:00,  4.80it/s]


Epoch 8/20, Train Loss: 0.1518, Val Loss: 0.1796, Mean SBD: 0.8365


100%|███████████████████████████████████████████| 40/40 [00:08<00:00,  4.82it/s]


Epoch 9/20, Train Loss: 0.1431, Val Loss: 0.1620, Mean SBD: 0.8528


100%|███████████████████████████████████████████| 40/40 [00:08<00:00,  4.84it/s]


Epoch 10/20, Train Loss: 0.1385, Val Loss: 0.1547, Mean SBD: 0.8621


100%|███████████████████████████████████████████| 40/40 [00:08<00:00,  4.74it/s]


Epoch 11/20, Train Loss: 0.1318, Val Loss: 0.1498, Mean SBD: 0.8657


100%|███████████████████████████████████████████| 40/40 [00:08<00:00,  4.81it/s]


Epoch 12/20, Train Loss: 0.1308, Val Loss: 0.1565, Mean SBD: 0.8598


100%|███████████████████████████████████████████| 40/40 [00:08<00:00,  4.83it/s]


Epoch 13/20, Train Loss: 0.1288, Val Loss: 0.1515, Mean SBD: 0.8631


100%|███████████████████████████████████████████| 40/40 [00:08<00:00,  4.82it/s]


Epoch 14/20, Train Loss: 0.1229, Val Loss: 0.1556, Mean SBD: 0.8564


100%|███████████████████████████████████████████| 40/40 [00:08<00:00,  4.76it/s]


Epoch 15/20, Train Loss: 0.1177, Val Loss: 0.1507, Mean SBD: 0.8634


100%|███████████████████████████████████████████| 40/40 [00:08<00:00,  4.71it/s]


Epoch 16/20, Train Loss: 0.1228, Val Loss: 0.1569, Mean SBD: 0.8584


100%|███████████████████████████████████████████| 40/40 [00:08<00:00,  4.73it/s]


Epoch 17/20, Train Loss: 0.1137, Val Loss: 0.1439, Mean SBD: 0.8690


100%|███████████████████████████████████████████| 40/40 [00:08<00:00,  4.75it/s]


Epoch 18/20, Train Loss: 0.1093, Val Loss: 0.1535, Mean SBD: 0.8597


100%|███████████████████████████████████████████| 40/40 [00:08<00:00,  4.73it/s]


Epoch 19/20, Train Loss: 0.1050, Val Loss: 0.1457, Mean SBD: 0.8676


100%|███████████████████████████████████████████| 40/40 [00:08<00:00,  4.73it/s]


Epoch 20/20, Train Loss: 0.1002, Val Loss: 0.1467, Mean SBD: 0.8665
Best Validation Loss: 0.14389776363968848, Best Mean SBD: 0.8690464049577713


100%|███████████████████████████████████████████| 68/68 [00:02<00:00, 28.67it/s]

Predicted results saved to ./predictions



