In [None]:
# Only run this cell if you haven't configured Kaggle API in this Colab session before
import os
from google.colab import files

def configure_kaggle_api_colab_local():
    """Configures the Kaggle API in Google Colab by uploading kaggle.json to the current directory."""
    print("📁 Please upload your kaggle.json file:")
    uploaded = files.upload()
    if 'kaggle.json' in uploaded:
        kaggle_json_path = "kaggle.json"
        try:
            with open(kaggle_json_path, "wb") as f:
                f.write(uploaded['kaggle.json'])
            os.chmod(kaggle_json_path, 0o600)
            print(f"✅ kaggle.json saved to {kaggle_json_path}")
            print("🚀 Kaggle API is now configured and ready to use!")
        except OSError as e:
            print(f"⚠️ Warning: Could not set file permissions for {kaggle_json_path}.")
            print(f"Error: {e}")
        except Exception as e:
            print(f"⚠️ An unexpected error occurred while saving kaggle.json.")
            print(f"Error: {e}")
    else:
        print("❌ kaggle.json not uploaded.")

configure_kaggle_api_colab_local()

# Set the environment variable to point to the local kaggle.json
os.environ['KAGGLE_CONFIG_DIR'] = '.'

📁 Please upload your kaggle.json file:


Saving kaggle.json to kaggle.json
✅ kaggle.json saved to kaggle.json
🚀 Kaggle API is now configured and ready to use!


In [None]:
dataset_id = "piyushsamant11/pidata-new-names"
zip_file_name = "pidata-new-names.zip"
extract_to = "dataset"

!kaggle datasets download -d {dataset_id} -p ./
import zipfile
import os
os.makedirs(extract_to, exist_ok=True)
with zipfile.ZipFile(zip_file_name, 'r') as zip_ref:
    zip_ref.extractall(extract_to)
print(f"✅ Dataset extracted to: {extract_to}")
!rm {zip_file_name} # Clean up the zip file

Dataset URL: https://www.kaggle.com/datasets/piyushsamant11/pidata-new-names
License(s): unknown
Downloading pidata-new-names.zip to .
 90% 576M/639M [00:03<00:01, 65.8MB/s]
100% 639M/639M [00:03<00:00, 192MB/s] 
✅ Dataset extracted to: dataset


In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.model_selection import train_test_split
import torchvision.transforms as T
import random
from skimage.measure import label, regionprops
import matplotlib.pyplot as plt

def load_and_preprocess_ct_scan(file_path):
    try:
        image = Image.open(file_path).convert('L')
        image_array = np.array(image).astype(np.float32) / 255.0
        return image_array
    except Exception as e:
        print(f"Error loading or preprocessing {file_path}: {e}")
        return None

def load_mask_image(file_path):
    try:
        mask = Image.open(file_path).convert('L')
        mask_array = np.array(mask) / 255.0
        return mask_array.astype(bool)
    except Exception as e:
        print(f"Error loading mask {file_path}: {e}")
        return None

def extract_patch(image, mask, bbox, patch_size=128):
    x_min, y_min, x_max, y_max = map(int, bbox)
    center_x = (x_min + x_max) // 2
    center_y = (y_min + y_max) // 2
    half_size = patch_size // 2
    x1 = center_x - half_size
    y1 = center_y - half_size
    x2 = center_x + half_size
    y2 = center_y + half_size

    pad_left = max(0, -x1)
    pad_top = max(0, -y1)
    pad_right = max(0, x2 - image.shape[1])
    pad_bottom = max(0, y2 - image.shape[0])

    x1 = max(0, x1)
    y1 = max(0, y1)
    x2 = min(image.shape[1], x2)
    y2 = min(image.shape[0], y2)

    image_patch = image[y1:y2, x1:x2]
    mask_patch = mask[y1:y2, x1:x2]

    padded_image_patch = np.pad(image_patch, ((pad_top, pad_bottom), (pad_left, pad_right)), mode='constant')
    padded_mask_patch = np.pad(mask_patch, ((pad_top, pad_bottom), (pad_left, pad_right)), mode='constant')

    return padded_image_patch, padded_mask_patch

class ContrastiveLungNoduleDataset(Dataset):
    def __init__(self, image_files, mask_dir, patch_size=128, num_positive_pairs=1, num_negative_samples=1, transform=None):
        self.image_files = image_files
        self.mask_dir = mask_dir
        self.patch_size = patch_size
        self.num_positive_pairs = num_positive_pairs
        self.num_negative_samples = num_negative_samples
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        image = load_and_preprocess_ct_scan(img_path)
        mask = load_mask_image(os.path.join(self.mask_dir, os.path.splitext(os.path.basename(img_path))[0] + ".png"))

        if image is None or mask is None:
            return None

        labeled_mask = label(mask)
        regions = regionprops(labeled_mask)
        bboxes = [region.bbox for region in regions]

        if not bboxes:
            dummy_patch = np.zeros((self.patch_size, self.patch_size), dtype=np.float32)
            dummy_patch_tensor = torch.from_numpy(dummy_patch).unsqueeze(0)
            return dummy_patch_tensor, dummy_patch_tensor, dummy_patch_tensor

        anchor_bbox = random.choice(bboxes)
        anchor_image_patch, _ = extract_patch(image, mask, anchor_bbox, self.patch_size)
        anchor_tensor = torch.from_numpy(anchor_image_patch).unsqueeze(0)
        if self.transform:
            anchor_tensor = self.transform(anchor_tensor)

        positive_patches = []
        for _ in range(self.num_positive_pairs):
            positive_image_patch = self._augment(anchor_image_patch)
            positive_tensor = positive_image_patch.unsqueeze(0)
            if self.transform:
                positive_tensor = self.transform(positive_tensor)
            positive_patches.append(positive_tensor)

        positive_tensor = torch.stack(positive_patches).squeeze(0).squeeze(0)

        negative_patches = []
        for _ in range(self.num_negative_samples):
            negative_patch = self._create_negative_patch(image, mask, bboxes, anchor_bbox=anchor_bbox)
            negative_tensor = torch.from_numpy(negative_patch).unsqueeze(0)
            if self.transform:
                negative_tensor = self.transform(negative_tensor)
            negative_patches.append(negative_tensor)

        negative_tensor = torch.stack(negative_patches).squeeze(1)

        return anchor_tensor, positive_tensor, negative_tensor

    def _create_negative_patch(self, image, mask, bboxes, anchor_bbox=None, healthy_ratio=0.5):
        if random.random() < healthy_ratio or not bboxes or len(bboxes) < 2:
            attempts = 0
            while attempts < 10:
                rand_y = random.randint(0, image.shape[0] - self.patch_size)
                rand_x = random.randint(0, image.shape[1] - self.patch_size)
                patch = mask[rand_y:rand_y + self.patch_size, rand_x:rand_x + self.patch_size]
                if not np.any(patch):
                    return image[rand_y:rand_y + self.patch_size, rand_x:rand_x + self.patch_size]
                attempts += 1
            rand_y = random.randint(0, image.shape[0] - self.patch_size)
            rand_x = random.randint(0, image.shape[1] - self.patch_size)
            return image[rand_y:rand_y + self.patch_size, rand_x:rand_x + self.patch_size]
        else:
            other_bboxes = [bbox for bbox in bboxes if bbox != anchor_bbox]
            if other_bboxes:
                other_bbox = random.choice(other_bboxes)
                negative_image_patch, _ = extract_patch(image, mask, other_bbox, self.patch_size)
                return negative_image_patch
            else:
                return self._create_negative_patch(image, mask, bboxes, healthy_ratio=1.0)

    def _augment(self, image_patch):
        transform = T.Compose([
            T.RandomRotation(degrees=10),
            T.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            T.RandomAdjustSharpness(sharpness_factor=1.5, p=0.5),
            T.RandomEqualize(p=0.5),
            T.ToTensor(),
            T.Normalize(mean=[0.5], std=[0.5])
        ])
        image_pil = Image.fromarray((image_patch * 255).astype(np.uint8))
        return transform(image_pil)

