<a href="https://colab.research.google.com/github/soumik12345/wandb-addons/blob/docs/docs/ciclo/examples/Ciclo_Wandb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/soumik12345/wandb-addons
!pip install .[jax]

In [None]:
from pathlib import Path
from time import time
from typing import Optional, Callable
from collections.abc import MutableMapping

import flax.linen as nn
import jax.numpy as jnp
import jax_metrics as jm
import matplotlib.pyplot as plt
import optax
import tensorflow as tf
import tensorflow_datasets as tfds

import ciclo
from ciclo.logging import Logs
from ciclo.types import Batch, S
from ciclo.timetracking import Elapsed
from ciclo.loops.loop import LoopCallbackBase
from ciclo.callbacks import LoopState, CallbackOutput

import wandb
from wandb_addons.ciclo import WandbLogger

In [None]:
wandb.init(project="ciclo-integration", entity="geekyrakshit", job_type="test")

In [None]:
batch_size = 32
total_samples = 32 * 100
total_steps = total_samples // batch_size
steps_per_epoch = total_steps // 10
test_steps = 10

In [None]:
# load the MNIST dataset
ds_train: tf.data.Dataset = tfds.load("mnist", split="train", shuffle_files=True)
ds_train = ds_train.map(lambda x: (x["image"], x["label"]))
ds_train = ds_train.repeat().shuffle(1024).batch(batch_size).prefetch(1)
ds_test: tf.data.Dataset = tfds.load("mnist", split="test")
ds_test = ds_test.map(lambda x: (x["image"], x["label"]))  # .take(10)
ds_test = ds_test.batch(32, drop_remainder=True).prefetch(1)

In [None]:
# Define model
class Linear(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = x / 255.0
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=10)(x)
        return x

In [None]:
# Initialize state
model = Linear()
state = ciclo.create_flax_state(
    model,
    inputs=jnp.empty((1, 28, 28, 1)),
    tx=optax.adamw(1e-3),
    losses={"loss": jm.losses.Crossentropy()},
    metrics={"accuracy": jm.metrics.Accuracy()},
    strategy="jit",
)

In [None]:
state, history, _ = ciclo.train_loop(
    state,
    ds_train.as_numpy_iterator(),
    callbacks=[
        ciclo.keras_bar(total=total_steps),
        ciclo.checkpoint(
            f"logdir/checkpoint/{int(time())}",
            monitor="accuracy_test",
            mode="max",
        ),
        WandbLogger(),
    ],
    test_dataset=lambda: ds_test.as_numpy_iterator(),
    epoch_duration=steps_per_epoch,
    test_duration=test_steps,
    stop=total_steps,
)

In [None]:
wandb.finish()