Skip to content

Commit

Permalink
modify GANloss
Browse files Browse the repository at this point in the history
  • Loading branch information
xinntao committed Apr 21, 2018
1 parent 231c96f commit 5448d2b
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions codes/models/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,24 @@
from torch.autograd import Variable


# Define GAN loss: [vanilla | lsgan | wgan]
# Define GAN loss: [vanilla | lsgan | wgan-gp]
class GANLoss(nn.Module):
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.gan_type = gan_type.lower()
self.real_label_val = real_label_val
self.fake_label_val = fake_label_val
self.Tensor = tensor

if self.gan_type in ['vanilla', 'lsgan']:
self.register_buffer('real_label', torch.Tensor())
self.register_buffer('fake_label', torch.Tensor())

# print(type(real_label), type(fake_label))
self.register_buffer('real_label', self.Tensor())
self.register_buffer('fake_label', self.Tensor())

if self.gan_type == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif self.gan_type == 'lsgan':
self.loss = nn.MSELoss()
elif self.gan_type == 'wgan':
elif self.gan_type == 'wgan-gp':
def wgan_loss(input, target):
# target is boolean
return -1 * input.mean() if target else input.mean()
Expand All @@ -30,22 +29,20 @@ def wgan_loss(input, target):
raise NotImplementedError('GAN type [%s] is not found' % self.gan_type)

def get_target_label(self, input, target_is_real):
if self.gan_type == 'wgan':
if self.gan_type == 'wgan-gp':
return target_is_real
if target_is_real:
if self.real_label.size() != input.size(): # check if new label needed
self.real_label.resize_(input.size()).fill_(self.real_label_val)
return Variable(self.real_label)
return Variable(self.real_label, requires_grad=False)
else:
if self.fake_label.size() != input.size(): # check if new label needed
self.fake_label.resize_(input.size()).fill_(self.fake_label_val)
return Variable(self.fake_label)
return Variable(self.fake_label, requires_grad=False)

def forward(self, input, target_is_real):
# start_time = time.time()
target_label = self.get_target_label(input, target_is_real)
loss = self.loss(input, target_label)
# print('GANLoss time:', time.time() - start_time)
return loss


Expand All @@ -67,4 +64,4 @@ def forward(self, interp, interp_crit):
grad_interp_norm = grad_interp.norm(2, dim=1)
# print(grad_interp_norm)
loss = ((grad_interp_norm - 1) ** 2).mean()
return loss
return loss

0 comments on commit 5448d2b

Please sign in to comment.