In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, patch_size, embed_dim, img_size):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.img_size = img_size

        self.patcher = nn.Conv2d(
            in_channels=in_channels,
            out_channels=embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
        self.flatten = nn.Flatten(start_dim=2, end_dim=3)
        self.projection = nn.Linear(embed_dim, embed_dim)

        num_patches = (img_size // patch_size) ** 2
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.positional_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))

    def forward(self, x):
        batch_size = x.shape[0]

        x = self.patcher(x)
        x = self.flatten(x).permute(0, 2, 1)
        x = self.projection(x)

        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        x = x + self.positional_embedding
        return x

In [36]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.out_projection = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.size()

        Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        attention = (Q @ K.transpose(-2, -1)) / self.head_dim**0.5
        attention = torch.softmax(attention, dim=-1)
        output = (attention @ V).transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
        return self.out_projection(output)

In [37]:
class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attention = MultiHeadAttention(embed_dim, num_heads)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * mlp_ratio),
            nn.GELU(),
            nn.Linear(embed_dim * mlp_ratio, embed_dim)
        )

    def forward(self, x):
        x = x + self.attention(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

In [38]:
class CNNFeatureExtractor(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = self.model(x)
        x = self.global_pool(x)
        return x.view(x.size(0), -1)

In [39]:
class HybridCNNViT(nn.Module):
    def __init__(self, cnn_channels, vit_embed_dim, num_classes, img_size, patch_size, num_heads, num_layers):
        super().__init__()
        self.cnn = CNNFeatureExtractor(cnn_channels)
        self.vit_embedding = PatchEmbedding(
            in_channels=3, patch_size=patch_size, embed_dim=vit_embed_dim, img_size=img_size
        )
        self.transformer = nn.Sequential(
            *[TransformerEncoder(vit_embed_dim, num_heads) for _ in range(num_layers)]
        )
        self.fc = nn.Sequential(
            nn.Linear(256 + vit_embed_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        cnn_features = self.cnn(x)
        vit_embeddings = self.vit_embedding(x)
        vit_output = self.transformer(vit_embeddings)
        vit_features = vit_output[:, 0]
        combined_features = torch.cat([cnn_features, vit_features], dim=1)
        return self.fc(combined_features)

In [40]:

# Instantiate the Hybrid Model
img_size = 224
patch_size = 16
vit_embed_dim = 512
cnn_channels = 3
num_classes = 2
num_heads = 8
num_layers = 4

model = HybridCNNViT(
    cnn_channels=cnn_channels, 
    vit_embed_dim=vit_embed_dim, 
    num_classes=num_classes, 
    img_size=img_size, 
    patch_size=patch_size, 
    num_heads=num_heads, 
    num_layers=num_layers
)

# Test the Model
x = torch.randn(8, 3, 224, 224)
output = model(x)
print(f"Output shape: {output.shape}")

from torchinfo import summary

summary(model=model, 
        input_size=(8, 3, 224, 224), 
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

Output shape: torch.Size([8, 2])


Layer (type (var_name))                            Input Shape          Output Shape         Param #              Trainable
HybridCNNViT (HybridCNNViT)                        [8, 3, 224, 224]     [8, 2]               --                   True
├─CNNFeatureExtractor (cnn)                        [8, 3, 224, 224]     [8, 256]             --                   True
│    └─Sequential (model)                          [8, 3, 224, 224]     [8, 256, 14, 14]     --                   True
│    │    └─Conv2d (0)                             [8, 3, 224, 224]     [8, 32, 224, 224]    896                  True
│    │    └─ReLU (1)                               [8, 32, 224, 224]    [8, 32, 224, 224]    --                   --
│    │    └─MaxPool2d (2)                          [8, 32, 224, 224]    [8, 32, 112, 112]    --                   --
│    │    └─Conv2d (3)                             [8, 32, 112, 112]    [8, 64, 112, 112]    18,496               True
│    │    └─ReLU (4)                           

In [41]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.optim import Adam
# from torch.optim.lr_scheduler import CosineAnnealingLR

# class FocalLoss(nn.Module):
#     def __init__(self, gamma=2., alpha=0.25):
#         super(FocalLoss, self).__init__()
#         self.gamma = gamma
#         self.alpha = alpha

#     def forward(self, inputs, targets):
#         BCE_loss = F.cross_entropy(inputs, targets, reduction='none')
#         pt = torch.exp(-BCE_loss)
#         F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
#         return torch.mean(F_loss)

In [42]:
# class PatchEmbedding(nn.Module):
#     def __init__(self, in_channels, patch_size, embed_dim, img_size):
#         super().__init__()
#         self.in_channels = in_channels
#         self.patch_size = patch_size
#         self.embed_dim = embed_dim
#         self.img_size = img_size
        

#         self.patcher = nn.Conv2d(
#             in_channels=in_channels,
#             out_channels=embed_dim,
#             kernel_size=patch_size,
#             stride=patch_size
#         )
#         self.flatten = nn.Flatten(start_dim=2, end_dim=3)
#         self.projection = nn.Linear(embed_dim, embed_dim)
        
#         num_patches = (img_size // patch_size) ** 2
#         self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
#         self.positional_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))

#     def forward(self, x):
#         batch_size = x.shape[0]
        
#         x = self.patcher(x)
#         x = self.flatten(x).permute(0, 2, 1)
#         x = self.projection(x)
        
#         cls_tokens = self.cls_token.expand(batch_size, -1, -1)
#         x = torch.cat([cls_tokens, x], dim=1)
#         x = x + self.positional_embedding
#         return x


In [43]:

# class MultiHeadAttention(nn.Module):
#     def __init__(self, embed_dim, num_heads):
#         super().__init__()
#         self.num_heads = num_heads
#         self.head_dim = embed_dim // num_heads
        
#         self.query = nn.Linear(embed_dim, embed_dim)
#         self.key = nn.Linear(embed_dim, embed_dim)
#         self.value = nn.Linear(embed_dim, embed_dim)
#         self.out_projection = nn.Linear(embed_dim, embed_dim)

#     def forward(self, x):
#         batch_size, seq_len, embed_dim = x.size()
        
#         Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
#         K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
#         V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
#         attention = (Q @ K.transpose(-2, -1)) / self.head_dim**0.5
#         attention = torch.softmax(attention, dim=-1)
#         output = (attention @ V).transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
#         return self.out_projection(output)


In [44]:
# class TransformerEncoder(nn.Module):
#     def __init__(self, embed_dim, num_heads, mlp_ratio=4):
#         super().__init__()
#         self.norm1 = nn.LayerNorm(embed_dim)
#         self.attention = MultiHeadAttention(embed_dim, num_heads)
#         self.norm2 = nn.LayerNorm(embed_dim)
#         self.mlp = nn.Sequential(
#             nn.Linear(embed_dim, embed_dim * mlp_ratio),
#             nn.GELU(),
#             nn.Linear(embed_dim * mlp_ratio, embed_dim)
#         )

#     def forward(self, x):
#         x = x + self.attention(self.norm1(x))
#         x = x + self.mlp(self.norm2(x))
#         return x

In [45]:
# class CNNFeatureExtractor(nn.Module):
#     def __init__(self, in_channels):
#         super().__init__()
#         self.model = nn.Sequential(
#             nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(2, 2),
#             nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(2, 2),
#             nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(2, 2),
#             nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(2, 2)
#         )
#         self.global_pool = nn.AdaptiveAvgPool2d((1, 1))

#     def forward(self, x):
#         x = self.model(x)
#         x = self.global_pool(x)
#         return x.view(x.size(0), -1)

In [46]:
# class HybridCNNViT(nn.Module):
#     def __init__(self, cnn_channels, vit_embed_dim, num_classes, img_size, patch_size, num_heads, num_layers):
#         super().__init__()
#         self.cnn = CNNFeatureExtractor(cnn_channels)
#         self.vit_embedding = PatchEmbedding(
#             in_channels=3, patch_size=patch_size, embed_dim=vit_embed_dim, img_size=img_size
#         )
#         self.transformer = nn.Sequential(
#             *[TransformerEncoder(vit_embed_dim, num_heads) for _ in range(num_layers)]
#         )
#         self.fc = nn.Sequential(
#             nn.Linear(256 + vit_embed_dim, 512),
#             nn.ReLU(),
#             nn.Dropout(0.5),
#             nn.Linear(512, num_classes)
#         )

#     def forward(self, x):
#         cnn_features = self.cnn(x)
#         vit_embeddings = self.vit_embedding(x)
#         vit_output = self.transformer(vit_embeddings)
#         vit_features = vit_output[:, 0]
#         combined_features = torch.cat([cnn_features, vit_features], dim=1)
#         return self.fc(combined_features)

In [47]:

# img_size = 224
# patch_size = 16
# vit_embed_dim = 512
# cnn_channels = 3
# num_classes = 2
# num_heads = 8
# num_layers = 4

# hybrid_model = HybridCNNViT(
#     cnn_channels=cnn_channels, 
#     vit_embed_dim=vit_embed_dim, 
#     num_classes=num_classes, 
#     img_size=img_size, 
#     patch_size=patch_size, 
#     num_heads=num_heads, 
#     num_layers=num_layers
# )

# # Define optimizer and scheduler
# optimizer = Adam(hybrid_model.parameters(), lr=1e-4)
# scheduler = CosineAnnealingLR(optimizer, T_max=10)

# # For gradient clipping during training
# # In your training loop, you would call `clip_grad_norm_`
# # Example:
# # torch.nn.utils.clip_grad_norm_(hybrid_model.parameters(), max_norm=1.0)

# # Test the Model
# x = torch.randn(8, 3, 224, 224)
# output = hybrid_model(x)
# print(f"Output shape: {output.shape}")

# from torchinfo import summary

# summary(model=hybrid_model, 
#         input_size=(8, 3, 224, 224), 
#         col_names=["input_size", "output_size", "num_params", "trainable"],
#         col_width=20,
#         row_settings=["var_names"])

Output shape: torch.Size([8, 2])


Layer (type (var_name))                            Input Shape          Output Shape         Param #              Trainable
HybridCNNViT (HybridCNNViT)                        [8, 3, 224, 224]     [8, 2]               --                   True
├─CNNFeatureExtractor (cnn)                        [8, 3, 224, 224]     [8, 256]             --                   True
│    └─Sequential (model)                          [8, 3, 224, 224]     [8, 256, 14, 14]     --                   True
│    │    └─Conv2d (0)                             [8, 3, 224, 224]     [8, 32, 224, 224]    896                  True
│    │    └─ReLU (1)                               [8, 32, 224, 224]    [8, 32, 224, 224]    --                   --
│    │    └─MaxPool2d (2)                          [8, 32, 224, 224]    [8, 32, 112, 112]    --                   --
│    │    └─Conv2d (3)                             [8, 32, 112, 112]    [8, 64, 112, 112]    18,496               True
│    │    └─ReLU (4)                           