In [9]:
# Teacher Model: Multi-Task (Classification + Segmentation + Regression with Normalized Output)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.transforms import functional as TF
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
import os
import random

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class TeacherNet(nn.Module):
    def __init__(self, in_channels=1, n_classes=3):
        super(TeacherNet, self).__init__()
        filters = [64, 128, 256, 512]

        self.encoder1 = ConvBlock(in_channels, filters[0])
        self.pool1 = nn.MaxPool2d(2)
        self.encoder2 = ConvBlock(filters[0], filters[1])
        self.pool2 = nn.MaxPool2d(2)
        self.encoder3 = ConvBlock(filters[1], filters[2])
        self.pool3 = nn.MaxPool2d(2)
        self.encoder4 = ConvBlock(filters[2], filters[3])
        self.pool4 = nn.MaxPool2d(2)

        self.bottleneck = ConvBlock(filters[3], filters[3])

        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(filters[3], 3)
        )

        self.up1 = nn.ConvTranspose2d(filters[3], filters[2], kernel_size=2, stride=2)
        self.decoder1 = ConvBlock(filters[2] + filters[3], filters[2])
        self.up2 = nn.ConvTranspose2d(filters[2], filters[1], kernel_size=2, stride=2)
        self.decoder2 = ConvBlock(filters[1] + filters[2], filters[1])
        self.up3 = nn.ConvTranspose2d(filters[1], filters[0], kernel_size=2, stride=2)
        self.decoder3 = ConvBlock(filters[0] + filters[1], filters[0])
        self.up4 = nn.ConvTranspose2d(filters[0], filters[0] // 2, kernel_size=2, stride=2)
        self.decoder4 = ConvBlock(filters[0] // 2, filters[0])
        self.segmentation_head = nn.Conv2d(filters[0], 1, kernel_size=1)

        self.regressor = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(filters[3], 1),
            nn.Sigmoid()  # Predict in range [0, 1]
        )

    def forward(self, x):
        e1 = self.encoder1(x)
        e2 = self.encoder2(self.pool1(e1))
        e3 = self.encoder3(self.pool2(e2))
        e4 = self.encoder4(self.pool3(e3))
        b = self.bottleneck(self.pool4(e4))

        plane_logits = self.classifier(b)

        d1 = self.up1(b)
        d1 = self.decoder1(torch.cat([d1, e4], dim=1))
        d2 = self.up2(d1)
        d2 = self.decoder2(torch.cat([d2, e3], dim=1))
        d3 = self.up3(d2)
        d3 = self.decoder3(torch.cat([d3, e2], dim=1))
        d4 = self.up4(d3)
        d4 = self.decoder4(d4)
        segmentation = self.segmentation_head(d4)

        value = self.regressor(b)

        return plane_logits, segmentation, value

# =========================
# NORMALIZATION SCRIPT
# =========================
def normalize_and_save():
    df = pd.read_csv("G:/Sajal_Data/Obj_4_Code/Teacher_model_training/data/kamra_teacher_expanded.csv")
    min_vals = df.groupby("type")["value"].min()
    max_vals = df.groupby("type")["value"].max()

    def normalize(row):
        min_val = min_vals[row["type"]]
        max_val = max_vals[row["type"]]
        return (row["value"] - min_val) / (max_val - min_val)

    df["value_norm"] = df.apply(normalize, axis=1)
    df.to_csv("kamra_measurements_normalized.csv", index=False)

    min_max_df = pd.DataFrame({"min": min_vals, "max": max_vals})
    min_max_df.to_csv("biometric_min_max.csv")

# =========================
# INFERENCE DENORMALIZATION
# =========================
def denormalize(value_norm, biom_type, min_max_path="biometric_min_max.csv"):
    min_max = pd.read_csv(min_max_path, index_col=0)
    row = min_max.loc[biom_type]
    return value_norm * (row["max"] - row["min"]) + row["min"]

# =========================
# DataLoader and Training Loop
# =========================
class RandomAugment:
    def __call__(self, image, mask):
        if random.random() > 0.5:
            image = TF.hflip(image)
            mask = TF.hflip(mask)
        if random.random() > 0.5:
            angle = random.uniform(-15, 15)
            image = TF.rotate(image, angle)
            mask = TF.rotate(mask, angle)
        return image, mask

class UltrasoundDataset(Dataset):
    def __init__(self, image_dir, mask_dir, csv_path, transform=None, augment=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.data = pd.read_csv(csv_path)
        self.transform = transform or T.Compose([T.Resize((224, 224)), T.ToTensor()])
        self.augment = augment

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img_path = os.path.join(self.image_dir, row['image'])
        mask_path = os.path.join(self.mask_dir, row['image'].replace('.jpg', '_mask.png'))

        image = Image.open(img_path).convert('L')
        mask = Image.open(mask_path).convert('L')

        if self.augment:
            image, mask = self.augment(image, mask)

        image = self.transform(image)
        mask = self.transform(mask)

        label_map = {'head': 0, 'abdomen': 1, 'femur': 2}
        label = torch.tensor(label_map[row['plane'].lower()], dtype=torch.long)
        value = torch.tensor(row['value_norm'], dtype=torch.float32)

        return image, mask, label, value

# Run normalization once before training
normalize_and_save()

model = TeacherNet()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
cls_loss = nn.CrossEntropyLoss()
seg_loss = nn.BCEWithLogitsLoss()
reg_loss = nn.MSELoss()

train_dataset = UltrasoundDataset(
    "G:/Sajal_Data/Obj_4_Code/Teacher_model_training/data/images",
    "G:/Sajal_Data/Obj_4_Code/Teacher_model_training/data/masks",
    "kamra_measurements_normalized.csv",
    augment=RandomAugment()
)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

for epoch in range(75):
    model.train()
    total_loss = 0
    for img, mask, label, value in train_loader:
        img, mask, label, value = img.to(device), mask.to(device), label.to(device), value.to(device)

        optimizer.zero_grad()
        out_cls, out_seg, out_val = model(img)

        loss_c = cls_loss(out_cls, label)
        loss_s = seg_loss(out_seg, mask)
        loss_r = reg_loss(out_val.view(-1), value.view(-1))

        loss = loss_c + loss_s + 0.01 * loss_r
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}: Total={loss.item():.2f} | Cls={loss_c.item():.2f} | Seg={loss_s.item():.2f} | Reg={loss_r.item():.2f}")

# Save the model
torch.save(model.state_dict(), "teacher_model.pth")

Epoch 1: Total=1.67 | Cls=1.16 | Seg=0.51 | Reg=0.04
Epoch 2: Total=1.11 | Cls=0.72 | Seg=0.39 | Reg=0.30
Epoch 3: Total=2.03 | Cls=1.69 | Seg=0.34 | Reg=0.12
Epoch 4: Total=0.91 | Cls=0.59 | Seg=0.32 | Reg=0.29
Epoch 5: Total=2.09 | Cls=1.77 | Seg=0.31 | Reg=0.23
Epoch 6: Total=1.71 | Cls=1.42 | Seg=0.28 | Reg=0.28
Epoch 7: Total=0.79 | Cls=0.50 | Seg=0.28 | Reg=0.28
Epoch 8: Total=0.75 | Cls=0.49 | Seg=0.26 | Reg=0.25
Epoch 9: Total=2.13 | Cls=1.85 | Seg=0.28 | Reg=0.00
Epoch 10: Total=1.77 | Cls=1.53 | Seg=0.24 | Reg=0.06
Epoch 11: Total=0.71 | Cls=0.48 | Seg=0.23 | Reg=0.27
Epoch 12: Total=0.61 | Cls=0.39 | Seg=0.22 | Reg=0.09
Epoch 13: Total=2.11 | Cls=1.90 | Seg=0.21 | Reg=0.04
Epoch 14: Total=0.67 | Cls=0.47 | Seg=0.20 | Reg=0.30
Epoch 15: Total=0.60 | Cls=0.41 | Seg=0.19 | Reg=0.29
Epoch 16: Total=1.69 | Cls=1.50 | Seg=0.18 | Reg=0.02
Epoch 17: Total=1.54 | Cls=1.36 | Seg=0.18 | Reg=0.13
Epoch 18: Total=1.60 | Cls=1.43 | Seg=0.17 | Reg=0.04
Epoch 19: Total=1.53 | Cls=1.37 | Seg