### Load DeepLabV3 Pretrained Model

In [None]:
import torch
import torchvision
from torchvision.models.segmentation import deeplabv3_resnet18

# Load pretrained model
model = deeplabv3_resnet18(pretrained=True)

# Freeze all layers (feature extraction)
for param in model.parameters():
    param.requires_grad = False

# Replace the classifier head (to allow training just on the final layer)
model.classifier[4] = torch.nn.Conv2d(
    in_channels=256,
    out_channels=1,  # 1 class (foreground: cell vs. background)
    kernel_size=1
)

# Make the new layer trainable
for param in model.classifier[4].parameters():
    param.requires_grad = True

# Move model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

### Load Stained Images and Prepare DataLoader


In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

# Define Dataset class
class StainedCellDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.image_files = os.listdir(root_dir)
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image

# Define simple transforms
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# Create Dataset and DataLoader
dataset = StainedCellDataset(root_dir='/content/drive/MyDrive/SegmentationProject/data/stained_images', transform=transform)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)


### Train the Model on the Stained Images


In [None]:
import torch.optim as optim
import torch.nn as nn

# Define loss and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.classifier[4].parameters(), lr=1e-3)

# Training loop
num_epochs = 25

model.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    for images in dataloader:
        images = images.to(device)

        # Ground truth is just the stained images themselves
        labels = (images.mean(dim=1, keepdim=True) > 0.5).float()  # Create dummy masks

        optimizer.zero_grad()
        outputs = model(images)['out']
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}")


### Predict and Visualize Segmentation Masks


In [None]:
import matplotlib.pyplot as plt

model.eval()
with torch.no_grad():
    for images in dataloader:
        images = images.to(device)
        outputs = model(images)['out']
        preds = torch.sigmoid(outputs)
        preds = (preds > 0.5).float()

        for i in range(images.size(0)):
            fig, axs = plt.subplots(1, 2, figsize=(10, 5))
            axs[0].imshow(images[i].cpu().permute(1, 2, 0))
            axs[0].set_title('Original Image')
            axs[1].imshow(preds[i][0].cpu(), cmap='gray')
            axs[1].set_title('Predicted Mask')
            plt.show()
        break  # Show only one batch


### Save Predicted Segmentation Masks to Drive


In [None]:
import os
import torchvision.transforms.functional as TF

# Create directory to save masks if not exists
save_dir = '/content/drive/MyDrive/SegmentationProject/data/predicted_masks'
os.makedirs(save_dir, exist_ok=True)

model.eval()
with torch.no_grad():
    for idx, images in enumerate(dataloader):
        images = images.to(device)
        outputs = model(images)['out']
        preds = torch.sigmoid(outputs)
        preds = (preds > 0.5).float()

        for i in range(images.size(0)):
            mask = preds[i][0].cpu()
            filename = f"mask_{idx}_{i}.png"
            TF.to_pil_image(mask).save(os.path.join(save_dir, filename))


### Create Ground Truth by Counting Cells in Predicted Masks


In [None]:
import cv2
import numpy as np

# Path to saved masks
mask_dir = '/content/drive/MyDrive/SegmentationProject/data/predicted_masks'

# List all mask files
mask_files = sorted(os.listdir(mask_dir))

# List to save cell counts
cell_counts = []

for mask_file in mask_files:
    mask_path = os.path.join(mask_dir, mask_file)
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

    # Threshold to make sure mask is binary
    _, binary_mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)

    # Find connected components (each component = one cell)
    num_labels, labels = cv2.connectedComponents(binary_mask)

    # Subtract 1 because background is also counted
    cell_count = num_labels - 1
    cell_counts.append((mask_file, cell_count))

# Print results
for filename, count in cell_counts:
    print(f"{filename}: {count} cells")


#Prediction Part


### Part 1–2: Load Original Images and Match with GT Cell Counts


In [None]:
/content/drive/MyDrive/SegmentationProject/data/original_images


In [None]:
from PIL import Image
import torchvision.transforms as transforms

# Path to original images
original_dir = '/content/drive/MyDrive/SegmentationProject/data/original_images'
original_files = sorted(os.listdir(original_dir))

# GT from previous step: dict of {basename: count}
gt_dict = {name.replace('mask', 'orig').replace('.png', '.png'): count for name, count in cell_counts}

# Define transform (resize + tensor)
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# Build dataset as list of (image, target)
image_targets = []
for filename in original_files:
    path = os.path.join(original_dir, filename)
    image = Image.open(path).convert("RGB")
    image = transform(image)

    count = gt_dict.get(filename, 0)
    image_targets.append((image, torch.tensor([count], dtype=torch.float)))


### Part 3: Build Dataset and DataLoader for Regression


In [None]:
from torch.utils.data import Dataset, DataLoader

# Define Dataset class for regression
class CellCountDataset(Dataset):
    def __init__(self, data):
        self.data = data  # list of (image, target)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image, target = self.data[idx]
        return image, target

# Create Dataset and DataLoader
regression_dataset = CellCountDataset(image_targets)
regression_dataloader = DataLoader(regression_dataset, batch_size=2, shuffle=True)


### Part 4: Build a Simple CNN Model for Cell Count Regression


In [None]:
import torch.nn as nn

# Define a simple CNN for regression
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1,1))
        )
        self.regressor = nn.Linear(64, 1)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.regressor(x)
        return x

# Instantiate model
regression_model = SimpleCNN().to(device)


### Part 5: Define Loss Function and Optimizer for Regression


In [None]:
import torch.optim as optim

# Define loss function (Mean Squared Error)
regression_criterion = nn.MSELoss()

# Define optimizer (Adam)
regression_optimizer = optim.Adam(regression_model.parameters(), lr=1e-3)


### Part 6: Train the CNN Model to Predict Cell Counts


In [None]:
num_epochs = 30
regression_model.train()

for epoch in range(num_epochs):
    running_loss = 0.0
    for images, targets in regression_dataloader:
        images = images.to(device)
        targets = targets.to(device)

        # Forward pass
        outputs = regression_model(images)
        loss = regression_criterion(outputs, targets)

        # Backward and optimize
        regression_optimizer.zero_grad()
        loss.backward()
        regression_optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(regression_dataloader):.4f}")


### Part 7: Evaluate Model and Plot Predictions vs Ground Truth



In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score

regression_model.eval()
all_preds = []
all_targets = []

with torch.no_grad():
    for images, targets in regression_dataloader:
        images = images.to(device)
        targets = targets.to(device)

        outputs = regression_model(images)
        all_preds.extend(outputs.cpu().squeeze().tolist())
        all_targets.extend(targets.cpu().squeeze().tolist())

# Compute R² score
r2 = r2_score(all_targets, all_preds)
print(f"R² Score: {r2:.4f}")

# Plot predictions vs ground truth
plt.figure(figsize=(6,6))
plt.scatter(all_targets, all_preds, color='blue')
plt.plot([min(all_targets), max(all_targets)], [min(all_targets), max(all_targets)], 'r--')
plt.xlabel("Ground Truth")
plt.ylabel("Predicted")
plt.title("Predicted vs Ground Truth Cell Counts")
plt.grid(True)
plt.show()
