In [None]:
# DenseNet Model for Diabetic Retinopathy Detection
# Google Colab Notebook - Run each cell individually

# ========== CELL 1: Install Dependencies ==========
!pip install opencv-python-headless
!pip install scikit-learn
!pip install seaborn
!pip install plotly

import warnings
warnings.filterwarnings('ignore')
print("✅ Dependencies installed successfully!")

# ========== CELL 2: Import Libraries ==========
import os
import zipfile
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
import cv2
from google.colab import files, drive
import seaborn as sns
from collections import Counter
import torch.nn.functional as F

print("✅ All libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

# ========== CELL 3: Mount Google Drive (Optional) ==========
# Uncomment if you want to save/load from Google Drive
# drive.mount('/content/drive')
print("💾 Google Drive mounting available if needed")

# ========== CELL 4: Upload Dataset ==========
print("📁 Please upload your diabetic retinopathy dataset zip file:")
uploaded = files.upload()

zip_filename = list(uploaded.keys())[0]
print(f"📦 Extracting {zip_filename}...")

with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
    zip_ref.extractall('/content/diabetic_retinopathy_data')

print("✅ Dataset extracted to: /content/diabetic_retinopathy_data")

# ========== CELL 5: Dataset Preprocessing Class ==========
class RetinalImagePreprocessor:
    def __init__(self, target_size=224):  # DenseNet uses 224x224
        self.target_size = target_size

    def crop_image_from_gray(self, img, tol=7):
        """Crop image to remove black borders"""
        if img.ndim == 2:
            mask = img > tol
            return img[np.ix_(mask.any(1), mask.any(0))]
        elif img.ndim == 3:
            gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
            mask = gray_img > tol
            check_shape = img[:,:,0][np.ix_(mask.any(1), mask.any(0))].shape[0]
            if (check_shape == 0):
                return img
            else:
                img1 = img[:,:,0][np.ix_(mask.any(1), mask.any(0))]
                img2 = img[:,:,1][np.ix_(mask.any(1), mask.any(0))]
                img3 = img[:,:,2][np.ix_(mask.any(1), mask.any(0))]
                img = np.stack([img1, img2, img3], axis=-1)
        return img

    def apply_clahe(self, image):
        """Apply CLAHE for contrast enhancement"""
        lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
        clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
        lab[:,:,0] = clahe.apply(lab[:,:,0])
        return cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)

    def preprocess_retinal_image(self, image_path):
        """Complete preprocessing pipeline"""
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.crop_image_from_gray(image)
        image = self.apply_clahe(image)
        image = cv2.resize(image, (self.target_size, self.target_size))
        return image

print("✅ Preprocessing class defined!")

# ========== CELL 6: Dataset Organization ==========
def organize_dataset(base_path):
    """Organize dataset from folder structure"""
    image_paths = []
    labels = []
    label_names = ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferative_DR']

    for root, dirs, files in os.walk(base_path):
        for file in files:
            if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                file_path = os.path.join(root, file)
                folder_name = os.path.basename(root)

                label = -1
                for idx, label_name in enumerate(label_names):
                    if label_name.lower() in folder_name.lower() or label_name.lower() in file.lower():
                        label = idx
                        break

                if label != -1:
                    image_paths.append(file_path)
                    labels.append(label)

    return image_paths, labels, label_names

# Organize the dataset
image_paths, labels, label_names = organize_dataset('/content/diabetic_retinopathy_data')

print(f"✅ Found {len(image_paths)} images")
print("📊 Class distribution:", Counter(labels))
print("🏷️ Classes:", label_names)

# ========== CELL 7: Custom Dataset Class ==========
class DiabeticRetinopathyDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None, preprocessor=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        self.preprocessor = preprocessor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]

        if self.preprocessor:
            image = self.preprocessor.preprocess_retinal_image(image_path)
            image = Image.fromarray(image)
        else:
            image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

print("✅ Custom Dataset class defined!")

# ========== CELL 8: Data Transforms ==========
# Training transforms with augmentation
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # DenseNet input size
    transforms.RandomRotation(15),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomVerticalFlip(0.2),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.1),
    transforms.RandomAffine(degrees=0, translate=(0.05, 0.05), scale=(0.95, 1.05)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Validation/Test transforms without augmentation
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("✅ Data transforms defined!")

# ========== CELL 9: Train-Validation-Test Split ==========
# Split dataset
X_temp, X_test, y_temp, y_test = train_test_split(
    image_paths, labels, test_size=0.2, random_state=42, stratify=labels
)

X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=0.125, random_state=42, stratify=y_temp
)

print(f"📊 Dataset Split:")
print(f"   Train samples: {len(X_train)} ({len(X_train)/len(image_paths)*100:.1f}%)")
print(f"   Validation samples: {len(X_val)} ({len(X_val)/len(image_paths)*100:.1f}%)")
print(f"   Test samples: {len(X_test)} ({len(X_test)/len(image_paths)*100:.1f}%)")

