In [1]:
#import statements

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.datasets import ImageFolder
from PIL import Image

import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import cv2
import xml.etree.ElementTree as ET
from sklearn.model_selection import train_test_split
import random
from tqdm import tqdm

from torchvision.models import resnet
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torchvision.transforms import RandomHorizontalFlip, RandomResizedCrop, Resize, ToTensor, Normalize
from torch.utils.data import DataLoader
from torchvision.datasets import VOCSegmentation
from torch.utils.data.dataset import Subset
from torch.utils.data.sampler import SubsetRandomSampler
from torch.nn import CrossEntropyLoss


In [None]:
#generate segmentation masks for 17k images

image_folder = "train/train/images/"
xml_folder = "train/train/annotations/"

output_folder = "train/train/segmentation_masks/"
os.makedirs(output_folder, exist_ok=True)

def parse_xml(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()
    objects = []
    for obj in root.findall('object'):
        obj_name = obj.find('name').text
        bbox = obj.find('bndbox')
        xmin = float(bbox.find('xmin').text)
        ymin = float(bbox.find('ymin').text)
        xmax = float(bbox.find('xmax').text)
        ymax = float(bbox.find('ymax').text)
        objects.append((obj_name, (xmin, ymin, xmax, ymax)))
    return objects


def generate_segmentation_mask(image_size, objects):
    mask = np.zeros((image_size[0], image_size[1], 3), dtype=np.uint8)  # Initialize mask with 3 channels (RGB)

    # Define colors for different classes
    colors = {}
    for obj_name, _ in objects:
        if obj_name not in colors:
            colors[obj_name] = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))  # Generate random color for each class

    # Assign colors to segmented regions
    for obj_name, (xmin, ymin, xmax, ymax) in objects:
        color = colors[obj_name]
        xmin, ymin, xmax, ymax = int(xmin), int(ymin), int(xmax), int(ymax)
        mask[ymin:ymax, xmin:xmax] = color  # Set pixels inside bounding box to color

    return mask



# Process each image and its corresponding XML annotation
for filename in tqdm(os.listdir(image_folder), desc="Processing Images"):
    if filename.endswith(".jpg"):
        image_path = os.path.join(image_folder, filename)
        xml_path = os.path.join(xml_folder, filename.replace(".jpg", ".xml"))

        # Check if corresponding XML file exists
        if os.path.exists(xml_path):
            objects = parse_xml(xml_path)

            # Read image
            image = cv2.imread(image_path)

            # Generate segmentation mask
            mask = generate_segmentation_mask(image.shape[:2], objects)

            mask = cv2.resize(mask, (256,256))

            # Save segmentation mask
            mask_filename = os.path.join(output_folder, filename.replace(".jpg", ".jpg"))
            cv2.imwrite(mask_filename, mask)
        else:
            print(f"XML annotation not found for {filename}")



In [None]:
#masks png to jpg

import shutil

# Set the path to the folders containing the images.
truth_path = "train/train/ground_truth"
truth_jpg_path = "train/train/ground_truth_jpg"
images_path = "train/train/images"
real_path = "train/train/true_real_images"

def convert_to_jpg(truth_path, truth_jpg_path):
    # Open the image
    img = Image.open(truth_path)
    # Convert to RGB if it's not already
    if img.mode != 'RGB':
        img = img.convert('RGB')
    # Get the filename without extension
    filename = os.path.splitext(os.path.basename(truth_path))[0]
    # Save as JPEG
    img.save(os.path.join(truth_jpg_path, filename + ".jpg"), "JPEG")
    print(f"Converted {filename} to JPEG")

# Convert PNG images to JPEG
for filename in os.listdir(truth_path):
    if filename.endswith(".png"):
        convert_to_jpg(os.path.join(truth_path, filename), truth_jpg_path)

# Get a list of the files in each folder.
truth_files = os.listdir(truth_jpg_path)
images_files = os.listdir(images_path)

# Create a list of the images that match the name of images in both folders.
matching_images = []
for file in truth_files:
    if file in images_files:
        matching_images.append(file)

# Extract the images from folder1 and save them to folder2.
for image in matching_images:
    image_path = os.path.join(images_path, image)
    new_image_path = os.path.join(real_path, image)
    shutil.copy(image_path, new_image_path)

