In [3]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F

class LitSmallCNN(pl.LightningModule):
    def __init__(self,
                 in_channels=3,
                 # List of filters for each of the 5 convolution layers
                 conv_channels=[32, 64, 128, 256, 512],
                 # Either an int (to be used for all layers) or a list of 5 ints specifying kernel sizes for each conv layer
                 kernel_sizes=3,
                 # Number of neurons in the dense (hidden) layer
                 dense_neurons=128,
                 num_classes=10,
                 # Activation function to be used in convolution layers
                 conv_activation=nn.ReLU,
                 # Optional activation function for the dense layer; if None, no activation is applied after fc1
                 dense_activation=None,
                 # Assumed input image size (iNaturalist images should be resized accordingly, e.g., 224x224)
                 image_size=224):
        super().__init__()
        # Automatically save hyperparameters for logging and reproducibility
        self.save_hyperparameters()

        self.conv_layers = nn.ModuleList()
        prev_channels = in_channels

        # Process kernel_sizes: if a single int is provided, use it for all layers;
        # if a list is provided, verify its length matches the number of conv layers.
        if isinstance(kernel_sizes, int):
            kernel_sizes = [kernel_sizes] * len(conv_channels)
        elif isinstance(kernel_sizes, list):
            assert len(kernel_sizes) == len(conv_channels), \
                "Length of kernel_sizes must match the number of conv layers."
        else:
            raise ValueError("kernel_sizes must be an int or a list of ints")

        # Build 5 conv-activation-maxpool blocks
        for out_channels, k in zip(conv_channels, kernel_sizes):
            block = nn.Sequential(
                nn.Conv2d(prev_channels, out_channels, kernel_size=k, padding=k // 2),
                conv_activation(),
                nn.MaxPool2d(kernel_size=2)  # This halves the spatial dimensions
            )
            self.conv_layers.append(block)
            prev_channels = out_channels

        # Calculate the spatial size after 5 max-pooling operations.
        # With an input image_size and 5 layers of pooling (each reducing dimensions by 2),
        # the final spatial dimension will be: image_size // (2 ** 5)
        final_size = image_size // (2 ** len(conv_channels))
        self.flatten_dim = prev_channels * final_size * final_size

        # Dense layers: one hidden (dense) layer followed by the output layer.
        self.fc1 = nn.Linear(self.flatten_dim, dense_neurons)
        self.fc2 = nn.Linear(dense_neurons, num_classes)
        self.dense_activation = dense_activation

    def forward(self, x):
        # Pass through each convolution block
        for conv in self.conv_layers:
            x = conv(x)
        # Flatten the feature maps before the dense layers
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        if self.dense_activation:
            x = self.dense_activation(x)
        x = self.fc2(x)
        return x

    # Training step using cross-entropy loss
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log("train_loss", loss)
        return loss

    # Validation step using cross-entropy loss
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log("val_loss", loss)
        return loss

    # Optimizer configuration (here we use Adam)
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

if __name__ == "__main__":
    # Create an instance of the model with default parameters
    model = LitSmallCNN(
        in_channels=3,
        conv_channels=[32, 64, 128, 256, 512],
        kernel_sizes=3,
        dense_neurons=128,
        num_classes=10,
        conv_activation=nn.ReLU,
        dense_activation=None,  # No activation after the dense layer (optional)
        image_size=224
    )

    # Print the model architecture
    print(model)

    # Test the model with a dummy input
    dummy_input = torch.randn(1, 3, 224, 224)
    output = model(dummy_input)
    print("Output shape:", output.shape)

    # Use GPU if available; otherwise, use CPU with 1 device
    accelerator = "gpu" if torch.cuda.is_available() else "cpu"
    devices = 1  # Always use 1 device for this example

    trainer = pl.Trainer(
        max_epochs=1,
        accelerator=accelerator,
        devices=devices,
        logger=False  # Disable logging for simplicity
    )
    # To train the model, you would call trainer.fit(model, train_dataloader, val_dataloader)
    # For example:
    # trainer.fit(model, train_dataloader, val_dataloader)


You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


LitSmallCNN(
  (conv_layers): ModuleList(
    (0): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      