## Setup

In [None]:
from anomalib.models import Fastflow
from anomalib.engine import Engine
from anomalib.data import Folder
from pytorch_lightning import seed_everything
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from torchvision.transforms import v2
import os
import logging
from dotenv import load_dotenv
import s3data
import mlflow
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import precision_recall_curve, roc_curve, auc

load_dotenv()

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

# Set seed for reproducibility
seed_everything(42)

## Setup MLFLow Logging

In [None]:
from anomalib.loggers import AnomalibMLFlowLogger

# MLflow tracking URI
mlflow_tracking_uri = "http://mlflow.local:30080"
experiment_name = "fastflow_experiment"

# Initialize the MLFlow logger
mlflow_logger = AnomalibMLFlowLogger(
    experiment_name=experiment_name,
    tracking_uri=mlflow_tracking_uri, 
    log_model=True,  # Log the model to MLFlow
)

# Set MLflow tracking URI for direct API access
mlflow.set_tracking_uri(mlflow_tracking_uri)
mlflow.set_experiment(experiment_name)

In [None]:
# Custom callback to log metrics to MLflow during training
callbacks = [
    ModelCheckpoint(
        mode="min",
        monitor="val_loss",
    ),
    EarlyStopping(
        monitor="val_loss",
        mode="min",
        patience=3,
    ),
]

## Prepare Data

In [None]:
category = 'Toonie Anomaly'
bucket_name = os.getenv('BUCKET_NAME')
if not bucket_name:
    raise ValueError("BUCKET_NAME environment variable not set")

print(f"Using bucket: {bucket_name}")
dataset_structure = s3data.get_dataset_structure(bucket_name, main_category=category)

s3data.cache_dataset(bucket_name, category, dataset_structure, 'cached_dataset')

datamodule = Folder(
    name=category,
    root=os.path.join('cached_dataset', category),
    normal_dir='train/good',
    # normal_test_dir='test/good',
    abnormal_dir='test/anomaly',
)

print("Preparing data...")
datamodule.prepare_data()

print(f"Setting up '{datamodule.category}' datasets...")
datamodule.setup()

print(f"Training samples: {len(datamodule.train_data)}")
print(f"Test samples: {len(datamodule.test_data)}")

## Training the Model

In [None]:
model = Fastflow(
    backbone="resnet18",  # or resnet50
    pre_trained=True,
)

engine = Engine(
    callbacks=callbacks,
    logger=mlflow_logger, 
    accelerator="auto", 
    devices=1, 
    max_epochs=5,
)

engine.fit(
    model=model,
    datamodule=datamodule,
)
