In [None]:
# ---------------- Imports ----------------
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import time
import os
from google.colab import files

# ---------------- Model Definition ----------------
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, dropout_rate=0.05):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.dropout = nn.Dropout(dropout_rate)
        self.shortcut = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
            nn.BatchNorm2d(out_channels),
        ) if stride != 1 or in_channels != out_channels else nn.Identity()

    def forward(self, x):
        identity = self.shortcut(x)
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        out = self.dropout(out)
        out += identity
        return self.relu(out)

class HFCLayer(nn.Module):
    def __init__(self, num_classes, D_b):
        super().__init__()
        self.V = nn.Parameter(torch.randn(num_classes, D_b))
        self.bn = nn.BatchNorm1d(num_classes * D_b)

    def forward(self, x):
        U_b = x.sum(dim=1)
        T_b = U_b.unsqueeze(1) * self.V.unsqueeze(0)
        batch_size = T_b.size(0)
        T_b_flat = T_b.view(batch_size, -1)
        T_b_bn = self.bn(T_b_flat)
        T_b_bn = T_b_bn.view(batch_size, self.V.size(0), -1)
        return torch.relu(T_b_bn).sum(dim=2)

class MergingLayer(nn.Module):
    def __init__(self, num_branches=3):
        super().__init__()
        self.w = nn.Parameter(torch.ones(num_branches) / num_branches)

    def forward(self, inputs):
        weights = torch.softmax(self.w, dim=0)
        return sum(w * logit for w, logit in zip(weights, inputs))

class BMCNNBase(nn.Module):
    def __init__(self, dropout_rate=0.05):
        super().__init__()
        self.conv_block1 = nn.Sequential(
            ResidualBlock(1, 128, 1, dropout_rate),
            ResidualBlock(128, 128, 1, dropout_rate),
            ResidualBlock(128, 128, 1, dropout_rate),
        )
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv_block2 = nn.Sequential(
            ResidualBlock(128, 256, 1, dropout_rate),
            ResidualBlock(256, 256, 1, dropout_rate),
            ResidualBlock(256, 256, 1, dropout_rate),
        )
        self.pool2 = nn.MaxPool2d(2, 2)
        self.conv_block3 = nn.Sequential(
            ResidualBlock(256, 512, 1, dropout_rate),
            ResidualBlock(512, 512, 1, dropout_rate),
            ResidualBlock(512, 512, 1, dropout_rate),
        )

    def forward(self, x):
        x1 = self.conv_block1(x)
        x2 = self.conv_block2(self.pool1(x1))
        x3 = self.conv_block3(self.pool2(x2))
        return x1, x2, x3

class EnhancedBMCNNwHFCs(BMCNNBase):
    def __init__(self, num_classes=46, dropout_rate=0.05):
        super().__init__(dropout_rate)
        self.hfc1 = HFCLayer(num_classes, 32*32)
        self.hfc2 = HFCLayer(num_classes, 16*16)
        self.hfc3 = HFCLayer(num_classes, 8*8)
        self.merging = MergingLayer(3)

    def forward(self, x):
        x1, x2, x3 = super().forward(x)
        logit1 = self.hfc1(x1.view(x1.size(0), x1.size(1), -1))
        logit2 = self.hfc2(x2.view(x2.size(0), x2.size(1), -1))
        logit3 = self.hfc3(x3.view(x3.size(0), x3.size(1), -1))
        return self.merging((logit1, logit2, logit3))

# ---------------- Main ----------------
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Upload model and image if not exist
    if not os.path.exists("/content/best_model.pth"):
        print("Upload best_model.pth")
        files.upload()  # select best_model.pth

    if not os.path.exists("/content/1339.png"):
        print("Upload 1339.png")
        files.upload()  # select PNG image

    # Load model
    checkpoint = torch.load("/content/best_model.pth", map_location=device)
    model = EnhancedBMCNNwHFCs(num_classes=46, dropout_rate=0.0).to(device)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    print("Model loaded successfully!")

    # Preprocess image
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    img = Image.open("/content/1339.png").convert("L")
    img_tensor = transform(img).unsqueeze(0).to(device)

    # Inference & timing
    with torch.no_grad():
        start_time = time.time()
        output = model(img_tensor)
        elapsed_ms = (time.time() - start_time) * 1000

    pred_class = output.argmax(dim=1).item()
    print(f"Predicted class: {pred_class}")
    print(f"Inference time: {elapsed_ms:.2f} ms")

if __name__ == "__main__":
    main()


OSError: [Errno 30] Read-only file system: '/content'

In [None]:
Model loaded successfully!
Predicted class: 10
Inference time: 3.14 ms