In [None]:
!pip install kaggle



In [None]:
from google.colab import files
files.upload()

Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"westoncadena","key":"72f09f2daaf65e5c25c7bd7c2e6d5b5f"}'}

In [None]:
from google.colab import drive
drive.mount('/content/drive')
save_dir = "/content/drive/MyDrive/Hack-Bone/"  # Change this to your desired folder

Mounted at /content/drive


In [None]:
import os
os.makedirs('/root/.kaggle', exist_ok=True)
os.rename('kaggle.json', '/root/.kaggle/kaggle.json')

# Set file permissions
os.chmod('/root/.kaggle/kaggle.json', 600)

In [None]:
import zipfile
import os

# Download and unzip the dataset
dataset = 'mariusmarin/bs-80k'
!kaggle datasets download -d {dataset}

# Define the zip file
zip_file = 'bs-80k.zip'

# Extract all files from the zip archive
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
    zip_ref.extractall('bs-80k')  # Extract to 'bs-80k' folder

Dataset URL: https://www.kaggle.com/datasets/mariusmarin/bs-80k
License(s): MIT


# mSegResRF-SPECT Implementation
*Based on "Novel Joint Classification Model" (Current Medical Imaging, 2024, Volume 20)*


## 1. Setup and Configuration

In [None]:
import os
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.models import ResNet34_Weights
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from sklearn.metrics import roc_auc_score
# from sklearn.metrics import accuracy_score, sensitivity_score, specificity_score, f1_score, confusion_matrix

from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Configuration
CONFIG = {
    'data_dir': 'bs-80k/temp',
    'batch_size': 16,
    'lr': 0.0001,
    'epochs': 100, #100
    'num_workers': 2,
    'image_size': 256,
    'crop_size': 224,
    'train_test_split': 0.8,
    'k_folds': 5, #5
    'device': torch.device("cuda" if torch.cuda.is_available() else "cpu")

}

# Define the selected regions (according to the paper)
SELECTED_REGIONS = [
    # 'testANT',        # 1
    # 'testPOST',
    'headANT',
    'vertebraANT',    # 5
    'chestLANT',      # 8
    'chestRANT',      # 9
    'pelvisANT',      # 10
    'kneeLANT',       # 12
    'kneeRANT'        # 13
]


## 2. Dataset and DataLoader

In [None]:
class BoneScanDataset(Dataset):
    def __init__(self, data_dir, region, file_list, labels, transform=None):
        self.data_dir = data_dir
        self.region = region
        self.file_list = file_list
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img_path = os.path.join(self.data_dir, self.region, self.file_list[idx])

        # Handle potential file not found errors
        try:
            image = Image.open(img_path).convert('RGB')
        except:
            print(f"Error loading image: {img_path}")
            # Return a placeholder black image
            image = Image.new('RGB', (CONFIG['image_size'], CONFIG['image_size']), color=0)

        if self.transform:
            image = self.transform(image)

        label = self.labels[idx]
        return image, label

## 3. Model Architecture

In [None]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        base = models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
        self.features = nn.Sequential(*list(base.children())[:-1])
        self.embedding = nn.Linear(512, 256)
        self.classifier = nn.Linear(256, 2)  # Add classification layer

    def forward(self, x, return_embedding=False):
        x = self.features(x)
        x = torch.flatten(x, 1)
        embedding = self.embedding(x)
        if return_embedding:
            return embedding
        out = self.classifier(embedding)
        return out

## 4. Data Preparation Function

