In [1]:
import torch
import torch.nn as nn
import torchvision.models as models

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

class CNNTransformerMelanomaClassifier(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        # Use ResNet50 as the CNN backbone
        self.cnn = models.resnet50(pretrained=True)
        # Remove the final fully connected layer
        self.cnn = nn.Sequential(*list(self.cnn.children())[:-2])

        # Get the number of features from the CNN
        with torch.no_grad():
            dummy = torch.zeros(1, 3, 224, 224)
            out = self.cnn(dummy)
            self.num_features = out.shape[1]

        # Self-attention layer
        self.self_attention = MultiHeadSelfAttention(embed_dim=self.num_features, num_heads=8)

        # Final classification layers
        self.norm = nn.LayerNorm(self.num_features)
        self.fc = nn.Linear(self.num_features, num_classes)

    def forward(self, x):
        # CNN feature extraction
        x = self.cnn(x)

        # Reshape for self-attention
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # (B, H*W, C)

        # Apply self-attention
        x = self.self_attention(x)

        # Global average pooling
        x = x.mean(dim=1)

        # Normalization and classification
        x = self.norm(x)
        x = self.fc(x)

        return x

# Usage
model = CNNTransformerMelanomaClassifier()
dummy_input = torch.randn(1, 3, 224, 224)
output = model(dummy_input)


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /Users/robpickerill/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100.0%


Output shape: torch.Size([1, 2])