# ========== CELL 10: Create Data Loaders ==========
# Create preprocessor
preprocessor = RetinalImagePreprocessor(target_size=224)

# Create datasets
train_dataset = DiabeticRetinopathyDataset(X_train, y_train, train_transform, preprocessor)
val_dataset = DiabeticRetinopathyDataset(X_val, y_val, val_transform, preprocessor)
test_dataset = DiabeticRetinopathyDataset(X_test, y_test, val_transform, preprocessor)

# Create data loaders
batch_size = 16  # DenseNet is memory intensive
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

print("✅ Data loaders created successfully!")
print(f"   Batch size: {batch_size}")
print(f"   Train batches: {len(train_loader)}")
print(f"   Validation batches: {len(val_loader)}")
print(f"   Test batches: {len(test_loader)}")

# ========== CELL 11: Visualize Sample Images ==========
def visualize_samples(data_loader, label_names, num_samples=8):
    dataiter = iter(data_loader)
    images, labels = next(dataiter)

    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    axes = axes.ravel()

    for i in range(min(num_samples, len(images))):
        image = images[i]
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        image = image * std + mean
        image = torch.clamp(image, 0, 1)

        image_np = image.permute(1, 2, 0).numpy()

        axes[i].imshow(image_np)
        axes[i].set_title(f'Class: {label_names[labels[i]]}', fontsize=12, fontweight='bold')
        axes[i].axis('off')

    plt.suptitle('Sample Images from Dataset', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Visualize samples
visualize_samples(train_loader, label_names)

# ========== CELL 12: Plot Class Distribution ==========
def plot_class_distribution(data_loader, label_names, title="Class Distribution"):
    all_labels = []
    for _, labels in data_loader:
        all_labels.extend(labels.numpy())

    plt.figure(figsize=(12, 6))
    unique_labels, counts = np.unique(all_labels, return_counts=True)

    colors = plt.cm.Set3(np.linspace(0, 1, len(label_names)))
    bars = plt.bar([label_names[i] for i in unique_labels], counts, color=colors)

    plt.title(title, fontsize=14, fontweight='bold')
    plt.xlabel('Diabetic Retinopathy Stage')
    plt.ylabel('Number of Images')
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3)

    # Add value labels on bars
    for bar, count in zip(bars, counts):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                str(count), ha='center', va='bottom', fontweight='bold')

    plt.tight_layout()
    plt.show()

    # Print statistics
    print(f"\n📊 {title}:")
    for i, count in zip(unique_labels, counts):
        percentage = (count / len(all_labels)) * 100
        print(f"   {label_names[i]}: {count} images ({percentage:.1f}%)")

# Plot class distribution
plot_class_distribution(train_loader, label_names, "Training Set Class Distribution")

