Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PrintReportNotebook does not seem to work properly #820

Open
yor-dev opened this issue May 29, 2024 · 2 comments · May be fixed by #821
Open

PrintReportNotebook does not seem to work properly #820

yor-dev opened this issue May 29, 2024 · 2 comments · May be fixed by #821
Assignees

Comments

@yor-dev
Copy link

yor-dev commented May 29, 2024

I may have discovered a bug, so I wanted to report it. If I used it incorrectly, I would appreciate it if you would point it out.

I tried running the MNIST example code in https://github.com/pfnet/pytorch-pfn-extras/blob/master/example/mnist.py in a Jupyter Notebook by simply replacing command line arguments with dummy args. However, I encountered the following error: KeyError: "None of [Index(['epoch', 'iteration', 'train/loss', 'lr', 'model/fc2.bias/grad/min',\n 'val/loss', 'val/acc'],\n dtype='object')] are in the [columns]"

Here is the code I used in the Jupyter Notebook. The difference from the original one is only treatment of the command line arguments.

import pytorch_pfn_extras as ppe
import pytorch_pfn_extras.training.extensions as extensions
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.flatten(start_dim=1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        ppe.nn.ensure(x, shape=(None, 10))
        return F.log_softmax(x, dim=1)


def train(manager, args, model, device, train_loader):
    while not manager.stop_trigger:
        model.train()
        for _, (data, target) in enumerate(train_loader):
            with manager.run_iteration(step_optimizers=["main"]):
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = F.nll_loss(output, target)
                ppe.reporting.report({"train/loss": loss.item()})
                loss.backward()


def test(args, model, device, data, target):
    """The extension loops over the iterator in order to
    drive the evaluator progress bar and reporting
    averages
    """
    model.eval()
    data, target = data.to(device), target.to(device)
    output = model(data)
    # Final result will be average of averages of the same size
    test_loss = F.nll_loss(output, target, reduction="mean").item()
    ppe.reporting.report({"val/loss": test_loss})
    pred = output.argmax(dim=1, keepdim=True)

    correct = pred.eq(target.view_as(pred)).sum().item()
    ppe.reporting.report({"val/acc": correct / len(data)})


def main():
    # Training settings

    class DummyArgs:
        def __init__(self, **kwargs):
            self.__dict__.update(kwargs)

    args = DummyArgs(batch_size=64, test_batch_size=1000, epochs=10, lr=0.01, momentum=0.5, cuda=True, seed=1, save_model=False, snapshot=None, 
                     slack=None)

    use_cuda = args.cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "../data",
            train=True,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,)),
                ]
            ),
        ),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs,  # type: ignore[arg-type]
    )
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "../data",
            train=False,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,)),
                ]
            ),
        ),
        batch_size=args.test_batch_size,
        shuffle=True,
        **kwargs,  # type: ignore[arg-type]
    )

    model = Net()
    model.to(device)

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

    # manager.extend(...) also works
    my_extensions = [
        extensions.LogReport(),
        # Enables TensorBoard support.
        # Run `tensorboard --logdir runs` to launch the TensorBoard.
        extensions.LogReport(
            writer=ppe.writing.TensorBoardWriter(out_dir="runs"),
            trigger=(1, "iteration"),
        ),
        extensions.ProgressBar(),
        extensions.observe_lr(optimizer=optimizer),
        extensions.ParameterStatistics(model, prefix="model"),
        extensions.VariableStatisticsPlot(model),
        extensions.Evaluator(
            test_loader,
            model,
            eval_func=lambda data, target: test(args, model, device, data, target),
            progress_bar=True,
        ),
        extensions.PlotReport(["train/loss", "val/loss"], "epoch", filename="loss.png"),
        extensions.PrintReport(
            [
                "epoch",
                "iteration",
                "train/loss",
                "lr",
                "model/fc2.bias/grad/min",
                "val/loss",
                "val/acc",
            ]
        ),
        extensions.snapshot(),
    ]

    if args.slack is not None:
        my_extensions.append(
            extensions.Slack(
                channel=args.slack,
                msg="Epoch #{manager.epoch}: val/loss = {val/loss}",
                # Surround the username with <> to mention.
                end_msg="{default}\n<@your_slack_user_name>",
                # Upload any artifacts generated during the training.
                filenames=["result/statistics.png"],
                # You can specify when to upload these files.
                # e.g., only at the final epoch:
                # upload_trigger=(args.epochs, 'epoch'),
            )
        )

    # Custom stop triggers can be added to the manager and
    # their status accessed through `manager.stop_trigger`
    trigger = None
    # trigger = ppe.training.triggers.EarlyStoppingTrigger(
    #     check_trigger=(1, 'epoch'), monitor='val/loss')
    manager = ppe.training.ExtensionsManager(
        model,
        optimizer,
        args.epochs,
        extensions=my_extensions,
        iters_per_epoch=len(train_loader),
        stop_trigger=trigger,
    )
    # Lets load the snapshot
    if args.snapshot is not None:
        state = torch.load(args.snapshot)
        manager.load_state_dict(state)
    train(manager, args, model, device, train_loader)
    # Test function is called from the evaluator extension
    # to get access to the reporter and other facilities
    # test(args, model, device, test_loader)

    if args.save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")


if __name__ == "__main__":
    main()
@linshokaku
Copy link
Member

Thank you for your report. We have confirmed that PrintReport on Jupyter Notebook is not working properly on our end.
As a workaround, we can make PrintReport work properly by specifying a trigger in ExtensionEntry as follows.

        ppe.training.ExtensionEntry(
            extensions.PrintReport(
                [
                    "epoch",
                    "iteration",
                    "train/loss",
                    "lr",
                    "model/fc2.bias/grad/min",
                    "val/loss",
                    "val/acc",
                ]
            ),
            trigger=(1, "epoch")
        ),

However, we also found that the problems you reported are ones that need to be resolved. We will create a separate PR to fix the problem, so please keep this issue open until the fix is complete.

@linshokaku linshokaku linked a pull request Jun 3, 2024 that will close this issue
@yor-dev
Copy link
Author

yor-dev commented Jun 9, 2024

Thank you for telling me a workaround. It works well in my environment, too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants