In [1]:
GIT_TOKEN = "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"
REPO_URL = f"https://{GIT_TOKEN}@github.com/semilleroCV/BreastCATT.git"
!git clone $REPO_URL

Cloning into 'BreastCATT'...
remote: Enumerating objects: 131, done.[K
remote: Counting objects: 100% (131/131), done.[K
remote: Compressing objects: 100% (99/99), done.[K
remote: Total 131 (delta 58), reused 95 (delta 30), pack-reused 0 (from 0)[K
Receiving objects: 100% (131/131), 928.33 KiB | 2.79 MiB/s, done.
Resolving deltas: 100% (58/58), done.


In [1]:
import sys
import os

sys.path.append(os.path.abspath("../"))

In [2]:
from huggingface_hub import hf_hub_download
import os

# Local directory where you want the checkpoint saved
save_path = "../checkpoints/segmentation"
os.makedirs(save_path, exist_ok=True)

# Download the file
checkpoint_path = hf_hub_download(
    repo_id="SemilleroCV/transunet-breast-cancer",
    filename="lucky-sweep-6_0.4937.pth",
    local_dir=save_path
)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import torch
import torch.nn as nn
from transunet.vit_seg_modeling import VisionTransformer as ViT_seg
from transunet.vit_seg_modeling import CONFIGS
from torchvision import transforms

In [4]:
class SegmentationModel(nn.Module):
    def __init__(self, img_size: int, n_skip: int, num_classes: int, dir_model: str, device: torch.device, threshold: float = 0.5):
        """
        Initializes the segmentation model with a Vision Transformer (ViT) backbone.
        
        Args:
            img_size (int): The size of the input images.
            n_skip (int): The number of skip connections.
            num_classes (int): Number of segmentation classes.
            dir_model (str): Path to the model weights file.
            device (torch.device): The device on which to load the model (CPU or GPU).
            threshold (float, optional): Threshold for converting probabilities into binary predictions. Defaults to 0.5.
        """
        super().__init__()

        self.device = device
        self.threshold = threshold

        # Fixed configuration; you can try other configurations like "ViT-B_16".
        self.config_vit = CONFIGS["R50-ViT-B_16"]  # You can try others like "ViT-B_16"
        self.config_vit.n_classes = num_classes  # Number of classes for binary segmentation
        self.config_vit.n_skip = n_skip
        self.config_vit.patches.grid = (14, 14)

        # Initialize the segmentation model using ViT_seg
        self.model = ViT_seg(self.config_vit, img_size=img_size, num_classes=num_classes).to(device)

        # Load model weights from file, mapping to the specified device
        try:
            self.model.load_state_dict(torch.load(dir_model, map_location=device))
            print(f"✅ Weights loaded from {dir_model}")
        except Exception as e:
            print(f"Error loading weights from {dir_model}: {e}")
            raise

    def forward(self, x):
        return self.model(x)

    def predict_on_image(self, image_tensor: torch.Tensor) -> torch.Tensor:
        """
        Performs prediction on a single image.
        
        Args:
            image_tensor (torch.Tensor): Input image tensor expected to have shape [1, C, H, W].
        
        Returns:
            torch.Tensor: Binary segmentation prediction with shape [C, H, W].
        """
        self.model.eval()  # Set model to evaluation mode
        with torch.no_grad():
            image_tensor = image_tensor.to(self.device)  # Move the tensor to the appropriate device
            logits = self.model(image_tensor)             # Forward pass; output shape: [1, C, H, W]
            probs = torch.sigmoid(logits)                   # Convert logits to probabilities using sigmoid

            # Convert probabilities to binary mask using the defined threshold
            preds = (probs > self.threshold).float()
            # Squeeze the batch dimension (assumes batch size is 1)
            return preds.squeeze(0)

In [7]:
img_size = 224
device = "cuda" if torch.cuda.is_available() else "cpu"
dir_model = "../checkpoints/segmentation/lucky-sweep-6_0.4937.pth"
segmentador = SegmentationModel(img_size, 3, 1, dir_model, device, threshold=0.7)

✅ Weights loaded from ../checkpoints/segmentation/lucky-sweep-6_0.4937.pth


In [8]:
data_transform = transforms.Compose([
  transforms.Resize((224, 224)),
  transforms.ToTensor()
])

In [10]:
from datasets import load_dataset

dataset = load_dataset("SemilleroCV/BreastThermography", split="test")

Generating train split: 100%|██████████| 285/285 [00:00<00:00, 831.70 examples/s]
Generating test split: 100%|██████████| 72/72 [00:00<00:00, 676.66 examples/s]


In [103]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

# Load a sample image from the dataset
img = dataset[66]['image']  # Select the desired image
img = np.array(img, dtype=np.float32)

# Normalize the image
MAX_TEMPERATURE = 36.44
normalized_img = img / MAX_TEMPERATURE  # Normalize to [0, 1]

# Apply the colormap and save the image
output_path = "sample_image_inferno.png"
plt.imshow(normalized_img, cmap="inferno")  # Apply the 'inferno' colormap
plt.axis('off')  # Remove axes for a clean image
plt.savefig(output_path, bbox_inches='tight', pad_inches=0)  # Save the image
plt.close()

print(f"Image saved to {output_path}")

Image saved to sample_image_inferno.png


In [104]:
# Perform segmentation on the image
matrix = Image.fromarray(normalized_img)  # Convert numpy array to PIL image
matrix = data_transform(matrix)  # Apply transformations
matrix = matrix.unsqueeze(0)  # Add batch dimension

# Get the segmentation output
output = segmentador.predict_on_image(matrix)

# Convert the output to a numpy array
output_np = output.cpu().squeeze(0).numpy()  # Remove batch dimension and convert to numpy

# Save the segmentation output with a colormap
segmentation_output_path = "segmentation_output_gray.png"
plt.imshow(output_np, cmap='gray')  # Apply the 'inferno' colormap
plt.axis('off')  # Remove axes for a clean image
plt.savefig(segmentation_output_path, bbox_inches='tight', pad_inches=0)  # Save the image
plt.close()

print(f"Segmentation output saved to {segmentation_output_path}")

Segmentation output saved to segmentation_output_gray.png
