Skip to content

Commit

Permalink
Cleanup of helpers_torch.py and ncf/pytorch.py
Browse files Browse the repository at this point in the history
  • Loading branch information
vatai committed Jul 7, 2020
1 parent 18d6979 commit 0c394df
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
11 changes: 6 additions & 5 deletions benchmarker/modules/problems/helpers_torch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict


class Net4Inference(nn.Module):
Expand All @@ -20,8 +22,9 @@ def __init__(self, net, criterion):

def __call__(self, x, t):
outs = self.net(x)
# TODO: figure this out. there's a reason why backward finction is returned
# precompiled? is it correct to ignore it?
# TODO: figure this out. there's a reason why backward
# finction is returned precompiled? is it correct to ignore
# it?
if isinstance(outs, OrderedDict):
outs = outs["out"]
loss = self.criterion(outs, t)
Expand All @@ -38,8 +41,6 @@ def Net4Both(params, net, inference, training):
class ClassifierInference(Net4Inference):
def __call__(self, x):
outs = self.net(x)
# TODO: figure this out. there's a reason why backward finction is returned
# precompiled? is it correct to ignore it?
if isinstance(outs, OrderedDict):
outs = outs["out"]
return F.softmax(outs, dim=-1)
Expand Down
1 change: 1 addition & 0 deletions benchmarker/modules/problems/ncf/pytorch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# This class is from https://github.com/mlperf.
# https://github.com/mlperf/training/blob/master/recommendation/pytorch/ncf.py
import argparse

import numpy as np
Expand Down

0 comments on commit 0c394df

Please sign in to comment.