Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How do you use the new Pytorch Profiler with Ignite #1916

Closed
ryanwongsa opened this issue Apr 8, 2021 · 9 comments
Closed

How do you use the new Pytorch Profiler with Ignite #1916

ryanwongsa opened this issue Apr 8, 2021 · 9 comments
Labels

Comments

@ryanwongsa
Copy link
Contributor

❓ Questions/Help/Support

Recently Pytorch announced a new profiling tool. Is there a way to use it with Ignite. The example code given in the blog post looks like:

 with torch.profiler.profile(
    schedule=torch.profiler.schedule(
        wait=2,
        warmup=2,
        active=6,
        repeat=1),
    on_trace_ready=tensorboard_trace_handler,
    with_trace=True
) as profiler:
    for step, data in enumerate(trainloader, 0):
        print("step:{}".format(step))
        inputs, labels = data[0].to(device=device), data[1].to(device=device)

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        profiler.step()

But the issue is that we don't have access to the main for loop with Ignite. Is there an easy way to get it to work with ignite. I am assuming there should be a way using the event handlers (Event. EPOCH_STARTED) maybe?

@ryanwongsa
Copy link
Contributor Author

ryanwongsa commented Apr 8, 2021

Never mind I see, I need to add an event to Event.EPOCH_STARTED with the following:

profiler = torch.profiler.profile(
    schedule=torch.profiler.schedule(
        wait=2,
        warmup=2,
        active=6,
        repeat=1),
    on_trace_ready=tensorboard_trace_handler,
    with_trace=True
) 

and then add to the update step:

profiler.step()

@sdesrozis
Copy link
Contributor

@ryanwongsa Thank you for your feedback about this new profiling tool. In your opinion, does it make sense to provide a specific handler for this as you did ?

@ryanwongsa
Copy link
Contributor Author

ryanwongsa commented Apr 8, 2021

@sdesrozis Yeah so I did something like:

profiler = None

def init_profiler(engine):
    global profiler
    profiler = torch.profiler.profile(
        schedule=torch.profiler.schedule(
            wait=2,
            warmup=2,
            active=6,
            repeat=1),
        on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name="profiler")
    ) 

profiler_handler = trainer.add_event_handler(
        Events.EPOCH_STARTED, init_profiler
    )

and added profiler.step() in the update training step.

I think it works for now, but might not be the best solution using the global profiler variable, but I am not sure about what are the alternatives.

@sdesrozis
Copy link
Contributor

@ryanwongsa Thank you !

@schuhschuh
Copy link
Contributor

schuhschuh commented Aug 5, 2021

Would it make sense and isn't it possible to add the profiler to the engine.state?

I think you can do

def init_profiler(engine):
    engine.state.profiler = torch.profiler.profile(
        schedule=torch.profiler.schedule(
            wait=2,
            warmup=2,
            active=6,
            repeat=1),
        on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name="profiler")
    ) 

trainer.add_event_handler(Events.EPOCH_STARTED, init_profiler)

And in the process_function of the engine:

def process_function(engine: Engine, batch):
    engine.state.profiler.step()
    return output

But otherwise, I am surprised this works without entering a context? There seem to be some things happening in __enter__:
https://github.com/pytorch/pytorch/blob/5c431981b5b36da6dba61f0e5d5101e72d2fd726/torch/autograd/profiler.py#L170-L177

May also be necessary to have the __exit__ routine called at the end:
https://github.com/pytorch/pytorch/blob/5c431981b5b36da6dba61f0e5d5101e72d2fd726/torch/autograd/profiler.py#L187-L200

The following might be a better solution (if it works, haven't tried):

with torch.profiler.profile(
    schedule=torch.profiler.schedule(
        wait=2,
        warmup=2,
        active=6,
        repeat=1),
    on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name="profiler")
) as profiler:

    def process_function(engine: Engine, batch):
        ...
        profiler.step()
        return output

    trainer = ignite.engine.Engine(process_function)

    trainer.run(...)

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Aug 6, 2021

cc @Priyansi for visibility and tutorials

@arisliang
Copy link

Does the profiler.step() work the same if it's put in a Events.GET_BATCH_COMPLETED handler?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 27, 2021

@arisliang it would work if everything is inside with torch.profiler.profile context manager

@arisliang
Copy link

@vfdev-5 thanks for the prompt reply!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants