In [None]:
from google.colab import drive

In [None]:
drive.mount('/content/drive/')

Mounted at /content/drive/


In [None]:
import torch
import torch.nn as nn
from torch import einsum


class MobileViT(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, d_model, d_ff, num_layers, num_heads, dropout=0.1):
        super(MobileViT, self).__init__()

        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.patch_embedding = nn.Conv2d(3, d_model, patch_size, patch_size)

        self.position_embedding = nn.Parameter(torch.zeros(1, self.num_patches + 1, d_model))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.dropout = nn.Dropout(dropout)

        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, num_heads, d_ff, dropout), num_layers
        )

        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        x = x.flatten(2).transpose(1, 2)

        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)

        x += self.position_embedding
        x = self.dropout(x)

        x = self.transformer_encoder(x)

        x = x[:, 0, :]  # Extract the cls_token

        x = self.fc(x)

        return x

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder


def train(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    return running_loss / len(train_loader)


def evaluate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item()

            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    accuracy = 100.0 * correct / total

    return running_loss / len(val_loader), accuracy


def main():
    # Set hyperparameters
    image_size = 224
    patch_size = 16
    num_classes = 4
    d_model = 512
    d_ff = 2048
    num_layers = 12
    num_heads = 8
    dropout = 0.1
    batch_size = 16
    learning_rate = 0.001
    num_epochs = 10
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load data
    train_transforms = transforms.Compose(
        [
            transforms.RandomResizedCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )
    val_transforms = transforms.Compose(
        [
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    train_dataset = ImageFolder(root='path/to/train/folder', transform=train_transforms)
    # val_dataset = ImageFolder(root='path/to/val/folder', transform=val_transforms)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    # val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    # Initialize MobileViT model
    model = MobileViT(image_size, patch_size, num_classes, d_model, d_ff, num_layers, num_heads, dropout).to(device)

    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Training loop
    for epoch in range(num_epochs):
        train_loss = train(model, train_loader, criterion, optimizer, device)
        # val_loss, val_accuracy = evaluate(model, val_loader, criterion, device)

        print(f"Epoch [{epoch + 1}/{num_epochs}]\t Train Loss: {train_loss:.4f}%")


    # Testing loop
    test_dataset = ImageFolder(root='path/to/test/folder', transform=val_transforms)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    test_loss, test_accuracy = evaluate(model, test_loader, criterion, device)
    print(f"Test Loss: {test_loss:.4f}\t Test Accuracy: {test_accuracy:.2f}%")

if __name__ == "__main__":
    main()