# ========== CELL 13: DenseNet Model ==========
class DenseNetDiabeticRetinopathy(nn.Module):
    def __init__(self, num_classes=5, model_size='121', pretrained=True, dropout_rate=0.4):
        super(DenseNetDiabeticRetinopathy, self).__init__()

        # Choose DenseNet variant
        if model_size == '121':
            if pretrained:
                self.backbone = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
            else:
                self.backbone = models.densenet121(weights=None)
            feature_dim = 1024
        elif model_size == '169':
            if pretrained:
                self.backbone = models.densenet169(weights=models.DenseNet169_Weights.IMAGENET1K_V1)
            else:
                self.backbone = models.densenet169(weights=None)
            feature_dim = 1664
        elif model_size == '201':
            if pretrained:
                self.backbone = models.densenet201(weights=models.DenseNet201_Weights.IMAGENET1K_V1)
            else:
                self.backbone = models.densenet201(weights=None)
            feature_dim = 1920

        # Remove the original classifier
        self.backbone.classifier = nn.Identity()

        # Add channel attention mechanism
        self.channel_attention = nn.Sequential(
            nn.Linear(feature_dim, feature_dim // 4),
            nn.ReLU(),
            nn.Linear(feature_dim // 4, feature_dim),
            nn.Sigmoid()
        )

        # Custom classifier with multiple branches
        self.classifier = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(feature_dim, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(dropout_rate * 0.5),
            nn.Linear(256, num_classes)
        )

        # Auxiliary classifier for feature learning
        self.aux_classifier = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(feature_dim, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in [self.channel_attention, self.classifier, self.aux_classifier]:
            for layer in m.modules():
                if isinstance(layer, nn.Linear):
                    nn.init.xavier_uniform_(layer.weight)
                    nn.init.constant_(layer.bias, 0)

    def forward(self, x):
        # Extract features
        features = self.backbone(x)

        # Apply channel attention
        attention_weights = self.channel_attention(features)
        attended_features = features * attention_weights

        # Main classification
        main_output = self.classifier(attended_features)

        # Auxiliary classification (for training regularization)
        aux_output = self.aux_classifier(features)

        if self.training:
            return main_output, aux_output
        else:
            return main_output

print("✅ DenseNet Model defined!")

# ========== CELL 14: Initialize Model and Training Setup ==========
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🔧 Using device: {device}")

# Initialize model (you can change to '121', '169', or '201')
model = DenseNetDiabeticRetinopathy(num_classes=len(label_names), model_size='121', pretrained=True, dropout_rate=0.4)
model = model.to(device)

# Mixed precision training setup
scaler = torch.cuda.amp.GradScaler()

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, epochs=40, steps_per_epoch=len(train_loader))

print("✅ Model initialized and moved to device!")
print(f"📊 Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"📊 Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
print("🚀 Mixed precision training enabled!")

# ========== CELL 15: Training Function ==========
def train_epoch(model, train_loader, criterion, optimizer, scheduler, scaler, device, aux_weight=0.3):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for batch_idx, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)

        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            if model.training:
                main_output, aux_output = model(data)
                main_loss = criterion(main_output, targets)
                aux_loss = criterion(aux_output, targets)
                loss = main_loss + aux_weight * aux_loss
                outputs = main_output
            else:
                outputs = model(data)
                loss = criterion(outputs, targets)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        if batch_idx % 20 == 0:
            print(f'   Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}, LR: {optimizer.param_groups[0]["lr"]:.6f}')

    return total_loss / len(train_loader), 100. * correct / total

def validate_epoch(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_predictions = []
    all_targets = []
    all_probabilities = []

    with torch.no_grad():
        for data, targets in val_loader:
            data, targets = data.to(device), targets.to(device)

            with torch.cuda.amp.autocast():
                outputs = model(data)
                loss = criterion(outputs, targets)

            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
            all_probabilities.extend(F.softmax(outputs, dim=1).cpu().numpy())

    return (total_loss / len(val_loader), 100. * correct / total,
            all_predictions, all_targets, all_probabilities)

print("✅ Training functions defined!")

# ========== CELL 16: Training Loop ==========
num_epochs = 40
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
learning_rates = []

best_val_acc = 0

print("🚀 Starting training with mixed precision...")
print("=" * 60)

for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}:')

    # Training
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, scheduler, scaler, device)

    # Validation
    val_loss, val_acc, val_preds, val_targets, val_probs = validate_epoch(model, val_loader, criterion, device)

    current_lr = optimizer.param_groups[0]['lr']

    # Store metrics
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_acc)
    val_accuracies.append(val_acc)
    learning_rates.append(current_lr)

    print(f'   Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
    print(f'   Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
    print(f'   Learning Rate: {current_lr:.6f}')

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), '/content/best_densenet_model.pth')
        print(f'   ✅ New best model saved! Validation Accuracy: {val_acc:.2f}%')

    print('-' * 60)

print("🎉 Training completed!")

# ========== CELL 17: Plot Training History ==========
def plot_training_history(train_losses, val_losses, train_accuracies, val_accuracies, learning_rates):
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))

    epochs = range(1, len(train_losses) + 1)

    # Loss curves
    axes[0,0].plot(epochs, train_losses, 'b-', label='Train Loss', linewidth=2)
    axes[0,0].plot(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2)
    axes[0,0].set_title('Training and Validation Loss', fontweight='bold')
    axes[0,0].set_xlabel('Epoch')
    axes[0,0].set_ylabel('Loss')
    axes[0,0].legend()
    axes[0,0].grid(True, alpha=0.3)

    # Accuracy curves
    axes[0,1].plot(epochs, train_accuracies, 'b-', label='Train Accuracy', linewidth=2)
    axes[0,1].plot(epochs, val_accuracies, 'r-', label='Validation Accuracy', linewidth=2)
    axes[0,1].set_title('Training and Validation Accuracy', fontweight='bold')
    axes[0,1].set_xlabel('Epoch')
    axes[0,1].set_ylabel('Accuracy (%)')
    axes[0,1].legend()
    axes[0,1].grid(True, alpha=0.3)

    # Learning rate
    axes[1,0].plot(epochs, learning_rates, 'g-', linewidth=2)
    axes[1,0].set_title('OneCycleLR Schedule', fontweight='bold')
    axes[1,0].set_xlabel('Epoch')
    axes[1,0].set_ylabel('Learning Rate')
    axes[1,0].grid(True, alpha=0.3)

    # Overfitting indicator
    loss_diff = np.array(val_losses) - np.array(train_losses)
    axes[1,1].plot(epochs, loss_diff, 'purple', linewidth=2)
    axes[1,1].set_title('Overfitting Indicator (Val Loss - Train Loss)', fontweight='bold')
    axes[1,1].set_xlabel('Epoch')
    axes[1,1].set_ylabel('Loss Difference')
    axes[1,1].grid(True, alpha=0.3)
    axes[1,1].axhline(y=0, color='black', linestyle='--', alpha=0.5)

    plt.tight_layout()
    plt.show()

