In [2]:
# Quick Copy-Paste Version - GPT-4 LLM Report Generator
# Just copy this entire cell to your notebook

import os
os.environ["CUDA_VISIBLE_DEVICES"] = '6'
import base64
import numpy as np
from PIL import Image
from io import BytesIO
from typing import Dict, Optional

# OpenAI API for LLM
try:
    from openai import OpenAI
    OPENAI_AVAILABLE = True
    print("‚úì OpenAI available")
except ImportError:
    OPENAI_AVAILABLE = False
    print("‚ùå Install: pip install openai")

#=============================================================================
# CONFIGURATION
#=============================================================================

class LLMConfig:
    """Configuration for LLM Report Generation"""
    
    # API Configuration - HARDCODED (Change this to your new key!)
    openai_api_key = "OPEN_AI_API"  # ‚Üê PUT YOUR NEW KEY HERE
    
    # Output settings
    output_dir = "./medical_reports"
    num_samples = 5
    
    # Model to use (can override XAI config)
    model_to_use = {
        'name': 'densenet_161',
        'size': 224,
        'folder': '(224, 224)'
    }
    fold_to_analyze = 1
    
    # GPT Model Configuration
    gpt_model = 'gpt-4o'  # Options: 'gpt-4o', 'gpt-4-turbo', 'gpt-4'
    temperature = 0.3

llm_config = LLMConfig()

print(f"OpenAI API Key set: {bool(llm_config.openai_api_key)}")
print(f"GPT Model: {llm_config.gpt_model}")

#=============================================================================
# GRAD-CAM ATTENTION ANALYSIS
#=============================================================================

