In [None]:
# 1. Updated Imports - Add these at the top of your script
import os
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import cv2
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.nn import BCEWithLogitsLoss
import pandas as pd
from tqdm import tqdm
import seaborn as sns
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score
import time
# New imports for EfficientNet-B1
from torchvision.models import efficientnet_b1, EfficientNet_B1_Weights
# For Windows memory management
import gc

labels = ["Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass", "Nodule",
          "Pneumonia", "Pneumothorax", "Consolidation", "Edema", "Emphysema", "Fibrosis",
          "Pleural_Thickening", "Hernia"]


# Configuration
class Config:
    BASE_DIR = "C:/projects_ml/Radi_Assist"
    DATA_DIR = os.path.join(BASE_DIR, "data")
    MODEL_DIR = os.path.join(BASE_DIR, "models")
    RESULTS_DIR = os.path.join(BASE_DIR, "results")
    MODEL_NAME = 'efficientnet_b1'
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Reduce batch size to avoid GPU memory issues
    BATCH_SIZE = 8  # Reduced from 16 to 8
    NUM_EPOCHS = 7
    LEARNING_RATE = 0.0001
    IMG_SIZE = 224
    SEED = 42
    
    @classmethod
    def create_directories(cls):
        os.makedirs(cls.MODEL_DIR, exist_ok=True)
        os.makedirs(cls.RESULTS_DIR, exist_ok=True)
        os.makedirs(os.path.join(cls.RESULTS_DIR, "plots"), exist_ok=True)
        os.makedirs(os.path.join(cls.RESULTS_DIR, "predictions"), exist_ok=True)

# Set random seeds for reproducibility
def set_seeds(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Custom dataset for chest X-ray images
class ChestXRayDataset(Dataset):
    def __init__(self, image_paths, targets=None, transform=None):
        self.image_paths = image_paths
        self.targets = targets
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        
        try:
            # Optimize image loading
            img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
            if img is None:
                print(f"Warning: Could not read image at {img_path}")
                img = np.zeros((Config.IMG_SIZE, Config.IMG_SIZE), dtype=np.uint8)
            
            # Simplify conversion - create 3-channel image
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
            
            if self.transform:
                img = self.transform(img)
                
            if self.targets is not None:
                return img, self.targets[idx]
            else:
                return img
        except Exception as e:
            print(f"Error processing image {img_path}: {str(e)}")
            # Return zero tensor with correct dimensions
            if self.transform:
                default_img = torch.zeros((3, Config.IMG_SIZE, Config.IMG_SIZE))
            else:
                default_img = np.zeros((Config.IMG_SIZE, Config.IMG_SIZE, 3), dtype=np.uint8)
            
            if self.targets is not None:
                return default_img, self.targets[idx]
            else:
                return default_img

# Data preparation and loading
def prepare_data():
    print("\n" + "="*50)
    print("PREPARING DATASET")
    print("="*50)
    
    # Define transformations
    train_transform = T.Compose([
        T.ToPILImage(),
        T.Resize((Config.IMG_SIZE, Config.IMG_SIZE)),
        T.RandomHorizontalFlip(),
        T.RandomRotation(10),
        T.RandomAffine(degrees=0, translate=(0.05, 0.05)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    val_test_transform = T.Compose([
        T.ToPILImage(),
        T.Resize((Config.IMG_SIZE, Config.IMG_SIZE)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Collect images and labels from train, val, test folders
    data_splits = {
        'train': {'images': [], 'labels': []},
        'val': {'images': [], 'labels': []},
        'test': {'images': [], 'labels': []}
    }
    
    for split in ['train', 'val', 'test']:
        split_dir = os.path.join(Config.DATA_DIR, split)
        if not os.path.exists(split_dir):
            raise ValueError(f"Directory {split_dir} does not exist")
            
        for i, class_name in enumerate(labels):
            class_dir = os.path.join(split_dir, class_name)
            if os.path.exists(class_dir):
                image_files = [os.path.join(class_dir, f) for f in os.listdir(class_dir) 
                             if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
                
                if image_files:
                    print(f"Found {len(image_files)} images for class '{class_name}' in {split} set")
                    data_splits[split]['images'].extend(image_files)
                    label = np.zeros(len(labels))
                    label[i] = 1
                    data_splits[split]['labels'].extend([label] * len(image_files))
    
    # Convert labels to numpy arrays
    for split in data_splits:
        data_splits[split]['labels'] = np.array(data_splits[split]['labels'])
    
    # Check if we have data
    for split in data_splits:
        if len(data_splits[split]['images']) == 0:
            raise ValueError(f"No images found in {split} directory")
    
    print(f"\nDataset split sizes:")
    print(f"Train: {len(data_splits['train']['images'])} images")
    print(f"Validation: {len(data_splits['val']['images'])} images")
    print(f"Test: {len(data_splits['test']['images'])} images")
    
    # Create datasets
    train_dataset = ChestXRayDataset(
        data_splits['train']['images'], 
        data_splits['train']['labels'], 
        transform=train_transform
    )
    val_dataset = ChestXRayDataset(
        data_splits['val']['images'], 
        data_splits['val']['labels'], 
        transform=val_test_transform
    )
    test_dataset = ChestXRayDataset(
        data_splits['test']['images'], 
        data_splits['test']['labels'], 
        transform=val_test_transform
    )
    
    # Create dataloaders
    num_workers = 2  # Reduce from 4 to 2
    train_loader = DataLoader(
        train_dataset, 
        batch_size=Config.BATCH_SIZE, 
        shuffle=True, 
        num_workers=num_workers, 
        pin_memory=True,
        persistent_workers=True
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=Config.BATCH_SIZE, 
        shuffle=False, 
        num_workers=num_workers, 
        pin_memory=True,
        persistent_workers=True
    )
    test_loader = DataLoader(
        test_dataset, 
        batch_size=Config.BATCH_SIZE, 
        shuffle=False, 
        num_workers=num_workers, 
        pin_memory=True,
        persistent_workers=True
    )
    
    # Visualize data distribution
    try:
        plt.figure(figsize=(15, 6))
        class_counts = np.sum(data_splits['train']['labels'], axis=0)
        plt.bar(labels, class_counts)
        plt.title('Class Distribution in Training Set')
        plt.xticks(rotation=45, ha='right')
        plt.ylabel('Number of Images')
        plt.tight_layout()
        
        dist_plot_path = os.path.join(Config.RESULTS_DIR, "plots", "class_distribution.png")
        plt.savefig(dist_plot_path)
        plt.close()
        print(f"Class distribution plot saved to {dist_plot_path}")
        
        # Visualize sample images from each class
        plt.figure(figsize=(15, 10))
        shown_classes = 0
        for i, class_name in enumerate(labels):
            class_images = [img for img, lbl in zip(data_splits['train']['images'], 
                          data_splits['train']['labels']) if lbl[i] == 1]
            if class_images:
                try:
                    sample_img = cv2.imread(class_images[0], cv2.IMREAD_GRAYSCALE)
                    if sample_img is not None:
                        plt.subplot(3, 5, shown_classes+1)
                        plt.imshow(sample_img, cmap='gray')
                        plt.title(class_name)
                        plt.axis('off')
                        shown_classes += 1
                except Exception as e:
                    print(f"Error visualizing sample for class {class_name}: {str(e)}")
        
        plt.tight_layout()
        samples_plot_path = os.path.join(Config.RESULTS_DIR, "plots", "sample_images.png")
        plt.savefig(samples_plot_path)
        plt.close()
        print(f"Sample images plot saved to {samples_plot_path}")
    except Exception as e:
        print(f"Warning: Error generating visualization plots: {str(e)}")
    
    return (train_loader, val_loader, test_loader, 
            (data_splits['train']['images'], 
             data_splits['val']['images'], 
             data_splits['test']['images']))

# Model creation
def create_model():
    print("\n" + "="*50)
    print("CREATING MODEL")
    print("="*50)
    
    try:
        # Create EfficientNet-B1 with pretrained weights from torchvision
        model = efficientnet_b1(weights=EfficientNet_B1_Weights.DEFAULT)
        
        # Modify classifier for multi-label classification
        num_features = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_features, len(labels))
        
        print(f"Model: EfficientNet-B1 from torchvision")
        print(f"Number of classes: {len(labels)}")
        print(f"Using device: {Config.DEVICE}")
        
        # Memory optimization for GPU
        if torch.cuda.is_available():
            # Use mixed precision training to save GPU memory
            print("Enabling mixed precision training to optimize VRAM usage")
        
        model = model.to(Config.DEVICE)
        return model
    except Exception as e:
        print(f"Error creating EfficientNet-B1 model from torchvision: {str(e)}")
        raise RuntimeError(f"Failed to create EfficientNet-B1 model. Ensure torchvision is installed properly.")


def train_model(model, train_loader, val_loader):
    print("\n" + "="*50)
    print("TRAINING MODEL")
    print("="*50)
    
    criterion = BCEWithLogitsLoss()
    optimizer = Adam(model.parameters(), lr=Config.LEARNING_RATE)
    
    print("\n" + "="*50)
    print("TRAINING MODEL")
    print("="*50)
    
    criterion = BCEWithLogitsLoss()
    optimizer = Adam(model.parameters(), lr=Config.LEARNING_RATE)
    
    # Add explicit GPU memory cleanup
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()
    
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    train_aurocs = []
    val_aurocs = []
    
    start_time = time.time()
    
    for epoch in range(Config.NUM_EPOCHS):
        epoch_start = time.time()
        
        model.train()
        train_loss = 0
        train_preds = []
        train_targets = []
        
        print(f"\nEpoch {epoch+1}/{Config.NUM_EPOCHS}")
        print("-" * 20)
        
        progress_bar = tqdm(train_loader, desc=f"Training")
        
        # Use a smaller batch accumulation to avoid memory issues
        for batch_idx, (images, targets) in enumerate(progress_bar):
            try:
                images = images.to(Config.DEVICE)
                targets = targets.float().to(Config.DEVICE)
                
                # Standard forward pass - no mixed precision
                outputs = model(images)
                loss = criterion(outputs, targets)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item() * images.size(0)
                train_preds.append(torch.sigmoid(outputs).detach().cpu().numpy())
                train_targets.append(targets.cpu().numpy())
                
                progress_bar.set_postfix({"batch_loss": loss.item()})
                
                # Add occasional GPU memory cleanup
                if batch_idx % 100 == 0 and torch.cuda.is_available():
                    torch.cuda.empty_cache()
                
            except Exception as e:
                print(f"Error in training batch: {str(e)}")
                continue
        
        if len(train_loader.dataset) > 0:
            train_loss /= len(train_loader.dataset)
        
        if train_preds and train_targets:
            try:
                train_preds = np.concatenate(train_preds)
                train_targets = np.concatenate(train_targets)
                train_auroc_val = roc_auc_score(train_targets, train_preds, average='macro')
            except Exception as e:
                print(f"Error calculating training AUROC: {str(e)}")
                train_auroc_val = 0
        else:
            train_auroc_val = 0
        
        # Inside train_model function, update the validation loop
        model.eval()
        val_loss = 0
        val_preds = []
        val_targets = []
        
        with torch.no_grad():
            progress_bar = tqdm(val_loader, desc=f"Validation")
            for images, targets in progress_bar:
                try:
                    images = images.to(Config.DEVICE)
                    targets = targets.float().to(Config.DEVICE)
                    
                    # No mixed precision, simple forward pass
                    outputs = model(images)
                    loss = criterion(outputs, targets)
                    
                    val_loss += loss.item() * images.size(0)
                    val_preds.append(torch.sigmoid(outputs).detach().cpu().numpy())
                    val_targets.append(targets.cpu().numpy())
                    
                    progress_bar.set_postfix({"batch_loss": loss.item()})
                except Exception as e:
                    print(f"Error in validation batch: {str(e)}")
                    continue
        
        if len(val_loader.dataset) > 0:
            val_loss /= len(val_loader.dataset)
            
        if val_preds and val_targets:
            try:
                val_preds = np.concatenate(val_preds)
                val_targets = np.concatenate(val_targets)
                val_auroc_val = roc_auc_score(val_targets, val_preds, average='macro')
            except Exception as e:
                print(f"Error calculating validation AUROC: {str(e)}")
                val_auroc_val = 0
        else:
            val_auroc_val = 0
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_aurocs.append(train_auroc_val)
        val_aurocs.append(val_auroc_val)
        
        epoch_time = time.time() - epoch_start
        print(f"Epoch {epoch+1}/{Config.NUM_EPOCHS} completed in {epoch_time:.2f}s")
        print(f"Train Loss: {train_loss:.4f}, Train AUROC: {train_auroc_val:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val AUROC: {val_auroc_val:.4f}")
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_path = os.path.join(Config.MODEL_DIR, f"{Config.MODEL_NAME}_best.pth")
            try:
                torch.save(model.state_dict(), best_model_path)
                print(f"Model improved! Saved to {best_model_path}")
            except Exception as e:
                print(f"Error saving model: {str(e)}")
        
        # ... visualization code remains the same ...
        if (epoch + 1) % 5 == 0 or epoch == 0 or epoch == Config.NUM_EPOCHS - 1:
            try:
                if val_preds.size > 0 and val_targets.size > 0:
                    class_aurocs = []
                    for i in range(len(labels)):
                        try:
                            class_auroc = roc_auc_score(val_targets[:, i], val_preds[:, i])
                            class_aurocs.append(class_auroc)
                        except Exception:
                            class_aurocs.append(0.5)
                    
                    plt.figure(figsize=(12, 6))
                    sns.barplot(x=labels, y=class_aurocs)
                    plt.title(f'Per-Class AUROC at Epoch {epoch+1}')
                    plt.xticks(rotation=45, ha='right')
                    plt.ylim(0, 1)
                    plt.tight_layout()
                    
                    class_metrics_path = os.path.join(Config.RESULTS_DIR, "plots", f"class_metrics_epoch_{epoch+1}.png")
                    plt.savefig(class_metrics_path)
                    plt.close()
                    print(f"Class metrics for epoch {epoch+1} saved to {class_metrics_path}")
            except Exception as e:
                print(f"Error visualizing class metrics: {str(e)}")
    
    total_time = time.time() - start_time
    print(f"\nTraining completed in {total_time/60:.2f} minutes")
    
    final_model_path = os.path.join(Config.MODEL_DIR, f"{Config.MODEL_NAME}_final.pth")
    try:
        torch.save(model.state_dict(), final_model_path)
        print(f"Final model saved to {final_model_path}")
    except Exception as e:
        print(f"Error saving final model: {str(e)}")
    
    try:
        plt.figure(figsize=(12, 5))
        
        plt.subplot(1, 2, 1)
        plt.plot(range(1, Config.NUM_EPOCHS+1), train_losses, label='Train Loss')
        plt.plot(range(1, Config.NUM_EPOCHS+1), val_losses, label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training and Validation Loss')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.plot(range(1, Config.NUM_EPOCHS+1), train_aurocs, label='Train AUROC')
        plt.plot(range(1, Config.NUM_EPOCHS+1), val_aurocs, label='Validation AUROC')
        plt.xlabel('Epoch')
        plt.ylabel('AUROC')
        plt.title('Training and Validation AUROC')
        plt.legend()
        
        plt.tight_layout()
        learning_curves_path = os.path.join(Config.RESULTS_DIR, "plots", "learning_curves.png")
        plt.savefig(learning_curves_path)
        plt.close()
        print(f"Learning curves saved to {learning_curves_path}")
    except Exception as e:
        print(f"Error plotting learning curves: {str(e)}")
    
    return os.path.join(Config.MODEL_DIR, f"{Config.MODEL_NAME}_best.pth")

def evaluate_model(model_path, test_loader):
    print("\n" + "="*50)
    print("EVALUATING MODEL")
    print("="*50)
    
    try:
        model = create_model()
        model.load_state_dict(torch.load(model_path, map_location=Config.DEVICE))
        print(f"Loaded model from {model_path}")
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        print("Using newly initialized model for evaluation")
        model = create_model()
    
    model.eval()
    
    criterion = BCEWithLogitsLoss()
    test_loss = 0
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        progress_bar = tqdm(test_loader, desc="Testing")
        for images, targets in progress_bar:
            try:
                images = images.to(Config.DEVICE)
                targets = targets.float().to(Config.DEVICE)
                
                outputs = model(images)
                loss = criterion(outputs, targets)
                
                test_loss += loss.item() * images.size(0)
                all_preds.append(torch.sigmoid(outputs).cpu().numpy())
                all_targets.append(targets.cpu().numpy())
            except Exception as e:
                print(f"Error in test batch: {str(e)}")
                continue
    
    if len(test_loader.dataset) > 0:
        test_loss /= len(test_loader.dataset)
    
    if not all_preds or not all_targets:
        print("No predictions or targets collected during evaluation")
        return 0, None
    
    try:
        all_preds = np.concatenate(all_preds)
        all_targets = np.concatenate(all_targets)
        
        class_aurocs = []
        for i in range(len(labels)):
            try:
                class_auroc = roc_auc_score(all_targets[:, i], all_preds[:, i])
                class_aurocs.append(class_auroc)
            except Exception:
                class_aurocs.append(0.5)
        
        avg_auroc = sum(class_aurocs) / len(class_aurocs)
        
        print(f"\nTest Loss: {test_loss:.4f}")
        print(f"Average AUROC: {avg_auroc:.4f}")
        
        print("\nPer-Class Metrics:")
        for i, class_name in enumerate(labels):
            print(f"{class_name}: AUROC = {class_aurocs[i]:.4f}")
        
        plt.figure(figsize=(14, 7))
        sns.barplot(x=labels, y=class_aurocs)
        plt.title('Per-Class AUROC on Test Set')
        plt.xticks(rotation=45, ha='right')
        plt.ylim(0, 1)
        plt.tight_layout()
        
        test_metrics_path = os.path.join(Config.RESULTS_DIR, "plots", "test_metrics.png")
        plt.savefig(test_metrics_path)
        plt.close()
        print(f"Test metrics plot saved to {test_metrics_path}")
        
        binary_preds = (all_preds > 0.5).astype(int)
        
        accuracies = []
        precisions = []
        recalls = []
        f1_scores = []
        
        for i in range(len(labels)):
            try:
                acc = accuracy_score(all_targets[:, i], binary_preds[:, i])
                prec = precision_score(all_targets[:, i], binary_preds[:, i], zero_division=0)
                rec = recall_score(all_targets[:, i], binary_preds[:, i], zero_division=0)
                f1 = f1_score(all_targets[:, i], binary_preds[:, i], zero_division=0)
                
                accuracies.append(acc)
                precisions.append(prec)
                recalls.append(rec)
                f1_scores.append(f1)
            except Exception as e:
                print(f"Error calculating metrics for class {labels[i]}: {str(e)}")
                accuracies.append(0)
                precisions.append(0)
                recalls.append(0)
                f1_scores.append(0)
        
        metrics_df = pd.DataFrame({
            'Class': labels,
            'AUROC': class_aurocs,
            'Accuracy': accuracies,
            'Precision': precisions,
            'Recall': recalls,
            'F1 Score': f1_scores
        })
        
        metrics_csv_path = os.path.join(Config.RESULTS_DIR, "test_metrics.csv")
        metrics_df.to_csv(metrics_csv_path, index=False)
        print(f"Detailed metrics saved to {metrics_csv_path}")
        
        visualize_predictions(model, test_loader)
        
        return avg_auroc, metrics_df
    
    except Exception as e:
        print(f"Error during evaluation: {str(e)}")
        return 0, None

def visualize_predictions(model, test_loader):
    print("\n" + "="*50)
    print("VISUALIZING PREDICTIONS")
    print("="*50)
    
    if len(test_loader) == 0:
        print("No test data available for visualization")
        return
    
    try:
        for images, targets in test_loader:
            break
        
        model.eval()
        with torch.no_grad():
            outputs = model(images.to(Config.DEVICE))
            probs = torch.sigmoid(outputs).cpu().numpy()
        
        targets = targets.numpy()
        
        num_images = min(12, len(images))
        rows = min(3, (num_images + 3) // 4)
        cols = min(4, num_images)
        
        fig, axs = plt.subplots(rows, cols, figsize=(20, 15))
        if rows == 1 and cols == 1:
            axs = np.array([axs])
        axs = axs.flatten() if hasattr(axs, 'flatten') else np.array([axs])
        
        for i in range(num_images):
            try:
                img = images[i].permute(1, 2, 0).numpy()
                img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
                img = np.clip(img, 0, 1)
                
                axs[i].imshow(img)
                axs[i].set_title(f"Image {i+1}")
                axs[i].axis('off')
                
                top_indices = np.argsort(probs[i])[::-1][:3]
                top_labels = [labels[idx] for idx in top_indices]
                top_probs = [probs[i][idx] for idx in top_indices]
                
                actual_indices = np.where(targets[i] > 0.5)[0]
                actual_labels = [labels[idx] for idx in actual_indices]
                
                text_content = "Predictions:\n"
                for j, (label, prob) in enumerate(zip(top_labels, top_probs)):
                    text_content += f"{j+1}. {label}: {prob:.2f}\n"
                
                text_content += "\nActual:\n"
                for j, label in enumerate(actual_labels):
                    text_content += f"{j+1}. {label}\n"
                
                axs[i].text(1.05, 0.5, text_content, transform=axs[i].transAxes,
                           verticalalignment='center', fontsize=10)
            
            except Exception as e:
                print(f"Error visualizing image {i}: {str(e)}")
                continue
        
        for j in range(num_images, len(axs)):
            axs[j].axis('off')
            axs[j].set_visible(False)
        
        plt.tight_layout()
        predictions_path = os.path.join(Config.RESULTS_DIR, "plots", "sample_predictions.png")
        plt.savefig(predictions_path)
        plt.close()
        print(f"Sample predictions visualization saved to {predictions_path}")
        
        plt.figure(figsize=(14, 10))
        sns.heatmap(probs[:min(10, len(probs))], 
                   xticklabels=labels,
                   yticklabels=[f"Image {i+1}" for i in range(min(10, len(probs)))],
                   cmap="YlGnBu", vmin=0, vmax=1, annot=True, fmt='.2f')
        plt.title('Prediction Probabilities Heatmap')
        plt.tight_layout()
        
        heatmap_path = os.path.join(Config.RESULTS_DIR, "plots", "prediction_heatmap.png")
        plt.savefig(heatmap_path)
        plt.close()
        print(f"Prediction heatmap saved to {heatmap_path}")
        
    except Exception as e:
        print(f"Error during prediction visualization: {str(e)}")

class ChestXRayPredictor:
    def __init__(self, model_path):
        self.model_path = model_path
        self.device = Config.DEVICE
        
        try:
            # Create model with torchvision - changed to B1
            self.model = efficientnet_b1(weights=None)  # No need for pretrained weights when loading
            # Update classifier to match our class count
            num_features = self.model.classifier[1].in_features
            self.model.classifier[1] = nn.Linear(num_features, len(labels))
            
            self.model.load_state_dict(torch.load(model_path, map_location=self.device))
            self.model.eval()
            print(f"Predictor initialized with model from {model_path}")
        except Exception as e:
            print(f"Error loading model: {str(e)}")
            self.model = None
        
        self.transform = T.Compose([
            T.ToPILImage(),
            T.Resize((Config.IMG_SIZE, Config.IMG_SIZE)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    # The rest of the class remains unchanged
    def predict(self, image_path):
        if self.model is None:
            return {"error": "Model not properly loaded"}
        
        try:
            img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
            if img is None:
                return {"error": f"Could not read image at {image_path}"}
            
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
            img_tensor = self.transform(img).unsqueeze(0).to(self.device)
            
            with torch.no_grad():
                # Use mixed precision inference for better performance
                if torch.cuda.is_available():
                    with torch.cuda.amp.autocast():
                        outputs = self.model(img_tensor)
                else:
                    outputs = self.model(img_tensor)
                probs = torch.sigmoid(outputs).squeeze().cpu().numpy()
            
            results = {}
            results["probabilities"] = {label: float(prob) for label, prob in zip(labels, probs)}
            diff_diagnosis = [(label, float(prob)) for label, prob in 
                            sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)]
            results["differential_diagnosis"] = diff_diagnosis[:5]
            
            severity = {}
            for label, prob in zip(labels, probs):
                sev = min(5, max(1, int(prob * 5) + (1 if np.random.random() < 0.3 else 0)))
                severity[label] = sev
            results["severity"] = severity
            
            return results
            
        except Exception as e:
            return {"error": str(e)}
    
    # Keep the visualize_prediction method unchanged
    def visualize_prediction(self, image_path, save_path=None):
        result = self.predict(image_path)
        if "error" in result:
            print(f"Error: {result['error']}")
            return None
        
        img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            print(f"Could not read image at {image_path} for visualization")
            return None
            
        fig, axs = plt.subplots(1, 2, figsize=(15, 7))
        
        axs[0].imshow(img, cmap='gray')
        axs[0].set_title('Original Chest X-Ray')
        axs[0].axis('off')
        
        probs = result["probabilities"]
        sorted_items = sorted(probs.items(), key=lambda x: x[1], reverse=True)
        labels_sorted = [item[0] for item in sorted_items]
        probs_sorted = [item[1] for item in sorted_items]
        
        bars = axs[1].barh(range(len(labels_sorted)), probs_sorted, color='skyblue')
        axs[1].set_yticks(range(len(labels_sorted)))
        axs[1].set_yticklabels(labels_sorted)
        axs[1].set_xlabel('Probability')
        axs[1].set_title('Prediction Probabilities')
        
        top_3_labels = [diff[0] for diff in result["differential_diagnosis"][:3]]
        for i, label in enumerate(labels_sorted):
            if label in top_3_labels:
                bars[i].set_color('red')
        
        for i, (label, prob) in enumerate(sorted_items):
            severity = result["severity"][label]
            axs[1].text(prob + 0.01, i, f"Severity: {severity}/5", va='center')
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path)
            print(f"Visualization saved to {save_path}")
        
        return fig
        
def main():
    Config.create_directories()
    set_seeds(Config.SEED)
    
    print("=" * 50)
    print("CHEST X-RAY ANALYSIS SYSTEM")
    print("=" * 50)
    print(f"Using device: {Config.DEVICE}")
    
    # Verify GPU availability
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        # Clean GPU memory at start
        torch.cuda.empty_cache()
        gc.collect()
    else:
        print("Warning: No GPU available, using CPU")

    train_loader, val_loader, test_loader, (train_images, val_images, test_images) = prepare_data()
    
    # Test dataloader with a single batch to check for dataset issues
    try:
        print("Testing dataloader with a single batch...")
        start_time = time.time()
        test_batch = next(iter(train_loader))
        print(f"Successfully loaded a batch in {time.time() - start_time:.2f} seconds")
        print(f"Batch images shape: {test_batch[0].shape}")
        print(f"Batch labels shape: {test_batch[1].shape}")
    except Exception as e:
        print(f"Error loading a batch: {str(e)}")
        print("This might indicate issues with your dataset or dataloader configuration.")
    
    # Create and train the model
    model = create_model()
    
    # Check if model was created successfully
    if model is not None:
        print("Model created successfully, starting training...")
        best_model_path = train_model(model, train_loader, val_loader)
        
        # Free memory after training
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            gc.collect()
        
        # Evaluate model
        avg_auroc, metrics_df = evaluate_model(best_model_path, test_loader)
        
        print("\n" + "="*50)
        print("TESTING ON INDIVIDUAL SAMPLES")
        print("="*50)
        
        predictor = ChestXRayPredictor(best_model_path)
        
        num_samples = min(5, len(test_images))
        if num_samples > 0:
            for i in range(num_samples):
                try:
                    img_path = test_images[i]
                    print(f"\nAnalyzing image: {os.path.basename(img_path)}")
                    
                    result = predictor.predict(img_path)
                    if "error" in result:
                        print(f"Error analyzing image: {result['error']}")
                        continue
                    
                    print("Differential Diagnosis:")
                    for j, (label, prob) in enumerate(result['differential_diagnosis']):
                        print(f"{j+1}. {label}: {prob:.2%} (Severity: {result['severity'][label]}/5)")
                    
                    save_dir = os.path.join(Config.RESULTS_DIR, "predictions")
                    os.makedirs(save_dir, exist_ok=True)
                    save_path = os.path.join(save_dir, f"pred_{os.path.basename(img_path)}.png")
                    predictor.visualize_prediction(img_path, save_path)
                except Exception as e:
                    print(f"Error processing test image {i}: {str(e)}")
        else:
            print("No test images available for individual testing")
        
        print("\n" + "="*50)
        print("SYSTEM SUMMARY")
        print("="*50)
        print(f"Model: EfficientNet-B1")
        print(f"Average AUROC on test set: {avg_auroc:.4f}")
        print(f"Best model saved to: {best_model_path}")
        print(f"All results saved to: {Config.RESULTS_DIR}")
    else:
        print("Failed to create model. Please check the logs for details.")
    
    print("=" * 50)
    print("CHEST X-RAY ANALYSIS SYSTEM COMPLETE")
    print("=" * 50)
    
    # Final memory cleanup
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()

if __name__ == "__main__":
    main()