In [21]:
# Libraries import

import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

from PIL import Image

from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation

from torch.utils.data import Subset

from sklearn.metrics import confusion_matrix

import matplotlib.pyplot as plt

import pandas as pd
import numpy as np

import os
import glob
import time
import csv

import seaborn as sns


import matplotlib.colors as mcolors

import onnx

import tifffile as tiff

import matplotlib.colors as mcolors


In [22]:
# Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [23]:

# Dataloader

class CropDataloader(Dataset):
    
    # Dataset => image and mask pairs. Images in RGB.   

    def __init__(self, image_directory, mask_directory):

        super().__init__()

        self.image_paths = sorted(glob.glob(os.path.join(image_directory, "*.jpg")))
        self.mask_paths = sorted(glob.glob(os.path.join(mask_directory, "*.tif")))

    def __len__(self):

        return len(self.image_paths)

    def __getitem__(self, index):

        image = Image.open(self.image_paths[index]).convert('RGB')

        mask = Image.open(self.mask_paths[index])

        return image, mask

def custom_collate_fn(batch):
    images, masks = zip(*batch)

    return list(images), list(masks)

# Image path

# Training dataloader

image_path_train = "/local/data/sdahal_p/Crop/data/train/images/"
mask_image_path_train = "/local/data/sdahal_p/Crop/data/train/masks/"

dataset_train = CropDataloader(image_path_train, mask_image_path_train)

dataloader_train = DataLoader(dataset_train, batch_size=1, shuffle=True, collate_fn=custom_collate_fn)


# validation dataloader

image_path_validation = "/local/data/sdahal_p/Crop/data/validation/images/"
mask_image_path_validation = "/local/data/sdahal_p/Crop/data/validation/masks/"

dataset_validation = CropDataloader(image_path_validation, mask_image_path_validation)

dataloader_validation = DataLoader(dataset_validation, batch_size=1, shuffle=True, collate_fn=custom_collate_fn)


In [24]:
# Testing data

class CropDataloaderTest(Dataset):
    
    def __init__(self, image_directory):
    
        super().__init__()

        self.image_paths = sorted(glob.glob(os.path.join(image_directory, "*.jpg")))

    def __len__(self):

        return len(self.image_paths)

    def __getitem__(self, index):

        image = Image.open(self.image_paths[index]).convert('RGB')

        return image

def custom_collate_fn(batch):

    return list(batch)

image_path = "/local/data/sdahal_p/Crop/data/test/"

dataset_test = CropDataloaderTest(image_path)

dataloader_test = DataLoader(dataset_test, batch_size=1, shuffle=True, collate_fn=custom_collate_fn)


In [25]:

# IoU function

def compute_iou(preds, target, num_classes=4):

        iou_per_class = []
        
        for cls in range(num_classes):

            pred_cls = preds == cls
            
            target_cls = target == cls

            intersection = torch.logical_and(pred_cls, target_cls).sum().item()
            
            union = torch.logical_or(pred_cls, target_cls).sum().item()

            if union == 0:
                
                iou_per_class.append(float('nan'))  

            else:
                
                iou_per_class.append(intersection / union)

        
        
        return iou_per_class  

In [26]:
# Confusion matrix

def func_confusion_matrix(all_preds, all_targets, class_id, class_mapping):  

    binary_targets = (np.array(all_targets) == class_id).astype(int)
    binary_preds = (np.array(all_preds) == class_id).astype(int)

    binary_conf_matrix = confusion_matrix(binary_targets, binary_preds)

    binary_conf_matrix_percent = binary_conf_matrix.astype(np.float32) / binary_conf_matrix.sum() * 100

    # Class name from mapping
    class_name = class_mapping.get(class_id, f"Class {class_id}")

    plt.figure(figsize=(5, 4))
    sns.heatmap(binary_conf_matrix_percent, annot=True, fmt=".2f", cmap="Blues")
    plt.xlabel("Predicted Label")
    plt.ylabel("Actual Label")
    plt.title(f"Confusion Matrix for {class_name} (Percentage)")
    plt.show()

    