def analyze_gradcam_for_prompt(cam: np.ndarray) -> Dict:
    """Analyze Grad-CAM heatmap for LLM prompt"""
    threshold = np.percentile(cam, 75)
    high_attention = cam > threshold
    
    attention_stats = {
        'mean_activation': float(cam.mean()),
        'max_activation': float(cam.max()),
        'attention_coverage': float(high_attention.sum() / high_attention.size),
        'attention_regions': []
    }
    
    # Divide into anatomical regions
    h, w = cam.shape
    regions = {
        'upper_medial': cam[:h//2, :w//2],
        'upper_lateral': cam[:h//2, w//2:],
        'lower_medial': cam[h//2:, :w//2],
        'lower_lateral': cam[h//2:, w//2:],
        'central_joint': cam[h//4:3*h//4, w//4:3*w//4]
    }
    
    for region_name, region_data in regions.items():
        region_mean = float(region_data.mean())
        if region_mean > attention_stats['mean_activation']:
            attention_stats['attention_regions'].append({
                'region': region_name.replace('_', ' ').title(),
                'activation': region_mean
            })
    
    attention_stats['attention_regions'].sort(key=lambda x: x['activation'], reverse=True)
    return attention_stats

#=============================================================================
# GPT-4 MEDICAL REPORT GENERATOR
#=============================================================================

class MedicalReportGenerator:
    """Generate medical reports using GPT-4 Vision API"""
    
    def __init__(self, model: str = 'gpt-4o', temperature: float = 0.3, api_key: str = None):
        """
        Initialize GPT-4 based medical report generator
        
        Args:
            model: GPT model ('gpt-4o', 'gpt-4-turbo', 'gpt-4')
            temperature: Generation temperature (0.0 - 1.0)
            api_key: OpenAI API key (or uses OPENAI_API_KEY env)
        """
        if not OPENAI_AVAILABLE:
            raise ImportError("Install: pip install openai")
        
        self.model = model
        self.temperature = temperature
        self.client = OpenAI(api_key=api_key)
        print(f"‚úì Initialized GPT-4 ({model}) with temperature={temperature}")
    
    def image_to_base64(self, image: np.ndarray) -> str:
        """Convert numpy image to base64"""
        pil_image = Image.fromarray(image)
        buffer = BytesIO()
        pil_image.save(buffer, format='PNG')
        return base64.b64encode(buffer.getvalue()).decode()
    
    def generate_report(
        self,
        original_image: np.ndarray,
        gradcam_image: np.ndarray,
        prediction: int,
        confidence: float,
        all_probs: np.ndarray,
        attention_analysis: Dict,
        class_names: list,
        grade_descriptions: dict,
        true_label: Optional[int] = None
    ) -> str:
        """Generate comprehensive medical report using GPT-4 Vision"""
        
        # Encode images
        original_b64 = self.image_to_base64(original_image)
        gradcam_b64 = self.image_to_base64(gradcam_image)
        
        # Format probability distribution
        prob_text = "\n".join([
            f"  - {class_names[i]}: {prob:.1%}"
            for i, prob in enumerate(all_probs)
        ])
        
        # Format attention regions
        attention_text = "\n".join([
            f"  - {region['region']}: Activation {region['activation']:.3f}"
            for region in attention_analysis['attention_regions'][:3]
        ]) if attention_analysis['attention_regions'] else "  - Distributed attention"
        
        # Build prompt
        prompt = f"""You are an expert radiologist assistant analyzing knee X-ray images for osteoarthritis classification using an AI model.

**Model Prediction:**
- Predicted Grade: {class_names[prediction]}
- Confidence: {confidence:.1%}
- Clinical Description: {grade_descriptions.get(prediction, 'N/A')}

**Probability Distribution:**
{prob_text}

**Grad-CAM Attention Analysis:**
The model focused on these regions when making its decision:
- Overall attention strength: {attention_analysis['mean_activation']:.3f}
- Peak activation: {attention_analysis['max_activation']:.3f}
- Coverage of high-attention areas: {attention_analysis['attention_coverage']:.1%}

Primary regions of interest:
{attention_text}

{f"**Ground Truth Label:** {class_names[true_label]}" if true_label is not None else ""}

**Task:**
Generate a professional medical report with these sections:

1. **CLINICAL IMPRESSION**
   - State the AI-predicted osteoarthritis grade
   - Assess confidence level and clinical significance

2. **RADIOLOGICAL FINDINGS**
   - Describe what anatomical features the AI focused on (based on Grad-CAM regions)
   - Interpret attention patterns in clinical terms
   - Mention relevant structures: joint space, osteophytes, bone margins, sclerosis

3. **AI MODEL INTERPRETATION**
   - Explain WHY the model focused on specific regions
   - Connect attention patterns to known OA features
   - Discuss if the AI's focus aligns with clinical practice

4. **CONFIDENCE ANALYSIS**
   - If confidence is low (<70%), discuss alternative diagnoses
   - Note any ambiguous features or borderline findings

5. **RECOMMENDATIONS**
   - Clinical correlation advised (standard disclaimer)
   - Suggest follow-up if appropriate
   - Note AI limitations

**Guidelines:**
- Keep report concise (300-400 words)
- Use professional medical terminology
- Translate AI insights into clinically meaningful observations
- Be honest about model limitations
- Focus on what the attention map reveals about the decision-making process

Generate the report now:"""

        # Call GPT-4 Vision API
        try:
            response = self.client.chat.completions.create(
                model=self.model,
                temperature=self.temperature,
                max_tokens=2000,
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": prompt
                            },
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:image/png;base64,{original_b64}",
                                    "detail": "high"
                                }
                            },
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:image/png;base64,{gradcam_b64}",
                                    "detail": "high"
                                }
                            }
                        ]
                    }
                ]
            )
            
            return response.choices[0].message.content
        
        except Exception as e:
            return f"‚ùå Error generating report: {str(e)}\n\nPlease check your API key and internet connection."

#=============================================================================
# TEST INITIALIZATION
#=============================================================================

print("\n" + "="*80)
print("Testing GPT-4 Report Generator Initialization")
print("="*80)

if llm_config.openai_api_key:
    try:
        # IMPORTANT: Pass API key explicitly to the generator
        report_generator = MedicalReportGenerator(
            model=llm_config.gpt_model,
            temperature=llm_config.temperature,
            api_key=llm_config.openai_api_key  # This is the critical line!
        )
        print("‚úÖ GPT-4 Report Generator ready!")
        print(f"   Model: {llm_config.gpt_model}")
        print(f"   Temperature: {llm_config.temperature}")
    except Exception as e:
        print(f"‚ùå Failed to initialize: {e}")
        print(f"   Check your API key is valid")
        report_generator = None
