In [2]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Statistical analysis of Fractal Dimension (FD) for various medical imaging datasets.

This script loads specified datasets, calculates the box-counting fractal dimension
for a sample of images from each class, and computes key statistics to evaluate
FD as a potential regularizer for semi-supervised learning.
"""

import os
import glob
import warnings
import pandas as pd
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from skimage.color import rgb2lab
from tqdm import tqdm

# ===============================================================
# UTILITIES & FRACTAL DIMENSION FUNCTIONS
# ===============================================================

def box_counting(image, threshold=0.5):
    """
    Calculates the fractal dimension of an image using the box-counting method.

    Args:
      image (torch.Tensor): A PyTorch tensor of shape [C, H, W] with values in [0, 1].
      threshold (float): The binarization threshold.

    Returns:
      float: The estimated fractal dimension.
    """
    if not isinstance(image, torch.Tensor):
        raise TypeError(f"Input must be a torch.Tensor, but got {type(image)}")

    # Convert to grayscale numpy array
    img_np = image.numpy()
    if img_np.shape[0] in (3, 4):  # Handle RGB or RGBA
        # Standard RGB to grayscale conversion
        img_np = 0.2989 * img_np[0, :, :] + 0.5870 * img_np[1, :, :] + 0.1140 * img_np[2, :, :]

    # Binarize the image
    binary_image = img_np < threshold

    # Check for empty images (all pixels are the same)
    if binary_image.sum() == 0 or binary_image.sum() == binary_image.size:
        return 1.0  # FD of a non-fractal object is its topological dimension

    p = min(binary_image.shape)
    n = int(np.floor(np.log2(p)))

    # Ensure at least two box sizes for regression
    if n < 2:
        return 1.0

    sizes = 2**np.arange(n, 0, -1)
    counts = []

    for size in sizes:
        S = np.add.reduceat(
            np.add.reduceat(binary_image, np.arange(0, binary_image.shape[0], size), axis=0),
            np.arange(0, binary_image.shape[1], size), axis=1)
        counts.append(len(np.where(S > 0)[0]))

    counts = np.array(counts)

    # Remove zero counts to avoid log(0) issues
    valid_indices = counts > 0
    if valid_indices.sum() < 2: # Need at least 2 points for a fit
        return 1.0

    counts = counts[valid_indices]
    sizes = sizes[valid_indices]

    # Linear regression on log-log scale
    coeffs = np.polyfit(np.log(sizes), np.log(counts), 1)
    return -coeffs[0]


# ===============================================================
# DATASET LOADER FUNCTIONS
# ===============================================================

def load_isic_data(path, sample_size=500):
    """Loads and samples data from the ISIC2024 dataset."""
    print("Loading ISIC 2024 dataset...")
    csv_path = os.path.join(path, "train-metadata.csv")
    img_dir = os.path.join(path, "train-image", "image")

    if not os.path.exists(csv_path) or not os.path.isdir(img_dir):
        print(f"Warning: ISIC paths not found. Searched for:\n - {csv_path}\n - {img_dir}\nSkipping analysis.")
        return None

    df = pd.read_csv(csv_path)
    df["image_path"] = df["isic_id"].apply(lambda x: os.path.join(img_dir, f"{x}.jpg"))
    df["label"] = df["target"].apply(lambda x: "Malignant" if x == 1 else "Benign")

    # Stratified sampling
    sampled_df = df.groupby('label', group_keys=False).apply(lambda x: x.sample(min(len(x), sample_size), random_state=42))
    print(f"Found {len(df)} images. Using a balanced sample of {len(sampled_df)} images.")
    return sampled_df[["image_path", "label"]]


def load_chest_xray_data(path, sample_size=500):
    """Loads and samples data from the Chest X-Ray (COVID) dataset."""
    print("Loading Chest X-Ray (COVID) dataset...")
    train_list_path = os.path.join(path, "train.txt")
    img_dir = os.path.join(path, "train")

    if not os.path.exists(train_list_path) or not os.path.isdir(img_dir):
        print(f"Warning: Chest X-Ray paths not found. Searched for:\n - {train_list_path}\n - {img_dir}\nSkipping analysis.")
        return None

    df = pd.read_csv(train_list_path, sep='\s+', header=None, names=['id', 'filename', 'label_str', 'source'])
    df['image_path'] = df['filename'].apply(lambda x: os.path.join(img_dir, x))
    df["label"] = df["label_str"].apply(lambda x: "COVID-19" if x.lower() == "positive" else "Normal")

    sampled_df = df.groupby('label', group_keys=False).apply(lambda x: x.sample(min(len(x), sample_size), random_state=42))
    print(f"Found {len(df)} images. Using a balanced sample of {len(sampled_df)} images.")
    return sampled_df[["image_path", "label"]]


def load_brain_tumor_data(path, sample_size=500):
    """Loads and samples data from the Brain Tumor MRI dataset."""
    print("Loading Brain Tumor MRI dataset...")
    train_dir = os.path.join(path, "Training")

    if not os.path.isdir(train_dir):
        print(f"Warning: Brain Tumor training path not found at {train_dir}. Skipping analysis.")
        return None

    image_paths = glob.glob(os.path.join(train_dir, "*", "*.jpg"))
    data = []
    for p in image_paths:
        label = os.path.basename(os.path.dirname(p))
        data.append({"image_path": p, "label": label})

    df = pd.DataFrame(data)
    sampled_df = df.groupby('label', group_keys=False).apply(lambda x: x.sample(min(len(x), sample_size), random_state=42))
    print(f"Found {len(df)} images. Using a sample of {len(sampled_df)} images.")
    return sampled_df


# ===============================================================
# CORE ANALYSIS LOGIC
# ===============================================================

def analyze_dataset_fd(df, fd_func, img_size=(256, 256)):
    """
    Performs FD analysis on a given dataframe of image paths and labels.

    Args:
        df (pd.DataFrame): DataFrame with 'image_path' and 'label' columns.
        fd_func (function): The function to calculate fractal dimension.
        img_size (tuple): The size to which images will be resized.

    Returns:
        pd.DataFrame: A formatted DataFrame with statistical results.
    """
    transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(),
    ])

    results = []
    # Using tqdm for a progress bar
    for index, row in tqdm(df.iterrows(), total=len(df), desc="Calculating FD"):
        try:
            img = Image.open(row['image_path']).convert("RGB")
            img_tensor = transform(img)
            fd = fd_func(img_tensor)
            results.append({'label': row['label'], 'fd': fd})
        except Exception as e:
            print(f"\nWarning: Could not process {row['image_path']}. Error: {e}")

    results_df = pd.DataFrame(results)

    # Calculate statistics
    stats = results_df.groupby('label')['fd'].agg(['mean', 'std']).reset_index()

    # Prepare the final summary table
    labels = stats['label'].tolist()
    if len(labels) == 2: # Binary classification case
        mean0, std0 = stats.loc[0, ['mean', 'std']]
        mean1, std1 = stats.loc[1, ['mean', 'std']]
        delta_mean = abs(mean1 - mean0)
        max_sigma = max(std0, std1)
        snr = delta_mean / max_sigma if max_sigma > 0 else 0

        summary_data = {
            "FD Method": ["Box-counting"],
            f"Mean FD ({labels[0]})": [f"{mean0:.3f}"],
            f"Mean FD ({labels[1]})": [f"{mean1:.3f}"],
            "ΔMean": [f"{delta_mean:.3f}"],
            f"σ ({labels[0]})": [f"{std0:.3f}"],
            f"σ ({labels[1]})": [f"{std1:.3f}"],
            "ΔMean/max(σ)": [f"**{snr:.3f}**"]
        }
        summary_df = pd.DataFrame(summary_data)

    else: # Multi-class case (e.g., Brain Tumor)
        summary_df = stats.rename(columns={'mean': 'Mean FD', 'std': 'Standard Deviation (σ)'})
        # Also create a binary comparison for the most different classes
        most_complex_class = stats.loc[stats['mean'].idxmax()]
        least_complex_class = stats.loc[stats['mean'].idxmin()]

        mean0, std0, label0 = least_complex_class[['mean', 'std', 'label']]
        mean1, std1, label1 = most_complex_class[['mean', 'std', 'label']]
        delta_mean = abs(mean1 - mean0)
        max_sigma = max(std0, std1)
        snr = delta_mean / max_sigma if max_sigma > 0 else 0

        binary_summary_data = {
            "FD Method": ["Box-counting"],
            f"Mean FD ({label0})": [f"{mean0:.3f}"],
            f"Mean FD ({label1})": [f"{mean1:.3f}"],
            "ΔMean": [f"{delta_mean:.3f}"],
            f"σ ({label0})": [f"{std0:.3f}"],
            f"σ ({label1})": [f"{std1:.3f}"],
            "ΔMean/max(σ)": [f"**{snr:.3f}**"]
        }
        binary_summary_df = pd.DataFrame(binary_summary_data)
        print("\n--- Multi-Class Overview ---")
        print(summary_df.to_markdown(index=False))
        print(f"\n--- Binary Comparison (most distinct classes) ---")
        summary_df = binary_summary_df

    return summary_df


# ===============================================================
# MAIN EXECUTION
# ===============================================================
if __name__ == "__main__":

    # --- CONFIGURATION ---
    # !!! IMPORTANT: Update these paths to match your local file structure !!!

    DATASET_PATHS = {
        "ISIC": "F:/datasets/ISIC2024",
        "ChestXRay": "F:/datasets/Chest X-Ray",
        "BrainTumor": "F:/datasets/Brain Tumor"
    }

    DATASET_LOADERS = {
        "ISIC": load_isic_data,
        "ChestXRay": load_chest_xray_data,
        "BrainTumor": load_brain_tumor_data
    }

    # Choose which datasets to analyze
    # To skip one, just comment it out from this list
    DATASETS_TO_ANALYZE = [
        "ISIC",
        "ChestXRay",
        "BrainTumor",
    ]

    # --- SCRIPT ---
    for name in DATASETS_TO_ANALYZE:
        print("\n" + "="*50)
        print(f"STARTING ANALYSIS FOR: {name}")
        print("="*50)

        path = DATASET_PATHS.get(name)
        loader_func = DATASET_LOADERS.get(name)

        if not path or not loader_func:
            print(f"Configuration for '{name}' not found. Skipping.")
            continue

        df = loader_func(path)

        if df is None or df.empty:
            print(f"No data loaded for {name}. Skipping analysis.")
            continue

        summary_table = analyze_dataset_fd(df, fd_func=box_counting)

        print(f"\n--- Statistical Comparison for {name} ---")
        print(summary_table.to_markdown(index=False))
        print("="*50 + "\n")

  df = pd.read_csv(train_list_path, sep='\s+', header=None, names=['id', 'filename', 'label_str', 'source'])



STARTING ANALYSIS FOR: ISIC
Loading ISIC 2024 dataset...


  df = pd.read_csv(csv_path)
  sampled_df = df.groupby('label', group_keys=False).apply(lambda x: x.sample(min(len(x), sample_size), random_state=42))


Found 401059 images. Using a balanced sample of 893 images.


Calculating FD: 100%|██████████| 893/893 [00:02<00:00, 439.34it/s]
  sampled_df = df.groupby('label', group_keys=False).apply(lambda x: x.sample(min(len(x), sample_size), random_state=42))



--- Statistical Comparison for ISIC ---
| FD Method    |   Mean FD (Benign) |   Mean FD (Malignant) |   ΔMean |   σ (Benign) |   σ (Malignant) | ΔMean/max(σ)   |
|:-------------|-------------------:|----------------------:|--------:|-------------:|----------------:|:---------------|
| Box-counting |              1.402 |                 1.544 |   0.142 |        0.456 |           0.372 | **0.312**      |


STARTING ANALYSIS FOR: ChestXRay
Loading Chest X-Ray (COVID) dataset...
Found 29986 images. Using a balanced sample of 1000 images.


Calculating FD: 100%|██████████| 1000/1000 [00:11<00:00, 90.28it/s]
  sampled_df = df.groupby('label', group_keys=False).apply(lambda x: x.sample(min(len(x), sample_size), random_state=42))



--- Statistical Comparison for ChestXRay ---
| FD Method    |   Mean FD (COVID-19) |   Mean FD (Normal) |   ΔMean |   σ (COVID-19) |   σ (Normal) | ΔMean/max(σ)   |
|:-------------|---------------------:|-------------------:|--------:|---------------:|-------------:|:---------------|
| Box-counting |                1.823 |              1.823 |       0 |          0.074 |        0.095 | **0.001**      |


STARTING ANALYSIS FOR: BrainTumor
Loading Brain Tumor MRI dataset...
Found 5712 images. Using a sample of 2000 images.


Calculating FD: 100%|██████████| 2000/2000 [00:07<00:00, 265.69it/s]


--- Multi-Class Overview ---
| label      |   Mean FD |   Standard Deviation (σ) |
|:-----------|----------:|-------------------------:|
| glioma     |   1.99772 |               0.00230447 |
| meningioma |   1.99058 |               0.0166172  |
| notumor    |   1.97445 |               0.0331868  |
| pituitary  |   1.99578 |               0.0029198  |

--- Binary Comparison (most distinct classes) ---

--- Statistical Comparison for BrainTumor ---
| FD Method    |   Mean FD (notumor) |   Mean FD (glioma) |   ΔMean |   σ (notumor) |   σ (glioma) | ΔMean/max(σ)   |
|:-------------|--------------------:|-------------------:|--------:|--------------:|-------------:|:---------------|
| Box-counting |               1.974 |              1.998 |   0.023 |         0.033 |        0.002 | **0.701**      |