print("Processing & extraction complete")


In [None]:
# normalize image dimensions

import sys
sys.path.append('C:\\users\\rpoje\\appdata\\local\\packages\\pythonsoftwarefoundation.python.3.10_qbz5n2kfra8p0\\localcache\\local-packages\\python310\\site-packages')


# Input and output directories
input_dir = "train/train"
output_dir = "train/train/normalized"

# Create the output directory if it doesn't exist
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Function to crop and resize an image
def process_image(image_path, output_path, width, height):
    try:
        # Load the image
        img = cv2.imread(image_path)

        # Check if the image loaded successfully
        if img is not None:
            # Crop and resize the image
            cropped_img = cv2.resize(img, (width, height))

            # Create intermediate directories if they don't exist
            os.makedirs(os.path.dirname(output_path), exist_ok=True)

            # Save the processed image to the output directory
            cv2.imwrite(output_path, cropped_img)
        else:
            raise Exception("Error loading the image")

    except Exception as e:
        print(f"Error processing {image_path}: {e}")
        # Delete the problematic image
        os.remove(image_path)

# Function to process images in subdirectories
def process_subdirectory(input_subdir, output_subdir, width, height):
    for root, _, files in os.walk(input_subdir):
        for filename in files:
            input_path = os.path.join(root, filename)
            rel_path = os.path.relpath(input_path, input_subdir)
            output_path = os.path.join(output_subdir, rel_path)
            process_image(input_path, output_path, width, height)
        print("done-1")


train_input_dir = os.path.join(input_dir, "images")
train_output_dir = os.path.join(output_dir, "images")
mask_input_dir = os.path.join(input_dir, "segmentation_masks")
mask_output_dir = os.path.join(output_dir, "masks")

# Process original image data
process_subdirectory(train_input_dir, train_output_dir, 256, 256)

# Process mask data
process_subdirectory(mask_input_dir, mask_output_dir, 256, 256)

print("Image processing and error handling completed.")


In [None]:
#Cascaded Atrous Layers

class AtrousCascade(nn.Module):
    def __init__(self, in_channels, out_channels=256, atrous_rates=[6, 12, 18]):
        super(AtrousCascade, self).__init__()

        self.conv1x1_input = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.bn0 = nn.BatchNorm2d(out_channels)

        #1x1 convolution
        self.conv1x1_1 = nn.Conv2d(out_channels, out_channels, kernel_size=1)
        self.bn1 = nn.BatchNorm2d(out_channels)

        #three 3x3 convolutions
        self.conv3x3_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=atrous_rates[0], dilation=atrous_rates[0])
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3x3_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=atrous_rates[1], dilation=atrous_rates[1])
        self.bn3 = nn.BatchNorm2d(out_channels)
        self.conv3x3_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=atrous_rates[2], dilation=atrous_rates[2])
        self.bn4 = nn.BatchNorm2d(out_channels)

        #final logits
        self.conv1x1_final = nn.Conv2d(out_channels,out_channels,kernel_size=1)

    def forward(self, x):
        out0x0 = F.relu(self.bn0(self.conv1x1_input(x)))
        out1x1 = F.relu(self.bn1(self.conv1x1_1(out0x0)))
        out3x3_1 = F.relu(self.bn2(self.conv3x3_1(out1x1)))
        out3x3_2 = F.relu(self.bn3(self.conv3x3_2(out3x3_1)))
        out3x3_3 = F.relu(self.bn4(self.conv3x3_3(out3x3_2)))

        # Final 1x1 convolution for logits
        out = self.conv1x1_final(out3x3_3)

        return out

In [3]:
#Atrous Spatial Pyramid Pooling

