In [46]:
import torch
from torch import nn
import numpy as np
torch.set_printoptions(precision=10)

class LearnedDropoutFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, A, B):
        ctx.save_for_backward(x, A, B)
        dropout_mask = 0.5 * torch.cos(A * x + B) + 0.5
        return dropout_mask

    @staticmethod
    def backward(ctx, grad_output):
        print(f"grad_output: {grad_output}")
        x, A, B = ctx.saved_tensors
        base_gradient = -0.5 * torch.sin(A* x + B)
        x_grad_wrt_dropout_mask = base_gradient * 0.04
        # x_grad_wrt_dropout_mask = x_grad_wrt_dropout_mask % (np.pi / A)

        grad_x = grad_output * x_grad_wrt_dropout_mask
        grad_A = grad_output * base_gradient * x
        grad_B = grad_output * base_gradient
        print(f"A: {A}")
        print(f"B: {B}")
        print(f"x: {x}")
        print(f"grad_A : {base_gradient * x}")
        print(f"grad_B : {base_gradient}")
        return grad_x, grad_A, grad_B

class LearnedDropout(nn.Module):
    def __init__(self, dim_in):
        super(LearnedDropout, self).__init__()
        self.A = nn.Parameter(
            torch.normal(
                1000000000,
                0,
                size=(dim_in,),
            )
        )
        self.B = nn.Parameter(
            torch.normal(
                0,
                0,
                size=(dim_in,),
            )
        )

    def forward(self, x):
        dropout_mask = LearnedDropoutFunction.apply(x, self.A, self.B)
        print(f"dropout_mask: {dropout_mask}")
        return x * dropout_mask
    

LD = LearnedDropout(1)
x = torch.tensor([0.001], requires_grad=True)
loss = LD(x)
loss = loss * 2
print(f"loss: {loss}")
loss.backward()


dropout_mask: tensor([0.9783917665], grad_fn=<LearnedDropoutFunctionBackward>)
loss: tensor([0.0019567837], grad_fn=<MulBackward0>)
grad_output: tensor([0.0020000001])
A: Parameter containing:
tensor([1.0000000000e+09], requires_grad=True)
B: Parameter containing:
tensor([0.], requires_grad=True)
x: tensor([0.0010000000], requires_grad=True)
grad_A : tensor([0.0001454006])
grad_B : tensor([0.1454006284])


In [47]:
x.grad, LD.A.grad, LD.B.grad

(tensor([1.9567952156]), tensor([2.9080129593e-07]), tensor([0.0002908013]))

In [48]:
a = torch.tensor(1000000000)
y = torch.tensor(0.001)
print(-0.5 * torch.sin(a * y))

tensor(0.1454006284)
