In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim=1280, num_heads=4, ff_dim=256, dropout=0.1):
        super(TransformerEncoderBlock, self).__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, embed_dim)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Self-Attention
        attn_output, _ = self.attn(x, x, x)
        x = self.norm1(x + self.dropout(attn_output))

        # Feedforward
        ff_output = self.ffn(x)
        x = self.norm2(x + self.dropout(ff_output))

        return x

class EffTransNet(nn.Module):
    def __init__(self, num_classes=5, transformer_depth=2):
        super(EffTransNet, self).__init__()

        # Load pretrained EfficientNetB0
        backbone = models.efficientnet_b0(pretrained=True)
        self.backbone = nn.Sequential(*list(backbone.children())[:-2])  # Exclude classifier

        self.embed_dim = 1280
        self.transformer_layers = nn.Sequential(
            *[TransformerEncoderBlock(embed_dim=self.embed_dim) for _ in range(transformer_depth)]
        )

        self.pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Linear(self.embed_dim, num_classes)

    def forward(self, x):
        batch_size = x.size(0)
        x = self.backbone(x)              # [B, 1280, 7, 7]
        x = x.view(batch_size, self.embed_dim, -1).permute(0, 2, 1)  # [B, 49, 1280]
        x = self.transformer_layers(x)    # [B, 49, 1280]
        x = self.pool(x.permute(0, 2, 1)).squeeze(-1)  # [B, 1280]
        x = self.classifier(x)            # [B, num_classes]
        return x

# Instantiate model
model = EffTransNet(num_classes=5).to(device)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Optional: Model Summary
print(model)


Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth
100%|██████████| 20.5M/20.5M [00:00<00:00, 145MB/s]


EffTransNet(
  (backbone): Sequential(
    (0): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU(inplace=True)
      )
      (1): Sequential(
        (0): MBConv(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
              (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
            (1): SqueezeExcitation(
              (avgpool): AdaptiveAvgPool2d(output_size=1)
              (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
              (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
              (activation): SiLU(inplace=True)
              (scale_activatio

# Use your own dataset and train it

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_dataset = datasets.ImageFolder("your_dataset_path", transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Train and get your result

In [None]:
for epoch in range(10):
    model.train()
    total_loss, correct = 0, 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (outputs.argmax(1) == labels).sum().item()

    acc = correct / len(train_loader.dataset)
    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}, Accuracy: {acc*100:.2f}%")