In [None]:
def prepare_data(data_dir):
    """
    Prepare the data by loading file names and labels from the region text files.
    Returns a dictionary with region names as keys and (file_list, labels) as values.
    """
    data_dict = {}
    whole_body_labels = {}
    whole_body_dict = {}

    # Load the wholeBodyANT labels
    whole_body_path = os.path.join(data_dir, 'wholeBodyANT')
    if os.path.exists(whole_body_path):
        label_file_path = os.path.join(whole_body_path, 'wholeBodyANT.txt')
        if os.path.exists(label_file_path):
            print(f"Found label file for whole body: {label_file_path}")
            with open(label_file_path, 'r') as f:
                lines = f.readlines()
                for line in lines:
                    parts = line.strip().split()
                    if len(parts) == 2:
                        img_name = parts[0]
                        label = int(parts[1])
                        whole_body_labels[img_name] = label
        else:
            print(f"Warning: No label file found for whole body at {label_file_path}.")
    else:
        print("Warning: wholeBodyANT directory not found.")

    for region in SELECTED_REGIONS:
        region_path = os.path.join(data_dir, region)

        if not os.path.exists(region_path):
            print(f"Warning: Region directory {region_path} does not exist.")
            continue

        # Read the region-specific label file
        label_file_path = os.path.join(region_path, f"{region}.txt")
        if not os.path.exists(label_file_path):
            print(f"Warning: Label file {label_file_path} does not exist.")
            continue
        else:
            print(f"Found label file: {label_file_path}")

        # Read the region label file
        region_image_labels = {}
        with open(label_file_path, 'r') as f:
            lines = f.readlines()
            for line in lines:
                parts = line.strip().split()
                if len(parts) == 2:
                    img_name = parts[0]
                    label = int(parts[1])
                    region_image_labels[img_name] = label

        # Get all available image files
        all_files = [f for f in os.listdir(region_path) if f.endswith('.jpg')]

        # Filter to include only images that have labels in both region-specific and wholeBody labels
        file_list = []
        labels = []

        for img_file in all_files:
            # Check if the image file is in both the region label and the wholeBody labels
            if img_file in region_image_labels and img_file in whole_body_labels:
                file_list.append(img_file)
                # Use the label from the wholeBodyANT file as the label for this image
                labels.append(region_image_labels[img_file])
                # Add the image to whole_body_labels if it's part of the region
                whole_body_dict[img_file] = whole_body_labels[img_file]

        if len(file_list) == 0:
            print(f"Warning: No valid labeled images found for region {region}.")
            continue

        data_dict[region] = (file_list, labels)

    return data_dict, whole_body_dict


In [None]:
def analyze_whole_body_dict(whole_body_dict):
    """
    Analyze the labels in whole_body_dict and output important statistics.
    """
    if not whole_body_dict:
        print("Warning: whole_body_dict is empty.")
        return

    # Count the occurrences of each label in whole_body_dict
    label_counts = {}
    for label in whole_body_dict.values():
        label_counts[label] = label_counts.get(label, 0) + 1

    # Print the total number of samples
    total_samples = len(whole_body_dict)
    print(f"Total number of samples in whole_body_dict: {total_samples}")

    # Print the label distribution
    print("Label distribution:")
    for label, count in label_counts.items():
        print(f"  Label {label}: {count} samples")


data_dir = "bs-80k/temp"  # Update this with the actual data directory
data_dict, whole_body_dict = prepare_data(data_dir)

# Now analyze the whole_body_dict
analyze_whole_body_dict(whole_body_dict)



## 5. Feature Extractor Training

