Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wandb logger #926

Merged
merged 14 commits into from
Apr 22, 2020
191 changes: 191 additions & 0 deletions examples/contrib/mnist/mnist_with_wandb_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""
MNIST example with training and validation monitoring using Weights & Biases

Requirements:
Weights & Biases: `pip install wandb`

Usage:

Make sure you are logged into Weights & Biases (use the `wandb` command).

Run the example:
```bash
python mnist_with_wandb_logger.py
```

Go to https://wandb.com and explore your experiment.
"""
import sys
from argparse import ArgumentParser
import logging

import torch
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F
from torch.optim import SGD
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.handlers import ModelCheckpoint

from ignite.contrib.handlers.wandb_logger import *


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=-1)


def get_data_loaders(train_batch_size, val_batch_size):
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])

train_loader = DataLoader(
MNIST(download=True, root=".", transform=data_transform, train=True), batch_size=train_batch_size, shuffle=True
)

val_loader = DataLoader(
MNIST(download=False, root=".", transform=data_transform, train=False), batch_size=val_batch_size, shuffle=False
)
return train_loader, val_loader


def run(train_batch_size, val_batch_size, epochs, lr, momentum):
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
model = Net()
device = "cpu"

if torch.cuda.is_available():
device = "cuda"

optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
criterion = nn.CrossEntropyLoss()
trainer = create_supervised_trainer(model, optimizer, criterion, device=device)

if sys.version_info > (3,):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can remove GPU logging as W&D does it by itself.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, will fix.

from ignite.contrib.metrics.gpu_info import GpuInfo

try:
GpuInfo().attach(trainer)
except RuntimeError:
print(
"INFO: By default, in this example it is possible to log GPU information (used memory, utilization). "
"As there is no pynvml python package installed, GPU information won't be logged. Otherwise, please "
"install it : `pip install pynvml`"
)

metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)}

train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)
validation_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)

@trainer.on(Events.EPOCH_COMPLETED)
def compute_metrics(engine):
train_evaluator.run(train_loader)
validation_evaluator.run(val_loader)

wandb_logger = WandBLogger(
project="pytorch-ignite-integration",
name="ignite-mnist-example",
config={
"train_batch_size": train_batch_size,
"val_batch_size": val_batch_size,
"epochs": epochs,
"lr": lr,
"momentum": momentum,
},
)

def iteration(engine):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use global_step_from_engine instead of iteration:
global_step_transform=global_step_from_engine(trainer)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that global_step_from_engine will use the iteration number for handlers called on Events.ITERATION_COMPLETED, but epoch number for handlers called on Events.EPOCH_COMPLETED. W&B does not allow to log events with a smaller step than previously logged, so I need to make sure to always use the iteration number.

Copy link
Collaborator

@vfdev-5 vfdev-5 Apr 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I see what you mean. So, we can not log at all any epoch number if we log iterations ?

Anyway, we can do that iteration(engine) in a more simple way:

global_step_transform=lambda _, _: trainer.state.iteration 

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly, at least not directly. The workaround (as described in the W&B docs) is to log epochs as "metric", and then select it as X-axis in the W&B web interface. Not sure how to facilitate this in the Logger here.

Thanks for the simpler global_step_transform, will add it!

def wrapper(_, event_name):
return engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED)

return wrapper

wandb_logger.attach(
trainer,
log_handler=OutputHandler(
tag="training",
output_transform=lambda loss: {"batchloss": loss},
metric_names="all",
global_step_transform=iteration(trainer),
),
event_name=Events.ITERATION_COMPLETED(every=100),
)

wandb_logger.attach(
train_evaluator,
log_handler=OutputHandler(
tag="training", metric_names=["loss", "accuracy"], global_step_transform=iteration(trainer)
),
event_name=Events.EPOCH_COMPLETED,
)

wandb_logger.attach(
validation_evaluator,
log_handler=OutputHandler(
tag="validation", metric_names=["loss", "accuracy"], global_step_transform=iteration(trainer)
),
event_name=Events.EPOCH_COMPLETED,
)

wandb_logger.attach(
trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_COMPLETED(every=100)
)
wandb_logger.watch(model, log="all")

def score_function(engine):
return engine.state.metrics["accuracy"]

model_checkpoint = ModelCheckpoint(
wandb_logger.run.dir,
n_saved=2,
filename_prefix="best",
score_function=score_function,
score_name="validation_accuracy",
global_step_transform=iteration(trainer),
)
validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {"model": model})

# kick everything off
trainer.run(train_loader, max_epochs=epochs)
wandb_logger.close()


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--batch_size", type=int, default=64, help="input batch size for training (default: 64)")
parser.add_argument(
"--val_batch_size", type=int, default=1000, help="input batch size for validation (default: 1000)"
)
parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train (default: 10)")
parser.add_argument("--lr", type=float, default=0.01, help="learning rate (default: 0.01)")
parser.add_argument("--momentum", type=float, default=0.5, help="SGD momentum (default: 0.5)")

args = parser.parse_args()

# Setup engine logger
logger = logging.getLogger("ignite.engine.engine.Engine")
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s %(name)-12s %(levelname)-8s %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)

run(args.batch_size, args.val_batch_size, args.epochs, args.lr, args.momentum)
1 change: 1 addition & 0 deletions ignite/contrib/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
from ignite.contrib.handlers.visdom_logger import VisdomLogger
from ignite.contrib.handlers.polyaxon_logger import PolyaxonLogger
from ignite.contrib.handlers.mlflow_logger import MLflowLogger
from ignite.contrib.handlers.wandb_logger import WandBLogger
from ignite.contrib.handlers.base_logger import global_step_from_engine
from ignite.contrib.handlers.lr_finder import FastaiLRFinder