In [None]:
mount = '/content/gdrive'
from google.colab import drive
drive.mount(mount)

Mounted at /content/gdrive


In [None]:
import cv2
import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset

def generate_heatmap(size, center, sigma=5):
    x = np.arange(0, size, 1)
    y = np.arange(0, size, 1)
    xx, yy = np.meshgrid(x, y)
    heatmap = np.exp(
        -((xx - center[0])**2 + (yy - center[1])**2) / (2 * sigma**2)
    )
    return heatmap


class FetalLandmarkDataset(Dataset):
    def __init__(self, img_dir, csv_path, img_size=256, sigma=5):
        self.df = pd.read_csv(csv_path)
        self.img_dir = img_dir
        self.img_size = img_size
        self.sigma = sigma

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        img_path = f"{self.img_dir}/{row['image_name']}"
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        h, w = img.shape

        img = cv2.resize(img, (self.img_size, self.img_size))
        img = img / 255.0

        landmarks = [
            (row['ofd_1_x'], row['ofd_1_y']),
            (row['ofd_2_x'], row['ofd_2_y']),
            (row['bpd_1_x'], row['bpd_1_y']),
            (row['bpd_2_x'], row['bpd_2_y']),
        ]

        heatmaps = []
        for x, y in landmarks:
            x = int(x * self.img_size / w)
            y = int(y * self.img_size / h)
            heatmaps.append(
                generate_heatmap(self.img_size, (x, y), self.sigma)
            )

        heatmaps = np.stack(heatmaps)

        return (
            torch.tensor(img).unsqueeze(0).float(),
            torch.tensor(heatmaps).float()
        )


In [None]:
# Sanity check (optional)
dataset = FetalLandmarkDataset(
    img_dir="/content/gdrive/MyDrive/images",
    csv_path="/content/gdrive/MyDrive/role_challenge_dataset_ground_truth.csv"
)

img, heatmaps = dataset[0]

print(img.shape)        # should be [1, 256, 256]
print(heatmaps.shape)   # should be [4, 256, 256]


In [None]:
import torch.nn as nn

class SimpleUNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 2, stride=2),
            nn.ReLU(),

            nn.ConvTranspose2d(32, 16, 2, stride=2),
            nn.ReLU(),
        )

        self.out = nn.Conv2d(16, 4, 1)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return self.out(x)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SimpleUNet().to(device)
model.load_state_dict(
    torch.load("hypothesis_final_full_saved_model.pth", map_location=device)
)
model.eval()


In [None]:
img, gt_heatmaps = dataset[0]
img = img.unsqueeze(0).to(device)

with torch.no_grad():
    pred_heatmaps = model(img)

print("GT heatmaps:", gt_heatmaps.shape)
print("Pred heatmaps:", pred_heatmaps.shape)
