(lightning_experiment_tracking)=

# Using W&B, CometML, and MLFlow Logger in LightningTrainer

W&B, CometML, and MLFlow are all popular tools in the field of machine learning for managing, visualizing, and tracking experiments. 

PyTorch Lightning provides built-in support for logging metrics to W&B and CometML during training, allowing users to easily track their model performance over time. MLFlow can also be used to track experiments and store artifacts such as trained models and hyperparameters. With AIR LightningTrainer, you can still leverage the built-in supports with the PyTorch Lightning's `logger`.


## Define your model and dataloader

No need for any code change here. We simply create a dummy model with dummy datasets for demo.

In [27]:
import os
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import TensorDataset, DataLoader

# create dummy data
X = torch.randn(128, 3)  # 128 samples, 3 features
y = torch.randint(0, 2, (128,))  # 128 binary labels

# create a TensorDataset to wrap the data
dataset = TensorDataset(X, y)

# create a DataLoader to iterate over the dataset
batch_size = 8
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [28]:
# Define a dummy model
class DummyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(3, 1)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy_with_logits(y_hat.flatten(), y.float())

        # The metrics below will be reported to Loggers
        self.log("train_loss", loss)
        self.log_dict({"metric_1": 1 / (batch_idx + 1), "metric_2": batch_idx * 100})
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

## Define your loggers

You don't need to make any changes to your original lightning code for offline logging. If you would like to upload your log online to W&B or CometML, make sure to set up the API key environment variables on the head node.

In [29]:
from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.loggers.comet import CometLogger
from pytorch_lightning.loggers.mlflow import MLFlowLogger
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.utilities.rank_zero import rank_zero_only
import wandb


# A callback to login wandb in each worker
class WandbLoginCallback(pl.Callback):
    def __init__(self, key):
        self.key = key

    def setup(self, trainer, pl_module, stage) -> None:
        wandb.login(key=self.key)


def create_loggers(name, project_name, save_dir="./logs", offline=False):
    # Avoid creating a new experiment run on the driver node.
    rank_zero_only.rank = None

    # Wandb
    wandb_api_key = os.environ.get("WANDB_API_KEY", None)
    wandb_logger = WandbLogger(
        name=name, project=project_name, save_dir=f"{save_dir}/wandb", offline=offline
    )
    callbacks = [] if offline else [WandbLoginCallback(key=wandb_api_key)]

    # CometML
    comet_api_key = os.environ.get("COMET_API_KEY", None)
    comet_logger = CometLogger(
        api_key=comet_api_key,
        experiment_name=name,
        project_name=project_name,
        save_dir=f"{save_dir}/comet",
        offline=offline,
    )

    # MLFlow
    mlflow_logger = MLFlowLogger(
        run_name=name,
        experiment_name=project_name,
        tracking_uri=f"file:{save_dir}/mlflow",
    )

    # Tensorboard
    tensorboard_logger = TensorBoardLogger(
        name=name, save_dir=f"{save_dir}/tensorboard"
    )

    return [wandb_logger, comet_logger, mlflow_logger, tensorboard_logger], callbacks

CometLogger will be initialized in online mode


In [None]:
YOUR_SAVE_DIR = "./logs"
loggers, callbacks = create_loggers(
    name="demo-run", project_name="demo-project", save_dir=YOUR_SAVE_DIR, offline=False
)

In [30]:
# FOR SMOKE TESTS
loggers, callbacks = create_loggers(
    name="demo-run", project_name="demo-project", offline=True
)

## Train the model and check out your logs:

In [31]:
from ray.train.lightning import LightningConfigBuilder, LightningTrainer

builder = LightningConfigBuilder()
builder.module(cls=DummyModel)
builder.trainer(
    max_epochs=5,
    accelerator="cpu",
    logger=loggers,
    callbacks=callbacks,
    log_every_n_steps=1,
)
builder.fit_params(train_dataloaders=dataloader)
lightning_config = builder.build()

In [32]:
from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig

scaling_config = ScalingConfig(num_workers=4, use_gpu=False)

run_config = RunConfig(
    name="ptl-exp-tracking",
    storage_path="/tmp/ray_results",
)

trainer = LightningTrainer(
    lightning_config=lightning_config,
    scaling_config=scaling_config,
    run_config=run_config,
)

trainer.fit()

0,1
Current time:,2023-04-28 12:46:53
Running for:,00:01:20.41
Memory:,10.1/186.6 GiB

