In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from torchvision.models import resnet50
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import os
import json
import random
import math
import time
from datetime import datetime

In [1]:
import torch
import gc

# Clear PyTorch's CUDA cache
torch.cuda.empty_cache()

# Force Python garbage collection
gc.collect()

10

In [6]:
try:
    import timm
except ImportError:
    print("Installing timm library...")
    import subprocess
    subprocess.check_call(["pip", "install", "timm"])
    import timm


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from torchvision.models import resnet50
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import os
import json
import random
import math
import time
from datetime import datetime
from tqdm import tqdm  # Add this import

# Install required packages
try:
    import timm
    from tqdm import tqdm
except ImportError:
    print("Installing required libraries...")
    import subprocess
    subprocess.check_call(["pip", "install", "timm", "tqdm"])
    import timm
    from tqdm import tqdm

# Enhanced Configuration with EfficientNet settings
class Config:
    data_dir = "/data1/home/prakrutp/medical_imaging/dataset"
    csv_path = "/data1/home/prakrutp/medical_imaging/dataset/Data_Entry_2017.csv"
    image_size = 512
    batch_size = 16  # Start with smaller batch size
    num_epochs = 100
    learning_rate = 0.001
    num_classes = 14
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # EfficientNet specific
    model_name = "tf_efficientnet_b3"
    use_pretrained = True
    feature_dim = 1536
    
    # Advanced training
    use_progressive_resizing = False  # Disable for now to debug
    use_advanced_augmentation = True
    use_class_weights = True
    
    # Split files
    test_list_file = "test_list.txt"
    train_val_list_file = "train_val_list.txt"
    
    # Checkpoint settings
    checkpoint_dir = "checkpoints_efficientnet"
    checkpoint_interval = 5
    resume_training = True

# Create checkpoint directory
os.makedirs(Config.checkpoint_dir, exist_ok=True)

# Optimized Class Weight Calculation
def calculate_class_weights(dataset):
    """Calculate class weights for handling imbalance - OPTIMIZED VERSION"""
    print("Calculating class weights...")
    
    # Method 1: Use the dataframe directly (MUCH faster)
    if hasattr(dataset, 'data_frame'):
        print("Using dataframe for fast class weight calculation...")
        class_counts = np.zeros(Config.num_classes)
        
        for i, disease in enumerate(dataset.disease_classes):
            class_counts[i] = dataset.data_frame[disease].sum()
        
        total_samples = len(dataset)
        class_weights = total_samples / (Config.num_classes * np.maximum(class_counts, 1))
        class_weights = class_weights / np.sum(class_weights) * Config.num_classes
        
    else:
        # Method 2: Fallback to iterative (with progress bar)
        print("Using iterative method with progress tracking...")
        class_counts = np.zeros(Config.num_classes)
        
        for idx in tqdm(range(len(dataset)), desc="Calculating class weights"):
            _, labels = dataset[idx]
            class_counts += labels.numpy()
    
        # Avoid division by zero
        class_counts = np.maximum(class_counts, 1)
        
        # Inverse frequency weighting
        total_samples = len(dataset)
        class_weights = total_samples / (Config.num_classes * class_counts)
        
        # Normalize weights
        class_weights = class_weights / np.sum(class_weights) * Config.num_classes
    
    print("Class weights calculated:")
    for i, disease in enumerate(dataset.disease_classes):
        print(f"  {disease}: {class_weights[i]:.2f} (count: {class_counts[i]})")
    
    return torch.FloatTensor(class_weights).to(Config.device)