In [None]:
def train_feature_extractors(data_dict, config):
    """
    Train a feature extractor for each region.
    Returns a dictionary of trained models.
    """
    overall_start = time.time()

    # Check if GPU is available
    device = config['device']
    if not torch.cuda.is_available():
        print("Warning: No GPU found. Training on CPU.")
        device = torch.device("cpu")

    # Transforms
    train_transform = transforms.Compose([
        transforms.Resize((config['image_size'], config['image_size'])),
        transforms.RandomCrop(config['crop_size']),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    val_transform = transforms.Compose([
        transforms.Resize((config['image_size'], config['image_size'])),
        transforms.CenterCrop(config['crop_size']),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    models_dict = {}

    for region in SELECTED_REGIONS:
        if region not in data_dict:
            continue

        print(f"\nTraining feature extractor for {region}...")
        region_start = time.time()

        files, labels = data_dict[region]
        split_idx = int(len(files) * config['train_test_split'])
        train_files, train_labels = files[:split_idx], labels[:split_idx]
        test_files, test_labels = files[split_idx:], labels[split_idx:]

        # Datasets + Loaders
        train_dataset = BoneScanDataset(config['data_dir'], region, train_files, train_labels, train_transform)
        test_dataset = BoneScanDataset(config['data_dir'], region, test_files, test_labels, val_transform)

        train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers'])
        test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'])

        # Model, loss, optimizer
        model = FeatureExtractor().to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(model.parameters(), lr=config['lr'])

        best_acc = -1.0

        for epoch in range(config['epochs']):
            model.train()
            train_loss = 0.0

            for inputs, targets in train_loader:
                inputs, targets = inputs.to(device), targets.to(device)

                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

                train_loss += loss.item() * inputs.size(0)

            train_loss /= len(train_loader.dataset)

            if (epoch + 1) % 5 == 0:
                model.eval()
                val_loss, correct, total = 0.0, 0, 0

                with torch.no_grad():
                    for inputs, targets in test_loader:
                        inputs, targets = inputs.to(device), targets.to(device)
                        outputs = model(inputs)
                        loss = criterion(outputs, targets)

                        val_loss += loss.item() * inputs.size(0)
                        _, predicted = outputs.max(1)
                        total += targets.size(0)
                        correct += predicted.eq(targets).sum().item()

                val_loss /= len(test_loader.dataset)
                val_acc = correct / total

                print(f"Region: {region}, Epoch: {epoch+1}/{config['epochs']}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

                if val_acc > best_acc:
                    best_acc = val_acc
                    print(f"Saving best model for {region} at epoch {epoch+1}, val_acc = {val_acc:.4f}")
                    os.makedirs(save_dir, exist_ok=True)
                    torch.save(model.state_dict(), os.path.join(save_dir, f"resnet34_{region}_best.pth"))

        # Load and strip the classifier for feature extraction
        model.load_state_dict(torch.load(os.path.join(save_dir, f"resnet34_{region}_best.pth"), map_location=device))
        model.classifier = nn.Identity()  # Remove the classifier for final use
        models_dict[region] = model

        region_end = time.time()
        print(f"Finished training for {region} in {(region_end - region_start):.2f} seconds.")

    overall_end = time.time()
    print(f"\nAll regions processed in {(overall_end - overall_start):.2f} seconds.")

    return models_dict

## 6. Feature Extraction

In [None]:
def extract_features(data_dict, models_dict, config):
    """
    Extract features for all regions using trained models.
    Returns a dictionary of features and labels for each region.
    """

    # Check if GPU is available
    device = config['device']
    if not torch.cuda.is_available():
        print("Warning: No GPU found. Training on CPU.")
        device = torch.device("cpu")

    val_transform = transforms.Compose([
        transforms.Resize((config['image_size'], config['image_size'])),
        transforms.CenterCrop(config['crop_size']),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    features_dict = {}

    for region in SELECTED_REGIONS:
        if region not in data_dict or region not in models_dict:
            continue

        print(f"Extracting features for {region}...")

        files, labels = data_dict[region]

        # Create dataset
        dataset = BoneScanDataset(config['data_dir'], region, files, labels, val_transform)

        # Create dataloader
        dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'])

        # Set model to evaluation mode
        model = models_dict[region].to(device)
        model.eval()

        # Extract features
        all_features = []
        all_labels = []
        all_files = []

        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(dataloader):
                inputs = inputs.to(device)

                outputs = model(inputs)

                all_features.append(outputs.cpu().numpy())
                all_labels.append(targets.numpy())

                # Keep track of which files we're processing
                start_idx = batch_idx * config['batch_size']
                end_idx = min(start_idx + config['batch_size'], len(files))
                all_files.extend(files[start_idx:end_idx])

        # Concatenate all batches
        features = np.vstack(all_features)
        labels = np.concatenate(all_labels)

        features_dict[region] = (features, labels, all_files)

    return features_dict

## 7. Random Forest Classifier Training

In [None]:
def train_random_forest(features_dict, whole_body_dict, config):
    start_time = time.time()

    region_files = {
        region: [os.path.splitext(f)[0] for f in features_dict[region][2]]
        for region in features_dict
    }

    common_identifiers = set(region_files[list(region_files.keys())[0]])
    for region in region_files:
        common_identifiers &= set(region_files[region])

    if not common_identifiers:
        print("Error: No common samples found across all regions.")
        return None, (0, 0, 0, 0, 0)

    common_identifiers = sorted(list(common_identifiers))
    print(f"Found {len(common_identifiers)} common samples across all regions.")

    all_features = []
    all_labels = []

    for identifier in common_identifiers:
        sample_feature = []

        for region in features_dict:
            features, _, files = features_dict[region]
            file_to_idx = {os.path.splitext(f)[0]: i for i, f in enumerate(files)}

            if identifier in file_to_idx:
                idx = file_to_idx[identifier]
                sample_feature.append(features[idx])
            else:
                sample_feature = None
                break

        if sample_feature is not None:
            file_name = f"{identifier}.jpg"
            if file_name in whole_body_dict:
                full_feature = np.concatenate(sample_feature)
                all_features.append(full_feature)
                all_labels.append(whole_body_dict[file_name])

    if not all_features:
        print("Error: No valid samples with full feature vectors.")
        return None, (0, 0, 0, 0, 0)

    X = np.vstack(all_features)
    y = np.array(all_labels)

    print(f"Total samples: {len(y)}")
    print(f"Label distribution: {dict(zip(*np.unique(y, return_counts=True)))}")

    if len(np.unique(y)) < 2:
        print("Warning: Only one class present in data. Cannot train classifier.")
        return None, (0, 0, 0, 0, 0)

    kf = StratifiedKFold(n_splits=min(config['k_folds'], len(y)), shuffle=True, random_state=42)
    accuracies, sensitivities, specificities, f1_scores, aucs = [], [], [], [], []

    for fold, (train_idx, val_idx) in enumerate(kf.split(X, y)):
        print(f"\nTraining fold {fold+1}/{kf.get_n_splits()}...")
        fold_start = time.time()

        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]

        if len(np.unique(y_train)) < 2 or len(np.unique(y_val)) < 2:
            print(f"Warning: Fold {fold+1} has insufficient class diversity. Skipping...")
            continue

        clf = RandomForestClassifier(n_estimators=100, random_state=42)
        clf.fit(X_train, y_train)

        y_proba = clf.predict_proba(X_val)[:, 1]
        y_pred = (y_proba >= 0.5).astype(int)

        acc = accuracy_score(y_val, y_pred)

        cm = confusion_matrix(y_val, y_pred, labels=[0, 1])
        if cm.shape == (2, 2):
            tn, fp, fn, tp = cm.ravel()
        else:
            tn = fp = fn = tp = 0
            if np.sum(y_val == 0) == len(y_val):
                tn = len(y_val)
            elif np.sum(y_val == 1) == len(y_val):
                tp = len(y_val)

        sen = tp / (tp + fn) if (tp + fn) > 0 else 0
        spe = tn / (tn + fp) if (tn + fp) > 0 else 0
        f1 = f1_score(y_val, y_pred, zero_division=0)

        if len(np.unique(y_val)) == 2:
            auc = roc_auc_score(y_val, y_proba)
        else:
            print(f"Skipping AUC for Fold {fold+1} — only one class in validation.")
            auc = 0

        accuracies.append(acc)
        sensitivities.append(sen)
        specificities.append(spe)
        f1_scores.append(f1)
        aucs.append(auc)

        fold_end = time.time()
        print(f"Fold {fold+1} - Acc: {acc:.4f}, Sen: {sen:.4f}, Spe: {spe:.4f}, F1: {f1:.4f}, AUC: {auc:.4f}")
        print(f"Fold {fold+1} completed in {(fold_end - fold_start):.2f} seconds.")

    if not accuracies:
        print("No valid folds were processed. Cannot compute metrics.")
        return None, (0, 0, 0, 0, 0)

    avg_acc = np.mean(accuracies)
    avg_sen = np.mean(sensitivities)
    avg_spe = np.mean(specificities)
    avg_f1 = np.mean(f1_scores)
    avg_auc = np.mean(aucs)

    print(f"\nAverage metrics - Acc: {avg_acc:.4f}, Sen: {avg_sen:.4f}, Spe: {avg_spe:.4f}, F1: {avg_f1:.4f}, AUC: {avg_auc:.4f}")

    final_clf = RandomForestClassifier(n_estimators=100, random_state=42)
    final_clf.fit(X, y)

    end_time = time.time()
    print(f"\nTotal training time: {(end_time - start_time):.2f} seconds.")

    return final_clf, (avg_acc, avg_sen, avg_spe, avg_f1, avg_auc)

## 8. Main Training Pipeline

In [None]:
# Main training pipeline
def train_mSegResRF_SPECT():
    """
    Main function to train the mSegResRF-SPECT model.
    """
    print("Starting mSegResRF-SPECT training pipeline...")

    # 1. Prepare data
    print("Preparing data...")
    data_dict, whole_body_dict = prepare_data(CONFIG['data_dir'])

    analyze_whole_body_dict(whole_body_dict)

    # 2. Train feature extractors for each region
    print("Training feature extractors...")
    models_dict = train_feature_extractors(data_dict, CONFIG)

    # 3. Extract features using trained models
    print("Extracting features...")
    features_dict = extract_features(data_dict, models_dict, CONFIG)

    # 4. Train Random Forest classifier on concatenated features
    print("Training Random Forest classifier...")
    rf_classifier, metrics = train_random_forest(features_dict, whole_body_dict, CONFIG)

    # 5. Save the final model and report results
    print("Saving model and reporting results...")

    # Save the models_dict (region-wise feature extractors)
    torch.save({
        'models_dict': {region: model.state_dict() for region, model in models_dict.items()},
        'rf_classifier': rf_classifier,
        'metrics': metrics
    }, os.path.join(save_dir, "mSegResRF_SPECT_final.pth"))

    # Optionally, save the metrics separately in a text file for future reference
    with open(os.path.join(save_dir, "metrics.txt"), "w") as f:
        f.write(f"Accuracy: {metrics[0]:.4f}\n")
        f.write(f"Sensitivity: {metrics[1]:.4f}\n")
        f.write(f"Specificity: {metrics[2]:.4f}\n")
        f.write(f"F1 Score: {metrics[3]:.4f}\n")
        f.write(f"AUC: {metrics[4]:.4f}\n")

    print("Training complete and results saved!")
    return rf_classifier, metrics, models_dict

# Run the training pipeline
if __name__ == "__main__":
    rf_classifier, metrics, models_dict = train_mSegResRF_SPECT()


Starting mSegResRF-SPECT training pipeline...
Preparing data...
Found label file for whole body: bs-80k/temp/wholeBodyANT/wholeBodyANT.txt
Found label file: bs-80k/temp/headANT/headANT.txt
Found label file: bs-80k/temp/chestLANT/chestLANT.txt
Found label file: bs-80k/temp/chestRANT/chestRANT.txt
Found label file: bs-80k/temp/pelvisANT/pelvisANT.txt
Found label file: bs-80k/temp/kneeLANT/kneeLANT.txt
Found label file: bs-80k/temp/kneeRANT/kneeRANT.txt
Total number of samples in whole_body_dict: 2925
Label distribution:
  Label 0: 1860 samples
  Label 1: 1065 samples
Training feature extractors...

Training feature extractor for headANT...
Region: headANT, Epoch: 5/100, Train Loss: 0.2417, Val Loss: 0.2317, Val Acc: 0.9265
Saving best model for headANT at epoch 5, val_acc = 0.9265
Region: headANT, Epoch: 10/100, Train Loss: 0.1997, Val Loss: 0.2447, Val Acc: 0.9043
Region: headANT, Epoch: 15/100, Train Loss: 0.1386, Val Loss: 0.2697, Val Acc: 0.9282
Saving best model for headANT at epoch