In [None]:
import torch
import torch.nn as nn
from torchvision import models

# Load pretrained DenseNet-121
densenet = models.densenet121(pretrained=True)

# Remove the classifier head
densenet_backbone = densenet.features

from transformers import SwinForImageClassification

# Load pretrained Swin Transformer
swin_model = SwinForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224")

# Modify the input layer to accept DenseNet-121 features
swin_model.classifier = nn.Linear(in_features=1024, out_features=5)  # Assuming 5 DR severity levels



In [None]:
class HybridModel(nn.Module):
    def __init__(self, densenet_backbone, swin_model):
        super(HybridModel, self).__init__()
        self.densenet = densenet_backbone
        self.pool = nn.AdaptiveAvgPool2d((1, 1))  # Pooling to match Swin input shape
        self.swin = swin_model

    def forward(self, x):
        x = self.densenet(x)  # Extract features using DenseNet-121
        x = self.pool(x).flatten(1)  # Pool and flatten for Swin Transformer
        x = self.swin.classifier(x)  # Final classification
        return x

# Initialize the hybrid model
hybrid_model = HybridModel(densenet_backbone, swin_model)


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(hybrid_model.parameters(), lr=1e-4)
