-
-
Notifications
You must be signed in to change notification settings - Fork 604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wandb logger #926
Wandb logger #926
Changes from 10 commits
26442cd
ed3e4f6
723a667
ce978e0
340898c
26aece0
06b11ff
f2cfa39
dab7cdc
5b928c1
29fdbb3
2d9582b
04381cc
07ae0a2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
""" | ||
MNIST example with training and validation monitoring using Weights & Biases | ||
|
||
Requirements: | ||
Weights & Biases: `pip install wandb` | ||
|
||
Usage: | ||
|
||
Make sure you are logged into Weights & Biases (use the `wandb` command). | ||
|
||
Run the example: | ||
```bash | ||
python mnist_with_wandb_logger.py | ||
``` | ||
|
||
Go to https://wandb.com and explore your experiment. | ||
""" | ||
import sys | ||
from argparse import ArgumentParser | ||
import logging | ||
|
||
import torch | ||
from torch.utils.data import DataLoader | ||
from torch import nn | ||
import torch.nn.functional as F | ||
from torch.optim import SGD | ||
from torchvision.datasets import MNIST | ||
from torchvision.transforms import Compose, ToTensor, Normalize | ||
|
||
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator | ||
from ignite.metrics import Accuracy, Loss | ||
from ignite.handlers import ModelCheckpoint | ||
|
||
from ignite.contrib.handlers.wandb_logger import * | ||
|
||
|
||
class Net(nn.Module): | ||
def __init__(self): | ||
super(Net, self).__init__() | ||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5) | ||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5) | ||
self.conv2_drop = nn.Dropout2d() | ||
self.fc1 = nn.Linear(320, 50) | ||
self.fc2 = nn.Linear(50, 10) | ||
|
||
def forward(self, x): | ||
x = F.relu(F.max_pool2d(self.conv1(x), 2)) | ||
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) | ||
x = x.view(-1, 320) | ||
x = F.relu(self.fc1(x)) | ||
x = F.dropout(x, training=self.training) | ||
x = self.fc2(x) | ||
return F.log_softmax(x, dim=-1) | ||
|
||
|
||
def get_data_loaders(train_batch_size, val_batch_size): | ||
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) | ||
|
||
train_loader = DataLoader( | ||
MNIST(download=True, root=".", transform=data_transform, train=True), batch_size=train_batch_size, shuffle=True | ||
) | ||
|
||
val_loader = DataLoader( | ||
MNIST(download=False, root=".", transform=data_transform, train=False), batch_size=val_batch_size, shuffle=False | ||
) | ||
return train_loader, val_loader | ||
|
||
|
||
def run(train_batch_size, val_batch_size, epochs, lr, momentum): | ||
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size) | ||
model = Net() | ||
device = "cpu" | ||
|
||
if torch.cuda.is_available(): | ||
device = "cuda" | ||
|
||
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum) | ||
criterion = nn.CrossEntropyLoss() | ||
trainer = create_supervised_trainer(model, optimizer, criterion, device=device) | ||
|
||
if sys.version_info > (3,): | ||
from ignite.contrib.metrics.gpu_info import GpuInfo | ||
|
||
try: | ||
GpuInfo().attach(trainer) | ||
except RuntimeError: | ||
print( | ||
"INFO: By default, in this example it is possible to log GPU information (used memory, utilization). " | ||
"As there is no pynvml python package installed, GPU information won't be logged. Otherwise, please " | ||
"install it : `pip install pynvml`" | ||
) | ||
|
||
metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)} | ||
|
||
train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) | ||
validation_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) | ||
|
||
@trainer.on(Events.EPOCH_COMPLETED) | ||
def compute_metrics(engine): | ||
train_evaluator.run(train_loader) | ||
validation_evaluator.run(val_loader) | ||
|
||
wandb_logger = WandBLogger( | ||
project="pytorch-ignite-integration", | ||
name="ignite-mnist-example", | ||
config={ | ||
"train_batch_size": train_batch_size, | ||
"val_batch_size": val_batch_size, | ||
"epochs": epochs, | ||
"lr": lr, | ||
"momentum": momentum, | ||
}, | ||
) | ||
|
||
def iteration(engine): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem is that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, I see what you mean. So, we can not log at all any epoch number if we log iterations ? Anyway, we can do that global_step_transform=lambda _, _: trainer.state.iteration There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Exactly, at least not directly. The workaround (as described in the W&B docs) is to log epochs as "metric", and then select it as X-axis in the W&B web interface. Not sure how to facilitate this in the Logger here. Thanks for the simpler |
||
def wrapper(_, event_name): | ||
return engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED) | ||
|
||
return wrapper | ||
|
||
wandb_logger.attach( | ||
trainer, | ||
log_handler=OutputHandler( | ||
tag="training", | ||
output_transform=lambda loss: {"batchloss": loss}, | ||
metric_names="all", | ||
global_step_transform=iteration(trainer), | ||
), | ||
event_name=Events.ITERATION_COMPLETED(every=100), | ||
) | ||
|
||
wandb_logger.attach( | ||
train_evaluator, | ||
log_handler=OutputHandler( | ||
tag="training", metric_names=["loss", "accuracy"], global_step_transform=iteration(trainer) | ||
), | ||
event_name=Events.EPOCH_COMPLETED, | ||
) | ||
|
||
wandb_logger.attach( | ||
validation_evaluator, | ||
log_handler=OutputHandler( | ||
tag="validation", metric_names=["loss", "accuracy"], global_step_transform=iteration(trainer) | ||
), | ||
event_name=Events.EPOCH_COMPLETED, | ||
) | ||
|
||
wandb_logger.attach( | ||
trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_COMPLETED(every=100) | ||
) | ||
wandb_logger.watch(model, log="all") | ||
|
||
def score_function(engine): | ||
return engine.state.metrics["accuracy"] | ||
|
||
model_checkpoint = ModelCheckpoint( | ||
wandb_logger.run.dir, | ||
n_saved=2, | ||
filename_prefix="best", | ||
score_function=score_function, | ||
score_name="validation_accuracy", | ||
global_step_transform=iteration(trainer), | ||
) | ||
validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {"model": model}) | ||
|
||
# kick everything off | ||
trainer.run(train_loader, max_epochs=epochs) | ||
wandb_logger.close() | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser() | ||
parser.add_argument("--batch_size", type=int, default=64, help="input batch size for training (default: 64)") | ||
parser.add_argument( | ||
"--val_batch_size", type=int, default=1000, help="input batch size for validation (default: 1000)" | ||
) | ||
parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train (default: 10)") | ||
parser.add_argument("--lr", type=float, default=0.01, help="learning rate (default: 0.01)") | ||
parser.add_argument("--momentum", type=float, default=0.5, help="SGD momentum (default: 0.5)") | ||
|
||
args = parser.parse_args() | ||
|
||
# Setup engine logger | ||
logger = logging.getLogger("ignite.engine.engine.Engine") | ||
handler = logging.StreamHandler() | ||
formatter = logging.Formatter("%(asctime)s %(name)-12s %(levelname)-8s %(message)s") | ||
handler.setFormatter(formatter) | ||
logger.addHandler(handler) | ||
logger.setLevel(logging.INFO) | ||
|
||
run(args.batch_size, args.val_batch_size, args.epochs, args.lr, args.momentum) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can remove GPU logging as W&D does it by itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, will fix.