class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels=256, atrous_rates=[6, 12, 18]):
        super(ASPP, self).__init__()

        self.conv1x1_input = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.bn0 = nn.BatchNorm2d(out_channels)

        #1x1 convolution
        self.conv1x1_1 = nn.Conv2d(out_channels, out_channels, kernel_size=1)
        self.bn1 = nn.BatchNorm2d(out_channels)

        #three 3x3 convolutions
        self.conv3x3_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=atrous_rates[0], dilation=atrous_rates[0])
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3x3_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=atrous_rates[1], dilation=atrous_rates[1])
        self.bn3 = nn.BatchNorm2d(out_channels)
        self.conv3x3_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=atrous_rates[2], dilation=atrous_rates[2])
        self.bn4 = nn.BatchNorm2d(out_channels)

        # Concatenate the parallel layers
        self.conv1x1_concat = nn.Conv2d(out_channels * 4, out_channels, kernel_size=1)
        self.bn_concat = nn.BatchNorm2d(out_channels)

        #final logits
        self.conv1x1_final = nn.Conv2d(out_channels,out_channels,kernel_size=1)

    def forward(self, x):
        out0x0 = F.relu(self.bn0(self.conv1x1_input(x)))
        out1x1 = F.relu(self.bn1(self.conv1x1_1(out0x0)))
        out3x3_1 = F.relu(self.bn2(self.conv3x3_1(out0x0)))
        out3x3_2 = F.relu(self.bn3(self.conv3x3_2(out0x0)))
        out3x3_3 = F.relu(self.bn4(self.conv3x3_3(out0x0)))

        out = torch.cat([out1x1, out3x3_1, out3x3_2, out3x3_3], dim=1)

        out = F.relu(self.bn_concat(self.conv1x1_concat(out)))

        # Final 1x1 convolution for logits
        out = self.conv1x1_final(out)

        return out

In [4]:
#DeepLabV3

class DeepLabV3(nn.Module):
    def __init__(self, num_classes):
        super(DeepLabV3, self).__init__()
        # Load pre-trained ResNet-101 backbone
        resnet101 = resnet.resnet101(pretrained=True)
        # Remove the fully connected layer AND the average pooling layer
        self.backbone = nn.Sequential(*list(resnet101.children())[:-2])
        #print(self.backbone)

        # ASPP module
        self.aspp = ASPP(in_channels=2048, out_channels=256)

        # Upsampling layer
        self.upsample = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)

        # Final convolution layer
        self.conv_final = nn.Conv2d(256, num_classes, kernel_size=1)

    def forward(self, x):
        # Backbone feature extraction
        x = self.backbone(x)

        # ASPP module
        x = self.aspp(x)

        # Upsampling
        x = self.upsample(x)

        # Final convolution
        x = self.conv_final(x)

        return x

In [5]:
#training protocols

# Learning rate policy function
def poly_lr_scheduler(optimizer, init_lr, iter, max_iter, power=0.9):
    lr = init_lr * ((1 - iter / max_iter) ** power)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return optimizer

