In [51]:
import torch
from torch import nn
import numpy as np

class LearnedDropoutFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, A, B):
        ctx.save_for_backward(x, A, B)
        dropout_mask = 0.5 * torch.cos(A * x + B) + 0.5
        return dropout_mask

    @staticmethod
    def backward(ctx, grad_output):
        x, A, B = ctx.saved_tensors
        base_gradient = -0.5 * torch.sin(A* x + B)
        x_grad_wrt_dropout_mask = base_gradient * A
        x_grad_wrt_dropout_mask = x_grad_wrt_dropout_mask % (np.pi / A)

        grad_x = grad_output * x_grad_wrt_dropout_mask
        grad_A = base_gradient * x
        grad_B = base_gradient
        return grad_x, grad_A, grad_B

class LearnedDropout(nn.Module):
    def __init__(self, dim_in):
        super(LearnedDropout, self).__init__()
        self.A = nn.Parameter(
            torch.normal(
                1000000000,
                0,
                size=(dim_in,),
            )
        )
        self.B = nn.Parameter(
            torch.normal(
                0,
                0,
                size=(dim_in,),
            )
        )

    def forward(self, x):
        dropout_mask = LearnedDropoutFunction.apply(x, self.A, self.B)
        return x * dropout_mask
    

LD = LearnedDropout(1)
x = torch.tensor([2.0], requires_grad=True)
loss = LD(x)
loss = loss
print(loss)
loss.backward()


tensor([1.4041], grad_fn=<MulBackward0>)


In [52]:
x.grad, LD.A.grad, LD.B.grad

(tensor([0.7021]), tensor([-0.9147]), tensor([-0.4574]))