In [23]:
import os
import time

from PIL import Image
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam, AdamW,SGD
import torch
import torch.nn.functional as F

In [4]:
class SegmentationDataset(Dataset):
    def __init__(self,images_dir,mask_dir):
        self.image_dir=images_dir
        self.mask_dir=mask_dir
        self.transform=transforms.Compose([
            transforms.Resize((512,512)),
            transforms.ToTensor()
        ]
        )
        self.valid_extension={".jpg",".jpeg",".png"}
        self.images=[f for f in os.listdir(self.image_dir) if os.path.splitext(f)[1].lower() in self.valid_extension]
    def __len__(self):
        return len(self.images)
    def __getitem__(self, index):
        image_path= os.path.join(self.image_dir,self.images[index])
        name,ext=os.path.splitext(self.images[index])
        masked_path=os.path.join(self.mask_dir,f"{name}.png")
        image=Image.open(image_path).convert("RGB")
        mask=Image.open(masked_path).convert("L")
        image=self.transform(image)
        mask=self.transform(mask)
        mask=(mask>0.5).float()
        return image,mask
    

    

In [5]:
def get_dataloader(image_dir,mask_dir,batch_size=2,shuffle=True):
    dataset=SegmentationDataset(images_dir=image_dir,
                        mask_dir=mask_dir)
    return DataLoader(dataset,batch_size=batch_size,shuffle=shuffle)


In [6]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_op = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv_op(x)



In [7]:
class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = DoubleConv(in_channels, out_channels)
        self.pool = nn.MaxPool2d(2)
    def forward(self, x):
        down = self.conv(x)
        p = self.pool(down)
        return down, p

In [8]:
class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)
    def forward(self, x1, x2):
        x1 = self.up(x1)
        # handle size mismatch (odd dimensions)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


In [9]:
class Unet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.downconv_1 = DownSample(in_channels, 64)
        self.downconv_2 = DownSample(64, 128)
        self.downconv_3 = DownSample(128, 256)
        self.downconv_4 = DownSample(256, 512)

        self.bottleneck = DoubleConv(512, 1024)

        self.up_conv1 = UpSample(1024, 512)
        self.up_conv2 = UpSample(512, 256)
        self.up_conv3 = UpSample(256, 128)
        self.up_conv4 = UpSample(128, 64)

        self.out = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        down_1, p1 = self.downconv_1(x)
        down_2, p2 = self.downconv_2(p1)
        down_3, p3 = self.downconv_3(p2)
        down_4, p4 = self.downconv_4(p3)

        b = self.bottleneck(p4)

        up_1 = self.up_conv1(b, down_4)
        up_2 = self.up_conv2(up_1, down_3)
        up_3 = self.up_conv3(up_2, down_2)
        up_4 = self.up_conv4(up_3, down_1)

        out = self.out(up_4)
        return out

In [25]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        # Apply sigmoid to logits
        inputs = torch.sigmoid(inputs)
        
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice_score = (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)

        # Return 1 - Dice Coefficient → smaller = better
        return 1 - dice_score


In [26]:
class BCEWithDiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(BCEWithDiceLoss, self).__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss(smooth)

    def forward(self, inputs, targets):
        bce_loss = self.bce(inputs, targets)
        dice_loss = self.dice(inputs, targets)
        return 0.5 * bce_loss + dice_loss


In [27]:
## Training loop

def train(model, dataloader, epochs=2, lr=0.001, save_path="unet_model", load_path=None):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    if load_path and os.path.exists(load_path):
        print(f"Loading model weights from {load_path}")
        model.load_state_dict(torch.load(load_path, map_location=device))
    else:
        print(f"No checkpoint found, training from scratch")
    
    print(f"Using device: {device}")
    model = model.to(device)

    criterion = BCEWithDiceLoss()  # use your combined loss
    optimizer = SGD(model.parameters(), lr=lr, momentum=0.9)

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0

        for images, masks in dataloader:
            # ✅ move data to same device as model
            images = images.to(device)
            masks = masks.to(device)

            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, masks)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, LR: {lr}")

        # ✅ Save every 10 epochs or at the end
        if (epoch + 1) % 10 == 0 or (epoch + 1) == epochs:
            torch.save(model.state_dict(), f"{save_path}_{epoch+1}.pth")
            print(f"Checkpoint saved at epoch {epoch+1}")

    print(f"Training complete! Final model saved to {save_path}_final.pth")
    torch.save(model.state_dict(), f"{save_path}_final.pth")


In [13]:
dataloader=get_dataloader("cnn/data/Human-Segmentation-Dataset/Training_Images","cnn/data/Human-Segmentation-Dataset/Ground_Truth",batch_size=8,shuffle=True)


In [28]:
model=Unet(3,num_classes=1)

In [15]:
import torch, torchvision, torchaudio
print("Torch:", torch.__version__, torch.version.cuda)
print("Torchvision:", torchvision.__version__)
print("Torchaudio:", torchaudio.__version__)
print("GPU:", torch.cuda.get_device_name(0))

Torch: 2.9.0+cu130 13.0
Torchvision: 0.24.0+cu130
Torchaudio: 2.9.0+cu130
GPU: NVIDIA GeForce RTX 5090


In [29]:
train(model,dataloader,2,lr=0.001)

No checkpoint found, training from scratch
Using device: cuda
Epoch [1/2], Loss: 0.9845, LR: 0.001
Epoch [2/2], Loss: 0.9870, LR: 0.001
Checkpoint saved at epoch 2
Training complete! Final model saved to unet_model_final.pth