else:
    print("‚ö†Ô∏è OPENAI_API_KEY not set. Set via:")
    print("   export OPENAI_API_KEY='your-key'")
    print("   OR set it directly in the notebook:")
    print("   llm_config.openai_api_key = 'your-key-here'")
    report_generator = None

print("="*80)

‚úì OpenAI available
OpenAI API Key set: True
GPT Model: gpt-4o

Testing GPT-4 Report Generator Initialization
‚úì Initialized GPT-4 (gpt-4o) with temperature=0.3
‚úÖ GPT-4 Report Generator ready!
   Model: gpt-4o
   Temperature: 0.3


In [3]:
# Process Multiple Samples - Generate GPT-4 Medical Reports
# Copy this cell to your notebook (after initializing GPT-4 generator)

import cv2
import pandas as pd
import torch
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
from pytorch_grad_cam.utils.image import show_cam_on_image
from xai import (
    XAIConfig,
    ImageDataset,
    create_model,
    get_target_layer,
    load_model_weights,
    generate_gradcam,
    generate_lime_explanation,
    create_comprehensive_xai_visualization,
    select_samples_for_visualization,
    
)
config = XAIConfig()
def create_report_visualization(
    original_image,
    gradcam_image,
    prediction,
    confidence,
    all_probs,
    report,
    save_path,
    class_names,
    true_label=None
):
    """Create comprehensive visualization with LLM report"""
    
    fig = plt.figure(figsize=(18, 12))
    gs = fig.add_gridspec(3, 3, height_ratios=[1, 1, 1.2], hspace=0.35, wspace=0.3)
    
    # Row 1: Original Image
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.imshow(original_image)
    title = 'Original Knee X-ray'
    if true_label is not None:
        title += f'\nGround Truth: {class_names[true_label]}'
    ax1.set_title(title, fontsize=12, fontweight='bold')
    ax1.axis('off')
    
    # Row 1: Grad-CAM Visualization
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.imshow(gradcam_image)
    ax2.set_title(f'AI Attention Map (Grad-CAM)\nPrediction: {class_names[prediction]}',
                  fontsize=12, fontweight='bold')
    ax2.axis('off')
    
    # Row 1: Probability Distribution
    ax3 = fig.add_subplot(gs[0, 2])
    colors = ['#27ae60' if i == prediction else '#3498db' for i in range(len(class_names))]
    bars = ax3.barh(class_names, all_probs, color=colors)
    ax3.set_xlabel('Confidence', fontsize=10)
    ax3.set_title(f'Class Probabilities\nMax: {confidence:.1%}', fontsize=12, fontweight='bold')
    ax3.set_xlim([0, 1])
    ax3.grid(axis='x', alpha=0.3)
    
    for i, (bar, prob) in enumerate(zip(bars, all_probs)):
        ax3.text(prob + 0.02, i, f'{prob:.1%}', va='center', fontsize=9)
    
    # Row 2 & 3: Medical Report (spans full width)
    ax4 = fig.add_subplot(gs[1:, :])
    ax4.axis('off')
    
    # Report header
    ax4.text(0.5, 0.98, 'üè• GPT-4 GENERATED MEDICAL REPORT',
             fontsize=16, fontweight='bold', ha='center', va='top', 
             transform=ax4.transAxes,
             bbox=dict(boxstyle='round,pad=0.5', facecolor='lightblue', alpha=0.3))
    
    # Report content
    report_formatted = report.replace('**', '')
    ax4.text(0.02, 0.90, report_formatted,
             fontsize=9.5, va='top', wrap=True, transform=ax4.transAxes,
             family='serif',
             bbox=dict(boxstyle='round,pad=1', facecolor='#f9f9f9', 
                      edgecolor='gray', linewidth=1.5, alpha=0.8))
    
    # Footer disclaimer
    disclaimer = ("‚ö†Ô∏è DISCLAIMER: This is an AI-generated report for research purposes only. "
                 "All findings must be verified by a qualified radiologist before clinical use.")
    ax4.text(0.5, 0.01, disclaimer,
             fontsize=8, ha='center', va='bottom', transform=ax4.transAxes,
             style='italic', color='red', weight='bold')
    
    # Overall title
    correct_symbol = '‚úÖ' if (true_label is not None and prediction == true_label) else '‚ùå' if true_label is not None else ''
    fig.suptitle(f'Knee Osteoarthritis AI Classification Report {correct_symbol}',
                 fontsize=18, fontweight='bold', y=0.98)
    
    plt.savefig(save_path, dpi=150, bbox_inches='tight', facecolor='white')
    plt.close()
    
    print(f"   ‚úì Saved visualization: {os.path.basename(save_path)}")