In [27]:
# Training configuration

# Model
model_name = "nvidia/segformer-b0-finetuned-ade-512-512"

num_classes = 4  

image_processor = SegformerImageProcessor.from_pretrained(model_name)

# 
model = SegformerForSemanticSegmentation.from_pretrained(

    model_name,
    
    num_labels=num_classes,  

    ignore_mismatched_sizes=True

)

model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=1e-4)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 4 ,gamma = 0.6, last_epoch= -1, verbose=False)


In [28]:

# Training loop


num_epochs = 32

epoch_mean_iou = []

training_iou = []

validation_iou = []

for epoch in range(num_epochs):
    
    model.train()  
    
    total_loss = 0.0
    
    num_batches = 0

    total_mean_iou = 0.0

    for images, masks in dataloader_train:

        encoding = image_processor(

            images=images,
            segmentation_maps=masks,
            return_tensors="pt"
        )

        pixel_values = encoding["pixel_values"].to(device)
         
        labels = encoding["labels"].to(device)             

        optimizer.zero_grad()

        outputs = model(pixel_values=pixel_values, labels=labels)

        loss = outputs.loss

        loss.backward()

        optimizer.step()

        total_loss += loss.item()

        num_batches += 1

    avg_loss = total_loss / num_batches
    
    logits = outputs.logits


    # saving the model 

    if epoch % 2 == 0:

        checkpoint = {
        'epoch': epoch, 
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        }

        torch.save(checkpoint, '/local/data/sdahal_p/genome-test/model_checkpoint.pth')

    logits_upsampled = F.interpolate(logits, size=(512, 512), mode="bilinear", align_corners=False)

    preds = torch.argmax(logits_upsampled, dim=1)

    iou_scores_training = compute_iou(preds, labels)

    mean_iou_training = torch.nanmean(torch.tensor(iou_scores_training))  

    epoch_mean_iou.append(mean_iou_training)

    scheduler.step()

    # print(f"Training IoU per class: {iou_scores_training}")
    # print(f"Mean IoU training: {mean_iou_training}")

    mean_iou_training = mean_iou_training

    # Validation:

    validation_loss = 0.0

    all_preds = []

    all_targets = []

    # Validation

    with torch.no_grad():

        for images, masks in dataloader_validation:

            encoding = image_processor(

                images=images,
                segmentation_maps=masks,
                return_tensors="pt"

            )

            pixel_values = encoding["pixel_values"].to(device) 

            labels = encoding["labels"].to(device)             

            outputs = model(pixel_values=pixel_values, labels=labels)

            loss = outputs.loss

            validation_loss += loss.item()

            predicted_segmentation = torch.argmax(outputs.logits, dim=1)

            all_preds.extend(preds.view(-1).cpu().numpy())  

            all_targets.extend(labels.view(-1).cpu().numpy())

        logits = outputs.logits

        logits_upsampled = F.interpolate(logits, size=(512, 512), mode="bilinear", align_corners=False)

        preds = torch.argmax(logits_upsampled, dim=1)

        iou_scores_validation = compute_iou(preds, labels)

        mean_iou_validation = torch.nanmean(torch.tensor(iou_scores_validation))  

        epoch_mean_iou.append(mean_iou_validation.item())

        # print(f"Validation IoU per class: {iou_scores_validation}")
        # print(f"Mean IoU Validation: {mean_iou_validation.item()}")

        if epoch == 15:

            all_preds.extend(preds.view(-1).cpu().numpy())  

            all_targets.extend(labels.view(-1).cpu().numpy())

            all_preds = np.array(all_preds)
            
            all_targets = np.array(all_targets)

            for index in range(4):
                class_id = index

                class_mapping = {
                    0: "Residue with sunlight",
                    1: "Residue with shade",
                    2: "Background with sunlight",
                    3: "Background with shade"
                }

                func_confusion_matrix(all_preds, all_targets, class_id, class_mapping)


    avg_val_loss = validation_loss/len(dataloader_validation)
    
    # print(f"Epoch [{epoch+1}/{num_epochs}], Validation Loss: {avg_val_loss:.4f}")

    print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_loss:.4f}")

    training_iou.append(mean_iou_training)

    validation_iou.append(mean_iou_validation)

    
    
    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)

