Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Predicting multiple outputs #819

Closed
mashu opened this issue Mar 2, 2020 · 6 comments
Closed

Predicting multiple outputs #819

mashu opened this issue Mar 2, 2020 · 6 comments
Labels

Comments

@mashu
Copy link

mashu commented Mar 2, 2020

I have a model which predicts 4 distinct multi-label classes. I have 4 independent losses which are weighted and summed into single loss, but also 4 different distinct accuracies. It seems like Ignite engine returns only one output (y_hat) and target y. Is is possible or how to do it with multiple model outputs?

Thanks

@mashu mashu added the question label Mar 2, 2020
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Mar 2, 2020

@mashu you can define your own model update function and return anything you wish. Please, see here. Maybe this will help:

def update(engine, batch):
    x, y = batch
    y_pred = model(inputs)
    loss1 = criterion1(y_pred, y)
    loss2 = criterion2(y_pred, y)
    # ...
    total_loss = ... 
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    return {
        'loss1': loss1.item(),
        'loss2': loss2.item(),
        ...
   }

trainer = Engine(update)

@mashu
Copy link
Author

mashu commented Mar 2, 2020

In the docs https://pytorch.org/ignite/metrics.html#ignite.metrics.Accuracy
I see that

If the engine’s output is not in the format (y_pred, y) or {‘y_pred’: y_pred, ‘y’: y, …}, the user can use the output_transform argument to transform it:

That would imply I need to return very specific parameters, and it would support only single accuracy metric, or am I missing something? How does Accuracy() metric know which pair of parameters belong to it?
Can I find somewhere more complete example with multiple accuracies updated?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Mar 2, 2020

Please provide an example/code snippet of what you are talking about. It would help to understand the problem.

IMO, it make sense to compute accuracies on a "fixed" model, so during the evaluation phase. You can do the following:

def inference(engine, batch):
    x, (y1, y2, y3, y4) = batch
    model.eval()
    with torch.no_grad():
        y_pred1, y_pred2, y_pred3, y_pred4 = model(x)

    return {
        'out1': (y_pred1, y1),
        'out2': (y_pred2, y2),
        'out3': (y_pred3, y3),
        'out4': (y_pred4, y4),
   }

evaluator = Engine(inference)

# If we would like to use for loop we can not use lambdas ...
Accuracy(output_transform=lambda out: out['out1']).attach(evaluator, 'acc1')
Accuracy(output_transform=lambda out: out['out2']).attach(evaluator, 'acc2')
Accuracy(output_transform=lambda out: out['out3']).attach(evaluator, 'acc3')
Accuracy(output_transform=lambda out: out['out4']).attach(evaluator, 'acc4')

state = evaluator.run(val_loader)
assert 'acc1' in state.metrics
assert 'acc2' in state.metrics
assert 'acc3' in state.metrics
assert 'acc4' in state.metrics

This will compute 4 accuracies for validation dataset.

@mashu
Copy link
Author

mashu commented Mar 2, 2020

Ok, so the solution it to return tuples for each accuracy. That makes sense, thanks a lot!

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Mar 2, 2020

Yes, you can return everything you wish: tuple, dict etc, just need to adapt correspondingly output_transform such that Accuracy gets either tuple or a dict with specific keys.

I let you close the issue if it solved for you. Thanks !

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Mar 3, 2020

@mashu I close the issue. Feel free to reopen it if needed

@vfdev-5 vfdev-5 closed this as completed Mar 3, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants