In [1]:
# TAPNet Training Script for Fetal Ultrasound Videos (Fully Upgraded with Proposed Components)
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image, ImageFilter
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, mean_squared_error, r2_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from torchcam.methods import GradCAM
from torchvision.transforms.functional import to_pil_image
import cv2

# --- Generate video_labels.csv from pseudo_labels.csv ---
def generate_video_labels(pseudo_csv_path, output_csv_path):
    df = pd.read_csv(pseudo_csv_path)
    df['video_id'] = df['image'].apply(lambda x: x.split('_frame_')[0])
    df_out = df[['image', 'video_id', 'plane', 'value']].copy()
    df_out.to_csv(output_csv_path, index=False)
    print(f"✅ video_labels.csv generated at: {output_csv_path}")

# --- Frame Quality Filter (Blur Detection) ---
def is_blurry(image, threshold=100):
    image_gray = image.convert('L')
    image_np = np.array(image_gray)
    laplacian_var = cv2.Laplacian(image_np, cv2.CV_64F).var()
    return laplacian_var < threshold

# --- Denoising Preprocessing ---
def denoise_image(image):
    return image.filter(ImageFilter.MedianFilter(size=3))

# --- Temporal Attention Module ---
class TemporalAttention(nn.Module):
    def __init__(self, hidden_dim):
        super(TemporalAttention, self).__init__()
        self.attn = nn.Linear(hidden_dim * 2, 1)

    def forward(self, rnn_output):
        weights = torch.softmax(self.attn(rnn_output), dim=1)
        context = torch.sum(weights * rnn_output, dim=1)
        return context

# --- CNN + GRU + Attention TAPNet Model ---
class CNNEncoder(nn.Module):
    def __init__(self, in_channels=1):
        super(CNNEncoder, self).__init__()
        self.backbone = models.resnet18(weights=None)
        self.backbone.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.backbone.fc = nn.Identity()

    def forward(self, x):
        return self.backbone(x)

class TAPNet(nn.Module):
    def __init__(self, in_channels=1, n_classes=3, hidden_dim=128, num_layers=1):
        super(TAPNet, self).__init__()
        self.encoder = CNNEncoder(in_channels)
        self.rnn = nn.GRU(
            input_size=512,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            bidirectional=True,
            batch_first=True
        )
        self.attn = TemporalAttention(hidden_dim)
        self.classifier = nn.Linear(hidden_dim * 2, n_classes)
        self.regressor = nn.Sequential(
            nn.Linear(hidden_dim * 2, 1),
            nn.Sigmoid()
        )

    def forward(self, x_seq):
        B, T, C, H, W = x_seq.shape
        feats = [self.encoder(x_seq[:, t]) for t in range(T)]
        feats_seq = torch.stack(feats, dim=1)
        rnn_out, _ = self.rnn(feats_seq)
        context = self.attn(rnn_out)
        cls_out = self.classifier(context)
        reg_out = self.regressor(context).squeeze(-1)
        return cls_out, reg_out

# --- Dataset with Frame Quality Selector + Denoising ---
class TAPNetVideoDataset(Dataset):
    def __init__(self, video_folder, label_csv):
        self.data = pd.read_csv(label_csv)
        self.video_folder = video_folder
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.Grayscale(),
            transforms.ToTensor()
        ])
        label_map = {'head': 0, 'abdomen': 1, 'femur': 2}
        self.data['label'] = self.data['plane'].map(label_map)

    def __len__(self):
        return self.data['video_id'].nunique()

    def __getitem__(self, idx):
        video_id = self.data['video_id'].unique()[idx]
        frames = self.data[self.data['video_id'] == video_id]
        imgs, labels, values = [], [], []

        for _, row in frames.iterrows():
            img_path = os.path.join(self.video_folder, row['image'])
            if not os.path.exists(img_path):
                continue
            img = Image.open(img_path)
            if is_blurry(img):
                continue
            img = denoise_image(img)
            img = img.convert('L')
            img = self.transform(img)
            imgs.append(img)
            labels.append(row['label'])
            values.append(row['value'])

        max_len = 200
        pad_len = max_len - len(imgs)
        if pad_len > 0:
            pad_img = torch.zeros_like(imgs[0])
            imgs.extend([pad_img] * pad_len)
            labels.extend([labels[-1]] * pad_len)
            values.extend([values[-1]] * pad_len)
        else:
            imgs = imgs[:max_len]
            labels = labels[:max_len]
            values = values[:max_len]

        img_tensor = torch.stack(imgs)
        label_tensor = torch.tensor(labels, dtype=torch.long)
        value_tensor = torch.tensor(values, dtype=torch.float32)
        return img_tensor, label_tensor, value_tensor

