In [1]:
# Importing essential PyTorch libraries for deep learning operations
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, random_split
from PIL import Image
import os

# Importing mixed precision training utilities for enhanced performance
from torch.cuda.amp import autocast, GradScaler

# Implementing the DoubleConv block that forms the fundamental building block of U-Net
class DoubleConv(nn.Module):
    """Defining a double convolution block with (Conv -> BatchNorm -> ReLU) repeated twice"""
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        # Constructing sequential convolution operations
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Processing input through the convolution block
        return self.conv(x)

# Implementing the complete U-Net architecture
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNet, self).__init__()

        # Constructing the encoder (downsampling path)
        self.enc1 = DoubleConv(in_channels, 64)  # First encoding block
        self.enc2 = DoubleConv(64, 128)         # Second encoding block
        self.enc3 = DoubleConv(128, 256)        # Third encoding block
        self.enc4 = DoubleConv(256, 512)        # Fourth encoding block
        self.enc5 = DoubleConv(512, 1024)       # Bottleneck layer

        # Defining max pooling operation for spatial dimension reduction
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Constructing the decoder (upsampling path) with skip connections
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)  # First upsampling
        self.dec4 = DoubleConv(1024, 512)       # First decoding block

        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)   # Second upsampling
        self.dec3 = DoubleConv(512, 256)        # Second decoding block

        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)   # Third upsampling
        self.dec2 = DoubleConv(256, 128)        # Third decoding block

        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)    # Fourth upsampling
        self.dec1 = DoubleConv(128, 64)         # Fourth decoding block

        # Final convolution layer producing the segmentation output
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        # Implementing the forward pass through the U-Net architecture
        
        # Processing through the encoder path
        x1 = self.enc1(x)          # First encoding block
        x2 = self.enc2(self.pool(x1))  # Second encoding block with downsampling
        x3 = self.enc3(self.pool(x2))  # Third encoding block with downsampling
        x4 = self.enc4(self.pool(x3))  # Fourth encoding block with downsampling
        x5 = self.enc5(self.pool(x4))  # Bottleneck layer with maximum downsampling

        # Processing through the decoder path with skip connections
        d4 = self.up4(x5)           # First upsampling operation
        d4 = torch.cat((d4, x4), dim=1)  # Incorporating skip connection
        d4 = self.dec4(d4)          # First decoding block

        d3 = self.up3(d4)           # Second upsampling operation
        d3 = torch.cat((d3, x3), dim=1)  # Incorporating skip connection
        d3 = self.dec3(d3)          # Second decoding block

        d2 = self.up2(d3)           # Third upsampling operation
        d2 = torch.cat((d2, x2), dim=1)  # Incorporating skip connection
        d2 = self.dec2(d2)          # Third decoding block

        d1 = self.up1(d2)           # Fourth upsampling operation
        d1 = torch.cat((d1, x1), dim=1)  # Incorporating skip connection
        d1 = self.dec1(d1)          # Fourth decoding block

        # Generating final segmentation output
        return self.final_conv(d1)

# Implementing custom dataset loader for image-mask pairs
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        # Listing all images in the face_crop directory
        self.images = os.listdir(os.path.join(root_dir, 'face_crop'))
        # Verifying corresponding masks exist for each image
        self.images = [img for img in self.images if os.path.exists(os.path.join(root_dir, 'face_crop_segmentation', img))]

    def __len__(self):
        # Returning total number of valid image-mask pairs
        return len(self.images)

    def __getitem__(self, idx):
        # Constructing paths to image and mask files
        img_name = os.path.join(self.root_dir, 'face_crop', self.images[idx])
        mask_name = os.path.join(self.root_dir, 'face_crop_segmentation', self.images[idx])

        # Loading image and mask
        image = Image.open(img_name).convert('RGB')
        mask = Image.open(mask_name).convert('L')

        # Applying transformations if specified
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask

# Implementing Intersection over Union (IoU) metric calculation
def iou_score(preds, targets, threshold=0.5):
    # Converting predictions to binary masks using threshold
    preds = torch.sigmoid(preds) > threshold
    targets = targets > threshold
    
    # Calculating intersection and union areas
    intersection = (preds & targets).float().sum((1, 2, 3))
    union = (preds | targets).float().sum((1, 2, 3))
    
    # Computing IoU while preventing division by zero
    iou = (intersection / (union + 1e-6)).mean()
    return iou.item()