def process_multiple_samples(num_samples=5):
    """
    Process multiple samples and generate GPT-4 medical reports
    
    Args:
        num_samples: Number of samples to process
    """
    
    print("="*80)
    print(" "*20 + "ü§ñ GENERATING GPT-4 MEDICAL REPORTS")
    print("="*80)
    
    # Check if report_generator is initialized
    if 'report_generator' not in globals() or report_generator is None:
        print("\n‚ùå Error: report_generator not initialized!")
        print("Please run the GPT-4 initialization cell first.")
        return
    
    # Load test data
    print(f"\nüìä Loading test data from: {config.test_csv}")
    test_df = pd.read_csv(config.test_csv)
    print(f"‚úì Total samples available: {len(test_df)}")
    
    # Select samples (random or stratified)
    samples = test_df.sample(n=min(num_samples, len(test_df)), random_state=42)
    print(f"\nüéØ Processing {len(samples)} samples...\n")
    
    # Load model
    print(f"üîß Loading model: {llm_config.model_to_use['name']}...")
    model, model_name = load_model_weights(llm_config.model_to_use, fold=llm_config.fold_to_analyze)
    target_layer = get_target_layer(model, model_name)
    print(f"‚úì Model loaded\n")
    
    # Image preprocessing
    img_size = llm_config.model_to_use['size']
    transform = A.Compose([
        A.Resize(img_size, img_size, interpolation=cv2.INTER_CUBIC),
        A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ToTensorV2()
    ])
    
    # Process each sample
    successful = 0
    failed = 0
    
    for idx, (_, row) in enumerate(samples.iterrows(), 1):
        img_path = row['data']
        true_label = row['label']
        
        print(f"{'='*80}")
        print(f"[{idx}/{len(samples)}] {os.path.basename(img_path)}")
        print(f"Ground Truth: {config.class_names[true_label]}")
        print(f"{'='*80}")
        
        try:
            # Load image
            print("üìÅ Loading image...")
            image = cv2.imread(img_path)
            if image is None:
                raise ValueError(f"Failed to load image: {img_path}")
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            # Prepare image
            image_resized = cv2.resize(image, (img_size, img_size))
            image_np = image_resized.astype(np.float32) / 255.0
            augmented = transform(image=image)
            img_tensor = augmented['image']
            
            # Generate Grad-CAM
            print("üîç Generating Grad-CAM...")
            cam, prediction, confidence = generate_gradcam(
                model, img_tensor, target_layer, method='gradcam'
            )
            
            # Get all probabilities
            with torch.no_grad():
                input_tensor = img_tensor.unsqueeze(0).to(config.device)
                output = model(input_tensor)
                all_probs = torch.nn.functional.softmax(output, dim=1)[0].cpu().numpy()
            
            # Create Grad-CAM visualization
            gradcam_img = show_cam_on_image(image_np, cam, use_rgb=True)
            
            print(f"   Prediction: {config.class_names[prediction]} ({confidence:.1%})")
            
            # Analyze attention patterns
            print("üìä Analyzing attention patterns...")
            attention_analysis = analyze_gradcam_for_prompt(cam)
            
            # Generate medical report with GPT-4
            print("‚úçÔ∏è  Generating medical report with GPT-4...")
            report = report_generator.generate_report(
                image_resized,
                gradcam_img,
                prediction,
                confidence,
                all_probs,
                attention_analysis,
                config.class_names,
                config.grade_descriptions,
                true_label
            )
        
            # Save visualization
            save_path = os.path.join(
                llm_config.output_dir,
                f"report_{idx:02d}_{os.path.basename(img_path)}"
            )
            
            create_report_visualization(
                image_resized,
                gradcam_img,
                prediction,
                confidence,
                all_probs,
                report,
                save_path,
                config.class_names,
                true_label
            )
            
            # Save text report
            report_txt = save_path.replace('.png', '.txt')
            with open(report_txt, 'w', encoding='utf-8') as f:
                f.write("KNEE OSTEOARTHRITIS AI CLASSIFICATION REPORT\n")
                f.write("="*80 + "\n\n")
                f.write(f"Patient Case ID: {os.path.basename(img_path)}\n")
                f.write(f"Ground Truth: {config.class_names[true_label]}\n")
                f.write(f"AI Prediction: {config.class_names[prediction]} (Confidence: {confidence:.1%})\n")
                f.write(f"\nModel: {model_name}\n")
                f.write(f"Fold: {llm_config.fold_to_analyze}\n")
                f.write(f"\n{'='*80}\n\n")
                f.write(report)
                f.write(f"\n\n{'='*80}\n")
                f.write("TECHNICAL DETAILS\n")
                f.write("="*80 + "\n")
                f.write(f"Attention Coverage: {attention_analysis['attention_coverage']:.1%}\n")
                f.write(f"Mean Activation: {attention_analysis['mean_activation']:.3f}\n")
                f.write(f"Max Activation: {attention_analysis['max_activation']:.3f}\n")
                if attention_analysis['attention_regions']:
                    f.write("\nTop Attention Regions:\n")
                    for region in attention_analysis['attention_regions'][:3]:
                        f.write(f"  - {region['region']}: {region['activation']:.3f}\n")
            
            print(f"   ‚úì Saved text report: {os.path.basename(report_txt)}")
            print(f"‚úÖ Sample {idx} completed successfully\n")
            successful += 1
            
        except Exception as e:
            print(f"‚ùå Error processing sample {idx}: {e}")
            import traceback
            traceback.print_exc()
            failed += 1
            print()
    
    # Summary
    print("="*80)
    print("üìä PROCESSING SUMMARY")
    print("="*80)
    print(f"‚úÖ Successful: {successful}/{len(samples)}")
    if failed > 0:
        print(f"‚ùå Failed: {failed}/{len(samples)}")
    print(f"\nüìÅ Reports saved to: {llm_config.output_dir}")
    print(f"   - {successful} PNG visualizations")
    print(f"   - {successful} TXT reports")
    print("="*80)


# ============================================================================
# RUN THE PROCESSING
# ============================================================================

if __name__ == "__main__" or True:  # Works in both script and notebook
    # Process 5 samples (change this number as needed)
    process_multiple_samples(num_samples=5)

  check_for_updates()


Device: cuda
Output directory: ./xai_visualizations
CUDA available: True
                    ü§ñ GENERATING GPT-4 MEDICAL REPORTS

üìä Loading test data from: ../KneeXray/test/test_correct.csv
‚úì Total samples available: 1656

üéØ Processing 5 samples...

üîß Loading model: densenet_161...
   Loading: 1fold_epoch8.pt
   From: ./models/densenet_161/(224, 224)/1fold_epoch8.pt
   ‚úì Successfully loaded: 1fold_epoch8.pt
‚úì Model loaded

[1/5] 9946846R.png
Ground Truth: Grade 4
üìÅ Loading image...
üîç Generating Grad-CAM...
   Prediction: Grade 4 (52.3%)
üìä Analyzing attention patterns...
‚úçÔ∏è  Generating medical report with GPT-4...
   ‚úì Saved visualization: report_01_9946846R.png
   ‚úì Saved text report: report_01_9946846R.txt
‚úÖ Sample 1 completed successfully

[2/5] 9283061L.png
Ground Truth: Grade 0
üìÅ Loading image...
üîç Generating Grad-CAM...
   Prediction: Grade 0 (72.5%)
üìä Analyzing attention patterns...
‚úçÔ∏è  Generating medical report with GPT-4...
   ‚ú