# Using W&B, CometML, and MLFlow Logger in LightningTrainer

In [1]:
import pytorch_lightning as pl
from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.loggers.comet import CometLogger
from pytorch_lightning.loggers.mlflow import MLFlowLogger
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger

import os
import torch
import torch.nn.functional as F




In [2]:
import torch
from torch.utils.data import TensorDataset, DataLoader

# create dummy data
X = torch.randn(128, 3)  # 128 samples, 3 features
y = torch.randint(0, 2, (128,))  # 128 binary labels

# create a TensorDataset to wrap the data
dataset = TensorDataset(X, y)

# create a DataLoader to iterate over the dataset
batch_size = 8
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [3]:
class DummyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(3, 1)
    
    def forward(self, x):
        return self.layer(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy_with_logits(y_hat.flatten(), y.float())
        self.log('train_loss', loss)
        self.log_dict({"batch_idx": batch_idx, "random_metric": 123})
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)
    

In [4]:
from pytorch_lightning.utilities.rank_zero import rank_zero_only


class WandbLoginCallback(pl.Callback):
    def __init__(self, key):
        self.key = key

    def setup(self, trainer, pl_module, stage) -> None:
        if self.key:
            wandb.login(key=self.key)


def create_loggers(name, project_name, save_dir="./logs", offline=False):
    # Set the global rank to None to avoid creating
    # a new experiment run on the driver.
    rank_zero_only.rank = None

    # Wandb
    wandb_logger = WandbLogger(
        name=name,
        project=project_name,
        save_dir=f"{save_dir}/wandb",
        offline=offline
    )
    callbacks = [] if offline else [WandbLoginCallback(key=os.environ.get("WANDB_API_KEY", None))]

    # CometML
    comet_logger = CometLogger(
        api_key=os.environ.get("COMET_API_KEY", None),
        experiment_name=name,
        project_name=project_name,
        save_dir=f"{save_dir}/comet",
        offline=offline,
    )

    # MLFlow
    mlflow_logger = MLFlowLogger(
        run_name=name,
        experiment_name=project_name,
        save_dir=f"{save_dir}/mlflow"
    )

    # Tensorboard
    tensorboard_logger = TensorBoardLogger(
        name=name,
        save_dir=f"{save_dir}/tensorboard"
    )

    return [wandb_logger, comet_logger, mlflow_logger, tensorboard_logger], callbacks


loggers, callbacks = create_loggers(
    name="demo-run", project_name="demo-project", offline=False
)


CometLogger will be initialized in offline mode


In [None]:
# FOR SMOKE TEST
loggers, callbacks = create_loggers(
    name="demo-run", project_name="demo-project", offline=True
)

In [6]:
from ray.train.lightning import LightningConfigBuilder, LightningTrainer

builder = LightningConfigBuilder()
builder.module(cls=DummyModel)
builder.trainer(max_epochs=5, accelerator="cpu", logger=loggers, callbacks=callbacks)
builder.fit_params(train_dataloaders=dataloader)
lightning_config = builder.build()


In [7]:
from ray.air.config import RunConfig, ScalingConfig

scaling_config = ScalingConfig(num_workers=4, use_gpu=False)

run_config = RunConfig(
    name="ptl-exp-tracking",
    storage_path="/tmp/ray_results",
)

trainer = LightningTrainer(
    lightning_config=lightning_config,
    scaling_config=scaling_config,
    run_config=run_config,
)

trainer.fit()
