In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from pycocotools.coco import COCO
from PIL import Image
import matplotlib.pyplot as plt

# ---------------------------
# 1. Dataset Class
# ---------------------------
class COCOSegmentationDataset(Dataset):
    def __init__(self, root, annFile, transforms=None):
        self.root = root
        self.coco = COCO(annFile)
        self.ids = list(self.coco.imgs.keys())
        self.transforms = transforms

    def __getitem__(self, index):
        img_id = self.ids[index]
        img_info = self.coco.loadImgs(img_id)[0]
        path = img_info['file_name']

        # Load image
        img = Image.open(os.path.join(self.root, path)).convert("RGB")

        # Load mask
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns = self.coco.loadAnns(ann_ids)
        mask = np.zeros((img_info['height'], img_info['width']), dtype=np.uint8)
        for ann in anns:
            mask = np.maximum(mask, self.coco.annToMask(ann))

        mask = Image.fromarray(mask)

        if self.transforms is not None:
            img = self.transforms(img)
            mask = self.transforms(mask)

        return img, mask

    def __len__(self):
        return len(self.ids)

# ---------------------------
# 2. Define U-Net
# ---------------------------
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNet, self).__init__()

        def CBR(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True)
            )

        self.enc1 = nn.Sequential(CBR(in_channels, 64), CBR(64, 64))
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = nn.Sequential(CBR(64, 128), CBR(128, 128))
        self.pool2 = nn.MaxPool2d(2)

        self.enc3 = nn.Sequential(CBR(128, 256), CBR(256, 256))
        self.pool3 = nn.MaxPool2d(2)

        self.enc4 = nn.Sequential(CBR(256, 512), CBR(512, 512))
        self.pool4 = nn.MaxPool2d(2)

        self.center = nn.Sequential(CBR(512, 1024), CBR(1024, 1024))

        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = nn.Sequential(CBR(1024, 512), CBR(512, 512))

        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = nn.Sequential(CBR(512, 256), CBR(256, 256))

        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = nn.Sequential(CBR(256, 128), CBR(128, 128))

        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = nn.Sequential(CBR(128, 64), CBR(64, 64))

        self.final = nn.Conv2d(64, out_channels, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        e4 = self.enc4(self.pool3(e3))
        center = self.center(self.pool4(e4))

        d4 = self.dec4(torch.cat([self.up4(center), e4], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))

        return self.final(d1)

# ---------------------------
# 3. Training Setup
# ---------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor()
])

dataDir = r"C:/Users/rahul/Downloads/coco"
train_dataset = COCOSegmentationDataset(
    root=f"{dataDir}/val2017",
    annFile=f"{dataDir}/annotations/instances_val2017.json",
    transforms=transform
)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

model = UNet().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# ---------------------------
# 4. Training Loop
# ---------------------------
epochs = 5  # adjust as needed
for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    for imgs, masks in train_loader:
        imgs, masks = imgs.to(device), masks.to(device)

        # Fix shape mismatch
        if masks.ndim == 3:   # [B,H,W]
            masks = masks.unsqueeze(1)  # [B,1,H,W]
        elif masks.ndim == 4 and masks.shape[1] != 1:
            masks = masks[:,0:1,:,:]    # take first channel

        outputs = model(imgs)           # [B,1,H,W]
        loss = criterion(outputs, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/len(train_loader):.4f}")


# Save trained model
torch.save(model.state_dict(), "unet_coco_val2017.pth")

# ---------------------------
# 5. Inference (Test on 1 Image)
# ---------------------------
model.eval()
test_img, test_mask = train_dataset[0]
test_inp = test_img.unsqueeze(0).to(device)

with torch.no_grad():
    pred_mask = torch.sigmoid(model(test_inp))[0][0].cpu().numpy()
    pred_mask = (pred_mask > 0.5).astype(np.uint8) * 255

# Show result
plt.figure(figsize=(12,4))
plt.subplot(1,3,1); plt.imshow(test_img.permute(1,2,0)); plt.title("Original"); plt.axis("off")
plt.subplot(1,3,2); plt.imshow(test_mask.squeeze(), cmap="gray"); plt.title("Ground Truth Mask"); plt.axis("off")
plt.subplot(1,3,3); plt.imshow(pred_mask, cmap="gray"); plt.title("Predicted Mask"); plt.axis("off")
plt.show()


OSError: [WinError 1455] The paging file is too small for this operation to complete. Error loading "C:\Users\rahul\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\lib\shm.dll" or one of its dependencies.