Skip to content

Commit

Permalink
add requirements, update ot
Browse files Browse the repository at this point in the history
  • Loading branch information
ruixv committed Dec 17, 2021
1 parent efe9052 commit 1df711a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
9 changes: 9 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
dominate==2.4.0
geomloss==0.2.4
matplotlib==2.2.2
numpy==1.19.5
Pillow==8.4.0
POT==0.8.0
torch==1.7.1
torchvision==0.8.2
visdom==0.1.8.9
20 changes: 11 additions & 9 deletions step2/models/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,20 @@ def get_l1loss(self,latent_i,latent_t):
return output

def get_otloss(self, latent_i,latent_t):
loss_ot = 0
pi_loss = nn.L1Loss(reduction='sum')
batchsize = latent_i.shape[0]
if batchsize != latent_t.shape[0]:
raise ValueError('The length of the two latent codes must be the same.')
for ii in range(batchsize):
for jj in range(batchsize):
M = torch.zeros(batchsize,batchsize)
M_metric = nn.L1Loss(reduction='sum')
for ii in range(0, batchsize):
for jj in range(0, batchsize):
if ii == jj:
c_loss = 1
M[ii,jj] = M_metric(latent_i[ii], latent_t[jj].detach())
else:
c_loss = 0
loss_ot = loss_ot + c_loss * pi_loss(latent_i[ii],latent_t[jj].detach())
M[ii,jj] = 10e10
aa = torch.ones([batchsize, ])
bb = torch.ones((batchsize, ))
# require ot.__version__ >= 0.8.0
gamma = ot.emd(aa, bb, M)
loss_ot = torch.sum(gamma*M).cuda()
return loss_ot


Expand Down

0 comments on commit 1df711a

Please sign in to comment.