In [33]:
import os 
import time
import random
from PIL import Image
import torch
import torch.nn as nn
from torch.optim import Adam, AdamW, SGD
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms 

In [49]:
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])
        self.images = [image for image in os.listdir(image_dir) if image.endswith(('.png', '.jpg', '.jpeg'))]
    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        name,ext=os.path.splitext(self.images[idx])
        mask_path = os.path.join(self.mask_dir, f'{name}.png')
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")
        image = self.transform(image)
        mask = self.transform(mask)
        mask = (mask > 0.5).float()
        return image, mask

In [50]:
def get_dataloader(image_dir, mask_dir, batch_size=2, shuffle=True):
    dataset=SegmentationDataset(image_dir, mask_dir)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)  

In [75]:
class DoubleConv(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.conv2d=nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channel, out_channel, 3, 1, 1),
            nn.ReLU(inplace=True),       
        )
    def forward(self, x):
        return self.conv2d(x)    

In [None]:
class DownSample(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.conv = DoubleConv(in_channel, out_channel)
        self.pool = nn.MaxPool2d(2, 2)
        
    def forward(self, x):
        down = self.conv(x)      
        pool = self.pool(down) 
        return down, pool


In [95]:
class UpSampple(nn.Module):
    def __init__(self,in_channel, out_channel):
        super().__init__()
        self.up_sam=nn.ConvTranspose2d(in_channel, in_channel//2, 2,2)
        self.conv=DoubleConv(in_channel, out_channel)
        
    def forward(self, x1, x2):
        x1=self.up_sam(x1)
        x=torch.cat([x1,x2],1)
        return self.conv(x)

In [104]:
class Unet(nn.Module):
    def __init__(self,in_channel, num_classes):
        super().__init__()
        self.down_con_1=DownSample(in_channel, 64)
        self.down_con_2=DownSample(64, 128)
        self.down_con_3=DownSample(128, 256)
        self.down_con_4=DownSample(256, 512)

        self.bottle_neck=DoubleConv(512,1024)
        
        self.up_con_1=UpSampple(1024, 512)
        self.up_con_2=UpSampple(512, 256)
        self.up_con_3=UpSampple(256, 128)
        self.up_con_4=UpSampple(128, 64)
        
        self.out=nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)
        
    def forward(self,x):
        down1, p1=self.down_con_1(x)
        down2, p2=self.down_con_2(p1)
        down3, p3=self.down_con_3(p2)
        down4, p4=self.down_con_4(p3)
        
        b=self.bottle_neck(p4)
        
        up1=self.up_con_1(b,down4)
        up2=self.up_con_2(up1,down3)
        up3=self.up_con_3(up2,down2)
        up4=self.up_con_4(up3,down1)
        
        output= self.out(up4)
        return output 
        
        

In [105]:
class DiceLoss(nn.Module):
  def __init__(self, smooth=1e-6):
    super(DiceLoss, self).__init__()
    self.smooth = smooth

  def forward(self, inputs, targets):
    inputs = inputs.view(-1) #Flatten
    targets = targets.view(-1)

    intersection = (inputs * targets).sum()
    dice_score = (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)

    return 1 - dice_score

In [106]:
class BCEWithDiceLoss(nn.Module):
  def __init__(self, smooth=1e-6):
    super(BCEWithDiceLoss, self).__init__()
    self.bce = nn.BCEWithLogitsLoss()
    self.dice = DiceLoss(smooth)

  def forward(self, inputs,targets):
    bce_loss = self.bce(inputs,targets)
    dice_loss = self.dice(inputs,targets)
    return 0.5*bce_loss + dice_loss

In [112]:
def train(model, dataloader, epochs=1, lr=0.0001, save_path="unet_model", load_path=None):
    if load_path is not None:
        print("Loading from the weight directly")
        model.load_state_dict(torch.load(load_path, map_location="cpu"))
    criterion=BCEWithDiceLoss()
    optimizer=AdamW(model.parameters() ,lr=lr)
    for e in range(epochs):
        model.train()
        epoch_loss=0
        
        for img, mask in dataloader:
            optimizer.zero_grad()
            
            output=model(img)
            
            loss=criterion(output, mask)
            loss.backward()
            optimizer.step()
            epoch_loss+=loss.item()
        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch [{e+1}/{epochs}], Loss : {avg_loss:.4f}, LR : {lr}")
        
    torch.save(model.state_dict(),f"{save_path}.pth")

In [None]:
data=get_dataloader("U-Net/Human-Segmentation-Dataset/Training_Images", "U-Net/Human-Segmentation-Dataset/Ground_Truth")

In [114]:
model=Unet(in_channel=3,num_classes=1)

In [115]:
train(model,data)

Epoch [1/1], Loss : 1.1634, LR : 0.0001


In [None]:
import numpy as np

def predict(model_path, input_image_path):
    model = Unet(in_channel=3, num_classes=1)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    image = Image.open(input_image_path).convert("RGB")
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    image_tensor = transform(image).unsqueeze(0)
    with torch.no_grad():
        output = model(image_tensor)
        output = torch.sigmoid(output)
   
    mask = output.squeeze(0).squeeze(0).cpu().numpy()
    mask = (mask > 0.5).astype(np.uint8) * 255
    mask_image = Image.fromarray(mask)

    combined = Image.new("RGB", (224 * 2, 224))
    combined.paste(image.resize((224, 224)), (0, 0))
    combined.paste(mask_image.convert("RGB"), (224, 0))
    combined.save("output.jpg")


In [None]:
predict(model_path="unet_model.pth", input_image_path="U-Net/Human-Segmentation-Dataset/Training_Images/24.jpg")

Using device: cpu

Prediction completed! Stats:
  Image Preprocessing Time: 0.0000 seconds
  Model Inference Time: 0.3965 seconds
  Postprocessing Time: 0.0013 seconds
  Total Prediction Time: 0.3978 seconds
Prediction saved as output.jpg