Trial name,status,loc,iter,total time (s),train_loss,metric_1,metric_2
LightningTrainer_3d398_00000,TERMINATED,10.0.40.230:212988,5,42.5942,0.773232,0.25,300


2023-04-28 12:45:33,508	INFO data_parallel_trainer.py:357 -- GPUs are detected in your Ray cluster, but GPU training is not enabled for this trainer. To enable GPU training, make sure to set `use_gpu` to True in your scaling config.
(TrainTrainable pid=212988) 2023-04-28 12:45:49,042	INFO data_parallel_trainer.py:357 -- GPUs are detected in your Ray cluster, but GPU training is not enabled for this trainer. To enable GPU training, make sure to set `use_gpu` to True in your scaling config.
(LightningTrainer pid=212988) 2023-04-28 12:45:49,047	INFO data_parallel_trainer.py:357 -- GPUs are detected in your Ray cluster, but GPU training is not enabled for this trainer. To enable GPU training, make sure to set `use_gpu` to True in your scaling config.
(LightningTrainer pid=212988) 2023-04-28 12:45:54,854	INFO backend_executor.py:128 -- Starting distributed worker processes: ['213736 (10.0.40.230)', '213737 (10.0.40.230)', '213738 (10.0.40.230)', '213739 (10.0.40.230)']
(RayTrainWorker pid=2

Trial name,_report_on,date,done,epoch,experiment_tag,hostname,iterations_since_restore,metric_1,metric_2,node_ip,pid,should_checkpoint,step,time_since_restore,time_this_iter_s,time_total_s,timestamp,train_loss,training_iteration,trial_id
LightningTrainer_3d398_00000,train_epoch_end,2023-04-28_12-46-31,True,4,0,ip-10-0-40-230,5,0.25,300,10.0.40.230,212988,True,20,42.5942,0.188582,42.5942,1682711191,0.773232,5,3d398_00000


(RayTrainWorker pid=213736) COMET INFO: ---------------------------
(RayTrainWorker pid=213736) COMET INFO: Comet.ml Experiment Summary
(RayTrainWorker pid=213736) COMET INFO: ---------------------------
(RayTrainWorker pid=213736) COMET INFO:   Data:
(RayTrainWorker pid=213736) COMET INFO:     display_summary_level : 1
(RayTrainWorker pid=213736) COMET INFO:     url                   : https://www.comet.com/woshiyyya/demo-project/ad042d654a5348ab896cb41041354344
(RayTrainWorker pid=213736) COMET INFO:   Metrics [count] (min, max):
(RayTrainWorker pid=213736) COMET INFO:     metric_1 [20]   : (0.25, 1.0)
(RayTrainWorker pid=213736) COMET INFO:     metric_2 [20]   : (0.0, 300.0)
(RayTrainWorker pid=213736) COMET INFO:     train_loss [20] : (0.6661834716796875, 0.9504892826080322)
(RayTrainWorker pid=213736) COMET INFO:   Others:
(RayTrainWorker pid=213736) COMET INFO:     Name : demo-run
(RayTrainWorker pid=213736) COMET INFO:   Uploads:
(RayTrainWorker pid=213736) COMET INFO:     conda

Result(
  metrics={'_report_on': 'train_epoch_end', 'train_loss': 0.7732318043708801, 'metric_1': 0.25, 'metric_2': 300.0, 'epoch': 4, 'step': 20, 'should_checkpoint': True, 'done': True, 'trial_id': '3d398_00000', 'experiment_tag': '0'},
  path='/tmp/ray_results/ptl-exp-tracking/LightningTrainer_3d398_00000_0_2023-04-28_12-45-33',
  checkpoint=LightningCheckpoint(local_path=/tmp/ray_results/ptl-exp-tracking/LightningTrainer_3d398_00000_0_2023-04-28_12-45-33/checkpoint_000004)
)

Now let's checkout our experiment results!

| **Wandb** (Online) | **CometML**  (Online) | 
| - | - | 
| ![alt](https://user-images.githubusercontent.com/26745457/235216924-ed27f820-3f2e-4812-bc62-982c3a1748c7.png) | ![alt](https://user-images.githubusercontent.com/26745457/235216949-72d80d7d-4460-480a-b20d-f154594507fc.png) | 
| **Tensorboard (Offline)** | **MLFlow (Offline)** |
| ![](https://user-images.githubusercontent.com/26745457/235227957-7c2ee93b-91ab-494c-a241-7b106cf9a5e6.png) | ![](https://user-images.githubusercontent.com/26745457/235241099-6850bcae-8843-4bbb-8268-c04b04a09e68.png) |