In [111]:
import torch.nn as nn
import torch
from torch.nn import functional as F

torch.random.manual_seed(0)
torch.set_printoptions(precision =6)

class LearnedDropout(nn.Module):
    def __init__(self, channel_dim):
        super().__init__()
        self.query = nn.Linear(channel_dim, channel_dim, bias= False)
        self.key = nn.Linear(channel_dim, channel_dim, bias=False)
        self.value = nn.Linear(channel_dim, channel_dim, bias=False)

    def forward(self, x):
        print(f"query: {self.query.weight}")
        print(f"key: {self.key.weight}")
        print(f"value: {self.value.weight}")
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        print(f"q: {q}")
        print(f"k: {k}")
        print(f"v: {v}")
        attn = (q.transpose(-2, -1) @ k) * (1**-0.5)
        print(f"attn: {attn}")
        out = v @ attn
        print(f"out: {out}")
        logits = F.softmax(out.view(out.size(0),-1), dim=-1)
        logits = logits.view_as(out)
        print(f"logits: {logits}")
        logits_mean = logits.mean(dim=(-1, -2), keepdim=True)
        print(f"logits_mean: {logits_mean}")
        dropout_mask = 1 / (1 + torch.exp(-(logits * 60 - 60 * logits_mean)))
        print(f"dropout_mask: {dropout_mask}")
        return x * dropout_mask

In [112]:
x = torch.tensor([[[1,-2,3, 1],[2,-3,-1, -1], [0,-1,2, 3]], 
                  [[2,-3,-1, -1],[0,-1,2, 3],[4,-4,2,0]]], dtype=torch.float32)
ld = LearnedDropout(4)
ld(x)

query: Parameter containing:
tensor([[-0.003743,  0.268222, -0.411523, -0.367970],
        [-0.192577,  0.134079, -0.009907,  0.396445],
        [-0.044372,  0.132306, -0.151107, -0.098283],
        [-0.477674, -0.331141, -0.206112,  0.018522]], requires_grad=True)
key: Parameter containing:
tensor([[ 0.197668,  0.300011, -0.338971, -0.217731],
        [ 0.181609,  0.415194, -0.102900,  0.374156],
        [-0.080592,  0.052907,  0.452738, -0.463835],
        [-0.314769, -0.126583, -0.194900,  0.432000]], requires_grad=True)
value: Parameter containing:
tensor([[-0.324090, -0.230166, -0.349320, -0.468280],
        [-0.291870,  0.429799,  0.223109,  0.242336],
        [ 0.026296, -0.256342,  0.084592, -0.466847],
        [-0.361283, -0.257765,  0.315469,  0.293161]], requires_grad=True)
q: tensor([[[-2.142724, -0.094010, -0.860587, -0.415205],
         [-0.032660, -1.173928, -0.236274,  0.225664],
         [-2.195175,  1.035442, -0.729367, -0.025517]],

        [[-0.032660, -1.173928, -0

tensor([[[ 0.006693, -0.013387,  0.020083,  0.006702],
         [ 2.000000, -0.020079, -0.006695, -0.006693],
         [ 0.000000, -0.006696,  0.013388,  0.020698]],

        [[ 2.000000, -0.020080, -0.008903, -0.006693],
         [ 0.000000, -0.007158,  0.013423,  0.061763],
         [ 0.026857, -0.026772,  0.014605,  0.000000]]],
       grad_fn=<MulBackward0>)

In [113]:
x @ ld.query.weight

tensor([[[-0.229379,  0.065842, -1.051141, -1.437185],
         [ 1.092291,  0.333042, -0.436107, -1.845512],
         [-1.329190, -0.862889, -0.910641, -0.537445]],

        [[ 1.092291,  0.333042, -0.436107, -1.845512],
         [-1.329190, -0.862889, -0.910641, -0.537445],
         [ 0.666591,  0.801185, -1.908677, -3.254222]]],
       grad_fn=<UnsafeViewBackward0>)