# Keep all your existing dataset and helper functions here (they're fine)
class ChestXrayDataset(Dataset):
    def __init__(self, csv_file, base_dir, image_list=None, transform=None):
        self.df = pd.read_csv(csv_file)
        self.base_dir = base_dir
        self.transform = transform
        
        self.disease_classes = [
            'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 
            'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 
            'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 
            'Pleural_Thickening', 'Hernia'
        ]
        
        self._create_label_columns()
        
        if image_list is not None:
            self.data_frame = self.df[self.df['Image Index'].isin(image_list)].reset_index(drop=True)
        else:
            self.data_frame = self.df
        
        print(f"Dataset initialized with {len(self.data_frame)} images")
        
    def _create_label_columns(self):
        for disease in self.disease_classes:
            self.df.loc[:, disease] = self.df['Finding Labels'].apply(
                lambda x: 1 if disease in x else 0
            )
    
    def _find_image_path(self, img_name):
        for i in range(1, 13):
            folder_name = f"images_{i:03d}"
            possible_path = os.path.join(self.base_dir, folder_name, "images", img_name)
            if os.path.exists(possible_path):
                return possible_path
        return os.path.join(self.base_dir, "images_001", "images", img_name)
    
    def __len__(self):
        return len(self.data_frame)
    
    def __getitem__(self, idx):
        img_name = self.data_frame.iloc[idx]['Image Index']
        img_path = self._find_image_path(img_name)
        
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Warning: Could not load image {img_path}: {e}")
            image = Image.new('RGB', (Config.image_size, Config.image_size), color='black')
        
        labels = []
        for disease in self.disease_classes:
            labels.append(self.data_frame.iloc[idx][disease])
        
        if self.transform:
            image = self.transform(image)
            
        return image, torch.FloatTensor(labels)
    
    def get_class_distribution(self):
        distribution = {}
        for disease in self.disease_classes:
            count = self.data_frame[disease].sum()
            percentage = (count / len(self.data_frame)) * 100
            distribution[disease] = {'count': count, 'percentage': percentage}
        return distribution

def load_split_file(file_path):
    with open(file_path, 'r') as f:
        image_names = [line.strip() for line in f.readlines()]
    return image_names

def create_datasets_from_splits(csv_file, base_dir, test_list_file, train_val_list_file, transform=None):
    test_images = load_split_file(test_list_file)
    train_val_images = load_split_file(train_val_list_file)
    
    print(f"Test images: {len(test_images)}")
    print(f"Train+Val images: {len(train_val_images)}")
    
    train_images, val_images = train_test_split(
        train_val_images, test_size=0.2, random_state=42
    )
    
    print(f"Train images: {len(train_images)}")
    print(f"Validation images: {len(val_images)}")
    
    train_dataset = ChestXrayDataset(
        csv_file=csv_file,
        base_dir=base_dir,
        image_list=train_images,
        transform=transform
    )
    
    val_dataset = ChestXrayDataset(
        csv_file=csv_file,
        base_dir=base_dir,
        image_list=val_images,
        transform=transform
    )
    
    test_dataset = ChestXrayDataset(
        csv_file=csv_file,
        base_dir=base_dir,
        image_list=test_images,
        transform=transform
    )
    
    return train_dataset, val_dataset, test_dataset

