<a href="https://colab.research.google.com/github/ridash2005/Pancreati-Tumor-Segmentation/blob/main/U_NET.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torch torchvision matplotlib pydicom
import zipfile, pydicom, io, numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split




In [2]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

import os, zipfile
# Copy to Colab local for speed
if not os.path.exists('/content/DCM.zip'):
    !cp "/content/drive/MyDrive/PIS/DCM.zip" "/content/DCM.zip"

zip_path = '/content/DCM.zip'
z = zipfile.ZipFile(zip_path)

# Find image slice paths
img_slice_paths = [f for f in z.namelist()
                   if f.startswith('DCM/Task1/ImagesTr/') and f.lower().endswith('.dcm')]

# Extract all folder IDs (patients/volumes)
folders = sorted(set(f.split('/')[3] for f in img_slice_paths))
print(f'Folders found: {folders[:10]}... Total: {len(folders)}')

# Pair up images and masks by basename within each folder
img_files, lbl_files = [], []
for folder in folders:
    imgs = [f for f in img_slice_paths if f.startswith(f'DCM/Task1/ImagesTr/{folder}/IMG')]
    lbls = [f for f in z.namelist() if f.startswith(f'DCM/Task1/LabelsTr/{folder}/IMG')]
    img_dict = {os.path.basename(f): f for f in imgs}
    lbl_dict = {os.path.basename(f): f for f in lbls}
    for fname in img_dict:
        if fname in lbl_dict:
            img_files.append(img_dict[fname])
            lbl_files.append(lbl_dict[fname])

print(f'Total paired slices: {len(img_files)}')
print('Sample image:', img_files[:3])
print('Sample mask :', lbl_files[:3])



Mounted at /content/drive
Folders found: ['10000', '10001', '10002', '10006', '10007', '10011', '10012', '10014', '10015', '10016']... Total: 92
Total paired slices: 6816
Sample image: ['DCM/Task1/ImagesTr/10000/IMG0001.dcm', 'DCM/Task1/ImagesTr/10000/IMG0002.dcm', 'DCM/Task1/ImagesTr/10000/IMG0003.dcm']
Sample mask : ['DCM/Task1/LabelsTr/10000/IMG0001.dcm', 'DCM/Task1/LabelsTr/10000/IMG0002.dcm', 'DCM/Task1/LabelsTr/10000/IMG0003.dcm']


In [3]:
from sklearn.model_selection import train_test_split

train_imgs, val_imgs, train_lbls, val_lbls = train_test_split(
    img_files, lbl_files, test_size=0.2, random_state=42
)

print(f'Train: {len(train_imgs)}, Val: {len(val_imgs)}')

Train: 5452, Val: 1364


In [4]:
import numpy as np, pydicom, io
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import zipfile

class ZipDicomSliceDataset(Dataset):
    def __init__(self, zip_path, img_files, lbl_files, target_size=(256,256)):
        self.zip_path = zip_path
        self.img_files = img_files
        self.lbl_files = lbl_files
        self.target_size = target_size
        self.to_tensor = transforms.ToTensor()
        self.resize = transforms.Resize(self.target_size, interpolation=Image.NEAREST)
        self.z = zipfile.ZipFile(zip_path)  # open once here

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        # Use the already opened zipfile here
        img_bytes = self.z.read(self.img_files[idx])
        mask_bytes = self.z.read(self.lbl_files[idx])

        img_dcm = pydicom.dcmread(io.BytesIO(img_bytes))
        mask_dcm = pydicom.dcmread(io.BytesIO(mask_bytes))

        image = img_dcm.pixel_array.astype(np.float32)
        image = (image - np.min(image)) / (np.max(image) - np.min(image) + 1e-7)

        mask_arr = mask_dcm.pixel_array.astype(np.float32)
        mask_bin = ((mask_arr == 1.0) | (mask_arr == 2.0)).astype(np.float32)

        image = self.resize(Image.fromarray((image * 255).astype(np.uint8)))
        mask = self.resize(Image.fromarray((mask_bin * 255).astype(np.uint8)))

        image = self.to_tensor(image)
        mask = self.to_tensor(mask)
        mask = (mask > 0.5).float()

        # Diagnostic print for the first few samples
        if idx < 3:
            print(f"Sample {idx}: image shape {image.shape}, mask unique {mask.unique().tolist()}")

        return image, mask


