In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output
import random
import pickle
from sklearn.model_selection import train_test_split


In [None]:
# ============================
# 1️⃣ GPU / CPU Detection
# ============================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print(f"✅ GPU detected: {torch.cuda.get_device_name(0)}")
else:
    print("⚠️ No GPU detected, training will be on CPU.")


In [None]:
# ============================
# 2️⃣ Live training plot
# ============================
train_loss_hist = []
val_loss_hist = []

def live_plot():
    clear_output(wait=True)
    plt.figure(figsize=(8, 5))
    plt.plot(train_loss_hist, label="Train Loss")
    plt.plot(val_loss_hist, label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Live Training Progress")
    plt.legend()
    plt.grid(True)
    plt.show()


In [None]:
# ============================
# 3️⃣ Helper Functions
# ============================
def project_01(im: np.ndarray) -> np.ndarray:
    im = np.squeeze(im)
    min_val = im.min()
    max_val = im.max()
    return (im - min_val) / (max_val - min_val)

def normalize_im(im: np.ndarray, dmean: float, dstd: float) -> np.ndarray:
    im = np.squeeze(im)
    return (im - dmean) / dstd

def matlab_style_gauss2D(shape=(7, 7), sigma=1) -> np.ndarray:
    m, n = [(ss - 1.) / 2. for ss in shape]
    y, x = np.ogrid[-m:m+1, -n:n+1]
    h = np.exp(-(x*x + y*y) / (2. * sigma * sigma))
    h[h < np.finfo(h.dtype).eps * h.max()] = 0
    sumh = h.sum()
    if sumh != 0:
        h /= sumh
    h *= 2.0
    return h.astype(np.float32)

# Gaussian PSF filter
psf_heatmap = matlab_style_gauss2D(shape=(7, 7), sigma=1)
gfilter = torch.tensor(psf_heatmap, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)


In [None]:

# ============================
# 4️⃣ Custom Loss (L1 + L2)
# ============================
class L1L2Loss(nn.Module):
    def __init__(self, gfilter):
        super().__init__()
        self.gfilter = gfilter

    def forward(self, heatmap_true, spikes_pred):
        heatmap_pred = F.conv2d(spikes_pred, self.gfilter, padding=3)
        loss_heatmaps = F.mse_loss(heatmap_true, heatmap_pred)
        loss_spikes = torch.mean(torch.abs(spikes_pred))
        return loss_heatmaps + loss_spikes


In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.75, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.bce = nn.BCEWithLogitsLoss(reduction="none")

    def forward(self, inputs, targets):
        bce_loss = self.bce(inputs, targets)
        probs = torch.sigmoid(inputs)
        pt = torch.where(targets == 1, probs, 1 - probs)
        focal = self.alpha * (1 - pt) ** self.gamma * bce_loss
        return focal.mean()



In [None]:
# ============================
# 5️⃣ CNN Architecture
# ============================
class ConvBNReLU(nn.Module):
    def __init__(self, in_ch, out_ch, k=3):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=k, padding=1, bias=False)
        nn.init.orthogonal_(self.conv.weight)
        self.bn = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

In [None]:
class CNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc1 = ConvBNReLU(1, 32)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = ConvBNReLU(32, 64)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = ConvBNReLU(64, 128)
        self.pool3 = nn.MaxPool2d(2)
        self.enc4 = ConvBNReLU(128, 512)

        self.up1 = nn.Upsample(scale_factor=2, mode="nearest")
        self.dec1 = ConvBNReLU(512, 128)
        self.up2 = nn.Upsample(scale_factor=2, mode="nearest")
        self.dec2 = ConvBNReLU(128, 64)
        self.up3 = nn.Upsample(scale_factor=2, mode="nearest")
        self.dec3 = ConvBNReLU(64, 32)

        self.pred = nn.Conv2d(32, 1, kernel_size=1, bias=False)
        nn.init.orthogonal_(self.pred.weight)

    def forward(self, x):
        x1 = self.enc1(x)
        x2 = self.enc2(self.pool1(x1))
        x3 = self.enc3(self.pool2(x2))
        x4 = self.enc4(self.pool3(x3))

        x = self.up1(x4)
        x = self.dec1(x)
        x = self.up2(x)
        x = self.dec2(x)
        x = self.up3(x)
        x = self.dec3(x)
        return self.pred(x)

In [None]:

# ============================
# 6️⃣ Load & preprocess data
# ============================
with open("from scratch_training_data.pkl", "rb") as f:
    data = pickle.load(f)

patches = data["patches"]
heatmaps = data["heatmaps"]

print("✅ Loaded data shapes:")
print("patches:", patches.shape)
print("heatmaps:", heatmaps.shape)

# Normalization
mean_val = np.mean(patches)
std_val = np.std(patches)
patches = (patches - mean_val) / std_val

# Add channel axis
patches = patches[..., np.newaxis]
heatmaps = heatmaps[..., np.newaxis]

# Train/val split
X_train, X_val, y_train, y_val = train_test_split(
    patches, heatmaps, test_size=0.1, random_state=42
)

# Convert to torch tensors
X_train = torch.tensor(X_train, dtype=torch.float32).permute(0, 3, 1, 2).to(device)
X_val = torch.tensor(X_val, dtype=torch.float32).permute(0, 3, 1, 2).to(device)
y_train = torch.tensor(y_train, dtype=torch.float32).permute(0, 3, 1, 2).to(device)
y_val = torch.tensor(y_val, dtype=torch.float32).permute(0, 3, 1, 2).to(device)

print(f"✅ Train: {X_train.shape}, Val: {X_val.shape}")


In [None]:
# ============================
# Train to overfit on mini-batch
# ============================

from sklearn.model_selection import train_test_split

# Choose a small random subset (100 samples)
micro_size = 100
indices = np.random.choice(len(patches), size=micro_size, replace=False)

patches_micro = patches[indices]
heatmaps_micro = heatmaps[indices]

print("🎯 Microsample set created:")
print("patches_micro:", patches_micro.shape)
print("heatmaps_micro:", heatmaps_micro.shape)

# Train/val split (e.g., 80/20)
X_train_micro, X_val_micro, y_train_micro, y_val_micro = train_test_split(
    patches_micro, heatmaps_micro, test_size=0.2, random_state=42
)

# Convert to torch tensors
X_train_micro = torch.tensor(X_train_micro, dtype=torch.float32).permute(0, 3, 1, 2).to(device)
X_val_micro   = torch.tensor(X_val_micro, dtype=torch.float32).permute(0, 3, 1, 2).to(device)
y_train_micro = torch.tensor(y_train_micro, dtype=torch.float32).permute(0, 3, 1, 2).to(device)
y_val_micro   = torch.tensor(y_val_micro, dtype=torch.float32).permute(0, 3, 1, 2).to(device)

print(f"✅ Micro Train: {X_train_micro.shape}, Micro Val: {X_val_micro.shape}")


In [None]:
# ============================
# Sanity chgeck: Visualize patches and heatmaps
# ============================

import matplotlib.pyplot as plt
import numpy as np

