In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import os
from google.colab import drive
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# --- Attention Modules ---
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        y = torch.cat([avg_out, max_out], dim=1)
        y = self.conv1(y)
        return self.sigmoid(y) * x

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1   = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2   = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out) * x

# --- Multi-Scale Feature Extractor ---
class DilatedConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, dilation_rate):
        super(DilatedConvLayer, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=dilation_rate, dilation=dilation_rate)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.conv(x))

class PyramidPoolingLayer(nn.Module):
    def __init__(self, in_channels, pool_sizes=(1, 2, 3, 6)):
        super(PyramidPoolingLayer, self).__init__()
        self.pools = nn.ModuleList()
        for size in pool_sizes:
            self.pools.append(nn.AdaptiveAvgPool2d(size))

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        pooled_outputs = [x]
        for pool in self.pools:
            pooled = pool(x)
            pooled = F.interpolate(pooled, size=(height, width), mode='bilinear', align_corners=False)
            pooled_outputs.append(pooled)
        output = torch.cat(pooled_outputs, dim=1)
        return output

class MultiScaleFeatureExtractor(nn.Module):
    def __init__(self, in_channels):
        super(MultiScaleFeatureExtractor, self).__init__()
        dcl_out_channels = 64
        ppl_out_mult = len(PyramidPoolingLayer(1).pools) + 1

        # Branch 1 (processes original)
        self.dcl1 = DilatedConvLayer(1, dcl_out_channels, dilation_rate=1)
        self.sa1 = SpatialAttention()
        self.ca1 = ChannelAttention(dcl_out_channels)
        self.ppl1 = PyramidPoolingLayer(dcl_out_channels)
        self.conv1_1 = nn.Conv2d(dcl_out_channels * ppl_out_mult, dcl_out_channels, kernel_size=1)

        # Branch 2 (processes original)
        self.dcl2 = DilatedConvLayer(1, dcl_out_channels, dilation_rate=2)
        self.sa2 = SpatialAttention()
        self.ca2 = ChannelAttention(dcl_out_channels)
        self.ppl2 = PyramidPoolingLayer(dcl_out_channels)
        self.conv2_1 = nn.Conv2d(dcl_out_channels * ppl_out_mult, dcl_out_channels, kernel_size=1)

        # Branch 3 (processes noisy)
        dcl4_out_channels = dcl_out_channels // 2
        dcl8_out_channels = dcl_out_channels // 2
        self.dcl4 = DilatedConvLayer(1, dcl4_out_channels, dilation_rate=4)
        self.dcl8 = DilatedConvLayer(1, dcl8_out_channels, dilation_rate=8)
        self.sa3 = SpatialAttention()
        self.ca3 = ChannelAttention(dcl4_out_channels + dcl8_out_channels)
        self.ppl3 = PyramidPoolingLayer(dcl4_out_channels + dcl8_out_channels)
        self.conv3_1 = nn.Conv2d((dcl4_out_channels + dcl8_out_channels) * ppl_out_mult, dcl_out_channels, kernel_size=1)

        # Final Fusion
        self.final_conv = nn.Conv2d(dcl_out_channels * 3, 256, kernel_size=1)

    def forward(self, combined_input):
        original_img = combined_input[:, 0:1, :, :]
        noisy_img = combined_input[:, 1:2, :, :]

        fe1 = self.dcl1(original_img)
        attn1 = self.sa1(fe1) + self.ca1(fe1)
        ppl1_out = self.ppl1(attn1)
        f1 = self.conv1_1(ppl1_out)

        fe2 = self.dcl2(original_img)
        attn2 = self.sa2(fe2) + self.ca2(fe2)
        ppl2_out = self.ppl2(attn2)
        f2 = self.conv2_1(ppl2_out)

        fe4 = self.dcl4(noisy_img)
        fe8 = self.dcl8(noisy_img)
        fe_concat = torch.cat([fe4, fe8], dim=1)
        attn3 = self.sa3(fe_concat) + self.ca3(fe_concat)
        ppl3_out = self.ppl3(attn3)
        f3 = self.conv3_1(ppl3_out)

        self.intermediate_features = torch.cat([f1, f2, f3], dim=1)
        final_features = self.final_conv(self.intermediate_features)
        return final_features

# --- UNet for Denoising ---
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class DenoisingUNet(nn.Module):
    def __init__(self, n_channels, n_classes, feature_channels):
        super().__init__()
        self.n_channels = n_channels + feature_channels # Input noisy + extracted features
        self.n_classes = n_classes

        self.inc = DoubleConv(self.n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024 // 2)
        self.up1 = Up(1024 // 2 + 512, 512)
        self.up2 = Up(512 + 256, 256)
        self.up3 = Up(256 + 128, 128)
        self.up4 = Up(128 + 64, 64)
        self.outc = OutConv(64, n_classes)

    def forward(self, noisy_input, features):
        x = torch.cat([noisy_input, features], dim=1)
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

# --- Dataset Class ---
class UltrasoundDataset(Dataset):
    def __init__(self, original_dir, noisy_dir, transform=None):
        self.original_dir = original_dir
        self.noisy_dir = noisy_dir
        self.transform = transform
        self.original_files = sorted([f for f in os.listdir(original_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))])
        self.noisy_files = sorted([f for f in os.listdir(noisy_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))])

        if len(self.original_files) != len(self.noisy_files):
            raise ValueError(f"Number of original and noisy images does not match in: {original_dir} vs {noisy_dir}")

        self.pairs = []
        for i in range(len(self.original_files)):
            original_path = os.path.join(original_dir, self.original_files[i])
            noisy_path = os.path.join(noisy_dir, self.noisy_files[i])
            self.pairs.append((original_path, noisy_path))

        print(f"Found {len(self.pairs)} images in: {original_dir}")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        original_path, noisy_path = self.pairs[idx]
        original_img = Image.open(original_path).convert('L')
        noisy_img = Image.open(noisy_path).convert('L')

        if self.transform:
            original_img = self.transform(original_img)
            noisy_img = self.transform(noisy_img)

        return noisy_img, original_img

def calculate_epi(denoised, original):
    # Placeholder for EPI calculation. Implement the actual formula if you have it.
    return 0

def train_model(model, feature_extractor, train_loader, val_loader, optimizer, criterion, num_epochs, device):
    feature_extractor.eval() # Keep feature extractor frozen during UNet training
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for noisy_batch, clean_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Train)"):
            noisy_batch = noisy_batch.to(device)
            clean_batch = clean_batch.to(device)

            with torch.no_grad():
                combined_batch = torch.cat([clean_batch, noisy_batch], dim=1)
                features_batch = feature_extractor(combined_batch)

            optimizer.zero_grad()
            denoised_batch = model(noisy_batch, features_batch)
            loss = criterion(denoised_batch, clean_batch)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * noisy_batch.size(0)
        train_loss /= len(train_loader.dataset)
        print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}")

        # Validation
        model.eval()
        val_loss = 0.0
        val_psnr = []
        val_ssim = []
        with torch.no_grad():
            for noisy_batch, clean_batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Val)"):
                noisy_batch = noisy_batch.to(device)
                clean_batch = clean_batch.to(device)
                combined_batch = torch.cat([clean_batch, noisy_batch], dim=1)
                features_batch = feature_extractor(combined_batch)
                denoised_batch = model(noisy_batch, features_batch)
                loss = criterion(denoised_batch, clean_batch)
                val_loss += loss.item() * noisy_batch.size(0)

                denoised_np = denoised_batch.squeeze().cpu().numpy()
                clean_np = clean_batch.squeeze().cpu().numpy()
                for i in range(denoised_np.shape[0]): # Iterate over batch
                    psnr = peak_signal_noise_ratio(clean_np[i], denoised_np[i], data_range=1.0)
                    ssim = structural_similarity(clean_np[i], denoised_np[i], data_range=1.0)
                    val_psnr.append(psnr)
                    val_ssim.append(ssim)

        val_loss /= len(val_loader.dataset)
        avg_val_psnr = np.mean(val_psnr)
        avg_val_ssim = np.mean(val_ssim)
        print(f"Epoch {epoch+1}, Val Loss: {val_loss:.4f}, Val PSNR: {avg_val_psnr:.4f} dB, Val SSIM: {avg_val_ssim:.4f}")

    return model

def test_model(model, feature_extractor,test_loader, device):
    model.eval()
    test_psnr = []
    test_ssim = []
    test_epi = []
    with torch.no_grad():
        for noisy_batch, clean_batch in tqdm(test_loader, desc="Testing"):
            noisy_batch = noisy_batch.to(device)
            clean_batch = clean_batch.to(device)
            combined_batch = torch.cat([clean_batch, noisy_batch], dim=1)
            features_batch = feature_extractor(combined_batch)
            denoised_batch = model(noisy_batch, features_batch)

            denoised_np = denoised_batch.squeeze().cpu().numpy()
            clean_np = clean_batch.squeeze().cpu().numpy()

            for i in range(denoised_np.shape[0]):
                psnr = peak_signal_noise_ratio(clean_np[i], denoised_np[i], data_range=1.0)
                ssim = structural_similarity(clean_np[i], denoised_np[i], data_range=1.0)
                epi = calculate_epi(denoised_np[i], clean_np[i]) # Calculate EPI
                test_psnr.append(psnr)
                test_ssim.append(ssim)
                test_epi.append(epi)

    avg_test_psnr = np.mean(test_psnr)
    avg_test_ssim = np.mean(test_ssim)
    avg_test_epi = np.mean(test_epi)
    print(f"Test PSNR: {avg_test_psnr:.4f} dB, Test SSIM: {avg_test_ssim:.4f}, Test EPI: {avg_test_epi:.4f}")
    return avg_test_psnr, avg_test_ssim, avg_test_epi

if __name__ == '__main__':
    # Mount Google Drive
    drive.mount('/content/drive')

    # Define data directories
    train_original_dir = '/content/drive/MyDrive/US_Speckle_dir/train_original_US' # Replace with your paths
    train_noisy_dir = '/content/drive/MyDrive/US_Speckle_dir/train_noisy_US'
    val_original_dir = '/content/drive/MyDrive/US_Speckle_dir/val_original_US'
    val_noisy_dir = '/content/drive/MyDrive/US_Speckle_dir/val_noisy_US'
    test_original_dir = '/content/drive/MyDrive/US_Speckle_dir/test_original_US'
    test_noisy_dir = '/content/drive/MyDrive/US_Speckle_dir/test_noisy_US'

    # Hyperparameters
    batch_size = 4
    learning_rate = 0.001
    num_epochs = 10
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Transformations
    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    # Create Datasets
    train_dataset = UltrasoundDataset(train_original_dir, train_noisy_dir, transform=transform)
    val_dataset = UltrasoundDataset(val_original_dir, val_noisy_dir, transform=transform)
    test_dataset = UltrasoundDataset(test_original_dir, test_noisy_dir, transform=transform)

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=1) # Typically batch size 1 for testing/evaluation

    # Initialize Feature Extractor and UNet
    feature_extractor = MultiScaleFeatureExtractor(2).to(device)
    feature_channels = 256 # Output channels of the feature extractor's final conv
    unet_model = DenoisingUNet(n_channels=1, n_classes=1, feature_channels=feature_channels).to(device)

    # Loss and Optimizer
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(unet_model.parameters(), lr=learning_rate)

    # Train the model
    trained_unet = train_model(unet_model, feature_extractor, train_loader, val_loader, optimizer, criterion, num_epochs, device)

    # Test the model
    avg_test_psnr, avg_test_ssim, avg_test_epi = test_model(trained_unet, feature_extractor, test_loader, device)
    print(f"Average Test PSNR: {avg_test_psnr:.4f} dB")
    print(f"Average Test SSIM: {avg_test_ssim:.4f}")
    print(f"Average Test EPI: {avg_test_epi:.4f}")

    # You can add code here to visualize test results if needed

ValueError: mount failed