Skip to content

Commit

Permalink
Merge pull request #8 from szerintedmi/resolve_0-dim_tensor_warnings
Browse files Browse the repository at this point in the history
Resolve invalid index of a 0-dim tensor warnings
  • Loading branch information
xuebinqin committed Jan 24, 2021
2 parents 527af34 + 7c62623 commit 3ef086a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
5 changes: 4 additions & 1 deletion u2net_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,12 @@ def main():
elif(model_name=='u2netp'):
print("...load U2NEP---4.7 MB")
net = U2NETP(3,1)
net.load_state_dict(torch.load(model_dir))

if torch.cuda.is_available():
net.load_state_dict(torch.load(model_dir))
net.cuda()
else:
net.load_state_dict(torch.load(model_dir, map_location='cpu'))
net.eval()

# --------- 4. inference for each image ---------
Expand Down
6 changes: 3 additions & 3 deletions u2net_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
loss6 = bce_loss(d6,labels_v)

loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.data[0],loss1.data[0],loss2.data[0],loss3.data[0],loss4.data[0],loss5.data[0],loss6.data[0]))
print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.data.item(),loss1.data.item(),loss2.data.item(),loss3.data.item(),loss4.data.item(),loss5.data.item(),loss6.data.item()))

return loss0, loss

Expand Down Expand Up @@ -144,8 +144,8 @@ def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
optimizer.step()

# # print statistics
running_loss += loss.data[0]
running_tar_loss += loss2.data[0]
running_loss += loss.data.item()
running_tar_loss += loss2.data.item()

# del temporary outputs and loss
del d0, d1, d2, d3, d4, d5, d6, loss2, loss
Expand Down

0 comments on commit 3ef086a

Please sign in to comment.