# Data augmentation
train_transform = transforms.Compose([
    RandomResizedCrop(513),
    RandomHorizontalFlip(),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [10]:
class CustomDataset(Dataset):
    def __init__(self, root, transform):
        self.root = root
        self.transform = transform
        self.image_folder = 'train/true_real_images'
        self.mask_folder = 'train/ground_truth_jpg'

        self.image_paths = sorted([os.path.join(self.root, self.image_folder, filename) for filename in os.listdir(os.path.join(self.root, self.image_folder))])
        self.mask_paths = sorted([os.path.join(self.root, self.mask_folder, filename) for filename in os.listdir(os.path.join(self.root, self.mask_folder))])

        print("Number of images:", len(self.image_paths))
        print("Number of masks:", len(self.mask_paths))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        image = Image.open(self.image_paths[index])
        mask = Image.open(self.mask_paths[index])

        if self.transform is not None:
            image = self.transform(image)
            mask = self.transform(mask)

        #mask = mask.permute(2, 0, 1)

        return image, mask

In [11]:
# Dataset and DataLoader
data_dir = 'train/'
train_dataset = CustomDataset(root=data_dir,transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

# Define the model
model = DeepLabV3(num_classes=21)


Number of images: 2913
Number of masks: 2913


In [12]:
#dice loss

class DiceLoss(nn.Module):
    def __init__(self, smooth=1.):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, input, target):
        # Flatten the input and target tensors
        input_flat = input.view(-1)
        target_flat = target.view(-1)

        # Calculate intersection and union
        intersection = (input_flat * target_flat).sum()
        union = input_flat.sum() + target_flat.sum()

        # Calculate Dice coefficient
        dice = (2. * intersection + self.smooth) / (union + self.smooth)

        # Calculate Dice loss
        dice_loss = 1 - dice
        return dice_loss


In [14]:
# Optimizer
optimizer = optim.SGD(model.parameters(), lr=0.007, momentum=0.9, weight_decay=0.9997)

# Loss function
criterion = CrossEntropyLoss()

# Training loop
max_iter = 30000 #30k mentioned in paper
for epoch in range(1, max_iter + 1):
    model.train()
    optimizer = poly_lr_scheduler(optimizer, 0.007, epoch, max_iter)

    for images, labels in train_loader:
        print(images.shape,labels.shape)
        optimizer.zero_grad()
        outputs = model(images)
        print(outputs.shape)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# Freeze batch normalization parameters
for module in model.modules():
    if isinstance(module, nn.BatchNorm2d):
        module.eval()
        print(module.requires_grad)

torch.Size([16, 3, 513, 513]) torch.Size([16, 3, 513, 513])
torch.Size([16, 21, 272, 272])
torch.Size([16, 3, 513, 513]) torch.Size([16, 3, 513, 513])
torch.Size([16, 21, 272, 272])
torch.Size([16, 3, 513, 513]) torch.Size([16, 3, 513, 513])
torch.Size([16, 21, 272, 272])
torch.Size([16, 3, 513, 513]) torch.Size([16, 3, 513, 513])
torch.Size([16, 21, 272, 272])
torch.Size([16, 3, 513, 513]) torch.Size([16, 3, 513, 513])
torch.Size([16, 21, 272, 272])
torch.Size([16, 3, 513, 513]) torch.Size([16, 3, 513, 513])
torch.Size([16, 21, 272, 272])
torch.Size([16, 3, 513, 513]) torch.Size([16, 3, 513, 513])
torch.Size([16, 21, 272, 272])
torch.Size([16, 3, 513, 513]) torch.Size([16, 3, 513, 513])
torch.Size([16, 21, 272, 272])
torch.Size([16, 3, 513, 513]) torch.Size([16, 3, 513, 513])
torch.Size([16, 21, 272, 272])
torch.Size([16, 3, 513, 513]) torch.Size([16, 3, 513, 513])
torch.Size([16, 21, 272, 272])
torch.Size([16, 3, 513, 513]) torch.Size([16, 3, 513, 513])
torch.Size([16, 21, 272, 272])

In [None]:
class CustomDataset2(Dataset):
    def __init__(self, root, transform):
        self.root = root
        self.transform = transform
        self.image_folder = 'train/images'
        self.mask_folder = 'train/segmentation_masks'

        self.image_paths = sorted([os.path.join(self.root, self.image_folder, filename) for filename in os.listdir(os.path.join(self.root, self.image_folder))])
        self.mask_paths = sorted([os.path.join(self.root, self.mask_folder, filename) for filename in os.listdir(os.path.join(self.root, self.mask_folder))])

        print("Number of images:", len(self.image_paths))
        print("Number of masks:", len(self.mask_paths))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        image = Image.open(self.image_paths[index])
        mask = Image.open(self.mask_paths[index])

        if self.transform is not None:
            image = self.transform(image)
            mask = self.transform(mask)

        #mask = mask.permute(2, 0, 1)

        return image, mask

In [None]:
# Optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.9997)

#Loss function
criterion = CrossEntropyLoss()

# Training loop
max_iter = 30000 #30k mentioned in paper
for epoch in range(1, max_iter + 1):
    model.train()
    optimizer = poly_lr_scheduler(optimizer, 0.007, epoch, max_iter)

    for images, labels in train_loader:
        print(images.shape,labels.shape)
        optimizer.zero_grad()
        outputs = model(images)
        print(outputs.shape)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# Freeze batch normalization parameters
for module in model.modules():
    if isinstance(module, nn.BatchNorm2d):
        module.eval()
        print(module.requires_grad)

In [None]:
'''# Validation loop example:
model.eval()
with torch.no_grad():
    for images, masks in val_loader:
        outputs = model(images)
        # Evaluate your model's performance on the validation set'''