DATA_DIR = "dataset/Dataset"
IMAGE_DIR = os.path.join(DATA_DIR, "Images")
MASK_DIR = os.path.join(DATA_DIR, "Annotations")

all_image_files = [os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR) if not f.startswith('.')]
train_files, val_files = train_test_split(all_image_files, test_size=0.2, random_state=42)

train_transform = T.Compose([
    T.RandomRotation(degrees=15),
    T.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    T.RandomAdjustSharpness(sharpness_factor=1.2, p=0.5),
    T.RandomEqualize(p=0.3),
    T.ToTensor(),
    T.Normalize(mean=[0.5], std=[0.5])
])

val_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.5], std=[0.5])
])

if os.path.exists(IMAGE_DIR) and os.path.exists(MASK_DIR):
    contrastive_train_dataset = ContrastiveLungNoduleDataset(train_files, MASK_DIR, patch_size=128, num_positive_pairs=1, num_negative_samples=1, transform=train_transform)
    contrastive_val_dataset = ContrastiveLungNoduleDataset(val_files, MASK_DIR, patch_size=128, num_positive_pairs=1, num_negative_samples=1, transform=val_transform)

    def collate_fn(batch):
        batch = list(filter(lambda x: x is not None, batch))
        if not batch:
            return None
        anchors = torch.stack([item[0] for item in batch])
        positives = torch.stack([item[1] for item in batch])
        negatives = torch.stack([item[2] for item in batch])
        return anchors, positives, negatives

    train_dataloader = DataLoader(contrastive_train_dataset, batch_size=32, shuffle=True, num_workers=2, collate_fn=collate_fn, drop_last=True)
    val_dataloader = DataLoader(contrastive_val_dataset, batch_size=32, shuffle=False, num_workers=2, collate_fn=collate_fn, drop_last=False)

    print("\nContrastive Train dataloader created.")
    print("Contrastive Validation dataloader created.")

else:
    print("Make sure the IMAGE_DIR and MASK_DIR paths are correct and exist.")


Contrastive Train dataloader created.
Contrastive Validation dataloader created.


In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import VGG16_Weights

class VGG16Encoder(nn.Module):
    def __init__(self, embedding_dim, pretrained=True):
        super().__init__()
        weights = VGG16_Weights.IMAGENET1K_V1 if pretrained else None
        vgg16 = models.vgg16(weights=weights)
        self.features = nn.Sequential(*list(vgg16.features.children())[:-1])

        first_conv_layer = self.features[0]
        self.features[0] = nn.Conv2d(1, first_conv_layer.out_channels,
                                            kernel_size=first_conv_layer.kernel_size,
                                            stride=first_conv_layer.stride,
                                            padding=first_conv_layer.padding)

        with torch.no_grad():
            dummy_input = torch.randn(1, 1, 128, 128)
            output_features = self.features(dummy_input)
            self.flattened_size = output_features.view(output_features.size(0), -1).shape[1]

        self.embedding_dim = embedding_dim
        self.fc = nn.Linear(self.flattened_size, self.embedding_dim)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Example instantiation (for checking)
patch_size = 128
embedding_dim = 128
encoder = VGG16Encoder(embedding_dim, pretrained=True).to('cuda' if torch.cuda.is_available() else 'cpu')

dummy_patch = torch.randn(2, 1, patch_size, patch_size).to('cuda' if torch.cuda.is_available() else 'cpu')
embeddings = encoder(dummy_patch)
print("Shape of embeddings (after correction):", embeddings.shape)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:06<00:00, 85.7MB/s]


Shape of embeddings (after correction): torch.Size([2, 128])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        dist_positive = F.pairwise_distance(anchor, positive)
        dist_negative = F.pairwise_distance(anchor, negative)
        losses = torch.relu(dist_positive - dist_negative + self.margin)
        return torch.mean(losses)

# Example usage (for checking)
if __name__ == '__main__':
    embedding_dim = 128
    encoder = VGG16Encoder(embedding_dim, pretrained=True).to('cuda' if torch.cuda.is_available() else 'cpu')
    triplet_loss_fn = TripletLoss(margin=1.0)
    optimizer = torch.optim.Adam(encoder.parameters(), lr=0.001)

    batch_size = 32
    anchors = torch.randn(batch_size, 1, 128, 128).to('cuda' if torch.cuda.is_available() else 'cpu')
    positives = torch.randn(batch_size, 1, 128, 128).to('cuda' if torch.cuda.is_available() else 'cpu')
    negatives = torch.randn(batch_size, 1, 128, 128).to('cuda' if torch.cuda.is_available() else 'cpu')

    optimizer.zero_grad()
    anchor_embeddings = encoder(anchors)
    positive_embeddings = encoder(positives)
    negative_embeddings = encoder(negatives)

    loss = triplet_loss_fn(anchor_embeddings, positive_embeddings, negative_embeddings)
    loss.backward()
    optimizer.step()

    print("Triplet Loss:", loss.item())

Triplet Loss: 1.0108985900878906


In [None]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import os
from tqdm import tqdm
import torch.nn as nn

# --- Data Loading (re-define to ensure it's run in this context) ---
DATA_DIR = "dataset/Dataset"
IMAGE_DIR = os.path.join(DATA_DIR, "Images")
MASK_DIR = os.path.join(DATA_DIR, "Annotations")

all_image_files = [os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR) if not f.startswith('.')]
train_files, val_files = train_test_split(all_image_files, test_size=0.2, random_state=42)

train_dataset = ContrastiveLungNoduleDataset(train_files, MASK_DIR, patch_size=128, num_positive_pairs=1, num_negative_samples=1)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0, collate_fn=collate_fn, drop_last=True)

val_dataset = ContrastiveLungNoduleDataset(val_files, MASK_DIR, patch_size=128, num_positive_pairs=1, num_negative_samples=1)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=collate_fn, drop_last=False)

# --- Model and Loss ---
embedding_dim = 256
encoder = VGG16Encoder(embedding_dim, pretrained=True).to('cuda' if torch.cuda.is_available() else 'cpu')
triplet_loss_fn = nn.TripletMarginLoss(margin=0.5).to('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = optim.Adam(encoder.parameters(), lr=0.0001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

num_epochs = 20

for epoch in range(num_epochs):
    encoder.train()
    total_train_loss = 0.0
    train_loader_tqdm = tqdm(train_dataloader, desc=f"Epoch [{epoch+1}/{num_epochs}]", unit="batch")
    for batch_idx, batch in enumerate(train_loader_tqdm):
        if batch is None:
            continue
        anchors, positives, negatives = batch
        anchors = anchors.to('cuda' if torch.cuda.is_available() else 'cpu')
        positives = positives.to('cuda' if torch.cuda.is_available() else 'cpu')
        negatives = negatives.to('cuda' if torch.cuda.is_available() else 'cpu')

        optimizer.zero_grad()
        anchor_embeddings = encoder(anchors)
        positive_embeddings = encoder(positives)
        negative_embeddings = encoder(negatives)

        loss = triplet_loss_fn(anchor_embeddings, positive_embeddings, negative_embeddings)
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item() * anchors.size(0)
        train_loader_tqdm.set_postfix(loss=f"{loss.item():.4f}")

    avg_train_loss = total_train_loss / len(train_dataloader.dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Average Training Loss: {avg_train_loss:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}")
    scheduler.step()

    encoder.eval()
    total_val_loss = 0.0
    with torch.no_grad():
        val_loader_tqdm = tqdm(val_dataloader, desc=f"Validation Epoch [{epoch+1}/{num_epochs}]", unit="batch")
        for batch_idx, batch in enumerate(val_loader_tqdm):
            if batch is None:
                continue
            anchors, positives, negatives = batch
            anchors = anchors.to('cuda' if torch.cuda.is_available() else 'cpu')
            positives = positives.to('cuda' if torch.cuda.is_available() else 'cpu')
            negatives = negatives.to('cuda' if torch.cuda.is_available() else 'cpu')

            anchor_embeddings = encoder(anchors)
            positive_embeddings = encoder(positives)
            negative_embeddings = encoder(negatives)

            loss = triplet_loss_fn(anchor_embeddings, positive_embeddings, negative_embeddings)
            total_val_loss += loss.item() * anchors.size(0)
            val_loader_tqdm.set_postfix(loss=f"{loss.item():.4f}")

        avg_val_loss = total_val_loss / len(val_dataloader.dataset)
        print(f"Epoch [{epoch+1}/{num_epochs}], Average Validation Loss: {avg_val_loss:.4f}")

torch.save(encoder.state_dict(), 'contrastive_vgg16_encoder.pth')
print("Trained contrastive encoder saved!")

Epoch [1/20]: 100%|██████████| 63/63 [00:56<00:00,  1.12batch/s, loss=0.1122]


Epoch [1/20], Average Training Loss: 0.1957, LR: 0.000100


Validation Epoch [1/20]: 100%|██████████| 16/16 [00:10<00:00,  1.59batch/s, loss=0.0334]


Epoch [1/20], Average Validation Loss: 0.0807


Epoch [2/20]: 100%|██████████| 63/63 [00:57<00:00,  1.09batch/s, loss=0.1094]


Epoch [2/20], Average Training Loss: 0.0680, LR: 0.000100


Validation Epoch [2/20]: 100%|██████████| 16/16 [00:10<00:00,  1.46batch/s, loss=0.1164]


Epoch [2/20], Average Validation Loss: 0.0565


Epoch [3/20]: 100%|██████████| 63/63 [00:55<00:00,  1.14batch/s, loss=0.1108]


Epoch [3/20], Average Training Loss: 0.0596, LR: 0.000100


Validation Epoch [3/20]: 100%|██████████| 16/16 [00:10<00:00,  1.60batch/s, loss=0.0958]


Epoch [3/20], Average Validation Loss: 0.0466


Epoch [4/20]: 100%|██████████| 63/63 [01:00<00:00,  1.04batch/s, loss=0.0120]


Epoch [4/20], Average Training Loss: 0.0546, LR: 0.000100


Validation Epoch [4/20]: 100%|██████████| 16/16 [00:10<00:00,  1.46batch/s, loss=0.0889]


Epoch [4/20], Average Validation Loss: 0.0536


Epoch [5/20]: 100%|██████████| 63/63 [00:55<00:00,  1.13batch/s, loss=0.0755]


Epoch [5/20], Average Training Loss: 0.0507, LR: 0.000100


Validation Epoch [5/20]: 100%|██████████| 16/16 [00:09<00:00,  1.62batch/s, loss=0.0613]


Epoch [5/20], Average Validation Loss: 0.0466


Epoch [6/20]: 100%|██████████| 63/63 [00:55<00:00,  1.13batch/s, loss=0.0318]


Epoch [6/20], Average Training Loss: 0.0419, LR: 0.000010


Validation Epoch [6/20]: 100%|██████████| 16/16 [00:09<00:00,  1.60batch/s, loss=0.0034]


Epoch [6/20], Average Validation Loss: 0.0416


Epoch [7/20]: 100%|██████████| 63/63 [00:55<00:00,  1.13batch/s, loss=0.0521]


Epoch [7/20], Average Training Loss: 0.0354, LR: 0.000010


Validation Epoch [7/20]: 100%|██████████| 16/16 [00:09<00:00,  1.71batch/s, loss=0.0000]


Epoch [7/20], Average Validation Loss: 0.0361


Epoch [8/20]: 100%|██████████| 63/63 [00:55<00:00,  1.13batch/s, loss=0.0146]


Epoch [8/20], Average Training Loss: 0.0321, LR: 0.000010


Validation Epoch [8/20]: 100%|██████████| 16/16 [00:09<00:00,  1.63batch/s, loss=0.0647]


Epoch [8/20], Average Validation Loss: 0.0369


Epoch [9/20]: 100%|██████████| 63/63 [00:54<00:00,  1.15batch/s, loss=0.0251]


Epoch [9/20], Average Training Loss: 0.0347, LR: 0.000010


Validation Epoch [9/20]: 100%|██████████| 16/16 [00:10<00:00,  1.53batch/s, loss=0.0807]


Epoch [9/20], Average Validation Loss: 0.0330


Epoch [10/20]: 100%|██████████| 63/63 [00:55<00:00,  1.13batch/s, loss=0.0288]


Epoch [10/20], Average Training Loss: 0.0256, LR: 0.000010


Validation Epoch [10/20]: 100%|██████████| 16/16 [00:09<00:00,  1.63batch/s, loss=0.0824]


Epoch [10/20], Average Validation Loss: 0.0273


Epoch [11/20]: 100%|██████████| 63/63 [00:56<00:00,  1.11batch/s, loss=0.0360]


Epoch [11/20], Average Training Loss: 0.0298, LR: 0.000001


Validation Epoch [11/20]: 100%|██████████| 16/16 [00:09<00:00,  1.61batch/s, loss=0.0986]


Epoch [11/20], Average Validation Loss: 0.0389


Epoch [12/20]: 100%|██████████| 63/63 [00:55<00:00,  1.13batch/s, loss=0.0370]


Epoch [12/20], Average Training Loss: 0.0333, LR: 0.000001


Validation Epoch [12/20]: 100%|██████████| 16/16 [00:09<00:00,  1.73batch/s, loss=0.0466]


Epoch [12/20], Average Validation Loss: 0.0386


Epoch [13/20]: 100%|██████████| 63/63 [00:55<00:00,  1.13batch/s, loss=0.0320]


Epoch [13/20], Average Training Loss: 0.0340, LR: 0.000001


Validation Epoch [13/20]: 100%|██████████| 16/16 [00:09<00:00,  1.64batch/s, loss=0.0295]


Epoch [13/20], Average Validation Loss: 0.0363


Epoch [14/20]: 100%|██████████| 63/63 [00:54<00:00,  1.15batch/s, loss=0.0513]


Epoch [14/20], Average Training Loss: 0.0292, LR: 0.000001


Validation Epoch [14/20]: 100%|██████████| 16/16 [00:09<00:00,  1.60batch/s, loss=0.0746]


Epoch [14/20], Average Validation Loss: 0.0370


Epoch [15/20]: 100%|██████████| 63/63 [00:55<00:00,  1.13batch/s, loss=0.0604]


Epoch [15/20], Average Training Loss: 0.0312, LR: 0.000001


Validation Epoch [15/20]: 100%|██████████| 16/16 [00:10<00:00,  1.59batch/s, loss=0.0743]


Epoch [15/20], Average Validation Loss: 0.0326


Epoch [16/20]: 100%|██████████| 63/63 [00:55<00:00,  1.14batch/s, loss=0.0332]


Epoch [16/20], Average Training Loss: 0.0330, LR: 0.000000


Validation Epoch [16/20]: 100%|██████████| 16/16 [00:09<00:00,  1.62batch/s, loss=0.0797]


Epoch [16/20], Average Validation Loss: 0.0340


Epoch [17/20]: 100%|██████████| 63/63 [00:56<00:00,  1.12batch/s, loss=0.0412]


Epoch [17/20], Average Training Loss: 0.0339, LR: 0.000000


Validation Epoch [17/20]: 100%|██████████| 16/16 [00:10<00:00,  1.59batch/s, loss=0.1082]


Epoch [17/20], Average Validation Loss: 0.0388


Epoch [18/20]: 100%|██████████| 63/63 [00:55<00:00,  1.13batch/s, loss=0.0282]


Epoch [18/20], Average Training Loss: 0.0269, LR: 0.000000


Validation Epoch [18/20]: 100%|██████████| 16/16 [00:09<00:00,  1.62batch/s, loss=0.0602]


Epoch [18/20], Average Validation Loss: 0.0319


Epoch [19/20]: 100%|██████████| 63/63 [00:55<00:00,  1.14batch/s, loss=0.0431]


Epoch [19/20], Average Training Loss: 0.0285, LR: 0.000000


Validation Epoch [19/20]: 100%|██████████| 16/16 [00:09<00:00,  1.61batch/s, loss=0.1070]


Epoch [19/20], Average Validation Loss: 0.0337


Epoch [20/20]: 100%|██████████| 63/63 [00:56<00:00,  1.12batch/s, loss=0.0666]


Epoch [20/20], Average Training Loss: 0.0325, LR: 0.000000


Validation Epoch [20/20]: 100%|██████████| 16/16 [00:10<00:00,  1.50batch/s, loss=0.0961]


Epoch [20/20], Average Validation Loss: 0.0339
Trained contrastive encoder saved!


In [None]:
import torch
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import numpy as np

class SegmentationLungNoduleDataset(Dataset):
    def __init__(self, image_files, mask_dir, image_transform=None, mask_transform=None):
        self.image_files = image_files
        self.mask_dir = mask_dir
        self.image_transform = image_transform
        self.mask_transform = mask_transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        mask_path = os.path.join(self.mask_dir, os.path.splitext(os.path.basename(img_path))[0] + ".png")
        image = load_and_preprocess_ct_scan(img_path)
        mask = load_mask_image(mask_path).astype(np.float32)

        if image is None or mask is None:
            return None, None

        image_pil = Image.fromarray((image * 255).astype(np.uint8))
        mask_pil = Image.fromarray((mask * 255).astype(np.uint8))

        image_tensor = image_pil
        mask_tensor = mask_pil

        if self.image_transform:
            image_tensor = self.image_transform(image_tensor)
        if self.mask_transform:
            mask_tensor = self.mask_transform(mask_tensor)

        return image_tensor, mask_tensor

# Define separate transforms for images and masks
image_segmentation_transform = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),
    T.Normalize(mean=[0.5], std=[0.5])
])