# Pick 10 random indices
idxs = np.random.choice(len(X_train), size=10, replace=False)

plt.figure(figsize=(12, 30))

for i, idx in enumerate(idxs):
    patch = X_train[idx].cpu().numpy().squeeze()
    heatmap = y_train[idx].cpu().numpy().squeeze()

    # Normalize for display
    patch_disp = (patch - patch.min()) / (patch.max() - patch.min() + 1e-8)
    heatmap_disp = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)

    # Overlay (red = GT)
    overlay = np.stack([
        heatmap_disp,                # Red channel
        patch_disp * 0.7,            # Green channel
        patch_disp * 0.7,            # Blue channel
    ], axis=-1)

    # Plot patch
    plt.subplot(10, 3, 3*i + 1)
    plt.imshow(patch_disp, cmap="gray")
    plt.title("Patch")
    plt.axis("off")

    # Plot GT heatmap
    plt.subplot(10, 3, 3*i + 2)
    plt.imshow(heatmap_disp, cmap="hot")
    plt.title("GT Heatmap")
    plt.axis("off")

    # Plot overlay
    plt.subplot(10, 3, 3*i + 3)
    plt.imshow(overlay)
    plt.title("Overlay")
    plt.axis("off")

plt.tight_layout()
plt.show()


ground truth matches up with corresponding patches

Original Structure, L1+L2 Loss

In [None]:
# sample_idx = random.randint(0, X_val.shape[0]-1)

In [None]:
sample_idx=20

In [None]:
# ============================
# 7️⃣ Training setup
# ============================
model_L1L2_old = CNNModel().to(device)
criterion = L1L2Loss(gfilter)

optimizer = optim.Adam(model_L1L2_old.parameters(), lr=0.001)

In [None]:
# ============================
# Training loop (Microsample)
# ============================
from tqdm import tqdm

epochs = 20
batch_size = 8