# Plot training history
plot_training_history(train_losses, val_losses, train_accuracies, val_accuracies, learning_rates)

# ========== CELL 18: Test Set Evaluation ==========
# Load best model
model.load_state_dict(torch.load('/content/best_densenet_model.pth'))
print("✅ Best model loaded for testing")

# Test evaluation
test_loss, test_acc, test_preds, test_targets, test_probs = validate_epoch(model, test_loader, criterion, device)

print("🎯 Test Set Results:")
print("=" * 40)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.2f}%")

# ========== CELL 19: Confusion Matrix ==========
def plot_confusion_matrix(y_true, y_pred, label_names):
    cm = confusion_matrix(y_true, y_pred)

    plt.figure(figsize=(12, 5))

    # Absolute confusion matrix
    plt.subplot(1, 2, 1)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
               xticklabels=label_names, yticklabels=label_names)
    plt.title('Confusion Matrix (Absolute)', fontweight='bold')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')

    # Normalized confusion matrix
    plt.subplot(1, 2, 2)
    cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='Blues',
               xticklabels=label_names, yticklabels=label_names)
    plt.title('Confusion Matrix (Normalized)', fontweight='bold')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')

    plt.tight_layout()
    plt.show()

# Plot confusion matrix
plot_confusion_matrix(test_targets, test_preds, label_names)

# ========== CELL 20: ROC Curves ==========
def plot_roc_curves(y_true, y_probs, label_names):
    from sklearn.preprocessing import label_binarize

    y_true_bin = label_binarize(y_true, classes=range(len(label_names)))

    plt.figure(figsize=(10, 8))
    colors = plt.cm.Set1(np.linspace(0, 1, len(label_names)))

    for i in range(len(label_names)):
        if len(np.unique(y_true_bin[:, i])) > 1:
            fpr, tpr, _ = roc_curve(y_true_bin[:, i], np.array(y_probs)[:, i])
            roc_auc = auc(fpr, tpr)
            plt.plot(fpr, tpr, color=colors[i], lw=2,
                    label=f'{label_names[i]} (AUC = {roc_auc:.3f})')

    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Multi-class ROC Curves - DenseNet Model', fontweight='bold')
    plt.legend(loc="lower right")
    plt.grid(True, alpha=0.3)
    plt.show()

# Plot ROC curves
plot_roc_curves(test_targets, test_probs, label_names)

# ========== CELL 21: Classification Report ==========
print("📊 Detailed Classification Report:")
print("=" * 60)
report = classification_report(test_targets, test_preds, target_names=label_names, digits=3)
print(report)

# Per-class metrics
report_dict = classification_report(test_targets, test_preds, target_names=label_names, output_dict=True)

print("\n📈 Per-Class Detailed Metrics:")
print("-" * 60)
for class_name in label_names:
    if class_name in report_dict:
        precision = report_dict[class_name]['precision']
        recall = report_dict[class_name]['recall']
        f1 = report_dict[class_name]['f1-score']
        support = report_dict[class_name]['support']
        print(f"{class_name:15} | Precision: {precision:.3f} | Recall: {recall:.3f} | "
              f"F1: {f1:.3f} | Support: {support}")

print(f"\n🎯 Overall Metrics:")
print(f"   Accuracy: {report_dict['accuracy']:.3f}")
print(f"   Macro Avg F1: {report_dict['macro avg']['f1-score']:.3f}")
print(f"   Weighted Avg F1: {report_dict['weighted avg']['f1-score']:.3f}")

# ========== CELL 22: Save Model ==========
# Save final model
torch.save(model.state_dict(), '/content/densenet_diabetic_retinopathy.pth')
print("💾 Model saved as: densenet_diabetic_retinopathy.pth")

# Save training history
training_history = {
    'train_losses': train_losses,
    'val_losses': val_losses,
    'train_accuracies': train_accuracies,
    'val_accuracies': val_accuracies,
    'learning_rates': learning_rates,
    'test_accuracy': test_acc,
    'test_loss': test_loss
}

import pickle
with open('/content/densenet_training_history.pkl', 'wb') as f:
    pickle.dump(training_history, f)

print("📈 Training history saved!")

# ========== CELL 23: Download Results ==========
# Download model and results
from google.colab import files

print("⬇️ Downloading model and results...")
files.download('/content/densenet_diabetic_retinopathy.pth')
files.download('/content/densenet_training_history.pkl')

print("✅ DenseNet Model Training Complete!")
print(f"🎯 Final Test Accuracy: {test_acc:.2f}%")
print("🎉 All files downloaded successfully!")