This notebook contains model class for one model. Namely:
1. Densenet201

## Densenet201

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from densenet_features import densenet121_features

class ECGClassifier(nn.Module):
    def __init__(self, num_classes=11):
        super(ECGClassifier, self).__init__()
        # Load the pretrained densenet121 model
        self.densenet121 = densenet121_features(pretrained=True)
        
        # Define the additional layers for classification
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(1024, 128)  # First fully connected layer
        self.dropout1 = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(128, 64)  # Second fully connected layer
        self.dropout2 = nn.Dropout(p=0.5)
        self.fc3 = nn.Linear(64, num_classes)  # Final layer for classification
        
    def forward(self, x):
        # Forward pass through the densenet121 feature extractor
        x = self.densenet121(x)
        # Adaptive average pooling to maintain the same output size
        x = self.avgpool(x)
        # Flatten the tensor for the fully connected layers
        x = torch.flatten(x, 1)
        # Pass through the first fully connected layer
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        # Pass through the second fully connected layer
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        # Pass through the final classification layer
        x = self.fc3(x)
        
        return x

# Example usage
model = ECGClassifier(num_classes=11)
inp = torch.randn(32, 3, 224, 224)  # Example input tensor with batch size 32
output = model(inp)
print(output.shape)  # Should print torch.Size([32, 11])

Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to ./pretrained_models/densenet121-a639ec97.pth
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30.8M/30.8M [00:03<00:00, 9.19MB/s]


torch.Size([32, 11])