# Defining image preprocessing transformations
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resizing images to consistent dimensions
    transforms.ToTensor(),          # Converting images to tensor format
])
  
# Determining base directory path
base_dir = os.getcwd() 

# Constructing dataset path  
dataset_path = os.path.join(base_dir, "MSFD", "1")  

# Loading and splitting dataset  
full_dataset = CustomDataset(root_dir=dataset_path, transform=transform)  
train_size = int(0.8 * len(full_dataset))  # Allocating 80% for training
test_size = len(full_dataset) - train_size  # Allocating 20% for testing
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])  

# Creating data loaders for efficient batch processing
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2, pin_memory=True)

# Determining computation device (GPU if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initializing model components
model = UNet(in_channels=3, out_channels=1).to(device)  # Instantiating U-Net model
criterion = nn.BCEWithLogitsLoss()  # Defining loss function
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)  # Configuring optimizer
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)  # Setting learning rate scheduler
scaler = GradScaler()  # Initializing gradient scaler for mixed precision

# Implementing training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()  # Setting model to training mode
    train_loss = 0
    train_iou = 0
    
    for batch_idx, (data, targets) in enumerate(train_loader):
        # Transferring data to computation device
        data, targets = data.to(device), targets.float().to(device)
        
        # Resetting gradients
        optimizer.zero_grad()

        # Performing forward pass with mixed precision
        with autocast():
            scores = model(data)  # Computing predictions
            loss = criterion(scores, targets)  # Calculating loss

        # Performing backward pass with gradient scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Accumulating training metrics
        train_loss += loss.item()
        train_iou += iou_score(scores, targets)

        # Printing batch progress
        print(f"Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}, IoU: {train_iou / (batch_idx + 1):.4f}")

    # Updating learning rate
    scheduler.step()
    
    # Calculating epoch averages
    avg_train_loss = train_loss / len(train_loader)
    avg_train_iou = train_iou / len(train_loader)
    print(f"Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f}, Train IoU: {avg_train_iou:.4f}")

# Saving trained model weights
base_dir=os.getcwd()
model_path = os.path.join(base_dir, "standard_unet.pth")
torch.save(model.state_dict(), model_path)

# Implementing model evaluation function
def evaluate_model(model, test_loader, device):
    model.eval()  # Setting model to evaluation mode
    test_iou = 0

    # Disabling gradient computation for evaluation
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.float().to(device)

            scores = model(data)  # Computing predictions
            iou = iou_score(scores, targets)  # Calculating IoU
            test_iou += iou

    # Computing average test IoU
    avg_test_iou = test_iou / len(test_loader)
    return avg_test_iou

# Uncomment to perform evaluation
# test_iou = evaluate_model(model, test_loader, device)
# print(f"Test IoU Score: {test_iou:.4f}")

  scaler = GradScaler()  # Initializing gradient scaler for mixed precision


In [None]:
def evaluate_model(model, test_loader, device):
    model.eval()  # Setting model to evaluation mode
    test_iou = 0

    with torch.no_grad():  # No gradient calculation needed for evaluation
        for data, targets in test_loader:
            data, targets = data.to(device), targets.float().to(device)

            scores = model(data)  # Forward pass
            iou = iou_score(scores, targets)  # Computting the IoU scores
            test_iou += iou

    avg_test_iou = test_iou / len(test_loader)  # Computing the average IoU scores
    return avg_test_iou

In [None]:
# Loading the Pretrained Model
model = UNet(in_channels=3, out_channels=1).to(device)
model_path=os.path.join(base_dir, "standard_unet.pth")
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()  # Setting the model to evaluation mode


UNet(
  (enc1): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (enc2): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (enc3): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(128, 256, kernel_size=

In [10]:
# Printing the final results
test_iou = evaluate_model(model, test_loader, device)
print(f"Test IoU Score: {test_iou:.4f}")


Test IoU Score: 0.9573