In [5]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    ZipDicomSliceDataset(zip_path, train_imgs, train_lbls),
    batch_size=16, shuffle=True, num_workers=0
)
val_loader = DataLoader(
    ZipDicomSliceDataset(zip_path, val_imgs, val_lbls),
    batch_size=16, shuffle=False, num_workers=0
)

# Check first batch shape
imgs, masks = next(iter(train_loader))
print(f"Batch shape: {imgs.shape}, {masks.shape}")


Batch shape: torch.Size([16, 1, 256, 256]), torch.Size([16, 1, 256, 256])


In [6]:
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, features=[64,128,256,512]):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, 2, stride=2))
            self.ups.append(DoubleConv(feature*2, feature))
        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, 1)
    def forward(self, x):
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = F.max_pool2d(x, 2)
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_conn = skip_connections[idx//2]
            if x.shape != skip_conn.shape:
                x = F.interpolate(x, size=skip_conn.shape[2:])
            x = torch.cat((skip_conn, x), dim=1)
            x = self.ups[idx+1](x)
        return self.final_conv(x)


In [7]:
import torch, time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
epochs = 1

train_losses, val_losses = [], []

for epoch in range(epochs):
    start_epoch = time.time()
    model.train()
    running_loss = 0.0
    for bi, (imgs, masks) in enumerate(train_loader):
        batch_start = time.time()
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
        if bi % 2 == 0:
            print(f'Epoch {epoch+1} - Batch {bi}: batch time {(time.time()-batch_start):.2f}s, batch loss={loss.item():.4f}')

    train_loss = running_loss / len(train_loader.dataset)
    train_losses.append(train_loss)

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for vbi, (imgs, masks) in enumerate(val_loader):
            imgs, masks = imgs.to(device), masks.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, masks)
            val_loss += loss.item() * imgs.size(0)
            if vbi==0:
                print(f"Validation batch shape: {imgs.shape}")
    val_loss = val_loss / len(val_loader.dataset)
    val_losses.append(val_loss)
    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f} Val Loss={val_loss:.4f}")
    print(f"Epoch {epoch+1} time: {(time.time()-start_epoch)/60:.2f} min")


Epoch 1 - Batch 0: batch time 109.81s, batch loss=0.7274
Epoch 1 - Batch 2: batch time 100.41s, batch loss=0.5785
Epoch 1 - Batch 4: batch time 100.99s, batch loss=0.4577
Epoch 1 - Batch 6: batch time 100.36s, batch loss=0.4206
Epoch 1 - Batch 8: batch time 100.27s, batch loss=0.3762
Epoch 1 - Batch 10: batch time 100.10s, batch loss=0.3382
Epoch 1 - Batch 12: batch time 99.46s, batch loss=0.3195
Epoch 1 - Batch 14: batch time 106.75s, batch loss=0.2947
Epoch 1 - Batch 16: batch time 98.97s, batch loss=0.2790
Epoch 1 - Batch 18: batch time 99.41s, batch loss=0.2556
Epoch 1 - Batch 20: batch time 106.38s, batch loss=0.2486
Sample 1: image shape torch.Size([1, 256, 256]), mask unique [0.0]
Epoch 1 - Batch 22: batch time 101.32s, batch loss=0.2325
Epoch 1 - Batch 24: batch time 100.23s, batch loss=0.2227
Epoch 1 - Batch 26: batch time 107.67s, batch loss=0.2203
Epoch 1 - Batch 28: batch time 101.20s, batch loss=0.2133
Epoch 1 - Batch 30: batch time 101.99s, batch loss=0.2067
Epoch 1 - Bat

KeyboardInterrupt: 