mask_segmentation_transform = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor()
])

# Assuming train_files, val_files, and MASK_DIR are defined
segmentation_train_dataset = SegmentationLungNoduleDataset(train_files, MASK_DIR, image_transform=image_segmentation_transform, mask_transform=mask_segmentation_transform)
segmentation_val_dataset = SegmentationLungNoduleDataset(val_files, MASK_DIR, image_transform=image_segmentation_transform, mask_transform=mask_segmentation_transform)

def segmentation_collate_fn(batch):
    batch = list(filter(lambda x: x[0] is not None, batch))
    if not batch:
        return None, None
    images = torch.stack([item[0] for item in batch])
    masks = torch.stack([item[1] for item in batch])
    return images, masks

segmentation_train_loader = DataLoader(segmentation_train_dataset, batch_size=16, shuffle=True, num_workers=0, collate_fn=segmentation_collate_fn, drop_last=True)
segmentation_val_loader = DataLoader(segmentation_val_dataset, batch_size=16, shuffle=False, num_workers=0, collate_fn=segmentation_collate_fn, drop_last=False)

print("\nSegmentation dataloaders created with separate transforms.")


Segmentation dataloaders created with separate transforms.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision.models import VGG16_Weights

class UNetWithFrozenEncoder(nn.Module):
    def __init__(self, encoder, num_classes=1):
        super().__init__()
        self.encoder = encoder.features
        for param in self.encoder.parameters():
            param.requires_grad = False

        vgg_channels = [64, 128, 256, 512] # 4 pooling layers

        self.upconv1 = nn.ConvTranspose2d(vgg_channels[-1], vgg_channels[-2], kernel_size=2, stride=2)
        self.conv_decoder1 = nn.Conv2d(vgg_channels[-2] * 2, vgg_channels[-2], kernel_size=3, padding=1)

        self.upconv2 = nn.ConvTranspose2d(vgg_channels[-2], vgg_channels[-3], kernel_size=2, stride=2)
        self.conv_decoder2 = nn.Conv2d(vgg_channels[-3] * 2, vgg_channels[-3], kernel_size=3, padding=1)

        self.upconv3 = nn.ConvTranspose2d(vgg_channels[-3], vgg_channels[-4], kernel_size=2, stride=2)
        self.conv_decoder3 = nn.Conv2d(vgg_channels[-4] * 2, vgg_channels[-4], kernel_size=3, padding=1)

        # Additional upsampling to reach 256x256
        self.upconv_final = nn.ConvTranspose2d(vgg_channels[-4], vgg_channels[-4] // 2, kernel_size=2, stride=2)
        self.conv_decoder_final = nn.Conv2d(vgg_channels[-4] // 2, vgg_channels[-4] // 2, kernel_size=3, padding=1)

        self.final_conv = nn.Conv2d(vgg_channels[-4] // 2, num_classes, kernel_size=1)

    def forward(self, x):
        # Encoder forward pass
        pool_outputs = []
        x = x
        for i, layer in enumerate(self.encoder):
            x = layer(x)
            if isinstance(layer, nn.MaxPool2d):
                pool_outputs.append(x)

        # Decoder forward pass with skip connections
        d1 = self.upconv1(pool_outputs[-1])
        d1 = torch.cat([d1, pool_outputs[-2]], dim=1)
        d1 = F.relu(self.conv_decoder1(d1))

        d2 = self.upconv2(d1)
        d2 = torch.cat([d2, pool_outputs[-3]], dim=1)
        d2 = F.relu(self.conv_decoder2(d2))

        d3 = self.upconv3(d2)
        d3 = torch.cat([d3, pool_outputs[-4]], dim=1)
        d3 = F.relu(self.conv_decoder3(d3))

        # Final upsampling
        d_final_up = self.upconv_final(d3)
        d_final = F.relu(self.conv_decoder_final(d_final_up))

        output = torch.sigmoid(self.final_conv(d_final))
        return output

# Load the trained encoder weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loaded_state_dict = torch.load('contrastive_vgg16_encoder.pth', map_location=device)

loaded_encoder = VGG16Encoder(embedding_dim=512, pretrained=False)
encoder_features_state_dict = loaded_encoder.features.state_dict()
pretrained_features_state_dict = {}
for name, param in loaded_state_dict.items():
    if name.startswith('features.'):
        pretrained_features_state_dict[name[len('features.'):]] = param

encoder_features_state_dict.update(pretrained_features_state_dict)
loaded_encoder.features.load_state_dict(encoder_features_state_dict)

segmentation_model = UNetWithFrozenEncoder(loaded_encoder, num_classes=1).to(device)

print("Segmentation model created successfully!")

Segmentation model created successfully!


In [None]:
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

criterion = nn.BCELoss()
optimizer_segmentation = torch.optim.Adam(segmentation_model.parameters(), lr=0.001)
num_epochs_segmentation = 25

for epoch in range(num_epochs_segmentation):
    segmentation_model.train()
    total_loss = 0.0
    train_loader_tqdm = tqdm(segmentation_train_loader, desc=f"Segmentation Epoch [{epoch+1}/{num_epochs_segmentation}]", unit="batch")
    for images, masks in train_loader_tqdm:
        if images is None or masks is None:
            continue
        images = images.to(device)
        masks = masks.to(device)

        optimizer_segmentation.zero_grad()
        outputs = segmentation_model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer_segmentation.step()

        total_loss += loss.item() * images.size(0)
        train_loader_tqdm.set_postfix(loss=f"{loss.item():.4f}")

    avg_loss = total_loss / len(segmentation_train_loader.dataset)
    print(f"Segmentation Epoch [{epoch+1}/{num_epochs_segmentation}], Average Training Loss: {avg_loss:.4f}")

    segmentation_model.eval()
    dice_score = 0.0
    num_batches = 0
    with torch.no_grad():
        val_loader_tqdm = tqdm(segmentation_val_loader, desc=f"Validation Epoch [{epoch+1}/{num_epochs_segmentation}]", unit="batch")
        for images, masks in val_loader_tqdm:
            if images is None or masks is None:
                continue
            images = images.to('cuda' if torch.cuda.is_available() else 'cpu')
            masks = masks.to('cuda' if torch.cuda.is_available() else 'cpu')
            outputs = segmentation_model(images)
            predicted_masks = (outputs > 0.5).float()

            intersection = (predicted_masks * masks).sum(dim=[1, 2, 3])
            union = predicted_masks.sum(dim=[1, 2, 3]) + masks.sum(dim=[1, 2, 3]) + 1e-7
            batch_dice = (2 * intersection / union).mean()
            dice_score += batch_dice.item()
            num_batches += 1

    avg_dice_score = dice_score / num_batches if num_batches > 0 else 0
    print(f"Segmentation Epoch [{epoch+1}/{num_epochs_segmentation}], Average Validation Dice Score: {avg_dice_score:.4f}")

print("Downstream segmentation training and evaluation complete.")

torch.save(segmentation_model.state_dict(), 'segmentation_unet.pth')
print("Trained segmentation model weights saved to segmentation_unet.pth")

Segmentation Epoch [1/25]: 100%|██████████| 126/126 [00:53<00:00,  2.34batch/s, loss=0.0288]


Segmentation Epoch [1/25], Average Training Loss: 0.0893


Validation Epoch [1/25]: 100%|██████████| 32/32 [00:11<00:00,  2.82batch/s]


Segmentation Epoch [1/25], Average Validation Dice Score: 0.2891


Segmentation Epoch [2/25]: 100%|██████████| 126/126 [00:54<00:00,  2.33batch/s, loss=0.0215]


Segmentation Epoch [2/25], Average Training Loss: 0.0280


Validation Epoch [2/25]: 100%|██████████| 32/32 [00:11<00:00,  2.84batch/s]


Segmentation Epoch [2/25], Average Validation Dice Score: 0.3021


Segmentation Epoch [3/25]: 100%|██████████| 126/126 [00:54<00:00,  2.32batch/s, loss=0.0143]


Segmentation Epoch [3/25], Average Training Loss: 0.0234


Validation Epoch [3/25]: 100%|██████████| 32/32 [00:11<00:00,  2.85batch/s]


Segmentation Epoch [3/25], Average Validation Dice Score: 0.4696


Segmentation Epoch [4/25]: 100%|██████████| 126/126 [00:54<00:00,  2.30batch/s, loss=0.0143]


Segmentation Epoch [4/25], Average Training Loss: 0.0203


Validation Epoch [4/25]: 100%|██████████| 32/32 [00:11<00:00,  2.71batch/s]


Segmentation Epoch [4/25], Average Validation Dice Score: 0.5288


Segmentation Epoch [5/25]: 100%|██████████| 126/126 [00:53<00:00,  2.38batch/s, loss=0.0199]


Segmentation Epoch [5/25], Average Training Loss: 0.0192


Validation Epoch [5/25]: 100%|██████████| 32/32 [00:11<00:00,  2.88batch/s]


Segmentation Epoch [5/25], Average Validation Dice Score: 0.5648


Segmentation Epoch [6/25]: 100%|██████████| 126/126 [00:53<00:00,  2.36batch/s, loss=0.0176]


Segmentation Epoch [6/25], Average Training Loss: 0.0172


Validation Epoch [6/25]: 100%|██████████| 32/32 [00:11<00:00,  2.83batch/s]


Segmentation Epoch [6/25], Average Validation Dice Score: 0.5035


Segmentation Epoch [7/25]: 100%|██████████| 126/126 [00:53<00:00,  2.38batch/s, loss=0.0137]


Segmentation Epoch [7/25], Average Training Loss: 0.0158


Validation Epoch [7/25]: 100%|██████████| 32/32 [00:11<00:00,  2.74batch/s]


Segmentation Epoch [7/25], Average Validation Dice Score: 0.5791


Segmentation Epoch [8/25]: 100%|██████████| 126/126 [00:53<00:00,  2.37batch/s, loss=0.0116]


Segmentation Epoch [8/25], Average Training Loss: 0.0167


Validation Epoch [8/25]: 100%|██████████| 32/32 [00:11<00:00,  2.81batch/s]


Segmentation Epoch [8/25], Average Validation Dice Score: 0.5884


Segmentation Epoch [9/25]: 100%|██████████| 126/126 [00:53<00:00,  2.36batch/s, loss=0.0191]


Segmentation Epoch [9/25], Average Training Loss: 0.0140


Validation Epoch [9/25]: 100%|██████████| 32/32 [00:11<00:00,  2.88batch/s]


Segmentation Epoch [9/25], Average Validation Dice Score: 0.5895


Segmentation Epoch [10/25]: 100%|██████████| 126/126 [00:53<00:00,  2.36batch/s, loss=0.0136]


Segmentation Epoch [10/25], Average Training Loss: 0.0133


Validation Epoch [10/25]: 100%|██████████| 32/32 [00:11<00:00,  2.83batch/s]


Segmentation Epoch [10/25], Average Validation Dice Score: 0.6002


Segmentation Epoch [11/25]: 100%|██████████| 126/126 [00:53<00:00,  2.35batch/s, loss=0.0124]


Segmentation Epoch [11/25], Average Training Loss: 0.0127


Validation Epoch [11/25]: 100%|██████████| 32/32 [00:11<00:00,  2.83batch/s]


Segmentation Epoch [11/25], Average Validation Dice Score: 0.6119


Segmentation Epoch [12/25]: 100%|██████████| 126/126 [00:53<00:00,  2.36batch/s, loss=0.0103]


Segmentation Epoch [12/25], Average Training Loss: 0.0119


Validation Epoch [12/25]: 100%|██████████| 32/32 [00:11<00:00,  2.80batch/s]


Segmentation Epoch [12/25], Average Validation Dice Score: 0.6342


Segmentation Epoch [13/25]: 100%|██████████| 126/126 [00:54<00:00,  2.33batch/s, loss=0.0089]


Segmentation Epoch [13/25], Average Training Loss: 0.0117


Validation Epoch [13/25]: 100%|██████████| 32/32 [00:11<00:00,  2.81batch/s]


Segmentation Epoch [13/25], Average Validation Dice Score: 0.6483


Segmentation Epoch [14/25]: 100%|██████████| 126/126 [00:53<00:00,  2.35batch/s, loss=0.0126]


Segmentation Epoch [14/25], Average Training Loss: 0.0111


Validation Epoch [14/25]: 100%|██████████| 32/32 [00:11<00:00,  2.77batch/s]


Segmentation Epoch [14/25], Average Validation Dice Score: 0.6261


Segmentation Epoch [15/25]: 100%|██████████| 126/126 [00:53<00:00,  2.34batch/s, loss=0.0133]


Segmentation Epoch [15/25], Average Training Loss: 0.0112


Validation Epoch [15/25]: 100%|██████████| 32/32 [00:11<00:00,  2.81batch/s]


Segmentation Epoch [15/25], Average Validation Dice Score: 0.6466


Segmentation Epoch [16/25]: 100%|██████████| 126/126 [00:54<00:00,  2.33batch/s, loss=0.0137]


Segmentation Epoch [16/25], Average Training Loss: 0.0108


Validation Epoch [16/25]: 100%|██████████| 32/32 [00:11<00:00,  2.79batch/s]


Segmentation Epoch [16/25], Average Validation Dice Score: 0.6190


Segmentation Epoch [17/25]: 100%|██████████| 126/126 [00:53<00:00,  2.36batch/s, loss=0.0093]


Segmentation Epoch [17/25], Average Training Loss: 0.0100


Validation Epoch [17/25]: 100%|██████████| 32/32 [00:11<00:00,  2.81batch/s]


Segmentation Epoch [17/25], Average Validation Dice Score: 0.6508


Segmentation Epoch [18/25]: 100%|██████████| 126/126 [00:53<00:00,  2.35batch/s, loss=0.0120]


Segmentation Epoch [18/25], Average Training Loss: 0.0098


Validation Epoch [18/25]: 100%|██████████| 32/32 [00:11<00:00,  2.71batch/s]


Segmentation Epoch [18/25], Average Validation Dice Score: 0.6732


Segmentation Epoch [19/25]: 100%|██████████| 126/126 [00:53<00:00,  2.35batch/s, loss=0.0089]


Segmentation Epoch [19/25], Average Training Loss: 0.0093


Validation Epoch [19/25]: 100%|██████████| 32/32 [00:11<00:00,  2.82batch/s]


Segmentation Epoch [19/25], Average Validation Dice Score: 0.6676


Segmentation Epoch [20/25]: 100%|██████████| 126/126 [00:53<00:00,  2.35batch/s, loss=0.0077]


Segmentation Epoch [20/25], Average Training Loss: 0.0088


Validation Epoch [20/25]: 100%|██████████| 32/32 [00:11<00:00,  2.81batch/s]


Segmentation Epoch [20/25], Average Validation Dice Score: 0.6630


Segmentation Epoch [21/25]: 100%|██████████| 126/126 [00:53<00:00,  2.35batch/s, loss=0.0102]


Segmentation Epoch [21/25], Average Training Loss: 0.0086


Validation Epoch [21/25]: 100%|██████████| 32/32 [00:11<00:00,  2.80batch/s]


Segmentation Epoch [21/25], Average Validation Dice Score: 0.6772


Segmentation Epoch [22/25]: 100%|██████████| 126/126 [00:54<00:00,  2.31batch/s, loss=0.0124]


Segmentation Epoch [22/25], Average Training Loss: 0.0085


Validation Epoch [22/25]: 100%|██████████| 32/32 [00:11<00:00,  2.82batch/s]


Segmentation Epoch [22/25], Average Validation Dice Score: 0.6656


Segmentation Epoch [23/25]: 100%|██████████| 126/126 [00:54<00:00,  2.32batch/s, loss=0.0071]


Segmentation Epoch [23/25], Average Training Loss: 0.0086


Validation Epoch [23/25]: 100%|██████████| 32/32 [00:11<00:00,  2.79batch/s]


Segmentation Epoch [23/25], Average Validation Dice Score: 0.6732


Segmentation Epoch [24/25]: 100%|██████████| 126/126 [00:54<00:00,  2.32batch/s, loss=0.0089]


Segmentation Epoch [24/25], Average Training Loss: 0.0083


Validation Epoch [24/25]: 100%|██████████| 32/32 [00:11<00:00,  2.77batch/s]


Segmentation Epoch [24/25], Average Validation Dice Score: 0.6237


Segmentation Epoch [25/25]: 100%|██████████| 126/126 [00:54<00:00,  2.30batch/s, loss=0.0071]


Segmentation Epoch [25/25], Average Training Loss: 0.0082


Validation Epoch [25/25]: 100%|██████████| 32/32 [00:11<00:00,  2.77batch/s]

Segmentation Epoch [25/25], Average Validation Dice Score: 0.6719
Downstream segmentation training and evaluation complete.
Trained segmentation model weights saved to segmentation_unet.pth





In [None]:
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as T
from skimage.measure import label, regionprops
import os

MASK_DIR = "dataset/Dataset/Annotations"

def load_mask_image(file_path):
    try:
        mask = Image.open(file_path).convert('L')
        mask_array = np.array(mask) / 255.0
        return mask_array.astype(float)
    except Exception as e:
        print(f"Error loading mask {file_path}: {e}")
        return None

def get_lesion_distances(ct_image_path, segmentation_model, image_transform, threshold=0.5, pixel_spacing=(1.0, 1.0)):
    original_image_pil = Image.open(ct_image_path).convert('L')
    original_width, original_height = original_image_pil.size
    original_image_array = np.array(original_image_pil)

    image_tensor = image_transform(original_image_pil).unsqueeze(0).to(next(segmentation_model.parameters()).device)

    with torch.no_grad():
        output = segmentation_model(image_tensor)
        predicted_mask_resized = (output > threshold).float().squeeze().cpu().numpy()

    predicted_mask_original_size = np.array(Image.fromarray(predicted_mask_resized).resize((original_width, original_height), Image.NEAREST))

    labeled_mask_original_size = label(predicted_mask_original_size)
    regions_original_size = regionprops(labeled_mask_original_size)

    lesion_info = []
    for i, region in enumerate(regions_original_size):
        minr, minc, maxr, maxc = region.bbox
        area_pixels = region.area
        centroid_row, centroid_col = region.centroid

        lesion_info.append({
            'id': i + 1,
            'bbox': (minr, minc, maxr, maxc),
            'area_pixels': area_pixels,
            'centroid': (centroid_row, centroid_col)
        })

    distances = []
    if len(lesion_info) >= 2:
        for i in range(len(lesion_info)):
            for j in range(i + 1, len(lesion_info)):
                r1, c1 = lesion_info[i]['centroid']
                r2, c2 = lesion_info[j]['centroid']
                distance_pixels = np.sqrt((r2 - r1)**2 + (c2 - c1)**2)
                distance_real_world = np.sqrt(((c2 - c1) * pixel_spacing[1])**2 + ((r2 - r1) * pixel_spacing[0])**2)
                distances.append((lesion_info[i]['id'], lesion_info[j]['id'], distance_real_world))

    return original_image_array, predicted_mask_original_size, lesion_info, distances

In [None]:
%%writefile app.py
import streamlit as st
from PIL import Image
import torch
import numpy as np
import torchvision.transforms as T
from skimage.measure import label, regionprops
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import os
import torch.nn as nn
import torchvision.models as models
from torchvision.models import VGG16_Weights
import torch.nn.functional as F

class VGG16Encoder(nn.Module):
    def __init__(self, embedding_dim, pretrained=True):
        super().__init__()
        weights = VGG16_Weights.IMAGENET1K_V1 if pretrained else None
        vgg16 = models.vgg16(weights=weights)
        self.features = nn.Sequential(*list(vgg16.features.children())[:-1])

        first_conv_layer = self.features[0]
        self.features[0] = nn.Conv2d(1, first_conv_layer.out_channels,
                                        kernel_size=first_conv_layer.kernel_size,
                                        stride=first_conv_layer.stride,
                                        padding=first_conv_layer.padding)

        with torch.no_grad():
            dummy_input = torch.randn(1, 1, 128, 128)
            output_features = self.features(dummy_input)
            self.flattened_size = output_features.view(output_features.size(0), -1).shape[1]

        self.embedding_dim = embedding_dim
        self.fc = nn.Linear(self.flattened_size, self.embedding_dim)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

class UNetWithFrozenEncoder(nn.Module):
    def __init__(self, encoder, num_classes=1):
        super().__init__()
        self.encoder = encoder.features
        for param in self.encoder.parameters():
            param.requires_grad = False

        vgg_channels = [64, 128, 256, 512] # 4 pooling layers

        self.upconv1 = nn.ConvTranspose2d(vgg_channels[-1], vgg_channels[-2], kernel_size=2, stride=2)
        self.conv_decoder1 = nn.Conv2d(vgg_channels[-2] * 2, vgg_channels[-2], kernel_size=3, padding=1)

        self.upconv2 = nn.ConvTranspose2d(vgg_channels[-2], vgg_channels[-3], kernel_size=2, stride=2)
        self.conv_decoder2 = nn.Conv2d(vgg_channels[-3] * 2, vgg_channels[-3], kernel_size=3, padding=1) # Added kernel_size=3

        self.upconv3 = nn.ConvTranspose2d(vgg_channels[-3], vgg_channels[-4], kernel_size=2, stride=2)
        self.conv_decoder3 = nn.Conv2d(vgg_channels[-4] * 2, vgg_channels[-4], kernel_size=3, padding=1)

        # Additional upsampling to reach 256x256
        self.upconv_final = nn.ConvTranspose2d(vgg_channels[-4], vgg_channels[-4] // 2, kernel_size=2, stride=2)
        self.conv_decoder_final = nn.Conv2d(vgg_channels[-4] // 2, vgg_channels[-4] // 2, kernel_size=3, padding=1)

        self.final_conv = nn.Conv2d(vgg_channels[-4] // 2, num_classes, kernel_size=1)

    def forward(self, x):
        # Encoder forward pass
        pool_outputs = []
        x = x
        for i, layer in enumerate(self.encoder):
            x = layer(x)
            if isinstance(layer, nn.MaxPool2d):
                pool_outputs.append(x)

        # Decoder forward pass with skip connections
        d1 = self.upconv1(pool_outputs[-1])
        d1 = torch.cat([d1, pool_outputs[-2]], dim=1)
        d1 = F.relu(self.conv_decoder1(d1))

        d2 = self.upconv2(d1)
        d2 = torch.cat([d2, pool_outputs[-3]], dim=1)
        d2 = F.relu(self.conv_decoder2(d2))

        d3 = self.upconv3(d2)
        d3 = torch.cat([d3, pool_outputs[-4]], dim=1)
        d3 = F.relu(self.conv_decoder3(d3))

        # Final upsampling
        d_final_up = self.upconv_final(d3)
        d_final = F.relu(self.conv_decoder_final(d_final_up))

        output = torch.sigmoid(self.final_conv(d_final))
        return output

# --- Load your trained segmentation model ---
def load_segmentation_model(model_path, device):
    embedding_dim = 512 # Adjust if your encoder's embedding dim was different
    loaded_encoder = VGG16Encoder(embedding_dim=embedding_dim, pretrained=False)
    segmentation_model = UNetWithFrozenEncoder(loaded_encoder, num_classes=1).to(device)
    try:
        segmentation_model.load_state_dict(torch.load(model_path, map_location=device))
    except FileNotFoundError:
        st.error(f"Error: Model weights not found at {model_path}")
        return None
    except RuntimeError as e:
        st.error(f"Error loading state_dict: {e}")
        return None
    segmentation_model.eval()
    return segmentation_model

# --- Preprocessing function ---
def preprocess_image(image):
    image = image.convert('L')  # Convert to grayscale
    transform = T.Compose([
        T.Resize((256, 256)),
        T.ToTensor(),
        T.Normalize(mean=[0.5], std=[0.5])
    ])
    return transform(image).unsqueeze(0)

# --- Function to get lesion distances (modified for Streamlit) ---
def get_lesion_info_and_distances(image, model, pixel_spacing=(1.0, 1.0), threshold=0.5):
    original_width, original_height = image.size
    original_image_array = np.array(image.convert('L'))
    image_tensor = preprocess_image(image).to(next(model.parameters()).device)

    with torch.no_grad():
        output = model(image_tensor)
        predicted_mask_resized = (output > threshold).float().squeeze().cpu().numpy()

    predicted_mask_original_size = np.array(Image.fromarray(predicted_mask_resized).resize((original_width, original_height), Image.NEAREST))
    labeled_mask_original_size = label(predicted_mask_original_size)
    regions_original_size = regionprops(labeled_mask_original_size)

    lesion_info = []
    for i, region in enumerate(regions_original_size):
        minr, minc, maxr, maxc = region.bbox
        area_pixels = region.area
        centroid_row, centroid_col = region.centroid
        lesion_info.append({
            'id': i + 1,
            'bbox': (minr, minc, maxr, maxc),
            'area_pixels': area_pixels,
            'centroid': (centroid_row, centroid_col)
        })

    distances = []
    if len(lesion_info) >= 2:
        for i in range(len(lesion_info)):
            for j in range(i + 1, len(lesion_info)):
                r1, c1 = lesion_info[i]['centroid']
                r2, c2 = lesion_info[j]['centroid']
                distance_pixels = np.sqrt((r2 - r1)**2 + (c2 - c1)**2)
                distance_real_world = np.sqrt(((c2 - c1) * pixel_spacing[1])**2 + ((r2 - r1) * pixel_spacing[0])**2)
                distances.append((lesion_info[i]['id'], lesion_info[j]['id'], distance_real_world))

    return original_image_array, predicted_mask_original_size, lesion_info, distances

# --- Main Streamlit App ---
def main():
    st.title("Necrotic Lung Lesion Distance Measurement")

    uploaded_file = st.file_uploader("Upload a CT Image...", type=["png", "jpg", "jpeg"])
    pixel_spacing_x = st.number_input("Pixel Spacing (X)", value=1.0)
    pixel_spacing_y = st.number_input("Pixel Spacing (Y)", value=1.0)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_path = 'segmentation_unet.pth'
    model = load_segmentation_model(model_path, device)

    if model is not None and uploaded_file is not None:
        image = Image.open(uploaded_file)
        original_image, predicted_mask, lesion_info, distances = get_lesion_info_and_distances(
            image, model, pixel_spacing=(pixel_spacing_y, pixel_spacing_x)
        )

        st.subheader("Original CT Image")
        st.image(original_image, use_container_width=True)

        st.subheader("Detected Lesions")
        fig, ax = plt.subplots()
        ax.imshow(original_image, cmap='gray')
        ax.imshow(predicted_mask, cmap='viridis', alpha=0.5)
        for lesion in lesion_info:
            bbox = lesion['bbox']
            area = lesion['area_pixels']
            minr, minc, maxr, maxc = bbox
            rect = patches.Rectangle((minc, minr), maxc - minc, maxr - minr, linewidth=1, edgecolor='lime', facecolor='none')
            ax.add_patch(rect)
            ax.text(minc, minr - 5, f"ID: {lesion['id']}, Area: {area}", color='lime', fontsize=8, ha='left', va='top')
            centroid_row, centroid_col = lesion['centroid']
            ax.plot(centroid_col, centroid_row, 'w+', markersize=5)
        st.pyplot(fig)

        st.subheader("Distances Between Lesions (mm)")
        if distances:
            for dist in distances:
                st.write(f"Lesion {dist[0]} and Lesion {dist[1]}: {dist[2]:.2f} mm")
        else:
            st.write("Less than two lesions detected.")

if __name__ == "__main__":
    main()

Writing app.py


In [None]:
!pip install streamlit pyngrok

import streamlit as st
from pyngrok import ngrok
import time
import subprocess
import os
import signal

# Kill any existing ngrok processes
ngrok.kill()

# Authtoken (replace with your actual authtoken)
NGROK_AUTH_TOKEN = "2y30gIHEqcnc0ruOyTFIu5HCCSk_4nBSRxNU1jvKYyevg6jFh" # Replace with your actual token
ngrok.set_auth_token(NGROK_AUTH_TOKEN)

# Set Streamlit server address and port
streamlit_address = "0.0.0.0"
streamlit_port = 8501
streamlit_command = [
    "streamlit", "run",
    "--server.address", streamlit_address,
    "--server.port", str(streamlit_port),
    "app.py"
]

# Start Streamlit app in the background
streamlit_process = subprocess.Popen(
    streamlit_command,
    stdout=subprocess.PIPE,
    stderr=subprocess.PIPE,
    preexec_fn=os.setsid # Create a new process group
)

print(f"⏳ Starting Streamlit app on {streamlit_address}:{streamlit_port}...")
time.sleep(15) # Give Streamlit ample time to start

# Check if Streamlit is running internally using curl
internal_url = f"http://localhost:{streamlit_port}"
try:
    curl_process = subprocess.run(
        ["curl", "-s", internal_url],
        capture_output=True,
        text=True,
        timeout=10
    )
    if "streamlit" in curl_process.stdout.lower():
        print(f"✅ Streamlit is responding at {internal_url}")
        try:
            # Open a ngrok tunnel to the Streamlit app's port
            public_url = ngrok.connect(streamlit_port, bind_tls=True).public_url
            print(f"🌍 Your public Streamlit URL (HTTPS): {public_url}")
            print("👆 Click on the link above to open your web application in a new tab.")
        except Exception as e:
            print(f"⚠️ Error creating ngrok tunnel: {e}")
            print("Please check your ngrok authtoken and internet connection.")
    else:
        error_output = streamlit_process.stderr.read().decode("utf-8")
        print(f"❌ Streamlit did not start correctly. Error output:\n{error_output}")

except subprocess.TimeoutExpired:
    print(f"❌ Streamlit did not respond at {internal_url} within the timeout. It might have failed to start.")
except FileNotFoundError:
    print("❌ Error: curl command not found. This is unexpected in Colab.")
except Exception as e:
    print(f"❌ An unexpected error occurred while checking Streamlit: {e}")
finally:
    # Clean up the Streamlit process if it's still running
    if streamlit_process.poll() is None:
        os.killpg(os.getpgid(streamlit_process.pid), signal.SIGTERM) # Terminate the process group
        streamlit_process.wait()

Collecting streamlit
  Downloading streamlit-1.45.1-py3-none-any.whl.metadata (8.9 kB)
Collecting pyngrok
  Downloading pyngrok-7.2.9-py3-none-any.whl.metadata (9.3 kB)
Collecting watchdog<7,>=2.1.5 (from streamlit)
  Downloading watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.3/44.3 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
Collecting pydeck<1,>=0.8.0b4 (from streamlit)
  Downloading pydeck-0.9.1-py2.py3-none-any.whl.metadata (4.1 kB)
Downloading streamlit-1.45.1-py3-none-any.whl (9.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.9/9.9 MB[0m [31m107.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyngrok-7.2.9-py3-none-any.whl (25 kB)
Downloading pydeck-0.9.1-py2.py3-none-any.whl (6.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m114.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl (