In [3]:
import pandas as pd
import os
import json

# Define paths
clahe_folder = '/home/intern08/Documents/X_ray_CoT/dataset/clahe_images'
reports_csv = '/home/intern08/Documents/X_ray_CoT/dataset/indiana_reports.csv'
projections_csv = '/home/intern08/Documents/X_ray_CoT/dataset/indiana_projections.csv'
output_manifest = '/home/intern08/Documents/X_ray_CoT/dataset/image_report_mapping.jsonl'

# Load CSVs
reports_df = pd.read_csv(reports_csv)
projections_df = pd.read_csv(projections_csv)

# Fix column name case and handle NaN
reports_df['findings'] = reports_df['findings'].fillna('').astype(str).str.strip()
reports_df['impression'] = reports_df['impression'].fillna('').astype(str).str.strip()

# Merge CSVs on 'uid'
merged_df = projections_df.merge(reports_df, on='uid', how='inner')

# Initialize manifest
manifest = []

# Link CLAHE images to reports
for idx, row in merged_df.iterrows():
    image_filename = row['filename']
    image_path = os.path.join(clahe_folder, image_filename)
    
    if os.path.exists(image_path):
        # Extract Findings and Impression (lowercase column names)
        findings = row['findings']
        impression = row['impression']
        report_text = f"Findings: {findings}\nImpression: {impression}".strip()
        
        # Only include non-empty reports
        if findings or impression:  # At least one field is non-empty
            manifest.append({
                'image_path': image_path,
                'report': report_text,
                'radgraph': {}  # Placeholder for Step 4.2.5
            })
        else:
            print(f"Skipping empty report for {image_filename}: Findings and Impression are empty.")
    else:
        print(f"Warning: Image not found: {image_path}")

# Save manifest as JSONL
with open(output_manifest, 'w') as f:
    for entry in manifest:
        f.write(json.dumps(entry) + '\n')

print(f"Manifest saved to {output_manifest}. Total entries: {len(manifest)}")

# Verify a sample entry
if manifest:
    print("\nSample manifest entry:")
    print(manifest[0])
else:
    print("No entries created. Check image paths or CSV data.")

# Count non-empty reports
non_empty_reports = sum(1 for entry in manifest if entry['report'].strip() != 'Findings: \nImpression:')
print(f"Non-empty reports: {non_empty_reports}/{len(manifest)}")

Skipping empty report for 16_IM-0389-1001.dcm.png: Findings and Impression are empty.
Skipping empty report for 16_IM-0389-2001.dcm.png: Findings and Impression are empty.
Skipping empty report for 614_IM-2200-4001.dcm.png: Findings and Impression are empty.
Skipping empty report for 614_IM-2200-1001.dcm.png: Findings and Impression are empty.
Skipping empty report for 673_IM-2247-3001.dcm.png: Findings and Impression are empty.
Skipping empty report for 894_IM-2404-0001-0001.dcm.png: Findings and Impression are empty.
Skipping empty report for 894_IM-2404-0001-0002.dcm.png: Findings and Impression are empty.
Skipping empty report for 1137_IM-0093-12012.dcm.png: Findings and Impression are empty.
Skipping empty report for 1137_IM-0093-4004.dcm.png: Findings and Impression are empty.
Skipping empty report for 1142_IM-0096-1001.dcm.png: Findings and Impression are empty.
Skipping empty report for 1142_IM-0096-2001.dcm.png: Findings and Impression are empty.
Skipping empty report for 1147

In [4]:
import os
clahe_files = set(os.listdir('/home/intern08/Documents/X_ray_CoT/dataset/clahe_images'))
projection_files = set(projections_df['filename'])
print(f"CLAHE images missing in projections: {projection_files - clahe_files}")
print(f"Projections files missing in CLAHE: {clahe_files - projection_files}")

CLAHE images missing in projections: set()
Projections files missing in CLAHE: {'2560_IM-1064-4001.dcm.png', '2084_IM-0715-1001-0002.dcm.png', '3809_IM-1919-1003002.dcm.png', '2084_IM-0715-2001-0001.dcm.png'}


In [5]:
print(merged_df[['uid', 'filename', 'findings', 'impression']].head(10))
print("Findings length stats:", merged_df['findings'].str.len().describe())
print("Impression length stats:", merged_df['impression'].str.len().describe())

   uid                   filename  \
0    1     1_IM-0001-4001.dcm.png   
1    1     1_IM-0001-3001.dcm.png   
2    2     2_IM-0652-1001.dcm.png   
3    2     2_IM-0652-2001.dcm.png   
4    3     3_IM-1384-1001.dcm.png   
5    3     3_IM-1384-2001.dcm.png   
6    4     4_IM-2050-1001.dcm.png   
7    4     4_IM-2050-2001.dcm.png   
8    5  5_IM-2117-1003002.dcm.png   
9    5  5_IM-2117-1004003.dcm.png   

                                            findings  \
0  The cardiac silhouette and mediastinum size ar...   
1  The cardiac silhouette and mediastinum size ar...   
2  Borderline cardiomegaly. Midline sternotomy XX...   
3  Borderline cardiomegaly. Midline sternotomy XX...   
4                                                      
5                                                      
6  There are diffuse bilateral interstitial and a...   
7  There are diffuse bilateral interstitial and a...   
8  The cardiomediastinal silhouette and pulmonary...   
9  The cardiomediastinal silhoue

In [12]:
empty_both = merged_df[(merged_df['findings'] == '') & (merged_df['impression'] == '')]
print(f"Reports with both fields empty: {len(empty_both)}")

Reports with both fields empty: 40


In [13]:
original_folder = '/kaggle/input/chest-xrays-indiana-university/images/images_normalized'
missing_files = ['2560_IM-1064-4001.dcm.png', '2084_IM-0715-1001-0002.dcm.png', '3809_IM-1919-1003002.dcm.png', '2084_IM-0715-2001-0001.dcm.png']
for f in missing_files:
    print(f"{f} in original folder: {os.path.exists(os.path.join(original_folder, f))}")

2560_IM-1064-4001.dcm.png in original folder: False
2084_IM-0715-1001-0002.dcm.png in original folder: False
3809_IM-1919-1003002.dcm.png in original folder: False
2084_IM-0715-2001-0001.dcm.png in original folder: False


In [14]:
import cv2
import os
clahe_folder = '/home/intern08/Documents/X_ray_CoT/dataset/clahe_images'
os.makedirs(clahe_folder, exist_ok=True)
for f in missing_files:
    input_path = os.path.join(original_folder, f)
    output_path = os.path.join(clahe_folder, f)
    if os.path.exists(input_path):
        img = cv2.imread(input_path, cv2.IMREAD_GRAYSCALE)
        if img is not None:
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
            clahe_img = clahe.apply(img)
            cv2.imwrite(output_path, clahe_img)
            print(f"Processed {f}")
        else:
            print(f"Failed to load {f}")
    else:
        print(f"{f} not found in original folder")

2560_IM-1064-4001.dcm.png not found in original folder
2084_IM-0715-1001-0002.dcm.png not found in original folder
3809_IM-1919-1003002.dcm.png not found in original folder
2084_IM-0715-2001-0001.dcm.png not found in original folder


In [16]:
import pandas as pd
import os
import json

# Define paths
clahe_folder = '/home/intern08/Documents/X_ray_CoT/dataset/clahe_images'
reports_csv = '/home/intern08/Documents/X_ray_CoT/dataset/indiana_reports.csv'
projections_csv = '/home/intern08/Documents/X_ray_CoT/dataset/indiana_projections.csv'
output_manifest = '/home/intern08/Documents/X_ray_CoT/dataset/image_report_mapping.jsonl'

# Load CSVs
reports_df = pd.read_csv(reports_csv)
projections_df = pd.read_csv(projections_csv)

# Fix column name case and handle NaN
reports_df['findings'] = reports_df['findings'].fillna('').astype(str).str.strip()
reports_df['impression'] = reports_df['impression'].fillna('').astype(str).str.strip()

# Merge CSVs on 'uid'
merged_df = projections_df.merge(reports_df, on='uid', how='inner')

# Initialize manifest
manifest = []

# Link CLAHE images to reports
for idx, row in merged_df.iterrows():
    image_filename = row['filename']
    image_path = os.path.join(clahe_folder, image_filename)
    
    if os.path.exists(image_path):
        # Extract Findings and Impression
        findings = row['findings']
        impression = row['impression']
        report_text = f"Findings: {findings}\nImpression: {impression}".strip()
        
        # Only include non-empty reports
        if findings or impression:
            manifest.append({
                'image_path': image_path,
                'report': report_text,
                'radgraph': {}  # Placeholder for Step 4.2.5
            })
        else:
            print(f"Skipping empty report for {image_filename}: Findings and Impression are empty.")
    else:
        print(f"Warning: Image not found: {image_path}")

# Save manifest as JSONL
with open(output_manifest, 'w') as f:
    for entry in manifest:
        f.write(json.dumps(entry) + '\n')

print(f"Manifest saved to {output_manifest}. Total entries: {len(manifest)}")

# Verify a sample entry
if manifest:
    print("\nSample manifest entry:")
    print(manifest[0])
else:
    print("No entries created. Check image paths or CSV data.")

# Count non-empty reports
non_empty_reports = sum(1 for entry in manifest if entry['report'].strip() != 'Findings: \nImpression:')
print(f"Non-empty reports: {non_empty_reports}/{len(manifest)}")

# Verify missing images
missing_files = ['3809_IM-1919-1003002.dcm.png', '2084_IM-0715-1001-0002.dcm.png', '2560_IM-1064-4001.dcm.png', '2084_IM-0715-2001-0001.dcm.png']
for f in missing_files:
    print(f"{f} in CLAHE folder: {os.path.exists(os.path.join(clahe_folder, f))}")

Skipping empty report for 16_IM-0389-1001.dcm.png: Findings and Impression are empty.
Skipping empty report for 16_IM-0389-2001.dcm.png: Findings and Impression are empty.
Skipping empty report for 614_IM-2200-4001.dcm.png: Findings and Impression are empty.
Skipping empty report for 614_IM-2200-1001.dcm.png: Findings and Impression are empty.
Skipping empty report for 673_IM-2247-3001.dcm.png: Findings and Impression are empty.
Skipping empty report for 894_IM-2404-0001-0001.dcm.png: Findings and Impression are empty.
Skipping empty report for 894_IM-2404-0001-0002.dcm.png: Findings and Impression are empty.
Skipping empty report for 1137_IM-0093-12012.dcm.png: Findings and Impression are empty.
Skipping empty report for 1137_IM-0093-4004.dcm.png: Findings and Impression are empty.
Skipping empty report for 1142_IM-0096-1001.dcm.png: Findings and Impression are empty.
Skipping empty report for 1142_IM-0096-2001.dcm.png: Findings and Impression are empty.
Skipping empty report for 1147

# 4.3  Data Split Strategy
## Use patient‑level stratified split (70/10/20). Verify no leakage via hashed MRN IDs.

In [17]:
import pandas as pd
import json
import hashlib
import numpy as np
from sklearn.model_selection import train_test_split
from collections import defaultdict
import os

# Define paths
manifest_file = '/home/intern08/Documents/X_ray_CoT/dataset/image_report_mapping.jsonl'
output_dir = '/home/intern08/Documents/X_ray_CoT/dataset/data_split'
os.makedirs(output_dir, exist_ok=True)

# Load the manifest
print("Loading manifest...")
manifest = []
with open(manifest_file, 'r') as f:
    for line in f:
        manifest.append(json.loads(line))

print(f"Total entries in manifest: {len(manifest)}")

# Extract patient IDs from filenames 
# Indiana University dataset format: {patient_id}_{image_id}.dcm.png
print("Extracting patient IDs...")
patient_data = defaultdict(list)

for entry in manifest:
    # Extract filename from path
    filename = os.path.basename(entry['image_path'])
    
    # Extract patient ID (part before first underscore)
    patient_id = filename.split('_')[0]
    
    # Add entry with patient ID
    entry['patient_id'] = patient_id
    patient_data[patient_id].append(entry)

print(f"Number of unique patients: {len(patient_data)}")
print(f"Images per patient - Min: {min(len(imgs) for imgs in patient_data.values())}, "
      f"Max: {max(len(imgs) for imgs in patient_data.values())}, "
      f"Mean: {np.mean([len(imgs) for imgs in patient_data.values()]):.2f}")

# Create patient-level dataset for stratification
patients_list = []
for patient_id, entries in patient_data.items():
    # Create a hash of patient ID for reproducible splits
    patient_hash = hashlib.md5(patient_id.encode()).hexdigest()
    
    patients_list.append({
        'patient_id': patient_id,
        'patient_hash': patient_hash,
        'num_images': len(entries),
        'entries': entries
    })

# Convert to DataFrame for easier manipulation
patients_df = pd.DataFrame([{
    'patient_id': p['patient_id'],
    'patient_hash': p['patient_hash'],
    'num_images': p['num_images']
} for p in patients_list])

print("\nPatient distribution by number of images:")
print(patients_df['num_images'].value_counts().sort_index())

# Stratified split based on number of images per patient
# This ensures similar distribution of single vs multi-image patients across splits
print("\nPerforming stratified patient-level split (70/10/20)...")

# Create stratification groups based on number of images
def create_strata(num_images):
    if num_images == 1:
        return 'single'
    elif num_images <= 3:
        return 'few'
    else:
        return 'many'

patients_df['strata'] = patients_df['num_images'].apply(create_strata)
print("Stratification groups:")
print(patients_df['strata'].value_counts())

# First split: 70% train, 30% temp (which will be split into 10% val, 20% test)
train_patients, temp_patients = train_test_split(
    patients_df, 
    test_size=0.3, 
    stratify=patients_df['strata'],
    random_state=42
)

# Second split: Split the 30% into 10% val and 20% test
# This means 10/30 = 0.333 for validation from the temp set
val_patients, test_patients = train_test_split(
    temp_patients,
    test_size=0.667,  # 20/30 = 0.667 for test
    stratify=temp_patients['strata'],
    random_state=42
)

print(f"\nSplit sizes:")
print(f"Train patients: {len(train_patients)} ({len(train_patients)/len(patients_df)*100:.1f}%)")
print(f"Validation patients: {len(val_patients)} ({len(val_patients)/len(patients_df)*100:.1f}%)")
print(f"Test patients: {len(test_patients)} ({len(test_patients)/len(patients_df)*100:.1f}%)")

# Verify no patient leakage
train_ids = set(train_patients['patient_id'])
val_ids = set(val_patients['patient_id'])
test_ids = set(test_patients['patient_id'])

assert len(train_ids & val_ids) == 0, "Patient leakage between train and validation!"
assert len(train_ids & test_ids) == 0, "Patient leakage between train and test!"
assert len(val_ids & test_ids) == 0, "Patient leakage between validation and test!"
print("✓ No patient leakage detected")

# Create patient ID to split mapping
patient_to_split = {}
for pid in train_ids:
    patient_to_split[pid] = 'train'
for pid in val_ids:
    patient_to_split[pid] = 'val'
for pid in test_ids:
    patient_to_split[pid] = 'test'

# Assign splits to all entries
train_entries = []
val_entries = []
test_entries = []

for entry in manifest:
    split = patient_to_split[entry['patient_id']]
    if split == 'train':
        train_entries.append(entry)
    elif split == 'val':
        val_entries.append(entry)
    else:  # test
        test_entries.append(entry)

print(f"\nImage-level split sizes:")
print(f"Train images: {len(train_entries)}")
print(f"Validation images: {len(val_entries)}")
print(f"Test images: {len(test_entries)}")

# Save splits as JSONL files
splits = {
    'train': train_entries,
    'val': val_entries,
    'test': test_entries
}

for split_name, entries in splits.items():
    output_file = os.path.join(output_dir, f'{split_name}.jsonl')
    with open(output_file, 'w') as f:
        for entry in entries:
            f.write(json.dumps(entry) + '\n')
    print(f"Saved {len(entries)} entries to {output_file}")

# Create a summary file with split statistics
summary = {
    'total_patients': len(patients_df),
    'total_images': len(manifest),
    'splits': {
        'train': {
            'patients': len(train_patients),
            'images': len(train_entries),
            'percentage': len(train_patients) / len(patients_df) * 100
        },
        'val': {
            'patients': len(val_patients),
            'images': len(val_entries),
            'percentage': len(val_patients) / len(patients_df) * 100
        },
        'test': {
            'patients': len(test_patients),
            'images': len(test_entries),
            'percentage': len(test_patients) / len(patients_df) * 100
        }
    },
    'stratification': {
        'single_image_patients': int(patients_df[patients_df['strata'] == 'single'].shape[0]),
        'few_image_patients': int(patients_df[patients_df['strata'] == 'few'].shape[0]),
        'many_image_patients': int(patients_df[patients_df['strata'] == 'many'].shape[0])
    }
}

# Save summary
summary_file = os.path.join(output_dir, 'split_summary.json')
with open(summary_file, 'w') as f:
    json.dump(summary, f, indent=2)

print(f"\n✓ Split summary saved to {summary_file}")

# Display final statistics
print("\n=== FINAL SPLIT STATISTICS ===")
for split_name in ['train', 'val', 'test']:
    split_patients = len([p for p in patients_df['patient_id'] if patient_to_split[p] == split_name])
    split_images = len(splits[split_name])
    print(f"{split_name.upper()}:")
    print(f"  Patients: {split_patients} ({split_patients/len(patients_df)*100:.1f}%)")
    print(f"  Images: {split_images} ({split_images/len(manifest)*100:.1f}%)")
    
    # Check stratification preservation
    split_patient_df = patients_df[patients_df['patient_id'].isin(
        [p for p in patients_df['patient_id'] if patient_to_split[p] == split_name]
    )]
    print(f"  Stratification: Single={len(split_patient_df[split_patient_df['strata']=='single'])}, "
          f"Few={len(split_patient_df[split_patient_df['strata']=='few'])}, "
          f"Many={len(split_patient_df[split_patient_df['strata']=='many'])}")

print("\n✓ Data splitting completed successfully!")
print(f"Files created in {output_dir}:")
print("  - train.jsonl")
print("  - val.jsonl") 
print("  - test.jsonl")
print("  - split_summary.json")

Loading manifest...
Total entries in manifest: 7426
Extracting patient IDs...
Number of unique patients: 3826
Images per patient - Min: 1, Max: 5, Mean: 1.94

Patient distribution by number of images:
num_images
1     435
2    3197
3     180
4      13
5       1
Name: count, dtype: int64

Performing stratified patient-level split (70/10/20)...
Stratification groups:
strata
few       3377
single     435
many        14
Name: count, dtype: int64

Split sizes:
Train patients: 2678 (70.0%)
Validation patients: 382 (10.0%)
Test patients: 766 (20.0%)
✓ No patient leakage detected

Image-level split sizes:
Train images: 5195
Validation images: 739
Test images: 1492
Saved 5195 entries to /home/intern08/Documents/X_ray_CoT/dataset/data_split/train.jsonl
Saved 739 entries to /home/intern08/Documents/X_ray_CoT/dataset/data_split/val.jsonl
Saved 1492 entries to /home/intern08/Documents/X_ray_CoT/dataset/data_split/test.jsonl

✓ Split summary saved to /home/intern08/Documents/X_ray_CoT/dataset/data_s

In [2]:
"""
Simplified Baseline Model for Radiology Report Generation
Compatible with Kaggle environment and existing libraries

Architecture:
- Vision Encoder: ResNet-50 (pre-trained)
- Text Decoder: GPT-2 with custom heads
- Simple fusion mechanism
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
import json
import os
from PIL import Image
import numpy as np
from typing import Dict, List, Optional, Tuple
import logging
import gc
import time
from collections import defaultdict

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class ModelConfig:
    """Configuration for the simplified baseline model"""
    # Model architecture
    vision_encoder = "resnet50"
    vision_hidden_size = 2048
    text_hidden_size = 768
    vocab_size = 50000
    max_length = 256
    
    # Training configuration
    batch_size = 8  # Reduced for memory constraints
    learning_rate = 1e-4
    weight_decay = 1e-4
    num_epochs = 5  # Reduced for faster training
    warmup_steps = 100
    
    # Image processing
    image_size = 224
    
    # Paths
    train_file = "/home/intern08/Documents/X_ray_CoT/dataset/data_split/train.jsonl"
    val_file = "/home/intern08/Documents/X_ray_CoT/dataset/data_split/val.jsonl"
    test_file = "/home/intern08/Documents/X_ray_CoT/dataset/data_split/test.jsonl"
    output_dir = "/home/intern08/Documents/X_ray_CoT/dataset"

class SimpleTokenizer:
    """Simple tokenizer for radiology reports"""
    
    def __init__(self, vocab_size=50000):
        self.vocab_size = vocab_size
        self.word_to_id = {'<pad>': 0, '<unk>': 1, '<sos>': 2, '<eos>': 3}
        self.id_to_word = {0: '<pad>', 1: '<unk>', 2: '<sos>', 3: '<eos>'}
        self.next_id = 4
        
    def build_vocab(self, texts):
        """Build vocabulary from texts"""
        word_counts = defaultdict(int)
        
        for text in texts:
            words = text.lower().split()
            for word in words:
                word_counts[word] += 1
        
        # Add most frequent words to vocabulary
        sorted_words = sorted(word_counts.items(), key=lambda x: x[1], reverse=True)
        
        for word, count in sorted_words[:self.vocab_size - 4]:
            if word not in self.word_to_id:
                self.word_to_id[word] = self.next_id
                self.id_to_word[self.next_id] = word
                self.next_id += 1
        
        logger.info(f"Built vocabulary with {len(self.word_to_id)} words")
    
    def encode(self, text, max_length=256):
        """Encode text to token IDs"""
        words = text.lower().split()
        tokens = [self.word_to_id.get(word, 1) for word in words]  # 1 is <unk>
        
        # Add start and end tokens
        tokens = [2] + tokens + [3]  # 2 is <sos>, 3 is <eos>
        
        # Pad or truncate
        if len(tokens) > max_length:
            tokens = tokens[:max_length]
        else:
            tokens.extend([0] * (max_length - len(tokens)))  # 0 is <pad>
        
        return tokens
    
    def decode(self, token_ids):
        """Decode token IDs to text"""
        words = []
        for token_id in token_ids:
            if token_id == 3:  # <eos>
                break
            if token_id not in [0, 2]:  # Skip <pad> and <sos>
                words.append(self.id_to_word.get(token_id, '<unk>'))
        
        return ' '.join(words)

class RadiologyDataset(Dataset):
    """Simplified dataset class for radiology images and reports"""
    
    def __init__(self, jsonl_file: str, tokenizer: SimpleTokenizer, transform=None, max_length=256):
        self.tokenizer = tokenizer
        self.transform = transform
        self.max_length = max_length
        
        # Load data
        self.data = []
        with open(jsonl_file, 'r') as f:
            for line in f:
                entry = json.loads(line)
                if os.path.exists(entry['image_path']):  # Only include existing images
                    self.data.append(entry)
        
        logger.info(f"Loaded {len(self.data)} valid samples from {jsonl_file}")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        entry = self.data[idx]
        
        # Load image
        try:
            image = Image.open(entry['image_path']).convert('RGB')
            if self.transform:
                image = self.transform(image)
        except Exception as e:
            logger.error(f"Error loading image {entry['image_path']}: {e}")
            # Create dummy image
            image = torch.zeros(3, 224, 224)
        
        # Tokenize report
        report = entry['report']
        tokens = self.tokenizer.encode(report, self.max_length)
        
        return {
            'image': image,
            'tokens': torch.tensor(tokens, dtype=torch.long),
            'report': report,
            'patient_id': entry.get('patient_id', 'unknown')
        }

class VisionEncoder(nn.Module):
    """Vision encoder using ResNet-50"""
    
    def __init__(self, hidden_size=768):
        super().__init__()
        # Load pre-trained ResNet-50
        resnet = models.resnet50(pretrained=True)
        
        # Remove final classification layer
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        
        # Add projection layer
        self.projection = nn.Linear(2048, hidden_size)
        
        # Freeze initial layers
        for param in list(self.features.parameters())[:50]:
            param.requires_grad = False
    
    def forward(self, x):
        # Extract features
        features = self.features(x)  # [batch_size, 2048, 1, 1]
        features = features.view(features.size(0), -1)  # [batch_size, 2048]
        
        # Project to hidden size
        features = self.projection(features)  # [batch_size, hidden_size]
        
        return features

class TextDecoder(nn.Module):
    """Simple text decoder with attention"""
    
    def __init__(self, vocab_size, hidden_size=768, num_layers=6):
        super().__init__()
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        
        # Transformer decoder layers
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_size,
            nhead=8,
            dim_feedforward=hidden_size * 4,
            dropout=0.1,
            batch_first=True
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        
        # Output projection
        self.output_projection = nn.Linear(hidden_size, vocab_size)
        
        # Position encoding
        self.pos_encoding = self._create_position_encoding(512, hidden_size)
    
    def _create_position_encoding(self, max_len, hidden_size):
        """Create positional encoding"""
        pe = torch.zeros(max_len, hidden_size)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        div_term = torch.exp(torch.arange(0, hidden_size, 2).float() * 
                           (-np.log(10000.0) / hidden_size))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        return pe.unsqueeze(0)  # [1, max_len, hidden_size]
    
    def forward(self, tokens, vision_features, attention_mask=None):
        batch_size, seq_len = tokens.shape
        
        # Embed tokens
        token_embeds = self.embedding(tokens)  # [batch_size, seq_len, hidden_size]
        
        # Add positional encoding
        if seq_len <= self.pos_encoding.size(1):
            pos_enc = self.pos_encoding[:, :seq_len, :].to(token_embeds.device)
            token_embeds = token_embeds + pos_enc
        
        # Prepare vision features as memory
        vision_memory = vision_features.unsqueeze(1)  # [batch_size, 1, hidden_size]
        
        # Create causal mask for decoder
        causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
        causal_mask = causal_mask.to(token_embeds.device)
        
        # Transformer decoder
        output = self.transformer_decoder(
            tgt=token_embeds,
            memory=vision_memory,
            tgt_mask=causal_mask
        )
        
        # Project to vocabulary
        logits = self.output_projection(output)
        
        return logits

class RadiologyModel(nn.Module):
    """Complete radiology report generation model"""
    
    def __init__(self, vocab_size, hidden_size=768):
        super().__init__()
        self.vision_encoder = VisionEncoder(hidden_size)
        self.text_decoder = TextDecoder(vocab_size, hidden_size)
        
    def forward(self, images, tokens):
        # Encode vision features
        vision_features = self.vision_encoder(images)
        
        # Decode text
        logits = self.text_decoder(tokens[:, :-1], vision_features)  # Teacher forcing
        
        return logits
    
    def generate(self, images, tokenizer, max_length=128, temperature=1.0):
        """Generate reports for given images"""
        self.eval()
        batch_size = images.size(0)
        
        # Encode vision features
        with torch.no_grad():
            vision_features = self.vision_encoder(images)
            
            # Start with <sos> token
            generated = torch.full((batch_size, 1), 2, dtype=torch.long, device=images.device)
            
            for _ in range(max_length):
                # Get logits for current sequence
                logits = self.text_decoder(generated, vision_features)
                
                # Get next token probabilities
                next_logits = logits[:, -1, :] / temperature
                next_probs = F.softmax(next_logits, dim=-1)
                
                # Sample next token
                next_tokens = torch.multinomial(next_probs, 1)
                
                # Append to generated sequence
                generated = torch.cat([generated, next_tokens], dim=1)
                
                # Stop if all sequences have generated <eos> token
                if (next_tokens == 3).all():  # 3 is <eos>
                    break
        
        return generated

class ModelTrainer:
    """Trainer class for the radiology model"""
    
    def __init__(self, config: ModelConfig):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        logger.info(f"Using device: {self.device}")
        
        # Initialize tokenizer
        self.tokenizer = SimpleTokenizer(config.vocab_size)
        
        # Create transforms
        self.train_transform = transforms.Compose([
            transforms.Resize((config.image_size, config.image_size)),
            transforms.RandomHorizontalFlip(0.5),
            transforms.ColorJitter(brightness=0.1, contrast=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        self.val_transform = transforms.Compose([
            transforms.Resize((config.image_size, config.image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
    def build_tokenizer(self, train_file: str):
        """Build tokenizer vocabulary from training data"""
        logger.info("Building vocabulary...")
        
        texts = []
        with open(train_file, 'r') as f:
            for line in f:
                entry = json.loads(line)
                texts.append(entry['report'])
        
        self.tokenizer.build_vocab(texts)
        
        # Save tokenizer
        tokenizer_path = os.path.join(self.config.output_dir, 'tokenizer.json')
        os.makedirs(os.path.dirname(tokenizer_path), exist_ok=True)
        
        tokenizer_data = {
            'word_to_id': self.tokenizer.word_to_id,
            'id_to_word': self.tokenizer.id_to_word,
            'vocab_size': len(self.tokenizer.word_to_id)
        }
        
        with open(tokenizer_path, 'w') as f:
            json.dump(tokenizer_data, f, indent=2)
        
        logger.info(f"Tokenizer saved to {tokenizer_path}")
    
    def create_dataloaders(self):
        """Create training and validation dataloaders"""
        
        train_dataset = RadiologyDataset(
            self.config.train_file, self.tokenizer, self.train_transform, self.config.max_length
        )
        
        val_dataset = RadiologyDataset(
            self.config.val_file, self.tokenizer, self.val_transform, self.config.max_length
        )
        
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=True if torch.cuda.is_available() else False
        )
        
        val_dataloader = DataLoader(
            val_dataset,
            batch_size=self.config.batch_size,
            shuffle=False,
            num_workers=2,
            pin_memory=True if torch.cuda.is_available() else False
        )
        
        return train_dataloader, val_dataloader
    
    def train_epoch(self, model, train_dataloader, optimizer, criterion, epoch):
        """Train for one epoch"""
        model.train()
        total_loss = 0
        num_batches = 0
        
        for batch_idx, batch in enumerate(train_dataloader):
            # Move to device
            images = batch['image'].to(self.device)
            tokens = batch['tokens'].to(self.device)
            
            # Forward pass
            optimizer.zero_grad()
            logits = model(images, tokens)
            
            # Calculate loss (ignore padding tokens)
            targets = tokens[:, 1:]  # Shift targets
            logits = logits.reshape(-1, logits.size(-1))
            targets = targets.reshape(-1)
            
            # Create mask to ignore padding tokens
            mask = targets != 0  # 0 is <pad> token
            
            if mask.sum() > 0:  # Only calculate loss if there are non-padding tokens
                loss = criterion(logits[mask], targets[mask])
                
                # Backward pass
                loss.backward()
                
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                
                optimizer.step()
                
                total_loss += loss.item()
                num_batches += 1
            
            # Log progress
            if batch_idx % 100 == 0:
                logger.info(f"Epoch {epoch}, Batch {batch_idx}/{len(train_dataloader)}, "
                          f"Loss: {loss.item():.4f}")
            
            # Memory cleanup
            if batch_idx % 50 == 0:
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
        
        return total_loss / max(num_batches, 1)
    
    def validate_epoch(self, model, val_dataloader, criterion):
        """Validate for one epoch"""
        model.eval()
        total_loss = 0
        num_batches = 0
        
        with torch.no_grad():
            for batch in val_dataloader:
                images = batch['image'].to(self.device)
                tokens = batch['tokens'].to(self.device)
                
                # Forward pass
                logits = model(images, tokens)
                
                # Calculate loss
                targets = tokens[:, 1:]
                logits = logits.reshape(-1, logits.size(-1))
                targets = targets.reshape(-1)
                
                # Create mask to ignore padding tokens
                mask = targets != 0
                
                if mask.sum() > 0:
                    loss = criterion(logits[mask], targets[mask])
                    total_loss += loss.item()
                    num_batches += 1
        
        return total_loss / max(num_batches, 1)
    
    def generate_sample_reports(self, model, val_dataloader, num_samples=5):
        """Generate sample reports for evaluation"""
        model.eval()
        samples = []
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_dataloader):
                if len(samples) >= num_samples:
                    break
                
                images = batch['image'].to(self.device)
                original_reports = batch['report']
                patient_ids = batch['patient_id']
                
                # Generate reports
                generated_tokens = model.generate(images, self.tokenizer, max_length=64)
                
                for i in range(min(len(images), num_samples - len(samples))):
                    generated_text = self.tokenizer.decode(generated_tokens[i].cpu().tolist())
                    
                    samples.append({
                        'patient_id': patient_ids[i],
                        'original': original_reports[i],
                        'generated': generated_text
                    })
        
        return samples
    
    def train(self):
        """Complete training pipeline"""
        
        # Create output directory
        os.makedirs(self.config.output_dir, exist_ok=True)
        
        # Build tokenizer
        self.build_tokenizer(self.config.train_file)
        
        # Create dataloaders
        train_dataloader, val_dataloader = self.create_dataloaders()
        
        # Initialize model
        vocab_size = len(self.tokenizer.word_to_id)
        model = RadiologyModel(vocab_size, self.config.text_hidden_size)
        model.to(self.device)
        
        # Loss function and optimizer
        criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding tokens
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=self.config.learning_rate,
            weight_decay=self.config.weight_decay
        )
        
        # Learning rate scheduler
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.config.num_epochs
        )
        
        # Training loop
        best_val_loss = float('inf')
        patience_counter = 0
        patience = 3
        
        logger.info("Starting training...")
        
        for epoch in range(self.config.num_epochs):
            start_time = time.time()
            
            # Train
            train_loss = self.train_epoch(model, train_dataloader, optimizer, criterion, epoch)
            
            # Validate
            val_loss = self.validate_epoch(model, val_dataloader, criterion)
            
            # Update scheduler
            scheduler.step()
            
            epoch_time = time.time() - start_time
            
            logger.info(f"Epoch {epoch+1}/{self.config.num_epochs} completed in {epoch_time:.2f}s")
            logger.info(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                
                # Save model
                model_path = os.path.join(self.config.output_dir, 'best_model.pth')
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'train_loss': train_loss,
                    'val_loss': val_loss,
                    'config': self.config
                }, model_path)
                
                logger.info(f"New best model saved with val_loss: {val_loss:.4f}")
                
                # Generate sample reports
                samples = self.generate_sample_reports(model, val_dataloader)
                
                # Save samples
                samples_path = os.path.join(self.config.output_dir, f'samples_epoch_{epoch+1}.json')
                with open(samples_path, 'w') as f:
                    json.dump(samples, f, indent=2)
                
                logger.info("Sample reports:")
                for i, sample in enumerate(samples[:3]):
                    logger.info(f"Sample {i+1}:")
                    logger.info(f"  Original: {sample['original'][:100]}...")
                    logger.info(f"  Generated: {sample['generated'][:100]}...")
                    logger.info("")
            
            else:
                patience_counter += 1
            
            # Early stopping
            if patience_counter >= patience:
                logger.info(f"Early stopping at epoch {epoch+1}")
                break
        
        logger.info("Training completed!")
        return os.path.join(self.config.output_dir, 'best_model.pth')

# Evaluation functions
def calculate_bleu_score(reference, hypothesis):
    """Simple BLEU-1 score calculation"""
    ref_words = set(reference.lower().split())
    hyp_words = set(hypothesis.lower().split())
    
    if len(hyp_words) == 0:
        return 0.0
    
    overlap = len(ref_words.intersection(hyp_words))
    return overlap / len(hyp_words)

def evaluate_model(model_path: str, test_file: str, tokenizer_path: str, config: ModelConfig):
    """Evaluate the trained model"""
    
    # Load tokenizer
    with open(tokenizer_path, 'r') as f:
        tokenizer_data = json.load(f)
    
    tokenizer = SimpleTokenizer()
    tokenizer.word_to_id = tokenizer_data['word_to_id']
    tokenizer.id_to_word = {int(k): v for k, v in tokenizer_data['id_to_word'].items()}
    
    # Load model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    vocab_size = len(tokenizer.word_to_id)
    model = RadiologyModel(vocab_size, config.text_hidden_size)
    
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    # Create test dataset
    val_transform = transforms.Compose([
        transforms.Resize((config.image_size, config.image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    test_dataset = RadiologyDataset(test_file, tokenizer, val_transform, config.max_length)
    test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)
    
    # Evaluate
    bleu_scores = []
    all_results = []
    
    logger.info("Starting evaluation...")
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(test_dataloader):
            images = batch['image'].to(device)
            original_reports = batch['report']
            patient_ids = batch['patient_id']
            
            # Generate reports
            generated_tokens = model.generate(images, tokenizer, max_length=128)
            
            for i in range(len(images)):
                generated_text = tokenizer.decode(generated_tokens[i].cpu().tolist())
                original_text = original_reports[i]
                
                # Calculate BLEU score
                bleu = calculate_bleu_score(original_text, generated_text)
                bleu_scores.append(bleu)
                
                all_results.append({
                    'patient_id': patient_ids[i],
                    'original': original_text,
                    'generated': generated_text,
                    'bleu_score': bleu
                })
            
            if batch_idx % 10 == 0:
                logger.info(f"Processed {batch_idx * config.batch_size} samples")
    
    # Calculate metrics
    avg_bleu = np.mean(bleu_scores)
    
    logger.info(f"Evaluation completed!")
    logger.info(f"Average BLEU-1 Score: {avg_bleu:.4f}")
    
    # Save results
    results_path = os.path.join(config.output_dir, 'evaluation_results.json')
    evaluation_summary = {
        'avg_bleu_score': avg_bleu,
        'num_samples': len(all_results),
        'detailed_results': all_results
    }
    
    with open(results_path, 'w') as f:
        json.dump(evaluation_summary, f, indent=2)
    
    logger.info(f"Results saved to {results_path}")
    
    return avg_bleu, all_results

# Main execution
if __name__ == "__main__":
    # Initialize configuration
    config = ModelConfig()
    
    # Check if files exist
    required_files = [config.train_file, config.val_file, config.test_file]
    missing_files = [f for f in required_files if not os.path.exists(f)]
    
    if missing_files:
        logger.error(f"Missing files: {missing_files}")
        logger.error("Please ensure the data split files exist before training.")
    else:
        # Initialize trainer
        trainer = ModelTrainer(config)
        
        # Start training
        try:
            logger.info("="*60)
            logger.info("STARTING BASELINE MODEL TRAINING")
            logger.info("="*60)
            logger.info(f"Configuration:")
            logger.info(f"  Batch size: {config.batch_size}")
            logger.info(f"  Learning rate: {config.learning_rate}")
            logger.info(f"  Number of epochs: {config.num_epochs}")
            logger.info(f"  Image size: {config.image_size}")
            logger.info(f"  Max sequence length: {config.max_length}")
            logger.info("="*60)
            
            model_path = trainer.train()
            
            logger.info("="*60)
            logger.info("TRAINING COMPLETED SUCCESSFULLY")
            logger.info("="*60)
            logger.info(f"Best model saved at: {model_path}")
            
            # Evaluate model
            logger.info("Starting evaluation on test set...")
            tokenizer_path = os.path.join(config.output_dir, 'tokenizer.json')
            
            avg_bleu, results = evaluate_model(model_path, config.test_file, tokenizer_path, config)
            
            logger.info("="*60)
            logger.info("EVALUATION COMPLETED")
            logger.info("="*60)
            logger.info(f"Average BLEU-1 Score: {avg_bleu:.4f}")
            logger.info(f"Evaluation results saved to: {config.output_dir}/evaluation_results.json")
            
            # Show sample results
            logger.info("\nSample Results:")
            logger.info("-" * 40)
            for i, result in enumerate(results[:3]):
                logger.info(f"Sample {i+1} (BLEU: {result['bleu_score']:.3f}):")
                logger.info(f"  Original: {result['original'][:150]}...")
                logger.info(f"  Generated: {result['generated'][:150]}...")
                logger.info("")
            
        except Exception as e:
            logger.error(f"Training/Evaluation failed: {e}")
            import traceback
            traceback.print_exc()

INFO:__main__:Using device: cuda
INFO:__main__:STARTING BASELINE MODEL TRAINING
INFO:__main__:Configuration:
INFO:__main__:  Batch size: 8
INFO:__main__:  Learning rate: 0.0001
INFO:__main__:  Number of epochs: 5
INFO:__main__:  Image size: 224
INFO:__main__:  Max sequence length: 256
INFO:__main__:Building vocabulary...
INFO:__main__:Built vocabulary with 2773 words
INFO:__main__:Tokenizer saved to /home/intern08/Documents/X_ray_CoT/dataset/tokenizer.json
INFO:__main__:Loaded 5195 valid samples from /home/intern08/Documents/X_ray_CoT/dataset/data_split/train.jsonl
INFO:__main__:Loaded 739 valid samples from /home/intern08/Documents/X_ray_CoT/dataset/data_split/val.jsonl
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/intern08/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100.0%
INFO:__main__:Starting training...
INFO:__main__:Epoch 0, Batch 0/650, Loss: 8.0559
INFO:__main__:Epoch 0, Batch 100/650, Loss: 3.0935
INFO:__main__:Epoch 0, Batch 200/650