In [None]:
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Union

from hyperopt import STATUS_OK, Trials, fmin, hp, tpe
from loguru import logger
from mads_datasets import DatasetFactoryProvider, DatasetType
from matplotlib import pyplot as plt
import mlflow
from mltrainer import ReportTypes, Trainer, TrainerSettings, metrics
from mltrainer.preprocessors import BasePreprocessor
from pydantic import BaseModel
import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
flowers_dataset_factory = DatasetFactoryProvider.create_factory(
    DatasetType.FLOWERS
)

In [None]:
class BasePreprocessor(BasePreprocessor):
    def __call__(self, batch: list[tuple]) -> tuple[torch.Tensor, torch.Tensor]:
        X, y = zip(*batch)
        return torch.stack(X), torch.stack(y)

In [None]:
data_streamer = flowers_dataset_factory.create_datastreamer(
    batchsize=64,
    preprocessor=BasePreprocessor(),
)
train_streamer = data_streamer["train"].stream()
valid_streamer = data_streamer["valid"].stream()

In [None]:
image, label = next(train_streamer)

index = 2
first_image = image[index]
first_image = first_image.permute(1, 2, 0)

plt.imshow(first_image)
plt.show()

In [None]:
mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("mlflow-flowers-hyperopt")

In [None]:
class ModelConfig(BaseModel):
    features: int  # Input channels (e.g., 1 for BW, 3 for RGB)
    num_classes: int
    kernel_size: int
    filter1: int
    filter2: int
    dropout: float = 0


class FlowersModel(nn.Module):
    def __init__(
        self,
        config: ModelConfig,
    ) -> None:
        super().__init__()

        # 1. Convolutional Block
        # Added BatchNorm: Stabilizes training and allows for higher learning rates
        self.convolutions = nn.Sequential(
            # Layer 1
            nn.Conv2d(
                in_channels=config.features,
                out_channels=config.filter1,
                kernel_size=config.kernel_size,
                stride=1,
                padding=1,
            ),
            nn.BatchNorm2d(config.filter1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            # Layer 2
            nn.Conv2d(
                in_channels=config.filter1,
                out_channels=config.filter2,
                kernel_size=config.kernel_size,
                stride=1,
                padding=1,
            ),
            nn.BatchNorm2d(config.filter2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            # Layer 3
            nn.Conv2d(
                in_channels=config.filter2,
                out_channels=32,
                kernel_size=config.kernel_size,
                stride=1,
                padding=1,
            ),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            # Global Average Pooling (GAP)
            # This makes the model size-agnostic and reduces parameter count
            nn.AdaptiveAvgPool2d((1, 1)),
        )

        # 2. Dense Block
        # Since we use AdaptiveAvgPool2d, the input to Linear is always 32
        self.dense = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Dropout(config.dropout),  # Prevents overfitting
            nn.Linear(64, config.num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.convolutions(x)
        return self.dense(x)

In [None]:
settings = TrainerSettings(
    epochs=20,
    metrics=[metrics.Accuracy()],
    logdir="modellogs",
    train_steps=180,
    valid_steps=180,
    reporttypes=[ReportTypes.MLFLOW, ReportTypes.TOML],
)

In [None]:
def objective(params: Dict[str, Union[int, float]]) -> Dict[str, Any]:
    with mlflow.start_run():
        mlflow.set_tag("model", "cnn")
        mlflow.set_tag("dev", "vanesterik")
        mlflow.log_params(params)

        model = FlowersModel(ModelConfig(**params))
        trainer = Trainer(
            model=model,
            settings=settings,
            loss_fn=nn.CrossEntropyLoss(),
            optimizer=optim.Adam,
            traindataloader=train_streamer,
            validdataloader=valid_streamer,
            scheduler=optim.lr_scheduler.ReduceLROnPlateau,
            device=torch.device("mps"),
        )
        trainer.loop()

        tag = datetime.now().strftime("%Y%m%d-%H%M")
        models_dir = Path("models").resolve()

        if not models_dir.exists():
            models_dir.mkdir()
            logger.info(f"Created {models_dir}")

        models_path = models_dir / (tag + "model.pt")
        torch.save(model, models_path)

        mlflow.log_artifact(
            local_path=models_path, artifact_path="pytorch_models"
        )

        return {"loss": trainer.test_loss, "status": STATUS_OK}


In [None]:
search_space = {
    "dropout": hp.uniform("dropout", 0.0, 0.5),
    "features": 3,
    "filter1": hp.choice("filter1", [32, 64, 128]),
    "filter2": hp.choice("filter2", [32, 64, 128]),
    "kernel_size": 3,
    "num_classes": 5,
}

In [None]:
results = fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=50,
    trials=Trials(),
)

In [None]:
logger.info(f"\n\n{results}")