for epoch in range(epochs):
    model_L1L2_old.train()
    train_loss = 0.0
    
    # Training loop with tqdm
    train_iter = tqdm(range(0, len(X_train_micro), batch_size), desc=f"Epoch {epoch+1}/{epochs} [Train]", leave=False)
    for i in train_iter:
        xb = X_train_micro[i:i+batch_size]
        yb = y_train_micro[i:i+batch_size]

        optimizer.zero_grad()
        outputs = model_L1L2_old(xb)
        loss = criterion(yb, outputs)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        avg_so_far = train_loss / ((i // batch_size) + 1)
        train_iter.set_postfix(loss=avg_so_far)

    # Validation loop with tqdm
    model_L1L2_old.eval()
    val_loss = 0.0
    val_iter = tqdm(range(0, len(X_val_micro), batch_size), desc=f"Epoch {epoch+1}/{epochs} [Val]", leave=False)
    with torch.no_grad():
        for i in val_iter:
            xb = X_val_micro[i:i+batch_size]
            yb = y_val_micro[i:i+batch_size]
            outputs = model_L1L2_old(xb)
            loss = criterion(yb, outputs)
            val_loss += loss.item()
            avg_so_far = val_loss / ((i // batch_size) + 1)
            val_iter.set_postfix(loss=avg_so_far)

    avg_train_loss = train_loss / (len(X_train_micro) / batch_size)
    avg_val_loss = val_loss / (len(X_val_micro) / batch_size)

    train_loss_hist.append(avg_train_loss)
    val_loss_hist.append(avg_val_loss)
    live_plot()

    print(f"Epoch [{epoch+1}/{epochs}] "
          f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")


In [None]:

# ============================
# 9️⃣ Prediction example
# ============================
# i = random.randint(0, X_val.shape[0]-1)
input_patch = X_val[sample_idx].cpu().squeeze().numpy()
true_heatmap = y_val[sample_idx].cpu().squeeze().numpy()

model_L1L2_old.eval()
with torch.no_grad():
    predicted = model_L1L2_old(X_val[sample_idx].unsqueeze(0)).cpu().squeeze().numpy()

plt.figure(figsize=(12, 4))
# plt.subplot(1, 3, 1); plt.imshow(input_patch, cmap='gray'); plt.title("Input Patch")
# plt.subplot(1, 3, 2); plt.imshow(true_heatmap, cmap='hot'); plt.title("True Heatmap")
plt.subplot(1, 3, 3); plt.imshow(predicted, cmap='hot'); plt.title("Predicted Heatmap")
plt.show()


terrible

Original Structure, MSE Loss

In [None]:
# ============================
# 7️⃣ Training setup
# ============================
model_MSE_old = CNNModel().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model_MSE_old.parameters(), lr=0.001)

In [None]:
# ============================
# Training loop (Microsample)
# ============================
from tqdm import tqdm

epochs = 20
batch_size = 8

for epoch in range(epochs):
    model_MSE_old.train()
    train_loss = 0.0
    
    # Training loop with tqdm
    train_iter = tqdm(range(0, len(X_train_micro), batch_size), desc=f"Epoch {epoch+1}/{epochs} [Train]", leave=False)
    for i in train_iter:
        xb = X_train_micro[i:i+batch_size]
        yb = y_train_micro[i:i+batch_size]

        optimizer.zero_grad()
        outputs = model_MSE_old(xb)
        loss = criterion(yb, outputs)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        avg_so_far = train_loss / ((i // batch_size) + 1)
        train_iter.set_postfix(loss=avg_so_far)

    # Validation loop with tqdm
    model_MSE_old.eval()
    val_loss = 0.0
    val_iter = tqdm(range(0, len(X_val_micro), batch_size), desc=f"Epoch {epoch+1}/{epochs} [Val]", leave=False)
    with torch.no_grad():
        for i in val_iter:
            xb = X_val_micro[i:i+batch_size]
            yb = y_val_micro[i:i+batch_size]
            outputs = model_MSE_old(xb)
            loss = criterion(yb, outputs)
            val_loss += loss.item()
            avg_so_far = val_loss / ((i // batch_size) + 1)
            val_iter.set_postfix(loss=avg_so_far)

    avg_train_loss = train_loss / (len(X_train_micro) / batch_size)
    avg_val_loss = val_loss / (len(X_val_micro) / batch_size)

    train_loss_hist.append(avg_train_loss)
    val_loss_hist.append(avg_val_loss)
    live_plot()

    print(f"Epoch [{epoch+1}/{epochs}] "
          f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")


In [None]:

# ============================
# 9️⃣ Prediction example
# ============================
# i = random.randint(0, X_val.shape[0]-1)
input_patch = X_val[sample_idx].cpu().squeeze().numpy()
true_heatmap = y_val[sample_idx].cpu().squeeze().numpy()

model_MSE_old.eval()
with torch.no_grad():
    predicted = model_MSE_old(X_val[sample_idx].unsqueeze(0)).cpu().squeeze().numpy()

plt.figure(figsize=(12, 4))
# plt.subplot(1, 3, 1); plt.imshow(input_patch, cmap='gray'); plt.title("Input Patch")
# plt.subplot(1, 3, 2); plt.imshow(true_heatmap, cmap='hot'); plt.title("True Heatmap")
plt.subplot(1, 3, 3); plt.imshow(predicted, cmap='hot'); plt.title("Predicted Heatmap")
plt.show()


terrible

Original Structure, BCE Loss

In [None]:
# ============================
# 7️⃣ Training setup
# ============================
model_BCE_old = CNNModel().to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([20.0], device=device))
optimizer = optim.Adam(model_BCE_old.parameters(), lr=0.001)

In [None]:
# ============================
# Training loop (Microsample)
# ============================
from tqdm import tqdm

epochs = 20
batch_size = 8

for epoch in range(epochs):
    model_BCE_old.train()
    train_loss = 0.0
    
    # Training loop with tqdm
    train_iter = tqdm(range(0, len(X_train_micro), batch_size), desc=f"Epoch {epoch+1}/{epochs} [Train]", leave=False)
    for i in train_iter:
        xb = X_train_micro[i:i+batch_size]
        yb = y_train_micro[i:i+batch_size]

        optimizer.zero_grad()
        outputs = model_BCE_old(xb)
        loss = criterion(outputs, yb)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        avg_so_far = train_loss / ((i // batch_size) + 1)
        train_iter.set_postfix(loss=avg_so_far)

    # Validation loop with tqdm
    model_BCE_old.eval()
    val_loss = 0.0
    val_iter = tqdm(range(0, len(X_val_micro), batch_size), desc=f"Epoch {epoch+1}/{epochs} [Val]", leave=False)
    with torch.no_grad():
        for i in val_iter:
            xb = X_val_micro[i:i+batch_size]
            yb = y_val_micro[i:i+batch_size]
            outputs = model_BCE_old(xb)
            loss = criterion(outputs, yb)

            val_loss += loss.item()
            avg_so_far = val_loss / ((i // batch_size) + 1)
            val_iter.set_postfix(loss=avg_so_far)

    avg_train_loss = train_loss / (len(X_train_micro) / batch_size)
    avg_val_loss = val_loss / (len(X_val_micro) / batch_size)

    train_loss_hist.append(avg_train_loss)
    val_loss_hist.append(avg_val_loss)
    live_plot()

    print(f"Epoch [{epoch+1}/{epochs}] "
          f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")


In [None]:

# ============================
# 9️⃣ Prediction example
# ============================
# i = random.randint(0, X_val.shape[0]-1)
input_patch = X_val[sample_idx].cpu().squeeze().numpy()
true_heatmap = y_val[sample_idx].cpu().squeeze().numpy()

model_BCE_old.eval()
with torch.no_grad():
    predicted = model_BCE_old(X_val[sample_idx].unsqueeze(0)).cpu().squeeze().numpy()

plt.figure(figsize=(12, 4))
# plt.subplot(1, 3, 1); plt.imshow(input_patch, cmap='gray'); plt.title("Input Patch")
# plt.subplot(1, 3, 2); plt.imshow(true_heatmap, cmap='hot'); plt.title("True Heatmap")
plt.subplot(1, 3, 3); plt.imshow(predicted, cmap='hot'); plt.title("Predicted Heatmap")
plt.show()


not the worst thing ive ever seen

Original Structure, Focal Loss

In [None]:

# ============================
# 7️⃣ Training setup
# ============================
model_focal_old = CNNModel().to(device)
criterion = FocalLoss(alpha=0.75, gamma=2.0)

optimizer = optim.Adam(model_focal_old.parameters(), lr=0.001)

In [None]:
# ============================
# Training loop (Microsample)
# ============================
from tqdm import tqdm

epochs = 20
batch_size = 8

for epoch in range(epochs):
    model_focal_old.train()
    train_loss = 0.0
    
    # Training loop with tqdm
    train_iter = tqdm(range(0, len(X_train_micro), batch_size), desc=f"Epoch {epoch+1}/{epochs} [Train]", leave=False)
    for i in train_iter:
        xb = X_train_micro[i:i+batch_size]
        yb = y_train_micro[i:i+batch_size]

        optimizer.zero_grad()
        outputs = model_focal_old(xb)
        loss = criterion(yb, outputs)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        avg_so_far = train_loss / ((i // batch_size) + 1)
        train_iter.set_postfix(loss=avg_so_far)

    # Validation loop with tqdm
    model_focal_old.eval()
    val_loss = 0.0
    val_iter = tqdm(range(0, len(X_val_micro), batch_size), desc=f"Epoch {epoch+1}/{epochs} [Val]", leave=False)
    with torch.no_grad():
        for i in val_iter:
            xb = X_val_micro[i:i+batch_size]
            yb = y_val_micro[i:i+batch_size]
            outputs = model_focal_old(xb)
            loss = criterion(yb, outputs)
            val_loss += loss.item()
            avg_so_far = val_loss / ((i // batch_size) + 1)
            val_iter.set_postfix(loss=avg_so_far)

    avg_train_loss = train_loss / (len(X_train_micro) / batch_size)
    avg_val_loss = val_loss / (len(X_val_micro) / batch_size)

    train_loss_hist.append(avg_train_loss)
    val_loss_hist.append(avg_val_loss)
    live_plot()

    print(f"Epoch [{epoch+1}/{epochs}] "
          f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")


In [None]:

# ============================
# 9️⃣ Prediction example
# ============================
# i = random.randint(0, X_val.shape[0]-1)
input_patch = X_val[sample_idx].cpu().squeeze().numpy()
true_heatmap = y_val[sample_idx].cpu().squeeze().numpy()

model_focal_old.eval()
with torch.no_grad():
    predicted = model_focal_old(X_val[sample_idx].unsqueeze(0)).cpu().squeeze().numpy()

plt.figure(figsize=(12, 4))
# plt.subplot(1, 3, 1); plt.imshow(input_patch, cmap='gray'); plt.title("Input Patch")
# plt.subplot(1, 3, 2); plt.imshow(true_heatmap, cmap='hot'); plt.title("True Heatmap")
plt.subplot(1, 3, 3); plt.imshow(predicted, cmap='hot'); plt.title("Predicted Heatmap")
plt.show()


Original Structure shows promise only with Focal Loss, but does not improve beyond a point, even when overfitting.

Update Model Architecture, add Bias and remove BatchNorm from output layer

In [None]:
# ============================
# 5️⃣ CNN Architecture (updated head)
# ============================
class CNNModel_new(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc1 = ConvBNReLU(1, 32)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = ConvBNReLU(32, 64)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = ConvBNReLU(64, 128)
        self.pool3 = nn.MaxPool2d(2)
        self.enc4 = ConvBNReLU(128, 512)

        self.up1 = nn.Upsample(scale_factor=2, mode="nearest")
        self.dec1 = ConvBNReLU(512, 128)
        self.up2 = nn.Upsample(scale_factor=2, mode="nearest")
        self.dec2 = ConvBNReLU(128, 64)
        self.up3 = nn.Upsample(scale_factor=2, mode="nearest")
        self.dec3 = ConvBNReLU(64, 32)

        # ✅ Prediction head with bias, no BatchNorm
        self.pred = nn.Conv2d(32, 1, kernel_size=1, bias=True)

    def forward(self, x):
        x1 = self.enc1(x)
        x2 = self.enc2(self.pool1(x1))
        x3 = self.enc3(self.pool2(x2))
        x4 = self.enc4(self.pool3(x3))

        x = self.up1(x4)
        x = self.dec1(x)
        x = self.up2(x)
        x = self.dec2(x)
        x = self.up3(x)
        x = self.dec3(x)

        return self.pred(x)   # logits (use with BCEWithLogits or plain regression)


New Structure, L1+L2 Loss

In [None]:

# ============================
# 7️⃣ Training setup
# ============================
model_L1L2 = CNNModel_new().to(device)
criterion = L1L2Loss(gfilter)

optimizer = optim.Adam(model_L1L2.parameters(), lr=0.001)


In [None]:
# ============================
# Training loop (Microsample)
# ============================
from tqdm import tqdm

epochs = 20
batch_size = 8

for epoch in range(epochs):
    model_L1L2.train()
    train_loss = 0.0
    
    # Training loop with tqdm
    train_iter = tqdm(range(0, len(X_train_micro), batch_size), desc=f"Epoch {epoch+1}/{epochs} [Train]", leave=False)
    for i in train_iter:
        xb = X_train_micro[i:i+batch_size]
        yb = y_train_micro[i:i+batch_size]

        optimizer.zero_grad()
        outputs = model_L1L2(xb)
        loss = criterion(yb, outputs)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        avg_so_far = train_loss / ((i // batch_size) + 1)
        train_iter.set_postfix(loss=avg_so_far)

    # Validation loop with tqdm
    model_L1L2.eval()
    val_loss = 0.0
    val_iter = tqdm(range(0, len(X_val_micro), batch_size), desc=f"Epoch {epoch+1}/{epochs} [Val]", leave=False)
    with torch.no_grad():
        for i in val_iter:
            xb = X_val_micro[i:i+batch_size]
            yb = y_val_micro[i:i+batch_size]
            outputs = model_L1L2(xb)
            loss = criterion(yb, outputs)
            val_loss += loss.item()
            avg_so_far = val_loss / ((i // batch_size) + 1)
            val_iter.set_postfix(loss=avg_so_far)

    avg_train_loss = train_loss / (len(X_train_micro) / batch_size)
    avg_val_loss = val_loss / (len(X_val_micro) / batch_size)

    train_loss_hist.append(avg_train_loss)
    val_loss_hist.append(avg_val_loss)
    live_plot()

    print(f"Epoch [{epoch+1}/{epochs}] "
          f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")


In [None]:

# ============================
# 9️⃣ Prediction example
# ============================
# i = random.randint(0, X_val.shape[0]-1)
input_patch = X_val[sample_idx].cpu().squeeze().numpy()
true_heatmap = y_val[sample_idx].cpu().squeeze().numpy()

model_L1L2.eval()
with torch.no_grad():
    predicted = model_L1L2(X_val[sample_idx].unsqueeze(0)).cpu().squeeze().numpy()

plt.figure(figsize=(12, 4))
# plt.subplot(1, 3, 1); plt.imshow(input_patch, cmap='gray'); plt.title("Input Patch")
# plt.subplot(1, 3, 2); plt.imshow(true_heatmap, cmap='hot'); plt.title("True Heatmap")
plt.subplot(1, 3, 3); plt.imshow(predicted, cmap='hot'); plt.title("Predicted Heatmap")
plt.show()


still terrible

New Structure, MSE Loss

In [None]:

# ============================
# 7️⃣ Training setup
# ============================
model_MSE = CNNModel_new().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model_MSE.parameters(), lr=0.001)


In [None]:
# ============================
# Training loop (Microsample)
# ============================
from tqdm import tqdm

epochs = 20
batch_size = 8

for epoch in range(epochs):
    model_MSE.train()
    train_loss = 0.0
    
    # Training loop with tqdm
    train_iter = tqdm(range(0, len(X_train_micro), batch_size), desc=f"Epoch {epoch+1}/{epochs} [Train]", leave=False)
    for i in train_iter:
        xb = X_train_micro[i:i+batch_size]
        yb = y_train_micro[i:i+batch_size]

        optimizer.zero_grad()
        outputs = model_MSE(xb)
        loss = criterion(outputs, yb)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        avg_so_far = train_loss / ((i // batch_size) + 1)
        train_iter.set_postfix(loss=avg_so_far)

    # Validation loop with tqdm
    model_MSE.eval()
    val_loss = 0.0
    val_iter = tqdm(range(0, len(X_val_micro), batch_size), desc=f"Epoch {epoch+1}/{epochs} [Val]", leave=False)
    with torch.no_grad():
        for i in val_iter:
            xb = X_val_micro[i:i+batch_size]
            yb = y_val_micro[i:i+batch_size]
            outputs = model_MSE(xb)
            loss = criterion(outputs, yb)

            val_loss += loss.item()
            avg_so_far = val_loss / ((i // batch_size) + 1)
            val_iter.set_postfix(loss=avg_so_far)
    
    
    avg_train_loss = train_loss / (len(X_train_micro) / batch_size)
    avg_val_loss = val_loss / (len(X_val_micro) / batch_size)

    train_loss_hist.append(avg_train_loss)
    val_loss_hist.append(avg_val_loss)
    live_plot()

    print(f"Epoch [{epoch+1}/{epochs}] "
          f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")


In [None]:

# ============================
# 9️⃣ Prediction example
# ============================
# i = random.randint(0, X_val.shape[0]-1)
input_patch = X_val[sample_idx].cpu().squeeze().numpy()
true_heatmap = y_val[sample_idx].cpu().squeeze().numpy()

model_MSE.eval()
with torch.no_grad():
    predicted = model_MSE(X_val[sample_idx].unsqueeze(0)).cpu().squeeze().numpy()

plt.figure(figsize=(12, 4))
# plt.subplot(1, 3, 1); plt.imshow(input_patch, cmap='gray'); plt.title("Input Patch")
# plt.subplot(1, 3, 2); plt.imshow(true_heatmap, cmap='hot'); plt.title("True Heatmap")
plt.subplot(1, 3, 3); plt.imshow(predicted, cmap='hot'); plt.title("Predicted Heatmap")
plt.show()


still terrible

New Structure, BCE Loss

In [None]:

# ============================
# 7️⃣ Training setup
# ============================
model_BCE = CNNModel_new().to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([20.0], device=device))
optimizer = optim.Adam(model_BCE.parameters(), lr=0.001)


In [None]:
# ============================
# Training loop (Microsample)
# ============================
from tqdm import tqdm

epochs = 20
batch_size = 8

for epoch in range(epochs):
    model_BCE.train()
    train_loss = 0.0
    
    # Training loop with tqdm
    train_iter = tqdm(range(0, len(X_train_micro), batch_size), desc=f"Epoch {epoch+1}/{epochs} [Train]", leave=False)
    for i in train_iter:
        xb = X_train_micro[i:i+batch_size]
        yb = y_train_micro[i:i+batch_size]

        optimizer.zero_grad()
        outputs = model_BCE(xb)
        loss = criterion(outputs, yb)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        avg_so_far = train_loss / ((i // batch_size) + 1)
        train_iter.set_postfix(loss=avg_so_far)

    # Validation loop with tqdm
    model_BCE.eval()
    val_loss = 0.0
    val_iter = tqdm(range(0, len(X_val_micro), batch_size), desc=f"Epoch {epoch+1}/{epochs} [Val]", leave=False)
    with torch.no_grad():
        for i in val_iter:
            xb = X_val_micro[i:i+batch_size]
            yb = y_val_micro[i:i+batch_size]
            outputs = model_BCE(xb)
            loss = criterion(outputs, yb)

            val_loss += loss.item()
            avg_so_far = val_loss / ((i // batch_size) + 1)
            val_iter.set_postfix(loss=avg_so_far)

    avg_train_loss = train_loss / (len(X_train_micro) / batch_size)
    avg_val_loss = val_loss / (len(X_val_micro) / batch_size)

    train_loss_hist.append(avg_train_loss)
    val_loss_hist.append(avg_val_loss)
    live_plot()

    print(f"Epoch [{epoch+1}/{epochs}] "
          f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")


In [None]:

# ============================
# 9️⃣ Prediction example
# ============================
# i = random.randint(0, X_val.shape[0]-1)
input_patch = X_val[sample_idx].cpu().squeeze().numpy()
true_heatmap = y_val[sample_idx].cpu().squeeze().numpy()

model_BCE.eval()
with torch.no_grad():
    predicted = model_BCE(X_val[sample_idx].unsqueeze(0)).cpu().squeeze().numpy()

plt.figure(figsize=(12, 4))
# plt.subplot(1, 3, 1); plt.imshow(input_patch, cmap='gray'); plt.title("Input Patch")
# plt.subplot(1, 3, 2); plt.imshow(true_heatmap, cmap='hot'); plt.title("True Heatmap")
plt.subplot(1, 3, 3); plt.imshow(predicted, cmap='hot'); plt.title("Predicted Heatmap")
plt.show()


maybe with many many epochs, may converge to a better fit
slightly better with new output layer

New Structure, Focal Loss

In [None]:

# ============================
# 7️⃣ Training setup
# ============================
model_FL = CNNModel_new().to(device)
criterion = FocalLoss(alpha=0.75, gamma=2.0)
optimizer = optim.Adam(model_FL.parameters(), lr=0.001)


In [None]:
# ============================
# Training loop (Microsample)
# ============================
from tqdm import tqdm

epochs = 20
batch_size = 8

for epoch in range(epochs):
    model_FL.train()
    train_loss = 0.0
    
    # Training loop with tqdm
    train_iter = tqdm(range(0, len(X_train_micro), batch_size), desc=f"Epoch {epoch+1}/{epochs} [Train]", leave=False)
    for i in train_iter:
        xb = X_train_micro[i:i+batch_size]
        yb = y_train_micro[i:i+batch_size]

        optimizer.zero_grad()
        outputs = model_FL(xb)
        loss = criterion(yb, outputs)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        avg_so_far = train_loss / ((i // batch_size) + 1)
        train_iter.set_postfix(loss=avg_so_far)

    # Validation loop with tqdm
    model_FL.eval()
    val_loss = 0.0
    val_iter = tqdm(range(0, len(X_val_micro), batch_size), desc=f"Epoch {epoch+1}/{epochs} [Val]", leave=False)
    with torch.no_grad():
        for i in val_iter:
            xb = X_val_micro[i:i+batch_size]
            yb = y_val_micro[i:i+batch_size]
            outputs = model_FL(xb)
            loss = criterion(yb, outputs)
            val_loss += loss.item()
            avg_so_far = val_loss / ((i // batch_size) + 1)
            val_iter.set_postfix(loss=avg_so_far)

    avg_train_loss = train_loss / (len(X_train_micro) / batch_size)
    avg_val_loss = val_loss / (len(X_val_micro) / batch_size)

    train_loss_hist.append(avg_train_loss)
    val_loss_hist.append(avg_val_loss)
    live_plot()

    print(f"Epoch [{epoch+1}/{epochs}] "
          f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")


In [None]:

# ============================
# 9️⃣ Prediction example
# ============================
# i = random.randint(0, X_val.shape[0]-1)
input_patch = X_val[sample_idx].cpu().squeeze().numpy()
true_heatmap = y_val[sample_idx].cpu().squeeze().numpy()

model_FL.eval()
with torch.no_grad():
    predicted = model_FL(X_val[sample_idx].unsqueeze(0)).cpu().squeeze().numpy()

plt.figure(figsize=(12, 4))
# plt.subplot(1, 3, 1); plt.imshow(input_patch, cmap='gray'); plt.title("Input Patch")
# plt.subplot(1, 3, 2); plt.imshow(true_heatmap, cmap='hot'); plt.title("True Heatmap")
plt.subplot(1, 3, 3); plt.imshow(predicted, cmap='hot'); plt.title("Predicted Heatmap")
plt.show()


Train new model with FocalLoss on full Dataset

In [None]:

# ============================
# 7️⃣ Training setup
# ============================
model_FL_Complete = CNNModel_new().to(device)
criterion = FocalLoss(alpha=0.75, gamma=2.0)
optimizer = optim.Adam(model_FL_Complete.parameters(), lr=0.001)


In [None]:

# ============================
# 8️⃣ Training loop
# ============================
from tqdm import tqdm

epochs = 20
batch_size = 16

for epoch in range(epochs):
    model_FL_Complete.train()
    train_loss = 0.0
    
    # Training loop with tqdm
    train_iter = tqdm(range(0, len(X_train), batch_size), desc=f"Epoch {epoch+1}/{epochs} [Train]", leave=False)
    for i in train_iter:
        xb = X_train[i:i+batch_size]
        yb = y_train[i:i+batch_size]

        optimizer.zero_grad()
        outputs = model_FL_Complete(xb)
        loss = criterion(yb, outputs)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        avg_so_far = train_loss / ((i // batch_size) + 1)
        train_iter.set_postfix(loss=avg_so_far)

    # Validation loop with tqdm
    model_FL_Complete.eval()
    val_loss = 0.0
    val_iter = tqdm(range(0, len(X_val), batch_size), desc=f"Epoch {epoch+1}/{epochs} [Val]", leave=False)
    with torch.no_grad():
        for i in val_iter:
            xb = X_val[i:i+batch_size]
            yb = y_val[i:i+batch_size]
            outputs = model_FL_Complete(xb)
            loss = criterion(yb, outputs)
            val_loss += loss.item()
            avg_so_far = val_loss / ((i // batch_size) + 1)
            val_iter.set_postfix(loss=avg_so_far)

    avg_train_loss = train_loss / (len(X_train) / batch_size)
    avg_val_loss = val_loss / (len(X_val) / batch_size)

    train_loss_hist.append(avg_train_loss)
    val_loss_hist.append(avg_val_loss)
    live_plot()

    print(f"Epoch [{epoch+1}/{epochs}] Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

#save model
torch.save(model_FL_Complete.state_dict(), "model_FL_Complete.pth")


In [None]:

# ============================
# 9️⃣ Prediction example
# ============================
# i = random.randint(0, X_val.shape[0]-1)
input_patch = X_val[sample_idx].cpu().squeeze().numpy()
true_heatmap = y_val[sample_idx].cpu().squeeze().numpy()

model_FL_Complete.eval()
with torch.no_grad():
    predicted = model_FL_Complete(X_val[sample_idx].unsqueeze(0)).cpu().squeeze().numpy()

plt.figure(figsize=(12, 4))
# plt.subplot(1, 3, 1); plt.imshow(input_patch, cmap='gray'); plt.title("Input Patch")
# plt.subplot(1, 3, 2); plt.imshow(true_heatmap, cmap='hot'); plt.title("True Heatmap")
plt.subplot(1, 3, 3); plt.imshow(predicted, cmap='hot'); plt.title("Predicted Heatmap")
plt.show()


compare, old model with FocalLoss on full set

In [None]:

# ============================
# 7️⃣ Training setup
# ============================
model_FL_old_Complete = CNNModel().to(device)
criterion = FocalLoss(alpha=0.75, gamma=2.0)
optimizer = optim.Adam(model_FL_old_Complete.parameters(), lr=0.001)


In [None]:

# ============================
# 8️⃣ Training loop
# ============================
from tqdm import tqdm

epochs = 20
batch_size = 16

for epoch in range(epochs):
    model_FL_old_Complete.train()
    train_loss = 0.0
    
    # Training loop with tqdm
    train_iter = tqdm(range(0, len(X_train), batch_size), desc=f"Epoch {epoch+1}/{epochs} [Train]", leave=False)
    for i in train_iter:
        xb = X_train[i:i+batch_size]
        yb = y_train[i:i+batch_size]

        optimizer.zero_grad()
        outputs = model_FL_old_Complete(xb)
        loss = criterion(yb, outputs)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        avg_so_far = train_loss / ((i // batch_size) + 1)
        train_iter.set_postfix(loss=avg_so_far)

    # Validation loop with tqdm
    model_FL_old_Complete.eval()
    val_loss = 0.0
    val_iter = tqdm(range(0, len(X_val), batch_size), desc=f"Epoch {epoch+1}/{epochs} [Val]", leave=False)
    with torch.no_grad():
        for i in val_iter:
            xb = X_val[i:i+batch_size]
            yb = y_val[i:i+batch_size]
            outputs = model_FL_old_Complete(xb)
            loss = criterion(yb, outputs)
            val_loss += loss.item()
            avg_so_far = val_loss / ((i // batch_size) + 1)
            val_iter.set_postfix(loss=avg_so_far)

    avg_train_loss = train_loss / (len(X_train) / batch_size)
    avg_val_loss = val_loss / (len(X_val) / batch_size)

    train_loss_hist.append(avg_train_loss)
    val_loss_hist.append(avg_val_loss)
    live_plot()

    print(f"Epoch [{epoch+1}/{epochs}] Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

#save model
torch.save(model_FL_Complete.state_dict(), 'model_FL_old_Complete.pth')


In [None]:

# ============================
# 9️⃣ Prediction example
# ============================
# i = random.randint(0, X_val.shape[0]-1)
input_patch = X_val[sample_idx].cpu().squeeze().numpy()
true_heatmap = y_val[sample_idx].cpu().squeeze().numpy()

model_FL_old_Complete.eval()
with torch.no_grad():
    predicted = model_FL_old_Complete(X_val[sample_idx].unsqueeze(0)).cpu().squeeze().numpy()

plt.figure(figsize=(12, 4))
# plt.subplot(1, 3, 1); plt.imshow(input_patch, cmap='gray'); plt.title("Input Patch")
# plt.subplot(1, 3, 2); plt.imshow(true_heatmap, cmap='hot'); plt.title("True Heatmap")
plt.subplot(1, 3, 3); plt.imshow(predicted, cmap='hot'); plt.title("Predicted Heatmap")
plt.show()


seems to perform just as well as without changes to the output layer..........

In [None]:
# #load models
# model_FL_Complete = CNNModel().to(device)
# model_FL_old_Complete = CNNModel().to(device)

# model_FL_Complete.load_state_dict(torch.load("model_FL_Complete.pth"))
# model_FL_old_Complete.load_state_dict(torch.load("model_FL_old_Complete.pth"))

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from skimage.feature import peak_local_max
from scipy.spatial import cKDTree

# ============================
# 🔍 Utility: extract emitter coordinates
# ============================
def extract_emitters(heatmap, threshold=0.3, min_distance=2):
    """
    Extract emitter coordinates from a heatmap using non-maximum suppression.
    heatmap: 2D numpy array
    threshold: minimum relative intensity for detection
    min_distance: minimum spacing between peaks
    """
    coords = peak_local_max(
        heatmap, 
        min_distance=min_distance, 
        threshold_abs=threshold * heatmap.max()
    )
    return coords  # array of (y, x) coords

# ============================
# 📐 Metric computation
# ============================
def evaluate_localization(model, X, Y, radius=2, threshold=0.3, device="cpu"):
    """
    Evaluate model predictions against ground truth heatmaps.
    Returns precision, recall, f1, rmse
    """
    model.eval()
    all_tp, all_fp, all_fn = 0, 0, 0
    localization_errors = []

    with torch.no_grad():
        for i in range(len(X)):
            xb = X[i].unsqueeze(0).to(device)
            yb = Y[i].cpu().squeeze().numpy()

            pred = model(xb).cpu().squeeze().numpy()

            # Get emitter coordinates
            gt_coords = extract_emitters(yb, threshold=0.3, min_distance=2)
            pred_coords = extract_emitters(pred, threshold=threshold, min_distance=2)

            if len(gt_coords) == 0 and len(pred_coords) == 0:
                continue

            # Match predictions to ground truth using KDTree
            if len(gt_coords) > 0 and len(pred_coords) > 0:
                tree = cKDTree(gt_coords)
                dists, idxs = tree.query(pred_coords, distance_upper_bound=radius)

                # True Positives = matches within radius
                tp_mask = dists != np.inf
                tp = np.sum(tp_mask)
                fp = len(pred_coords) - tp
                fn = len(gt_coords) - tp

                # Localization error for matched emitters
                localization_errors.extend(dists[tp_mask])

            else:
                tp, fp, fn = 0, len(pred_coords), len(gt_coords)

            all_tp += tp
            all_fp += fp
            all_fn += fn

    # Precision, Recall, F1
    precision = all_tp / (all_tp + all_fp + 1e-8)
    recall    = all_tp / (all_tp + all_fn + 1e-8)
    f1        = 2 * precision * recall / (precision + recall + 1e-8)
    rmse      = np.sqrt(np.mean(np.square(localization_errors))) if localization_errors else np.nan

    return precision, recall, f1, rmse


In [None]:
precision, recall, f1, rmse = evaluate_localization(
    model_FL_old_Complete, 
    X_val, 
    y_val, 
    radius=2,      # tolerance in px
    threshold=0.3, # adjust based on your heatmap scaling
    device=device
)

print(f"Precision: {precision:.3f}, Recall: {recall:.3f}, F1: {f1:.3f}, RMSE: {rmse:.3f} px")


In [None]:
precision, recall, f1, rmse = evaluate_localization(
    model_FL_Complete, 
    X_val, 
    y_val, 
    radius=2,      # tolerance in px
    threshold=0.3, # adjust based on your heatmap scaling
    device=device
)

print(f"Precision: {precision:.3f}, Recall: {recall:.3f}, F1: {f1:.3f}, RMSE: {rmse:.3f} px")


In [None]:
import math
from skimage.metrics import structural_similarity as ssim

# ============================
# 🔍 Evaluation functions
# ============================
def psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float("inf")
    max_pixel = 1.0 if img1.max() <= 1.0 else 255.0
    return 20 * math.log10(max_pixel / math.sqrt(mse))

def evaluate_model(model, X_val, y_val):
    model.eval()
    psnr_list, ssim_list, mse_list = [], [], []
    with torch.no_grad():
        for i in range(len(X_val)):
            xb = X_val[i].unsqueeze(0)
            yb = y_val[i].cpu().squeeze().numpy()
            pred = model(xb).cpu().squeeze().numpy()

            # Normalize predicted output to match target scale
            if pred.max() > 0:
                pred = pred / pred.max()
            if yb.max() > 0:
                yb = yb / yb.max()

            psnr_list.append(psnr(pred, yb))
            ssim_list.append(ssim(pred, yb, data_range=1.0))
            mse_list.append(np.mean((pred - yb) ** 2))
    return np.mean(psnr_list), np.mean(ssim_list), np.mean(mse_list)

# ============================
# 📊 Compare all models
# ============================
models = {
    "model_L1L2_old": model_L1L2_old,
    "model_MSE_old": model_MSE_old,
    "model_BCE_old": model_BCE_old,
    "model_focal_old": model_focal_old,
    "model_L1L2": model_L1L2,
    "model_MSE": model_MSE,
    "model_BCE": model_BCE,
    "model_FL": model_FL,
    "model_FL_Complete": model_FL_Complete,
    "model_FL_old_Complete": model_FL_old_Complete,
}

results = {}
for name, model in models.items():
    print(f"Evaluating {name}...")
    psnr_val, ssim_val, mse_val = evaluate_model(model, X_val, y_val)
    results[name] = {"PSNR": psnr_val, "SSIM": ssim_val, "MSE": mse_val}

# Convert results to a clean table
import pandas as pd
df_results = pd.DataFrame(results).T
print(df_results)

# Optional: save results
df_results.to_csv("model_comparison_metrics.csv")


In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Load metrics
df = pd.read_csv("model_comparison_metrics.csv", index_col=0)

plt.figure(figsize=(10,6))
sns.heatmap(
    df, annot=True, fmt=".3f", cmap="viridis", 
    cbar_kws={'label': 'Metric Value'},
    vmax=50  # clip color scaling, e.g. 50
)
plt.title("Model Comparison Heatmap (PSNR, SSIM, MSE)")
plt.ylabel("Model")
plt.xlabel("Metric")
plt.show()



In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# Load metrics
df = pd.read_csv("model_comparison_metrics.csv", index_col=0)
# Normalized copy
df_norm = df.copy()

for col in df.columns:
    if col == "MSE":
        # Exclude BCE models from min/max calculation
        mask_valid = ~df.index.str.contains("BCE")
        col_min, col_max = df.loc[mask_valid, col].min(), df.loc[mask_valid, col].max()
        df_norm.loc[mask_valid, col] = (df.loc[mask_valid, col] - col_min) / (col_max - col_min)
        # BCE rows remain unnormalized (will be masked black later)
        df_norm.loc[~mask_valid, col] = np.nan
    else:
        # Normal min-max normalization
        col_min, col_max = df[col].min(), df[col].max()
        df_norm[col] = (df[col] - col_min) / (col_max - col_min)

# Mask for outlier cells (MSE column for BCE models)
mask = np.zeros_like(df_norm, dtype=bool)
mask[df.index.str.contains("BCE"), df.columns.get_loc("MSE")] = True

plt.figure(figsize=(10,6))
sns.heatmap(
    df_norm, annot=df, fmt=".3f", cmap="viridis",
    cbar_kws={'label': 'Normalized Metric Value'},
    mask=mask,
    linewidths=0.5, linecolor="gray"
)

# Manually color masked cells black
for i in range(len(df)):
    for j in range(len(df.columns)):
        if mask[i, j]:
            plt.gca().add_patch(plt.Rectangle((j, i), 1, 1, fill=True, color='black', ec='gray'))

plt.title("Model Comparison Heatmap (Normalized per Metric)")
plt.ylabel("Model")
plt.xlabel("Metric")
plt.show()


In [None]:
import matplotlib.pyplot as plt

# Map names in CSV to actual model objects
model_dict = {
    "model_L1L2_old": model_L1L2_old,
    "model_MSE_old": model_MSE_old,
    "model_BCE_old": model_BCE_old,
    "model_focal_old": model_focal_old,
    "model_L1L2": model_L1L2,
    "model_MSE": model_MSE,
    "model_BCE": model_BCE,
    "model_FL": model_FL,
    "model_FL_Complete": model_FL_Complete,
    "model_FL_old_Complete": model_FL_old_Complete,
}

# Pick test sample
# i = 0
input_patch = X_val[sample_idx].cpu().squeeze().numpy()
true_patch = y_val[sample_idx].cpu().squeeze().numpy()

plt.figure(figsize=(20, 12))

# Input
plt.subplot(3, 5, 1)
plt.imshow(input_patch, cmap="gray")
plt.title("Input")
plt.axis("off")

# Ground truth
plt.subplot(3, 5, 2)
plt.imshow(true_patch, cmap="gray")
plt.title("Ground Truth")
plt.axis("off")

# Predictions for each model
for j, name in enumerate(df.index, start=3):  # order from CSV
    model = model_dict[name]
    pred = model(X_val[sample_idx].unsqueeze(0)).cpu().squeeze().detach().numpy()

    # Optional: add PSNR/SSIM on title from CSV
    psnr = df.loc[name, "PSNR"]
    ssim = df.loc[name, "SSIM"]

    plt.subplot(3, 5, j)
    plt.imshow(pred, cmap="gray")
    plt.title(f"{name}\nPSNR:{psnr:.2f}, SSIM:{ssim:.3f}")
    plt.axis("off")

plt.suptitle("Side-by-Side Reconstructions", fontsize=16)
plt.tight_layout()
plt.show()


In [None]:
import pandas as pd
# ============================
# 🟦 IoU evaluation
# ============================
def compute_iou(pred, gt, threshold=0.1):
    pred_bin = (pred >= threshold).astype(np.uint8)
    gt_bin = (gt >= threshold).astype(np.uint8)

    intersection = np.logical_and(pred_bin, gt_bin).sum()
    union = np.logical_or(pred_bin, gt_bin).sum()

    if union == 0:
        return 1.0 if intersection == 0 else 0.0
    return intersection / union

def evaluate_model_iou(model, X_val, y_val, threshold=0.2):
    model.eval()
    iou_list = []
    with torch.no_grad():
        for i in range(len(X_val)):
            xb = X_val[i].unsqueeze(0)
            yb = y_val[i].cpu().squeeze().numpy()
            pred = model(xb).cpu().squeeze().numpy()

            # Normalize predicted output to [0,1]
            if pred.max() > 0:
                pred = pred / pred.max()
            if yb.max() > 0:
                yb = yb / yb.max()

            iou_list.append(compute_iou(pred, yb, threshold=threshold))
    return np.mean(iou_list)

# ============================
# 📊 IoU for last two models
# ============================
iou_results = {}
models_to_check = {
    "model_FL_Complete": model_FL_Complete,
    "model_FL_old_Complete": model_FL_old_Complete,
}

for name, model in models_to_check.items():
    print(f"Evaluating IoU for {name}...")
    iou_val = evaluate_model_iou(model, X_val, y_val, threshold=0.5)
    iou_results[name] = {"IoU@0.1": iou_val}

# Convert to DataFrame
df_iou = pd.DataFrame(iou_results).T
print(df_iou)

# Save separately
df_iou.to_csv("model_iou_metrics.csv")


volin splot/box and whisker plot

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# ============================
# 🎻 Collect individual IoU scores for violin plots
# ============================
def evaluate_model_iou_detailed(model, X_val, y_val, threshold=0.1):
    """
    Same as evaluate_model_iou but returns individual IoU scores instead of mean
    """
    model.eval()
    iou_list = []
    with torch.no_grad():
        for i in range(len(X_val)):
            xb = X_val[i].unsqueeze(0)
            yb = y_val[i].cpu().squeeze().numpy()
            pred = model(xb).cpu().squeeze().numpy()

            # Normalize predicted output to [0,1]
            if pred.max() > 0:
                pred = pred / pred.max()
            if yb.max() > 0:
                yb = yb / yb.max()

            iou_list.append(compute_iou(pred, yb, threshold=threshold))
    return iou_list

# Collect detailed IoU scores for violin plots
print("Collecting detailed IoU scores for visualization...")
iou_detailed = {}
for name, model in models_to_check.items():
    print(f"Collecting IoU scores for {name}...")
    iou_scores = evaluate_model_iou_detailed(model, X_val, y_val, threshold=0.5)
    iou_detailed[name] = iou_scores

# ============================
# 📊 Create violin plots
# ============================
# Prepare data for plotting
plot_data = []
for model_name, scores in iou_detailed.items():
    for score in scores:
        plot_data.append({
            'Model': model_name.replace('model_', '').replace('_Complete', ''),  # Clean names
            'IoU Score': score
        })

df_plot = pd.DataFrame(plot_data)

# Create the violin plot
plt.figure(figsize=(10, 6))
sns.violinplot(data=df_plot, x='Model', y='IoU Score', palette='Set2')

# Overlay box plot for additional statistics
sns.boxplot(data=df_plot, x='Model', y='IoU Score', 
           width=0.3, boxprops=dict(alpha=0.3), 
           whiskerprops=dict(alpha=0.3),
           capprops=dict(alpha=0.3))

plt.title('IoU Score Distribution Comparison\n(Violin Plot with Box Plot Overlay)', 
          fontsize=14, fontweight='bold')
plt.ylabel('IoU Score', fontsize=12)
plt.xlabel('Model', fontsize=12)
plt.grid(True, alpha=0.3)

# Add mean values as text annotations
for i, (model_name, scores) in enumerate(iou_detailed.items()):
    mean_iou = np.mean(scores)
    clean_name = model_name.replace('model_', '').replace('_Complete', '')
    plt.text(i, mean_iou + 0.02, f'μ={mean_iou:.3f}', 
             ha='center', va='bottom', fontweight='bold', fontsize=10)

plt.tight_layout()
plt.savefig('iou_violin_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

# ============================
# 📈 Additional statistical summary
# ============================
print("\n" + "="*50)
print("📊 DETAILED IoU STATISTICS")
print("="*50)

stats_summary = {}
for model_name, scores in iou_detailed.items():
    clean_name = model_name.replace('model_', '').replace('_Complete', '')
    stats = {
        'Mean': np.mean(scores),
        'Median': np.median(scores),
        'Std': np.std(scores),
        'Min': np.min(scores),
        'Max': np.max(scores),
        'Q25': np.percentile(scores, 25),
        'Q75': np.percentile(scores, 75)
    }
    stats_summary[clean_name] = stats
    
    print(f"\n{clean_name}:")
    print(f"  Mean: {stats['Mean']:.4f}")
    print(f"  Median: {stats['Median']:.4f}")
    print(f"  Std: {stats['Std']:.4f}")
    print(f"  Range: [{stats['Min']:.4f}, {stats['Max']:.4f}]")
    print(f"  IQR: [{stats['Q25']:.4f}, {stats['Q75']:.4f}]")

# Convert stats to DataFrame and save
df_stats = pd.DataFrame(stats_summary).T
df_stats.to_csv("model_iou_detailed_stats.csv")
print(f"\nDetailed statistics saved to 'model_iou_detailed_stats.csv'") 

In [None]:
def visualize_random_iou_samples(model, X_val, y_val, n_samples=3, threshold=0.5):
    model.eval()
    idxs = random.sample(range(len(X_val)), n_samples)

    fig, axes = plt.subplots(n_samples, 3, figsize=(10, 4 * n_samples))

    with torch.no_grad():
        for row, i in enumerate(idxs):
            xb = X_val[i].unsqueeze(0)
            yb = y_val[i].cpu().squeeze().numpy()
            pred = model(xb).cpu().squeeze().numpy()

            # Normalize
            if pred.max() > 0:
                pred = pred / pred.max()
            if yb.max() > 0:
                yb = yb / yb.max()

            # Binarize
            pred_bin = (pred >= threshold).astype(np.uint8)
            gt_bin = (yb >= threshold).astype(np.uint8)

            # Compute IoU
            iou_val = compute_iou(pred, yb, threshold=threshold)

            # Plot GT, Pred, Overlay
            axes[row, 0].imshow(gt_bin, cmap="gray")
            axes[row, 0].set_title("Ground Truth")
            axes[row, 0].axis("off")

            axes[row, 1].imshow(pred_bin, cmap="gray")
            axes[row, 1].set_title("Prediction")
            axes[row, 1].axis("off")

            axes[row, 2].imshow(gt_bin, cmap="gray", alpha=0.5, label="GT")
            axes[row, 2].imshow(pred_bin, cmap="Reds", alpha=0.5, label="Pred")
            axes[row, 2].set_title(f"Overlay\nIoU={iou_val:.3f}")
            axes[row, 2].axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
visualize_random_iou_samples(model_FL_Complete, X_val, y_val, n_samples=3, threshold=0.2)

In [None]:
visualize_random_iou_samples(model_FL_old_Complete, X_val, y_val, n_samples=3, threshold=0.2)