In [38]:
import os
import numpy as np
import pandas as pd
import xml.etree.ElementTree as ET
import cv2
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import glob
import random
import time
import datetime
import pydicom

# Set seed for reproducibility
random.seed(42)
np.random.seed(42)
tf.random.set_seed(42)

# Check for GPU
print("GPU Available:", tf.config.list_physical_devices('GPU'))

# Dataset paths
BASE_DIR = "./dataset"
IMAGE_DIR = os.path.join(BASE_DIR, "images/images")
ANNOTATION_DIR = os.path.join(BASE_DIR, "annotations/annotations/tcia-lidc-xml")
CSV_FILE = os.path.join(BASE_DIR, "lidc_metadata.csv")

# Load metadata
print("Loading metadata from CSV...")
try:
    metadata_df = pd.read_csv(CSV_FILE)
    print(f"Loaded {len(metadata_df)} entries from metadata CSV")
    print(metadata_df.head())
except Exception as e:
    print(f"Error loading CSV: {e}")
    # Create a backup empty dataframe if loading fails
    metadata_df = pd.DataFrame(columns=['case_id', 'image_id', 'projection', 'findings'])

# Print dataset structure for debugging
print("\nExploring dataset structure:")
for subdir in ['157', '185', '186', '187', '188', '189']:
    subdir_path = os.path.join(ANNOTATION_DIR, subdir)
    if os.path.exists(subdir_path):
        xml_files = [f for f in os.listdir(subdir_path) if f.endswith('.xml')]
        print(f"Found {len(xml_files)} XML files in {subdir_path}")
        if xml_files:
            print(f"Sample files: {xml_files[:3]}")

# Improved XML parsing function
def parse_xml_annotations(xml_path):
    try:
        tree = ET.parse(xml_path)
        root = tree.getroot()
        
        # Extract namespace if present
        namespace = ''
        if root.tag.startswith('{'):
            namespace = root.tag.split('}')[0] + '}'
        
        # Try different paths to find image ID
        image_id = None
        possible_paths = [
            ".//{}ResponseHeader/{}SeriesInstanceUid".format(namespace, namespace),
            "./{}ResponseHeader/{}SeriesInstanceUid".format(namespace, namespace),
            ".//SeriesInstanceUid",
            ".//ResponseHeader/SeriesInstanceUid",
            ".//studyInstanceUID",
            ".//seriesUID"
        ]
        
        for path in possible_paths:
            try:
                element = root.find(path)
                if element is not None and element.text:
                    image_id = element.text
                    break
            except:
                continue
        
        # If we still don't have an image ID, try to get it from the filename
        if not image_id:
            filename = os.path.basename(xml_path)
            if filename.endswith('.xml'):
                filename = filename[:-4]  # Remove .xml extension
            image_id = filename
        
        nodules = []
        
        # Try different paths to find nodules
        nodule_paths = [
            ".//{}readingSession/{}unblindedReadNodule".format(namespace, namespace),
            ".//readingSession/unblindedReadNodule",
            ".//unblindedReadNodule",
            ".//nodule"
        ]
        
        for nodule_path in nodule_paths:
            nodule_elements = root.findall(nodule_path)
            if nodule_elements:
                for nodule in nodule_elements:
                    # Try different paths for ROI
                    roi_paths = [
                        ".//{}roi".format(namespace),
                        "./{}roi".format(namespace),
                        ".//roi",
                        "./roi"
                    ]
                    
                    coords = []
                    for roi_path in roi_paths:
                        roi = nodule.find(roi_path)
                        if roi is not None:
                            # Try different paths for edge maps
                            edge_map_paths = [
                                ".//{}edgeMap".format(namespace),
                                "./{}edgeMap".format(namespace),
                                ".//edgeMap",
                                "./edgeMap"
                            ]
                            
                            for edge_map_path in edge_map_paths:
                                edge_maps = roi.findall(edge_map_path)
                                if edge_maps:
                                    for edge_map in edge_maps:
                                        x = None
                                        y = None
                                        
                                        x_elements = edge_map.findall(".//xCoord") or edge_map.findall("./xCoord")
                                        if x_elements and x_elements[0].text:
                                            x = int(x_elements[0].text)
                                            
                                        y_elements = edge_map.findall(".//yCoord") or edge_map.findall("./yCoord")
                                        if y_elements and y_elements[0].text:
                                            y = int(y_elements[0].text)
                                            
                                        if x is not None and y is not None:
                                            coords.append((x, y))
                    
                    # Try to get characteristics
                    subtlety = 3  # Default value
                    malignancy = 3  # Default value
                    
                    subtlety_paths = [
                        ".//{}characteristics/{}subtlety".format(namespace, namespace),
                        ".//characteristics/subtlety",
                        ".//subtlety"
                    ]
                    
                    for path in subtlety_paths:
                        element = nodule.find(path)
                        if element is not None and element.text:
                            try:
                                subtlety = int(element.text)
                                break
                            except:
                                pass
                    
                    malignancy_paths = [
                        ".//{}characteristics/{}malignancy".format(namespace, namespace),
                        ".//characteristics/malignancy",
                        ".//malignancy"
                    ]
                    
                    for path in malignancy_paths:
                        element = nodule.find(path)
                        if element is not None and element.text:
                            try:
                                malignancy = int(element.text)
                                break
                            except:
                                pass
                    
                    if coords:
                        nodules.append({
                            "coords": coords,
                            "subtlety": subtlety,
                            "malignancy": malignancy
                        })
        
        return image_id, nodules
    
    except Exception as e:
        print(f"Error parsing {xml_path}: {str(e)}")
        return None, []

# Function to find all XML files
def find_all_xml_files():
    print("Searching for XML files...")
    xml_files = []
    for root, dirs, files in os.walk(ANNOTATION_DIR):
        for file in files:
            if file.endswith('.xml'):
                xml_files.append(os.path.join(root, file))
    print(f"Found {len(xml_files)} XML files")
    return xml_files

# Function to find image files by pattern matching
def find_image_files(image_id):
    # Try different patterns to match images
    patterns = [
        f"{IMAGE_DIR}/**/{image_id}*.dcm",
        f"{IMAGE_DIR}/**/*{image_id}*.dcm",
        f"{IMAGE_DIR}/**/{image_id.replace('-', '_')}*.dcm"
    ]
    
    for pattern in patterns:
        image_paths = glob.glob(pattern, recursive=True)
        if image_paths:
            return image_paths[0]
    
    return None

# Function to create binary masks from nodule coordinates
def create_nodule_mask(image_shape, nodules):
    mask = np.zeros(image_shape[:2], dtype=np.uint8)
    
    for nodule in nodules:
        coords = nodule["coords"]
        if len(coords) > 2:
            pts = np.array(coords, np.int32)
            pts = pts.reshape((-1, 1, 2))
            cv2.fillPoly(mask, [pts], 255)
    
    return mask

# Improved function to load DICOM images
def load_dicom_image(image_path):
    try:
        # Try to load as DICOM
        ds = pydicom.dcmread(image_path)
        img = ds.pixel_array
        
        # Normalize to 0-255
        img = img - np.min(img)
        if np.max(img) > 0:
            img = img / np.max(img) * 255
        img = img.astype(np.uint8)
        
        # Check if image needs to be inverted (black background)
        if np.mean(img) > 127:
            img = 255 - img
            
        return img
    except Exception as e:
        print(f"Error loading DICOM {image_path}: {e}")
        
        # Try to load as regular image
        try:
            img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
            if img is not None:
                return img
        except:
            pass
            
        return None

