-
Notifications
You must be signed in to change notification settings - Fork 3
/
functions.py
37 lines (27 loc) · 1.07 KB
/
functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
from torch.autograd import Function, grad
class ReverseLayerF(Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
output = grad_output.neg() * ctx.alpha
return output, None
def gradient_penalty(critic, h_s, h_t):
# based on: https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py#L116
alpha = torch.rand(h_s.size(0), 1).cuda()
differences = h_t - h_s
interpolates = h_s + (alpha * differences)
interpolates = torch.stack([interpolates, h_s, h_t]).requires_grad_()
preds = critic(interpolates)
gradients = grad(preds, interpolates,
grad_outputs=torch.ones_like(preds),
retain_graph=True, create_graph=True)[0]
gradient_norm = gradients.norm(2, dim=1)
gradient_penalty = ((gradient_norm - 1)**2).mean()
return gradient_penalty
def set_requires_grad(model, requires_grad=True):
for param in model.parameters():
param.requires_grad = requires_grad