In [None]:
import torch
from torch import nn

class LearnedDropoutFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, A, B):
        ctx.save_for_backward(x, A, B)
        dropout_mask = 0.5 * torch.cos(A * x + B) + 0.5
        return x * dropout_mask

    @staticmethod
    def backward(ctx, grad_output):
        x, A, B = ctx.saved_tensors
        dropout_mask = 0.5 * torch.cos(A * x + B) + 0.5
        # Compute the gradient of the output with respect to the input (x)
        grad_x = grad_output * dropout_mask
        # Here you can modify grad_x according to your specific needs
        # For example, apply your scaling method here
        grad_x_scaled = grad_x * 0.5  # Example scaling, replace with your method
        
        # Compute gradients for A and B as None since we don't need to modify them here
        grad_A = grad_B = None
        return grad_x_scaled, grad_A, grad_B

class LearnedDropout(nn.Module):
    def __init__(self, dim_in):
        super(LearnedDropout, self).__init__()
        self.A = nn.Parameter(torch.randn(dim_in))
        self.B = nn.Parameter(torch.randn(dim_in))

    def forward(self, x):
        return LearnedDropoutFunction.apply(x, self.A, self.B)
    

x = LearnedDropout(10)(torch.randn(10))
x.backward()