# Function to generate dataset with fallback mechanisms
def prepare_dataset(max_samples=None):
    print("\nPreparing dataset...")
    
    # Find all XML files
    xml_files = find_all_xml_files()
    
    # Parse XML annotations with progress bar
    annotations = {}
    successful_files = 0
    for xml_file in tqdm(xml_files, desc="Parsing XML files"):
        image_id, nodules = parse_xml_annotations(xml_file)
        if image_id and nodules:
            annotations[image_id] = nodules
            successful_files += 1
    
    print(f"Successfully parsed {successful_files} XML files with {len(annotations)} annotations")
    
    # If we have metadata CSV, use it to match images
    dataset = []
    
    if not metadata_df.empty:
        for _, row in tqdm(metadata_df.iterrows(), desc="Processing from metadata", total=len(metadata_df)):
            image_id = row['image_id']
            case_id = row['case_id']
            
            # Find image path
            image_path = find_image_files(image_id)
            
            if not image_path:
                continue
            
            # Check if we have annotation for this image
            nodules = []
            has_nodules = False
            
            # Try exact match
            if image_id in annotations:
                nodules = annotations[image_id]
                has_nodules = len(nodules) > 0
            else:
                # Try partial match
                for anno_id in annotations:
                    if anno_id in image_id or image_id in anno_id:
                        nodules = annotations[anno_id]
                        has_nodules = len(nodules) > 0
                        break
            
            # If no annotation found but CSV says there are nodules
            if not nodules and row.get('findings', '') == 'Nodules':
                has_nodules = True
            
            # Create dataset entry
            dataset.append({
                "image_id": image_id,
                "case_id": case_id,
                "image_path": image_path,
                "nodules": nodules,
                "has_nodules": has_nodules,
                "projection": row.get('projection', 'Unknown')
            })
            
            if max_samples and len(dataset) >= max_samples:
                break
    
    # If we couldn't create dataset from metadata, try direct file-based approach
    if not dataset:
        print("Creating dataset from direct file scanning...")
        image_files = glob.glob(f"{IMAGE_DIR}/**/*.dcm", recursive=True)
        
        for image_path in tqdm(image_files, desc="Scanning image files"):
            image_id = os.path.basename(image_path).split('.')[0]
            
            # Try to find a matching annotation
            nodules = []
            has_nodules = False
            
            # Try exact match
            if image_id in annotations:
                nodules = annotations[image_id]
                has_nodules = len(nodules) > 0
            else:
                # Try partial match
                for anno_id in annotations:
                    if anno_id in image_id or image_id in anno_id:
                        nodules = annotations[anno_id]
                        has_nodules = len(nodules) > 0
                        break
            
            # Create dataset entry
            dataset.append({
                "image_id": image_id,
                "case_id": image_id.split('-')[0] if '-' in image_id else image_id,
                "image_path": image_path,
                "nodules": nodules,
                "has_nodules": has_nodules,
                "projection": "Unknown"
            })
            
            if max_samples and len(dataset) >= max_samples:
                break
    
    # If still no dataset, create synthetic dataset for model architecture testing
    if not dataset:
        print("WARNING: Creating synthetic dataset for model testing...")
        image_files = glob.glob(f"{IMAGE_DIR}/**/*.dcm", recursive=True)
        
        if not image_files:
            image_files = glob.glob(f"{IMAGE_DIR}/**/*.*", recursive=True)
        
        if image_files:
            for i, image_path in enumerate(image_files[:max_samples or 100]):
                image_id = f"synthetic_{i}"
                
                # Create synthetic annotation
                has_nodules = random.choice([True, False])
                nodules = []
                
                if has_nodules:
                    # Load image to get dimensions
                    img = load_dicom_image(image_path)
                    if img is not None:
                        h, w = img.shape[:2]
                        
                        # Create random polygon
                        num_points = random.randint(3, 8)
                        center_x = random.randint(w//4, 3*w//4)
                        center_y = random.randint(h//4, 3*h//4)
                        radius = random.randint(10, min(w, h)//8)
                        
                        coords = []
                        for j in range(num_points):
                            angle = j * (2 * np.pi / num_points)
                            x = int(center_x + radius * np.cos(angle))
                            y = int(center_y + radius * np.sin(angle))
                            coords.append((x, y))
                        
                        nodules.append({
                            "coords": coords,
                            "subtlety": random.randint(1, 5),
                            "malignancy": random.randint(1, 5)
                        })
                
                dataset.append({
                    "image_id": image_id,
                    "case_id": f"case_{i}",
                    "image_path": image_path,
                    "nodules": nodules,
                    "has_nodules": has_nodules,
                    "projection": random.choice(["Frontal", "Lateral"])
                })
    
    # Balance dataset if needed
    if dataset:
        # Count positives and negatives
        positives = sum(1 for item in dataset if item["has_nodules"])
        negatives = len(dataset) - positives
        
        print(f"Dataset initially has {positives} positive and {negatives} negative samples")
        
        # Balance if too skewed
        if positives > 0 and negatives > 0 and (positives / negatives > 3 or negatives / positives > 3):
            if positives > negatives:
                target_count = min(negatives * 3, positives)
                positive_samples = random.sample([item for item in dataset if item["has_nodules"]], target_count)
                negative_samples = [item for item in dataset if not item["has_nodules"]]
                dataset = positive_samples + negative_samples
            else:
                target_count = min(positives * 3, negatives)
                negative_samples = random.sample([item for item in dataset if not item["has_nodules"]], target_count)
                positive_samples = [item for item in dataset if item["has_nodules"]]
                dataset = positive_samples + negative_samples
            
            print(f"Balanced dataset to {len([item for item in dataset if item['has_nodules']])} positive and {len([item for item in dataset if not item['has_nodules']])} negative samples")
    
    print(f"Final dataset has {len(dataset)} entries")
    return dataset

# ...existing code...

def preprocess_dataset(dataset, img_size=(224, 224), batch_size=32):
    print("\nPreprocessing dataset...")
    
    if not dataset:
        raise ValueError("Dataset is empty. Cannot preprocess.")
    
    # Split dataset
    train_data, temp_data = train_test_split(dataset, test_size=0.3, random_state=42)
    val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42)
    
    print(f"Split dataset into Train: {len(train_data)}, Validation: {len(val_data)}, Test: {len(test_data)}")
    
    def generate_data(data_list, augment=False):
        print(f"Generating data from {len(data_list)} samples {'with' if augment else 'without'} augmentation...")
        
        for item in data_list:
            try:
                image_path = item["image_path"]
                img = load_dicom_image(image_path)
                
                if img is None:
                    print(f"Warning: Could not load image {image_path}")
                    continue
                    
                # Create mask for nodules
                mask = create_nodule_mask(img.shape, item["nodules"])
                
                # Resize image and mask
                img_resized = cv2.resize(img, img_size)
                mask_resized = cv2.resize(mask, img_size)
                
                # Data augmentation for training
                if augment:
                    # Random horizontal flip
                    if random.random() > 0.5:
                        img_resized = cv2.flip(img_resized, 1)
                        mask_resized = cv2.flip(mask_resized, 1)
                    
                    # Random rotation
                    if random.random() > 0.5:
                        angle = random.uniform(-15, 15)
                        M = cv2.getRotationMatrix2D((img_size[0]//2, img_size[1]//2), angle, 1.0)
                        img_resized = cv2.warpAffine(img_resized, M, img_size)
                        mask_resized = cv2.warpAffine(mask_resized, M, img_size)
                    
                    # Random brightness and contrast
                    if random.random() > 0.5:
                        alpha = random.uniform(0.8, 1.2)  # Contrast
                        beta = random.uniform(-10, 10)    # Brightness
                        img_resized = cv2.convertScaleAbs(img_resized, alpha=alpha, beta=beta)
                
                # Normalize the image
                img_resized = img_resized / 255.0
                
                # Prepare binary mask
                mask_resized = (mask_resized > 0).astype(np.float32)
                
                # Create sequential data for BiLSTM
                # Here we simulate sequential data by creating multiple "views" of the same image
                seq_data = []
                for i in range(3):  # 3 frames for the sequence
                    # Zoom in/out effect
                    scale = 1.0 - (i * 0.1)
                    width = int(img_size[0] * scale)
                    height = int(img_size[1] * scale)
                    
                    if width < img_size[0] and height < img_size[1]:
                        x = (img_size[0] - width) // 2
                        y = (img_size[1] - height) // 2
                        
                        roi = img_resized[y:y+height, x:x+width]
                        roi = cv2.resize(roi, img_size)
                        seq_data.append(roi)
                    else:
                        seq_data.append(img_resized)
                
                seq_data = np.array(seq_data)
                
                # Use one-hot encoding for classification
                is_nodule = 1.0 if item["has_nodules"] else 0.0
                
                # Create sample
                sample = (
                    {
                        "cnn_input": np.expand_dims(np.expand_dims(img_resized, axis=-1), axis=0),
                        "vit_input": np.expand_dims(np.expand_dims(img_resized, axis=-1), axis=0),
                        "lstm_input": np.expand_dims(np.expand_dims(seq_data, axis=-1), axis=0)
                    }, 
                    {
                        "classification": np.array([is_nodule]),
                        "segmentation": np.expand_dims(np.expand_dims(mask_resized, axis=0), axis=-1)
                    }
                )
                
                yield sample
                
            except Exception as e:
                print(f"Error processing {item.get('image_path', 'unknown')}: {e}")
                continue
    
    def create_tf_dataset(data_list, augment=False):
        output_signature = (
            {
                "cnn_input": tf.TensorSpec(shape=(1, img_size[0], img_size[1], 1), dtype=tf.float32),
                "vit_input": tf.TensorSpec(shape=(1, img_size[0], img_size[1], 1), dtype=tf.float32),
                "lstm_input": tf.TensorSpec(shape=(1, 3, img_size[0], img_size[1], 1), dtype=tf.float32)
            },
            {
                "classification": tf.TensorSpec(shape=(1,), dtype=tf.float32),
                "segmentation": tf.TensorSpec(shape=(1, img_size[0], img_size[1], 1), dtype=tf.float32)
            }
        )
        
        # Create the dataset
        try:
            dataset = tf.data.Dataset.from_generator(
                lambda: generate_data(data_list, augment),
                output_signature=output_signature
            )
            
            # Unbatch the data (remove the batch dimension of 1)
            dataset = dataset.map(lambda x, y: (
                {
                    "cnn_input": x["cnn_input"][0],
                    "vit_input": x["vit_input"][0],
                    "lstm_input": x["lstm_input"][0]
                },
                {
                    "classification": y["classification"][0],
                    "segmentation": y["segmentation"][0]
                }
            ))
            
            # Recache for faster access
            dataset = dataset.cache()
            
            # Batch the data
            dataset = dataset.batch(batch_size)
            dataset = dataset.prefetch(tf.data.AUTOTUNE)
            
            return dataset
            
        except Exception as e:
            print(f"Error creating TensorFlow dataset: {e}")
            raise
    
    # Create datasets
    train_dataset = create_tf_dataset(train_data, augment=True)
    val_dataset = create_tf_dataset(val_data, augment=False)
    test_dataset = create_tf_dataset(test_data, augment=False)
    
    return train_dataset, val_dataset, test_dataset

# ...existing code...
# Define Vision Transformer Block
def vision_transformer_block(inputs, dim, num_heads):
    # Layer Normalization 1
    x = layers.LayerNormalization(epsilon=1e-6)(inputs)
    
    # Multi-Head Attention
    attention_output = layers.MultiHeadAttention(
        num_heads=num_heads, key_dim=dim // num_heads
    )(x, x)
    
    # Skip Connection 1
    x = layers.Add()([attention_output, inputs])
    
    # Layer Normalization 2
    y = layers.LayerNormalization(epsilon=1e-6)(x)
    
    # MLP
    y = layers.Dense(dim * 4, activation="gelu")(y)
    y = layers.Dense(dim)(y)
    y = layers.Dropout(0.1)(y)
    
    # Skip Connection 2
    out = layers.Add()([y, x])
    
    return out

# Build the Vision Transformer Module

def build_vit(img_size=(224, 224), patch_size=16, num_heads=8, transformer_layers=6):
    print("Building Vision Transformer...")
    
    inputs = layers.Input(shape=(img_size[0], img_size[1], 1))
    
    # Patch embedding
    x = layers.Conv2D(
        filters=64,
        kernel_size=patch_size,
        strides=patch_size,
        padding="valid"
    )(inputs)
    
    # Calculate dimensions
    h = img_size[0] // patch_size
    w = img_size[1] // patch_size
    
    # Reshape patches
    x = layers.Reshape((h * w, 64))(x)
    
    # Add position embedding
    positions = tf.range(start=0, limit=h * w, delta=1)
    pos_embedding = layers.Embedding(input_dim=h * w, output_dim=64)(positions)
    x = x + pos_embedding
    
    # Transformer layers
    for _ in range(transformer_layers):
        # Layer Normalization 1
        x1 = layers.LayerNormalization(epsilon=1e-6)(x)
        
        # Multi-Head Self Attention
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads,
            key_dim=64 // num_heads,
            value_dim=64 // num_heads
        )(x1, x1)
        
        # Skip Connection 1
        x2 = layers.Add()([attention_output, x])
        
        # Layer Normalization 2
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        
        # MLP
        x3 = layers.Dense(64 * 4, activation="gelu")(x3)
        x3 = layers.Dense(64)(x3)
        x3 = layers.Dropout(0.1)(x3)
        
        # Skip Connection 2
        x = layers.Add()([x2, x3])
    
    # Final Layer Norm
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    
    # Global Average Pooling
    x = layers.GlobalAveragePooling1D()(x)
    
    return tf.keras.Model(inputs=inputs, outputs=x, name="vision_transformer")

class ExpandDimsLayer(layers.Layer):
    def call(self, inputs):
        return tf.expand_dims(inputs, -1)

class TileLayer(layers.Layer):
    def call(self, inputs):
        return tf.tile(inputs, [1, 1, 1, 3])

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[1], input_shape[2], 3)



# Build the Hybrid Model

def build_hybrid_model(img_size=(224, 224)):
    print("Building hybrid model...")
    
    # 1. Input layers
    cnn_input = layers.Input(shape=(img_size[0], img_size[1], 1), name="cnn_input")
    vit_input = layers.Input(shape=(img_size[0], img_size[1], 1), name="vit_input")
    lstm_input = layers.Input(shape=(3, img_size[0], img_size[1], 1), name="lstm_input")
    
    # 2. CNN Branch
    print("Building CNN branch...")
    x_cnn = TileLayer()(cnn_input)
    efficient_net = EfficientNetB0(
        include_top=False,
        weights='imagenet',
        input_shape=(img_size[0], img_size[1], 3),
        pooling='avg'
    )
    for layer in efficient_net.layers[:100]:
        layer.trainable = False
    x_cnn = efficient_net(x_cnn)
    x_cnn = layers.Dense(128, activation='relu')(x_cnn)
    
    # 3. ViT Branch (in build_hybrid_model function)
    print("Building ViT branch...")
    vit_model = build_vit(img_size=img_size)
    x_vit = vit_model(vit_input)
    x_vit = layers.Dense(128, activation='relu')(x_vit)
    
    # 4. BiLSTM Branch
    print("Building BiLSTM branch...")
    x_lstm = layers.TimeDistributed(layers.Conv2D(32, (3, 3), strides=(2, 2), padding='same', activation='relu'))(lstm_input)
    x_lstm = layers.TimeDistributed(layers.BatchNormalization())(x_lstm)
    
    x_lstm = layers.TimeDistributed(layers.Conv2D(64, (3, 3), strides=(2, 2), padding='same', activation='relu'))(x_lstm)
    x_lstm = layers.TimeDistributed(layers.BatchNormalization())(x_lstm)
    
    x_lstm = layers.TimeDistributed(layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same', activation='relu'))(x_lstm)
    x_lstm = layers.TimeDistributed(layers.BatchNormalization())(x_lstm)
    
    # Global average pooling instead of flatten
    x_lstm = layers.TimeDistributed(layers.GlobalAveragePooling2D())(x_lstm)
    
    # BiLSTM layer
    x_lstm = layers.Bidirectional(layers.LSTM(64, return_sequences=False))(x_lstm)
    x_lstm = layers.Dense(128, activation='relu')(x_lstm)
    x_lstm = layers.Dropout(0.3)(x_lstm)

    
    # 5. Feature Fusion
    combined_features = layers.concatenate([x_cnn, x_vit, x_lstm])
    shared_features = layers.Dense(256, activation='relu')(combined_features)
    shared_features = layers.BatchNormalization()(shared_features)
    shared_features = layers.Dropout(0.3)(shared_features)
    
    # 6. Classification Branch
    classification_output = layers.Dense(1, activation='sigmoid', name="classification")(shared_features)
    
    # 7. Segmentation Branch
    print("Building segmentation branch...")
    initial_size = img_size[0] // 32
    initial_channels = 256
    
    x_seg = layers.Dense(initial_size * initial_size * initial_channels)(shared_features)
    x_seg = layers.Reshape((initial_size, initial_size, initial_channels))(x_seg)
    
    # Upsampling blocks
    for filters in [128, 64, 32, 16]:
        x_seg = layers.Conv2DTranspose(filters, 3, strides=2, padding='same')(x_seg)
        x_seg = layers.BatchNormalization()(x_seg)
        x_seg = layers.Activation('relu')(x_seg)
        x_seg = layers.Conv2D(filters, 3, padding='same', activation='relu')(x_seg)
    
    # Final upsampling and convolution
    x_seg = layers.Conv2DTranspose(8, 3, strides=2, padding='same')(x_seg)
    x_seg = layers.BatchNormalization()(x_seg)
    x_seg = layers.Activation('relu')(x_seg)
    segmentation_output = layers.Conv2D(1, 1, activation='sigmoid', name="segmentation")(x_seg)
    
    # Create model
    model = models.Model(
        inputs=[cnn_input, vit_input, lstm_input],
        outputs=[classification_output, segmentation_output]
    )
    
    # Compile model
    model.compile(
        optimizer=optimizers.Adam(learning_rate=1e-4),
        loss={
            "classification": "binary_crossentropy",
            "segmentation": "binary_crossentropy"
        },
        metrics={
            "classification": ["accuracy", tf.keras.metrics.AUC()],
            "segmentation": ["accuracy", tf.keras.metrics.IoU(num_classes=2, target_class_ids=[1])]
        },
        loss_weights={
            "classification": 1.0,
            "segmentation": 0.5
        }
    )
    
    return model


def train_model(model, train_dataset, val_dataset, epochs=5):
    # Create callbacks
    checkpoint = ModelCheckpoint(
        "best_model.keras",
        monitor="val_classification_accuracy",
        save_best_only=True,
        mode="max",
        verbose=1
    )
    
    early_stopping = EarlyStopping(
        monitor="val_classification_accuracy",
        patience=5,
        mode="max",
        restore_best_weights=True,
        verbose=1
    )
    
    reduce_lr = ReduceLROnPlateau(
        monitor="val_classification_accuracy",
        factor=0.5,
        patience=3,
        min_lr=1e-6,
        mode="max",
        verbose=1
    )
    
    # TensorBoard for visualization
    log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    tensorboard = TensorBoard(
        log_dir=log_dir,
        histogram_freq=1,
        write_graph=True,
        update_freq="epoch"
    )
    
    # Train model
    print("Starting training...")
    start_time = time.time()
    
    history = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=epochs,
        callbacks=[checkpoint, early_stopping, reduce_lr, tensorboard],
        verbose=1
    )
    
    total_time = time.time() - start_time
    print(f"Training completed in {total_time/60:.2f} minutes")
    
    return history, model

# ...existing code...

# Evaluate model
def evaluate_model(model, test_dataset):
    print("Evaluating model...")
    results = model.evaluate(test_dataset, verbose=1)
    
    metrics = model.metrics_names
    for i, metric in enumerate(metrics):
        print(f"{metric}: {results[i]:.4f}")
    
    return results

# Plot training history
def plot_training_history(history):
    plt.figure(figsize=(12, 8))
    
    # Plot classification accuracy
    plt.subplot(2, 2, 1)
    plt.plot(history.history['classification_accuracy'])
    plt.plot(history.history['val_classification_accuracy'])
    plt.title('Classification Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')
    
    # Plot classification loss
    plt.subplot(2, 2, 2)
    plt.plot(history.history['classification_loss'])
    plt.plot(history.history['val_classification_loss'])
    plt.title('Classification Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')
    
    # Plot segmentation accuracy
    plt.subplot(2, 2, 3)
    plt.plot(history.history['segmentation_accuracy'])
    plt.plot(history.history['val_segmentation_accuracy'])
    plt.title('Segmentation Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')
    
    # Plot segmentation loss
    plt.subplot(2, 2, 4)
    plt.plot(history.history['segmentation_loss'])
    plt.plot(history.history['val_segmentation_loss'])
    plt.title('Segmentation Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')
    
    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.show()

# Prediction function
def predict_nodules(model, image_path, img_size=(224, 224)):
    # Load and preprocess image
    img = load_dicom_image(image_path)
    if img is None:
        print(f"Failed to load image: {image_path}")
        return None, None
    
    # Resize and normalize
    img_resized = cv2.resize(img, img_size)
    img_normalized = img_resized / 255.0
    
    # Create sequential data
    seq_data = []
    for i in range(3):
        scale = 1.0 - (i * 0.1)
        width = int(img_size[0] * scale)
        height = int(img_size[1] * scale)
        
        if width < img_size[0] and height < img_size[1]:
            x = (img_size[0] - width) // 2
            y = (img_size[1] - height) // 2
            
            roi = img_normalized[y:y+height, x:x+width]
            roi = cv2.resize(roi, img_size)
            seq_data.append(roi)
        else:
            seq_data.append(img_normalized)
    
    seq_data = np.array(seq_data)
    
    # Prepare inputs
    inputs = {
        "cnn_input": np.expand_dims(img_normalized, axis=0),
        "vit_input": np.expand_dims(img_normalized, axis=0),
        "lstm_input": np.expand_dims(seq_data, axis=0)
    }
    
    # Make prediction
    nodule_prob, nodule_mask = model.predict(inputs)
    
    # Post-process mask
    nodule_mask = nodule_mask[0]
    nodule_mask = (nodule_mask > 0.5).astype(np.uint8) * 255
    
    return nodule_prob[0][0], nodule_mask

# Visualize predictions
def visualize_prediction(image_path, model, img_size=(224, 224)):
    # Load image
    img = load_dicom_image(image_path)
    if img is None:
        print(f"Failed to load image: {image_path}")
        return
    
    # Resize image
    img_resized = cv2.resize(img, img_size)
    
    # Make prediction
    prob, mask = predict_nodules(model, image_path, img_size)
    
    # Plot results
    plt.figure(figsize=(12, 4))
    
    # Original image
    plt.subplot(1, 3, 1)
    plt.imshow(img_resized, cmap='gray')
    plt.title(f'Original Image')
    plt.axis('off')
    
    # Prediction mask
    plt.subplot(1, 3, 2)
    plt.imshow(mask, cmap='jet', alpha=0.7)
    plt.title(f'Predicted Mask')
    plt.axis('off')
    
    # Overlay
    plt.subplot(1, 3, 3)
    plt.imshow(img_resized, cmap='gray')
    plt.imshow(mask, cmap='jet', alpha=0.4)
    plt.title(f'Nodule Prob: {prob:.2f}')
    plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('prediction_visualization.png')
    plt.show()

# Main execution
def main():
    print("Starting Nodule Detection Pipeline")
    
    # Prepare dataset (limit samples for faster testing)
    # For full training, remove the max_samples parameter
    dataset = prepare_dataset(max_samples=1000)
    
    # Preprocess dataset
    img_size = (224, 224)
    batch_size = 8
    train_dataset, val_dataset, test_dataset = preprocess_dataset(dataset, img_size, batch_size)
    
    # Build model
    model = build_hybrid_model(img_size)
    model.summary()
    
    # Train model
    history, trained_model = train_model(model, train_dataset, val_dataset, epochs=5)
    
    # Evaluate model
    evaluate_model(trained_model, test_dataset)
    
    # Plot training history
    plot_training_history(history)
    
    # Save model
    trained_model.save('nodule_detection_model.h5')
    print("Model saved as 'nodule_detection_model.h5'")
    
    # Test prediction on a sample image
    if dataset:
        sample_image_path = dataset[0]["image_path"]
        visualize_prediction(sample_image_path, trained_model, img_size)

if __name__ == "__main__":
    main()

GPU Available: []
Loading metadata from CSV...
Loaded 463 entries from metadata CSV
          case_id               image_id projection findings
0  LIDC-IDRI-0001  LIDC-IDRI-0001-000001    Frontal  Nodules
1  LIDC-IDRI-0001  LIDC-IDRI-0001-000002    Lateral  Nodules
2  LIDC-IDRI-0003  LIDC-IDRI-0003-000001    Frontal  Nodules
3  LIDC-IDRI-0003  LIDC-IDRI-0003-000002    Lateral  Nodules
4  LIDC-IDRI-0004  LIDC-IDRI-0004-000001    Frontal  Nodules

Exploring dataset structure:
Found 11 XML files in ./dataset\annotations/annotations/tcia-lidc-xml\157
Sample files: ['158.xml', '159.xml', '160.xml']
Found 232 XML files in ./dataset\annotations/annotations/tcia-lidc-xml\185
Sample files: ['068.xml', '069.xml', '070.xml']
Found 300 XML files in ./dataset\annotations/annotations/tcia-lidc-xml\186
Sample files: ['000.xml', '001.xml', '002.xml']
Found 300 XML files in ./dataset\annotations/annotations/tcia-lidc-xml\187
Sample files: ['000.xml', '001.xml', '002.xml']
Found 300 XML files in ./data

Parsing XML files: 100%|██████████| 1319/1319 [00:43<00:00, 30.26it/s]


Successfully parsed 0 XML files with 0 annotations


Processing from metadata: 100%|██████████| 463/463 [00:01<00:00, 287.95it/s]


Dataset initially has 440 positive and 23 negative samples
Balanced dataset to 69 positive and 23 negative samples
Final dataset has 92 entries

Preprocessing dataset...
Split dataset into Train: 64, Validation: 14, Test: 14
Building hybrid model...
Building CNN branch...
Building ViT branch...
Building Vision Transformer...
Building BiLSTM branch...
Building segmentation branch...


Starting training...
Epoch 1/5




ValueError: Exception encountered when calling Functional.call().

[1mInput 0 of layer "vision_transformer" is incompatible with the layer: expected shape=(None, 224, 224, 1), found shape=(None, 3, 224, 224)[0m

Arguments received by Functional.call():
  • inputs={'cnn_input': 'tf.Tensor(shape=(None, 224, 224, 1), dtype=float32)', 'vit_input': 'tf.Tensor(shape=(None, 224, 224, 1), dtype=float32)', 'lstm_input': 'tf.Tensor(shape=(None, 3, 224, 224, 1), dtype=float32)'}
  • training=True
  • mask={'cnn_input': 'None', 'vit_input': 'None', 'lstm_input': 'None'}

In [None]:
import os
import numpy as np
import pandas as pd
import xml.etree.ElementTree as ET
import cv2
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import glob
import random
import time
import datetime
import pydicom

# Set seed for reproducibility
random.seed(42)
np.random.seed(42)
tf.random.set_seed(42)

# Check for GPU
print("GPU Available:", tf.config.list_physical_devices('GPU'))

# Dataset paths
BASE_DIR = "./dataset"
IMAGE_DIR = os.path.join(BASE_DIR, "images/images")
ANNOTATION_DIR = os.path.join(BASE_DIR, "annotations/annotations/tcia-lidc-xml")
CSV_FILE = os.path.join(BASE_DIR, "lidc_metadata.csv")

# Load metadata
print("Loading metadata from CSV...")
try:
    metadata_df = pd.read_csv(CSV_FILE)
    print(f"Loaded {len(metadata_df)} entries from metadata CSV")
    print(metadata_df.head())
except Exception as e:
    print(f"Error loading CSV: {e}")
    # Create a backup empty dataframe if loading fails
    metadata_df = pd.DataFrame(columns=['case_id', 'image_id', 'projection', 'findings'])

# Print dataset structure for debugging
print("\nExploring dataset structure:")
for subdir in ['157', '185', '186', '187', '188', '189']:
    subdir_path = os.path.join(ANNOTATION_DIR, subdir)
    if os.path.exists(subdir_path):
        xml_files = [f for f in os.listdir(subdir_path) if f.endswith('.xml')]
        print(f"Found {len(xml_files)} XML files in {subdir_path}")
        if xml_files:
            print(f"Sample files: {xml_files[:3]}")

# Improved XML parsing function
def parse_xml_annotations(xml_path):
    try:
        tree = ET.parse(xml_path)
        root = tree.getroot()
        
        # Get the namespace from root tag
        namespace = ''
        if '}' in root.tag:
            namespace = '{' + root.tag.split('}')[0].split('{')[1] + '}'
        
        # Get reading session elements
        reading_sessions = root.findall(f".//{namespace}readingSession")
        if not reading_sessions:
            reading_sessions = root.findall(".//readingSession")
        
        # Get series UID
        series_uid_elem = root.find(f".//{namespace}SeriesInstanceUid") or \
                         root.find(".//SeriesInstanceUid") or \
                         root.find(f".//{namespace}seriesuid") or \
                         root.find(".//seriesuid")
        
        image_id = series_uid_elem.text if series_uid_elem is not None else os.path.splitext(os.path.basename(xml_path))[0]
        
        nodules = []
        for session in reading_sessions:
            # Find nodule elements
            nodule_elems = session.findall(f".//{namespace}unblindedReadNodule") or \
                          session.findall(".//unblindedReadNodule")
            
            for nodule in nodule_elems:
                roi_elem = nodule.find(f".//{namespace}roi") or nodule.find(".//roi")
                if roi_elem is None:
                    continue
                
                # Get coordinates
                coords = []
                edge_maps = roi_elem.findall(f".//{namespace}edgeMap") or roi_elem.findall(".//edgeMap")
                for edge_map in edge_maps:
                    x_elem = edge_map.find(f".//{namespace}xCoord") or edge_map.find(".//xCoord")
                    y_elem = edge_map.find(f".//{namespace}yCoord") or edge_map.find(".//yCoord")
                    
                    if x_elem is not None and y_elem is not None:
                        try:
                            x = int(float(x_elem.text))
                            y = int(float(y_elem.text))
                            coords.append((x, y))
                        except (ValueError, TypeError):
                            continue
                
                if len(coords) < 3:  # Need at least 3 points for a polygon
                    continue
                
                # Get characteristics
                chars_elem = nodule.find(f".//{namespace}characteristics") or nodule.find(".//characteristics")
                subtlety = malignancy = 3  # Default values
                
                if chars_elem is not None:
                    subt_elem = chars_elem.find(f".//{namespace}subtlety") or chars_elem.find(".//subtlety")
                    malig_elem = chars_elem.find(f".//{namespace}malignancy") or chars_elem.find(".//malignancy")
                    
                    if subt_elem is not None and subt_elem.text:
                        subtlety = int(float(subt_elem.text))
                    if malig_elem is not None and malig_elem.text:
                        malignancy = int(float(malig_elem.text))
                
                nodules.append({
                    "coords": coords,
                    "subtlety": subtlety,
                    "malignancy": malignancy
                })
        
        if nodules:
            print(f"Successfully parsed {len(nodules)} nodules from {os.path.basename(xml_path)}")
            return image_id, nodules
        
        return None, []
        
    except Exception as e:
        print(f"Error parsing {xml_path}: {str(e)}")
        return None, []

# Function to find all XML files
def find_all_xml_files():
    print("Searching for XML files...")
    xml_files = []
    for root, dirs, files in os.walk(ANNOTATION_DIR):
        for file in files:
            if file.endswith('.xml'):
                xml_files.append(os.path.join(root, file))
    print(f"Found {len(xml_files)} XML files")
    return xml_files

# Function to find image files by pattern matching
def find_image_files(image_id):
    # Try different patterns to match images
    patterns = [
        f"{IMAGE_DIR}/**/{image_id}*.dcm",
        f"{IMAGE_DIR}/**/*{image_id}*.dcm",
        f"{IMAGE_DIR}/**/{image_id.replace('-', '_')}*.dcm"
    ]
    
    for pattern in patterns:
        image_paths = glob.glob(pattern, recursive=True)
        if image_paths:
            return image_paths[0]
    
    return None

# Function to create binary masks from nodule coordinates
def create_nodule_mask(image_shape, nodules):
    mask = np.zeros(image_shape[:2], dtype=np.uint8)
    
    for nodule in nodules:
        coords = nodule["coords"]
        if len(coords) > 2:
            pts = np.array(coords, np.int32)
            pts = pts.reshape((-1, 1, 2))
            cv2.fillPoly(mask, [pts], 255)
    
    return mask

# Improved function to load DICOM images
def load_dicom_image(image_path):
    try:
        # Try to load as DICOM
        ds = pydicom.dcmread(image_path)
        img = ds.pixel_array
        
        # Normalize to 0-255
        img = img - np.min(img)
        if np.max(img) > 0:
            img = img / np.max(img) * 255
        img = img.astype(np.uint8)
        
        # Check if image needs to be inverted (black background)
        if np.mean(img) > 127:
            img = 255 - img
            
        return img
    except Exception as e:
        print(f"Error loading DICOM {image_path}: {e}")
        
        # Try to load as regular image
        try:
            img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
            if img is not None:
                return img
        except:
            pass
            
        return None

# Function to generate dataset with fallback mechanisms
def prepare_dataset(max_samples=None):
    print("\nPreparing dataset...")
    
    # Validate directories
    if not os.path.exists(IMAGE_DIR):
        raise ValueError(f"Image directory not found: {IMAGE_DIR}")
    if not os.path.exists(ANNOTATION_DIR):
        raise ValueError(f"Annotation directory not found: {ANNOTATION_DIR}")
    
    # Find all XML files
    xml_files = find_all_xml_files()
    if not xml_files:
        raise ValueError("No XML files found in the annotation directory")
    
    print(f"Found {len(xml_files)} XML files")
    
    # Parse XML annotations with progress bar
    annotations = {}
    successful_files = 0
    total_nodules = 0
    
    for xml_file in tqdm(xml_files, desc="Parsing XML files"):
        image_id, nodules = parse_xml_annotations(xml_file)
        if image_id and nodules:
            annotations[image_id] = nodules
            successful_files += 1
            total_nodules += len(nodules)
    
    print(f"\nParsing Summary:")
    print(f"Successfully parsed {successful_files} XML files")
    print(f"Total nodules found: {total_nodules}")
    print(f"Total unique images with annotations: {len(annotations)}")
    
    # Create dataset
    dataset = []
    
    # Try metadata-based approach first
    if not metadata_df.empty:
        print("\nProcessing from metadata...")
        for _, row in tqdm(metadata_df.iterrows(), desc="Processing metadata entries"):
            image_id = str(row['image_id'])
            case_id = str(row['case_id'])
            
            # Find image path
            image_path = find_image_files(image_id)
            if not image_path:
                continue
            
            # Verify image can be loaded
            img = load_dicom_image(image_path)
            if img is None:
                continue
            
            # Get nodule information
            nodules = annotations.get(image_id, [])
            has_nodules = bool(nodules) or ('findings' in row and 'Nodules' in str(row['findings']))
            
            dataset.append({
                "image_id": image_id,
                "case_id": case_id,
                "image_path": image_path,
                "nodules": nodules,
                "has_nodules": has_nodules,
                "projection": row.get('projection', 'Unknown')
            })
    
    # If dataset is too small, try direct file scanning
    if len(dataset) < (max_samples or 100):
        print("\nScanning image files directly...")
        image_files = glob.glob(f"{IMAGE_DIR}/**/*.dcm", recursive=True)
        
        for image_path in tqdm(image_files, desc="Processing image files"):
            if any(d['image_path'] == image_path for d in dataset):
                continue
                
            image_id = os.path.splitext(os.path.basename(image_path))[0]
            
            # Verify image can be loaded
            img = load_dicom_image(image_path)
            if img is None:
                continue
            
            nodules = annotations.get(image_id, [])
            dataset.append({
                "image_id": image_id,
                "case_id": image_id.split('-')[0],
                "image_path": image_path,
                "nodules": nodules,
                "has_nodules": bool(nodules),
                "projection": "Unknown"
            })
    
    # Balance dataset
    if dataset:
        positives = sum(1 for item in dataset if item["has_nodules"])
        negatives = len(dataset) - positives
        
        print(f"\nDataset balance:")
        print(f"Positive samples (with nodules): {positives}")
        print(f"Negative samples (without nodules): {negatives}")
        
        # Balance if needed
        if positives > 0 and negatives > 0:
            if positives > negatives * 3:
                target_count = negatives * 3
                positive_samples = random.sample([item for item in dataset if item["has_nodules"]], target_count)
                negative_samples = [item for item in dataset if not item["has_nodules"]]
                dataset = positive_samples + negative_samples
            elif negatives > positives * 3:
                target_count = positives * 3
                negative_samples = random.sample([item for item in dataset if not item["has_nodules"]], target_count)
                positive_samples = [item for item in dataset if item["has_nodules"]]
                dataset = positive_samples + negative_samples
            
            print(f"After balancing:")
            print(f"Positive samples: {len([item for item in dataset if item['has_nodules']])}")
            print(f"Negative samples: {len([item for item in dataset if not item['has_nodules']])}")
    
    if not dataset:
        raise ValueError("Failed to create dataset. No valid images found.")
    
    print(f"\nFinal dataset has {len(dataset)} entries")
    return dataset

# ...existing code...

def preprocess_dataset(dataset, img_size=(224, 224), batch_size=32):
    print("\nPreprocessing dataset...")
    
    def generate_data(data_list, augment=False):
        while True:
            for item in data_list:
                try:
                    image_path = item["image_path"]
                    img = load_dicom_image(image_path)
                    
                    if img is None:
                        continue
                    
                    # Ensure consistent image shape
                    img_resized = cv2.resize(img, img_size)
                    img_normalized = img_resized.astype(np.float32) / 255.0
                    img_normalized = np.expand_dims(img_normalized, axis=-1)
                    
                    yield img_normalized, float(item["has_nodules"])
                    
                except Exception as e:
                    print(f"Error processing {image_path}: {e}")
                    continue
    
    # Create datasets with correct shapes
    output_signature = (
        tf.TensorSpec(shape=(img_size[0], img_size[1], 1), dtype=tf.float32),
        tf.TensorSpec(shape=(), dtype=tf.float32)
    )
    
    # Split dataset
    train_data, temp_data = train_test_split(dataset, test_size=0.3, random_state=42)
    val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42)
    
    # Create tf.data.Dataset objects
    train_dataset = tf.data.Dataset.from_generator(
        lambda: generate_data(train_data, augment=True),
        output_signature=output_signature
    ).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    
    val_dataset = tf.data.Dataset.from_generator(
        lambda: generate_data(val_data, augment=False),
        output_signature=output_signature
    ).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    
    test_dataset = tf.data.Dataset.from_generator(
        lambda: generate_data(test_data, augment=False),
        output_signature=output_signature
    ).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    
    return train_dataset, val_dataset, test_dataset

# ...existing code...
# Define Vision Transformer Block
def vision_transformer_block(inputs, dim, num_heads):
    # Layer Normalization 1
    x = layers.LayerNormalization(epsilon=1e-6)(inputs)
    
    # Multi-Head Attention
    attention_output = layers.MultiHeadAttention(
        num_heads=num_heads, key_dim=dim // num_heads
    )(x, x)
    
    # Skip Connection 1
    x = layers.Add()([attention_output, inputs])
    
    # Layer Normalization 2
    y = layers.LayerNormalization(epsilon=1e-6)(x)
    
    # MLP
    y = layers.Dense(dim * 4, activation="gelu")(y)
    y = layers.Dense(dim)(y)
    y = layers.Dropout(0.1)(y)
    
    # Skip Connection 2
    out = layers.Add()([y, x])
    
    return out

# Build the Vision Transformer Module

def build_vit(img_size=(224, 224), patch_size=16, num_heads=8, transformer_layers=6):
    print("Building Vision Transformer...")
    
    # Input shape is (batch_size, height, width, channels)
    inputs = layers.Input(shape=(img_size[0], img_size[1], 1))
    
    # Ensure correct shape before processing
    x = layers.Permute((2, 3, 1))(inputs) if len(inputs.shape) == 4 and inputs.shape[1] == 3 else inputs
    
    # Convert to 3 channels using TileLayer
    x = TileLayer()(x)
    
    # Patch embedding
    patches = layers.Conv2D(
        filters=64,
        kernel_size=patch_size,
        strides=patch_size,
        padding="same",
        name="patch_embedding"
    )(x)
    
    # Calculate patch dimensions
    patch_h = img_size[0] // patch_size
    patch_w = img_size[1] // patch_size
    num_patches = patch_h * patch_w
    
    # Reshape patches
    x = layers.Reshape((num_patches, 64))(patches)
    
    # Add position embedding
    positions = tf.range(start=0, limit=num_patches, delta=1)
    pos_embed = layers.Embedding(input_dim=num_patches, output_dim=64)(positions)
    x = x + pos_embed
    
    # Transformer blocks
    for i in range(transformer_layers):
        x = vision_transformer_block(x, dim=64, num_heads=num_heads)
    
    # Final processing
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    x = layers.GlobalAveragePooling1D()(x)
    
    return tf.keras.Model(inputs=inputs, outputs=x, name="vision_transformer")

class ExpandDimsLayer(layers.Layer):
    def call(self, inputs):
        return tf.expand_dims(inputs, -1)

class TileLayer(layers.Layer):
    def call(self, inputs):
        # Ensure input is float32
        x = tf.cast(inputs, tf.float32)
        
        # Handle different input shapes
        if len(x.shape) == 4:  # (batch, height, width, channels)
            if x.shape[-1] == 1:
                # Tile the channels dimension
                return tf.tile(x, [1, 1, 1, 3])
            elif x.shape[1] == 3:  # If channels are in wrong position
                # Transpose to move channels to end
                return tf.transpose(x, [0, 2, 3, 1])
            return x
        elif len(x.shape) == 3:  # (height, width, channels)
            if x.shape[-1] == 1:
                return tf.tile(x, [1, 1, 3])
            return x
        else:
            raise ValueError(f"Unexpected input shape: {x.shape}")

    def compute_output_shape(self, input_shape):
        return (*input_shape[:-1], 3)



# Build the Hybrid Model

def build_hybrid_model(img_size=(224, 224)):
    print("Building hybrid model...")
    
    # Input layers with explicit shapes
    cnn_input = layers.Input(shape=(img_size[0], img_size[1], 1), name="cnn_input")
    vit_input = layers.Input(shape=(img_size[0], img_size[1], 1), name="vit_input")
    lstm_input = layers.Input(shape=(3, img_size[0], img_size[1], 1), name="lstm_input")
    
    # CNN Branch
    x_cnn = TileLayer()(cnn_input)
    efficient_net = EfficientNetB0(
        include_top=False,
        weights='imagenet',
        input_shape=(img_size[0], img_size[1], 3),
        pooling='avg'
    )
    for layer in efficient_net.layers[:100]:
        layer.trainable = False
    x_cnn = efficient_net(x_cnn)
    x_cnn = layers.Dense(128, activation='relu')(x_cnn)
    
    # ViT Branch
    print("Building ViT branch...")
    vit_model = build_vit(img_size=img_size)
    # Ensure input shape matches
    x_vit = vit_model(vit_input)
    x_vit = layers.Dense(128, activation='relu')(x_vit)
    
    # BiLSTM Branch
    print("Building BiLSTM branch...")
    # Reshape LSTM input to handle the sequence
    x_lstm = layers.Permute((2, 3, 1, 4))(lstm_input)
    x_lstm = layers.Reshape((img_size[0], img_size[1], 3))(x_lstm)
    
    x_lstm = layers.Conv2D(32, (3, 3), strides=(2, 2), padding='same')(x_lstm)
    x_lstm = layers.BatchNormalization()(x_lstm)
    x_lstm = layers.Activation('relu')(x_lstm)
    
    x_lstm = layers.Conv2D(64, (3, 3), strides=(2, 2), padding='same')(x_lstm)
    x_lstm = layers.BatchNormalization()(x_lstm)
    x_lstm = layers.Activation('relu')(x_lstm)
    
    x_lstm = layers.GlobalAveragePooling2D()(x_lstm)
    x_lstm = layers.Reshape((1, -1))(x_lstm)
    
    x_lstm = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(x_lstm)
    x_lstm = layers.Bidirectional(layers.LSTM(32))(x_lstm)
    x_lstm = layers.Dense(128, activation='relu')(x_lstm)
    x_lstm = layers.Dropout(0.3)(x_lstm)
    
    # 5. Feature Fusion
    combined_features = layers.concatenate([x_cnn, x_vit, x_lstm])
    shared_features = layers.Dense(256, activation='relu')(combined_features)
    shared_features = layers.BatchNormalization()(shared_features)
    shared_features = layers.Dropout(0.3)(shared_features)
    
    # 6. Classification Branch
    classification_output = layers.Dense(1, activation='sigmoid', name="classification")(shared_features)
    
    # 7. Segmentation Branch
    print("Building segmentation branch...")
    initial_size = img_size[0] // 32
    initial_channels = 256
    
    x_seg = layers.Dense(initial_size * initial_size * initial_channels)(shared_features)
    x_seg = layers.Reshape((initial_size, initial_size, initial_channels))(x_seg)
    
    # Upsampling blocks
    for filters in [128, 64, 32, 16]:
        x_seg = layers.Conv2DTranspose(filters, 3, strides=2, padding='same')(x_seg)
        x_seg = layers.BatchNormalization()(x_seg)
        x_seg = layers.Activation('relu')(x_seg)
        x_seg = layers.Conv2D(filters, 3, padding='same', activation='relu')(x_seg)
    
    # Final upsampling and convolution
    x_seg = layers.Conv2DTranspose(8, 3, strides=2, padding='same')(x_seg)
    x_seg = layers.BatchNormalization()(x_seg)
    x_seg = layers.Activation('relu')(x_seg)
    segmentation_output = layers.Conv2D(1, 1, activation='sigmoid', name="segmentation")(x_seg)
    
    # Create model
    model = models.Model(
        inputs=[cnn_input, vit_input, lstm_input],
        outputs=[classification_output, segmentation_output]
    )
    
    # Enable mixed precision training
    tf.keras.mixed_precision.set_global_policy('mixed_float16')
    
    # Compile model with updated settings
    model.compile(
        optimizer=optimizers.Adam(learning_rate=1e-4),
        loss={
            "classification": "binary_crossentropy",
            "segmentation": "binary_crossentropy"
        },
        metrics={
            "classification": ["accuracy", tf.keras.metrics.AUC()],
            "segmentation": ["accuracy", tf.keras.metrics.IoU(num_classes=2, target_class_ids=[1])]
        },
        loss_weights={
            "classification": 1.0,
            "segmentation": 0.5
        }
    )
    
    return model


def train_model(model, train_dataset, val_dataset, epochs=5):
    # Create callbacks
    checkpoint = ModelCheckpoint(
        "best_model.keras",
        monitor="val_classification_accuracy",
        save_best_only=True,
        mode="max",
        verbose=1
    )
    
    early_stopping = EarlyStopping(
        monitor="val_classification_accuracy",
        patience=5,
        mode="max",
        restore_best_weights=True,
        verbose=1
    )
    
    reduce_lr = ReduceLROnPlateau(
        monitor="val_classification_accuracy",
        factor=0.5,
        patience=3,
        min_lr=1e-6,
        mode="max",
        verbose=1
    )
    
    # TensorBoard for visualization
    log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    tensorboard = TensorBoard(
        log_dir=log_dir,
        histogram_freq=1,
        write_graph=True,
        update_freq="epoch"
    )
    
    # Train model
    print("Starting training...")
    start_time = time.time()
    
    history = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=epochs,
        callbacks=[checkpoint, early_stopping, reduce_lr, tensorboard],
        verbose=1
    )
    
    total_time = time.time() - start_time
    print(f"Training completed in {total_time/60:.2f} minutes")
    
    return history, model

# ...existing code...

# Evaluate model
def evaluate_model(model, test_dataset):
    print("Evaluating model...")
    results = model.evaluate(test_dataset, verbose=1)
    
    metrics = model.metrics_names
    for i, metric in enumerate(metrics):
        print(f"{metric}: {results[i]:.4f}")
    
    return results

# Plot training history
def plot_training_history(history):
    plt.figure(figsize=(12, 8))
    
    # Plot classification accuracy
    plt.subplot(2, 2, 1)
    plt.plot(history.history['classification_accuracy'])
    plt.plot(history.history['val_classification_accuracy'])
    plt.title('Classification Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')
    
    # Plot classification loss
    plt.subplot(2, 2, 2)
    plt.plot(history.history['classification_loss'])
    plt.plot(history.history['val_classification_loss'])
    plt.title('Classification Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')
    
    # Plot segmentation accuracy
    plt.subplot(2, 2, 3)
    plt.plot(history.history['segmentation_accuracy'])
    plt.plot(history.history['val_segmentation_accuracy'])
    plt.title('Segmentation Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')
    
    # Plot segmentation loss
    plt.subplot(2, 2, 4)
    plt.plot(history.history['segmentation_loss'])
    plt.plot(history.history['val_segmentation_loss'])
    plt.title('Segmentation Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')
    
    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.show()

# Prediction function
def predict_nodules(model, image_path, img_size=(224, 224)):
    # Load and preprocess image
    img = load_dicom_image(image_path)
    if img is None:
        print(f"Failed to load image: {image_path}")
        return None, None
    
    # Resize and normalize
    img_resized = cv2.resize(img, img_size)
    img_normalized = img_resized / 255.0
    
    # Create sequential data
    seq_data = []
    for i in range(3):
        scale = 1.0 - (i * 0.1)
        width = int(img_size[0] * scale)
        height = int(img_size[1] * scale)
        
        if width < img_size[0] and height < img_size[1]:
            x = (img_size[0] - width) // 2
            y = (img_size[1] - height) // 2
            
            roi = img_normalized[y:y+height, x:x+width]
            roi = cv2.resize(roi, img_size)
            seq_data.append(roi)
        else:
            seq_data.append(img_normalized)
    
    seq_data = np.array(seq_data)
    
    # Prepare inputs
    inputs = {
        "cnn_input": np.expand_dims(img_normalized, axis=0),
        "vit_input": np.expand_dims(img_normalized, axis=0),
        "lstm_input": np.expand_dims(seq_data, axis=0)
    }
    
    # Make prediction
    nodule_prob, nodule_mask = model.predict(inputs)
    
    # Post-process mask
    nodule_mask = nodule_mask[0]
    nodule_mask = (nodule_mask > 0.5).astype(np.uint8) * 255
    
    return nodule_prob[0][0], nodule_mask

# Visualize predictions
def visualize_prediction(image_path, model, img_size=(224, 224)):
    # Load image
    img = load_dicom_image(image_path)
    if img is None:
        print(f"Failed to load image: {image_path}")
        return
    
    # Resize image
    img_resized = cv2.resize(img, img_size)
    
    # Make prediction
    prob, mask = predict_nodules(model, image_path, img_size)
    
    # Plot results
    plt.figure(figsize=(12, 4))
    
    # Original image
    plt.subplot(1, 3, 1)
    plt.imshow(img_resized, cmap='gray')
    plt.title(f'Original Image')
    plt.axis('off')
    
    # Prediction mask
    plt.subplot(1, 3, 2)
    plt.imshow(mask, cmap='jet', alpha=0.7)
    plt.title(f'Predicted Mask')
    plt.axis('off')
    
    # Overlay
    plt.subplot(1, 3, 3)
    plt.imshow(img_resized, cmap='gray')
    plt.imshow(mask, cmap='jet', alpha=0.4)
    plt.title(f'Nodule Prob: {prob:.2f}')
    plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('prediction_visualization.png')
    plt.show()

# Main execution
def main():
    try:
        print("Starting Nodule Detection Pipeline")
        print("Checking GPU availability...")
        gpus = tf.config.list_physical_devices('GPU')
        if gpus:
            print(f"Found {len(gpus)} GPU(s)")
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
        else:
            print("No GPU found, using CPU")
        
        # Load and validate CSV data
        if not os.path.exists(CSV_FILE):
            print(f"Warning: CSV file not found: {CSV_FILE}")
            print("Will attempt to create dataset without metadata")
        else:
            print(f"Loading metadata from {CSV_FILE}")
        
        # Prepare dataset with proper error handling
        dataset = prepare_dataset(max_samples=1000)
        
        # Set appropriate batch size based on available memory
        if gpus:
            batch_size = 8
        else:
            batch_size = 4
        
        img_size = (224, 224)
        print(f"\nUsing batch size: {batch_size}, image size: {img_size}")
        
        # Create datasets
        train_dataset, val_dataset, test_dataset = preprocess_dataset(
            dataset, img_size, batch_size
        )
        
        # Build model
        print("\nBuilding model...")
        model = build_simple_model(img_size)
        model.summary()
        
        # Train model
        print("\nStarting training...")
        history = model.fit(
            train_dataset,
            validation_data=val_dataset,
            epochs=10,
            callbacks=[
                ModelCheckpoint("best_model.keras", monitor="val_accuracy", 
                              save_best_only=True, mode="max"),
                EarlyStopping(monitor="val_accuracy", patience=5, 
                            restore_best_weights=True),
                ReduceLROnPlateau(monitor="val_accuracy", factor=0.5, 
                                patience=3, min_lr=1e-6)
            ],
            verbose=1
        )
        
        # Evaluate model
        print("\nEvaluating model...")
        results = model.evaluate(test_dataset, verbose=1)
        print(f"Test accuracy: {results[1]:.4f}")
        
        # Plot training history
        plt.figure(figsize=(10, 4))
        
        # Plot accuracy
        plt.subplot(1, 2, 1)
        plt.plot(history.history['accuracy'])
        plt.plot(history.history['val_accuracy'])
        plt.title('Model Accuracy')
        plt.ylabel('Accuracy')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Validation'])
        
        # Plot loss
        plt.subplot(1, 2, 2)
        plt.plot(history.history['loss'])
        plt.plot(history.history['val_loss'])
        plt.title('Model Loss')
        plt.ylabel('Loss')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Validation'])
        
        plt.tight_layout()
        plt.savefig('training_history.png')
        plt.show()
        
        # Save model
        model_path = 'nodule_detection_model.h5'
        model.save(model_path)
        print(f"\nModel saved as '{model_path}'")
        
        # Test prediction
        if dataset:
            print("\nTesting prediction on a sample image...")
            sample_image_path = dataset[0]["image_path"]
            visualize_prediction(sample_image_path, model, img_size)
        
    except Exception as e:
        print(f"\nError in main execution: {str(e)}")
        import traceback
        traceback.print_exc()
        raise

def build_simple_model(img_size=(224, 224)):
    print("Building EfficientNet model...")
    
    # Input layer
    inputs = layers.Input(shape=(img_size[0], img_size[1], 1), name="image_input")
    
    # Convert single channel to 3 channels
    x = TileLayer()(inputs)
    
    # EfficientNet backbone
    efficient_net = EfficientNetB0(
        include_top=False,
        weights='imagenet',
        input_shape=(img_size[0], img_size[1], 3)
    )
    
    # Freeze early layers for transfer learning
    for layer in efficient_net.layers[:100]:
        layer.trainable = False
    
    x = efficient_net(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.3)(x)
    
    # Classification head
    x = layers.Dense(256, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.4)(x)
    outputs = layers.Dense(1, activation='sigmoid', name="classification")(x)
    
    # Create model
    model = models.Model(inputs=inputs, outputs=outputs)
    
    # Compile model
    model.compile(
        optimizer=optimizers.Adam(learning_rate=1e-4),
        loss="binary_crossentropy",
        metrics=["accuracy", tf.keras.metrics.AUC()]
    )
    
    return model

if __name__ == "__main__":
    main()

GPU Available: []
Loading metadata from CSV...
Loaded 463 entries from metadata CSV
          case_id               image_id projection findings
0  LIDC-IDRI-0001  LIDC-IDRI-0001-000001    Frontal  Nodules
1  LIDC-IDRI-0001  LIDC-IDRI-0001-000002    Lateral  Nodules
2  LIDC-IDRI-0003  LIDC-IDRI-0003-000001    Frontal  Nodules
3  LIDC-IDRI-0003  LIDC-IDRI-0003-000002    Lateral  Nodules
4  LIDC-IDRI-0004  LIDC-IDRI-0004-000001    Frontal  Nodules

Exploring dataset structure:
Found 11 XML files in ./dataset\annotations/annotations/tcia-lidc-xml\157
Sample files: ['158.xml', '159.xml', '160.xml']
Found 232 XML files in ./dataset\annotations/annotations/tcia-lidc-xml\185
Sample files: ['068.xml', '069.xml', '070.xml']
Found 300 XML files in ./dataset\annotations/annotations/tcia-lidc-xml\186
Sample files: ['000.xml', '001.xml', '002.xml']
Found 300 XML files in ./dataset\annotations/annotations/tcia-lidc-xml\187
Sample files: ['000.xml', '001.xml', '002.xml']
Found 300 XML files in ./data

Parsing XML files: 100%|██████████| 1319/1319 [00:41<00:00, 31.75it/s] 



Parsing Summary:
Successfully parsed 0 XML files
Total nodules found: 0
Total unique images with annotations: 0

Processing from metadata...


Processing metadata entries: 463it [00:42, 10.96it/s]



Scanning image files directly...


Processing image files: 100%|██████████| 463/463 [00:00<00:00, 48746.49it/s]



Dataset balance:
Positive samples (with nodules): 440
Negative samples (without nodules): 23
After balancing:
Positive samples: 69
Negative samples: 23

Final dataset has 92 entries

Using batch size: 4, image size: (224, 224)

Preprocessing dataset...

Building model...
Building EfficientNet model...



Starting training...
Epoch 1/10
   2045/Unknown [1m6584s[0m 3s/step - accuracy: 0.6631 - auc_35: 0.7150 - loss: 0.6666