# --- Training Loop with Reports + GA/EFW ---
def train_tapnet():
    pseudo_csv_path = "G:/Sajal_Data/Obj_4_Code/TAPNet_Dataset/pseudo_labels.csv"
    label_csv = "G:/Sajal_Data/Obj_4_Code/TAPNet_Dataset/video_labels.csv"
    video_folder = "G:/Sajal_Data/Obj_4_Code/TAPNet_Dataset/video_frames"
    report_dir = "G:/Sajal_Data/Obj_4_Code/TAPNet_Dataset/evaluation_reports"
    os.makedirs(report_dir, exist_ok=True)

    if not os.path.exists(label_csv):
        generate_video_labels(pseudo_csv_path, label_csv)

    dataset = TAPNetVideoDataset(video_folder, label_csv)
    loader = DataLoader(dataset, batch_size=2, shuffle=True)

    model = TAPNet().cuda() if torch.cuda.is_available() else TAPNet()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    loss_cls = nn.CrossEntropyLoss()
    loss_reg = nn.MSELoss()

    for epoch in range(100):
        model.train()
        total_loss = 0
        all_preds, all_labels, all_preds_reg, all_true_reg = [], [], [], []

        for imgs, labels, values in loader:
            imgs, labels, values = imgs.to(device), labels.to(device), values.to(device)
            optimizer.zero_grad()
            out_cls, out_val = model(imgs)
            loss_c = loss_cls(out_cls, labels[:, 0])
            loss_r = loss_reg(out_val, values.mean(dim=1))
            loss = loss_c + 0.01 * loss_r
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            all_preds.extend(torch.argmax(out_cls, dim=1).cpu().numpy())
            all_labels.extend(labels[:, 0].cpu().numpy())
            all_preds_reg.extend(out_val.cpu().detach().numpy().flatten())
            all_true_reg.extend(values.mean(dim=1).cpu().numpy())

        acc = accuracy_score(all_labels, all_preds)
        prec = precision_score(all_labels, all_preds, average='macro', zero_division=0)
        rec = recall_score(all_labels, all_preds, average='macro', zero_division=0)
        f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
        mse = mean_squared_error(all_true_reg, all_preds_reg)
        r2 = r2_score(all_true_reg, all_preds_reg)

        cm = confusion_matrix(all_labels, all_preds)
        plt.figure(figsize=(6, 5))
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=['head', 'abdomen', 'femur'], yticklabels=['head', 'abdomen', 'femur'])
        plt.title(f"Epoch {epoch+1} Confusion Matrix")
        plt.xlabel("Predicted")
        plt.ylabel("True")
        plt.tight_layout()
        plt.savefig(os.path.join(report_dir, f"conf_matrix_epoch_{epoch+1}.png"))
        plt.close()

        plt.figure()
        plt.scatter(all_true_reg, all_preds_reg, alpha=0.6)
        plt.xlabel("True Biometric Value")
        plt.ylabel("Predicted Biometric Value")
        plt.title(f"Epoch {epoch+1} Regression Scatter")
        plt.grid(True)
        plt.savefig(os.path.join(report_dir, f"regression_epoch_{epoch+1}.png"))
        plt.close()

        # --- Calculate GA (in weeks) and EFW (Hadlock formula approximation) ---
        ga_weeks = [val * 10 + 20 for val in all_preds_reg]  # example linear scale
        efw = [((hc/100)**3.2) * 1.07 if hc > 0 else 0 for hc in all_preds_reg]  # HC in cm, approximate formula

        print(f"Epoch {epoch+1}: Loss={total_loss:.2f} | Acc={acc:.3f} | Prec={prec:.3f} | Recall={rec:.3f} | F1={f1:.3f} | MSE={mse:.4f} | R2={r2:.4f}")

        model.eval()
        if len(dataset) > 0:
            sample_imgs, _, _ = dataset[0]
            sample_img = sample_imgs[0].unsqueeze(0).to(device)
            gradcam = GradCAM(model.encoder.backbone, target_layer="layer4")
            output_cls, _ = model(sample_img.unsqueeze(0))  # Forward pass
            predicted_class = torch.argmax(output_cls, dim=1).item()
            cam_map = gradcam(class_idx=predicted_class, scores=output_cls)
            cam_img = to_pil_image(cam_map[0].squeeze().cpu().clamp(0, 1))
            cam_img.save(os.path.join(report_dir, f"gradcam_epoch_{epoch+1}.png"))

    torch.save(model.state_dict(), "tapnet_model.pth")
    print("✅ TAPNet training complete. Visualizations saved in report folder.")

if __name__ == "__main__":
    train_tapnet()

Epoch 1: Loss=294.88 | Acc=0.655 | Prec=0.646 | Recall=0.647 | F1=0.646 | MSE=631.3806 | R2=-29.9253
Epoch 2: Loss=284.81 | Acc=0.726 | Prec=0.720 | Recall=0.708 | F1=0.711 | MSE=621.4226 | R2=-29.4375
Epoch 3: Loss=282.50 | Acc=0.798 | Prec=0.792 | Recall=0.798 | F1=0.794 | MSE=620.5501 | R2=-29.3948
Epoch 4: Loss=278.61 | Acc=0.821 | Prec=0.830 | Recall=0.802 | F1=0.810 | MSE=620.1993 | R2=-29.3776
Epoch 5: Loss=278.56 | Acc=0.786 | Prec=0.787 | Recall=0.767 | F1=0.773 | MSE=620.0573 | R2=-29.3707
Epoch 6: Loss=275.34 | Acc=0.845 | Prec=0.845 | Recall=0.835 | F1=0.839 | MSE=619.8939 | R2=-29.3627
Epoch 7: Loss=279.92 | Acc=0.857 | Prec=0.866 | Recall=0.841 | F1=0.849 | MSE=619.7772 | R2=-29.3569
Epoch 8: Loss=275.69 | Acc=0.857 | Prec=0.866 | Recall=0.841 | F1=0.849 | MSE=619.8030 | R2=-29.3582
Epoch 9: Loss=270.81 | Acc=0.893 | Prec=0.895 | Recall=0.884 | F1=0.888 | MSE=619.7985 | R2=-29.3580
Epoch 10: Loss=276.15 | Acc=0.845 | Prec=0.850 | Recall=0.831 | F1=0.837 | MSE=619.7017 | R