csv_filename = "/local/data/sdahal_p/Crop/iou_scores_log.csv"


epochs = list(range(1, len(training_iou) + 1))

training_iou = [iou.item() for iou in training_iou]
validation_iou = [iou.item() for iou in validation_iou]

df = pd.DataFrame({
    "Epoch": epochs,
    "Training IoU": training_iou,
    "Validation IoU": validation_iou
})

df.to_csv(csv_filename, mode='w', index=False)

# Training ends



In [29]:
# Testing the model


original_dir = "/local/data/sdahal_p/Crop/original_images"
predicted_dir = "/local/data/sdahal_p/Crop/predicted_segmentations"
area_dir = "/local/data/sdahal_p/Crop/area_reports"

os.makedirs(original_dir, exist_ok=True)
os.makedirs(predicted_dir, exist_ok=True)
os.makedirs(area_dir, exist_ok=True)

class_labels = {0: "Residue with sunlight", 1: "Residue with shade 1", 2: "Background with sunlight", 3: "Background with shade"}
class_colors = ["black", "red", "green", "blue"]  

cmap = mcolors.ListedColormap(class_colors)
bounds = list(class_labels.keys()) + [len(class_labels)]
norm = mcolors.BoundaryNorm(bounds, cmap.N)

pixel_size_mm2 = 0.01  


output_folder = "segmentation_results"
os.makedirs(output_folder, exist_ok=True) 

with torch.no_grad():
    for idx, images in enumerate(dataloader_test):

        if isinstance(images, torch.Tensor):
            images_pil = [to_pil(img) for img in images]
        else:
            images_pil = images  

        encoding = image_processor(images=images_pil, return_tensors="pt")
        pixel_values = encoding["pixel_values"].to(device)
        outputs = model(pixel_values=pixel_values)

        logits = outputs.logits
        logits_upsampled = F.interpolate(logits, size=(512, 512), mode="bilinear", align_corners=False)
        predicted_segmentation = torch.argmax(logits_upsampled, dim=1).squeeze().cpu().numpy()

        if predicted_segmentation.dtype != np.uint8:
            
            predicted_segmentation = predicted_segmentation.astype(np.uint8)

        filename = os.path.join(output_folder, f"segmentation_{idx}.tif")

        tiff.imwrite(filename, predicted_segmentation)

        original_image_path = os.path.join(original_dir, f"image_{idx}.png")
        images_pil[0].save(original_image_path)

        predicted_image_path = os.path.join(predicted_dir, f"segmentation_{idx}.png")

        fig, ax = plt.subplots(figsize=(6, 6))
        img = ax.imshow(predicted_segmentation, cmap=cmap, norm=norm)
        ax.axis("off")

        cbar = plt.colorbar(img, ticks=list(class_labels.keys()))
        cbar.ax.set_yticklabels(list(class_labels.values()))  

        plt.savefig(predicted_image_path, bbox_inches='tight', dpi=300)
        plt.close(fig)

        area_report = {}
        for class_id, class_name in class_labels.items():
            pixel_count = np.sum(predicted_segmentation == class_id)
            area_mm2 = pixel_count * pixel_size_mm2
            area_report[class_name] = area_mm2

        area_report_path = os.path.join(area_dir, f"area_{idx}.txt")
        with open(area_report_path, "w") as f:
            for class_name, area in area_report.items():
                f.write(f"{class_name}: {area:.2f} mm²\n")