# Enhanced Data Augmentation for EfficientNet
def get_advanced_transforms():
    if Config.use_advanced_augmentation:
        train_transform = transforms.Compose([
            transforms.Resize((Config.image_size, Config.image_size)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
    else:
        train_transform = transforms.Compose([
            transforms.Resize((Config.image_size, Config.image_size)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
    
    val_transform = transforms.Compose([
        transforms.Resize((Config.image_size, Config.image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    return train_transform, val_transform

# Simplified EfficientNet Model (remove progressive training for now)
class EfficientNetModel(nn.Module):
    def __init__(self, num_classes=14, model_name="tf_efficientnet_b3"):
        super(EfficientNetModel, self).__init__()
        
        self.backbone = timm.create_model(
            model_name, 
            pretrained=Config.use_pretrained,
            num_classes=0
        )
        
        if "b3" in model_name:
            self.feature_dim = 1536
        else:
            self.feature_dim = self.backbone.num_features
        
        self.classifier = nn.Linear(self.feature_dim, num_classes)
        self.sigmoid = nn.Sigmoid()
        
        print(f"Initialized {model_name} with {self.feature_dim} feature dimensions")
    
    def forward(self, x):
        features = self.backbone(x)
        output = self.classifier(features)
        output = self.sigmoid(output)
        return output

# Enhanced Loss Function
class EnhancedWeightedLoss(nn.Module):
    def __init__(self, class_weights=None):
        super(EnhancedWeightedLoss, self).__init__()
        self.class_weights = class_weights
    
    def forward(self, outputs, targets):
        batch_size, num_classes = targets.shape
        
        # Simplified loss calculation
        loss = torch.zeros(1, device=outputs.device)
        
        for i in range(num_classes):
            class_output = outputs[:, i]
            class_target = targets[:, i]
            
            # Binary cross entropy for each class
            class_loss = - (class_target * torch.log(class_output + 1e-8) + 
                          (1 - class_target) * torch.log(1 - class_output + 1e-8))
            
            # Apply class weights if provided
            if self.class_weights is not None:
                class_loss = class_loss * self.class_weights[i]
            
            loss += class_loss.mean()
        
        return loss / num_classes

# Checkpoint functions (simplified)
def save_checkpoint(epoch, model, optimizer, train_losses, val_losses, val_auc_scores, best_auc, is_best=False):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses,
        'val_auc_scores': val_auc_scores,
        'best_auc': best_auc,
    }
    
    filename = f"checkpoint_epoch_{epoch:03d}.pth"
    checkpoint_path = os.path.join(Config.checkpoint_dir, filename)
    torch.save(checkpoint, checkpoint_path)
    
    if is_best:
        best_path = os.path.join(Config.checkpoint_dir, "best_model.pth")
        torch.save(checkpoint, best_path)
    
    print(f"Checkpoint saved: {checkpoint_path}")

def find_latest_checkpoint():
    checkpoint_files = [f for f in os.listdir(Config.checkpoint_dir) if f.startswith('checkpoint_epoch_')]
    if not checkpoint_files:
        return None
    epochs = [int(f.split('_')[2].split('.')[0]) for f in checkpoint_files]
    latest_epoch = max(epochs)
    return os.path.join(Config.checkpoint_dir, f"checkpoint_epoch_{latest_epoch:03d}.pth")

# Training function with progress bars
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, start_epoch=0):
    train_losses = []
    val_losses = []
    val_auc_scores = []
    best_auc = 0.0
    
    for epoch in range(start_epoch, num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        print('-' * 50)
        
        # Training phase
        model.train()
        running_loss = 0.0
        
        train_pbar = tqdm(train_loader, desc=f'Training Epoch {epoch+1}')
        for batch_idx, (inputs, labels) in enumerate(train_pbar):
            inputs = inputs.to(Config.device)
            labels = labels.to(Config.device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
            train_pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
        
        epoch_loss = running_loss / len(train_loader.dataset)
        train_losses.append(epoch_loss)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        all_labels = []
        all_preds = []
        
        val_pbar = tqdm(val_loader, desc=f'Validation Epoch {epoch+1}')
        with torch.no_grad():
            for inputs, labels in val_pbar:
                inputs = inputs.to(Config.device)
                labels = labels.to(Config.device)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * inputs.size(0)
                
                all_labels.append(labels.cpu().numpy())
                all_preds.append(outputs.cpu().numpy())
        
        val_loss = val_loss / len(val_loader.dataset)
        val_losses.append(val_loss)
        
        # Calculate AUC
        all_labels = np.concatenate(all_labels)
        all_preds = np.concatenate(all_preds)
        
        auc_scores = []
        for i in range(Config.num_classes):
            try:
                if np.sum(all_labels[:, i]) > 0:
                    auc = roc_auc_score(all_labels[:, i], all_preds[:, i])
                    auc_scores.append(auc)
                else:
                    auc_scores.append(0.0)
            except:
                auc_scores.append(0.0)
        
        mean_auc = np.mean(auc_scores)
        val_auc_scores.append(mean_auc)
        
        print(f'Training Loss: {epoch_loss:.4f}, Validation Loss: {val_loss:.4f}')
        print(f'Validation AUC: {mean_auc:.4f}')
        
        # Save checkpoint
        if mean_auc > best_auc:
            best_auc = mean_auc
            save_checkpoint(epoch+1, model, optimizer, train_losses, val_losses, val_auc_scores, best_auc, is_best=True)
        elif (epoch + 1) % Config.checkpoint_interval == 0:
            save_checkpoint(epoch+1, model, optimizer, train_losses, val_losses, val_auc_scores, best_auc)
    
    return model, train_losses, val_losses, val_auc_scores

# Simplified Main Function
def main_efficientnet():
    print("ðŸš€ Setting up EFFICIENTNET model...")
    print(f"Using device: {Config.device}")
    print(f"Model: {Config.model_name}")
    
    # Create transforms
    train_transform, val_transform = get_advanced_transforms()
    
    # Create datasets
    test_list_path = os.path.join(Config.data_dir, Config.test_list_file)
    train_val_list_path = os.path.join(Config.data_dir, Config.train_val_list_file)
    csv_path = os.path.join(Config.data_dir, Config.csv_path)
    
    print("Loading datasets...")
    train_dataset, val_dataset, test_dataset = create_datasets_from_splits(
        csv_file=csv_path,
        base_dir=Config.data_dir,
        test_list_file=test_list_path,
        train_val_list_file=train_val_list_path,
        transform=train_transform
    )
    
    val_dataset.transform = val_transform
    test_dataset.transform = val_transform
    
    print(f"Train dataset size: {len(train_dataset)}")
    print(f"Validation dataset size: {len(val_dataset)}")
    print(f"Test dataset size: {len(test_dataset)}")
    
    # Calculate class weights (optimized)
    class_weights = None
    if Config.use_class_weights:
        class_weights = calculate_class_weights(train_dataset)
    
    # Create data loaders
    print("Creating data loaders...")
    train_loader = DataLoader(
        train_dataset, 
        batch_size=Config.batch_size, 
        shuffle=True, 
        num_workers=4,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=Config.batch_size, 
        shuffle=False, 
        num_workers=4,
        pin_memory=True
    )
    
    # Initialize model
    print(f"Initializing {Config.model_name}...")
    model = EfficientNetModel(
        num_classes=Config.num_classes, 
        model_name=Config.model_name
    )
    model = model.to(Config.device)
    
    # Optimizer
    optimizer = optim.AdamW(model.parameters(), lr=Config.learning_rate, weight_decay=0.01)
    criterion = EnhancedWeightedLoss(class_weights=class_weights)
    
    # Count parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model initialized with {trainable_params:,} trainable parameters")
    
    # Start training
    print("\nStarting training...")
    model, train_losses, val_losses, val_auc_scores = train_model(
        model, train_loader, val_loader, criterion, optimizer, 
        Config.num_epochs
    )
    
    print("\nTraining completed!")
    print(f"Best validation AUC: {max(val_auc_scores):.4f}")

if __name__ == "__main__":
    main_efficientnet()

ðŸš€ Setting up EFFICIENTNET model with modern enhancements...
Using device: cuda
Model: tf_efficientnet_b3
Available GPU VRAM: 23.6 GB
Using batch size 12 for EfficientNet
Test images: 25596
Train+Val images: 86524
Train images: 69219
Validation images: 17305
Dataset initialized with 69219 images
Dataset initialized with 17305 images
Dataset initialized with 25596 images
Train dataset size: 69219
Validation dataset size: 17305
Test dataset size: 25596


KeyboardInterrupt: 