In [None]:
import torch
import torch.nn as nn
from torchvision import models

In [None]:
class MyModel(nn.Module):
    def __init__(self, num_classes=10, fc_hidden=1024):
        """
        Initialize MyModel with ResNet18 backbone and custom FC layer.
        
        Args:
            num_classes (int): Number of output classes. Default: 10
            fc_hidden (int): Hidden dimension for FC layer (512 or 1024). Default: 1024
        """
        super(MyModel, self).__init__()
        
        # Load pre-trained ResNet18
        self.resnet18 = models.resnet18(pretrained=True)
        
        # Get the number of input features from ResNet18's FC layer
        num_features = self.resnet18.fc.in_features
        
        # Replace the final FC layer with custom FC layers
        self.resnet18.fc = nn.Sequential(
            nn.Linear(num_features, fc_hidden),
            nn.ReLU(inplace=True),
            nn.Linear(fc_hidden, num_classes)
        )
    
    def forward(self, x):
        """Forward pass through the model."""
        return self.resnet18(x)


# Example usage:
model = MyModel(num_classes=10, fc_hidden=1024)
print(model)