In [8]:
import torch
import torch.nn as nn
from torch.autograd import Variable

class KSparseDropout(nn.Module):
    def __init__(self, p=0.5, k=None, infer_ratio=1.0):
        super().__init__()
        self.p = p
        self.k = k
        
    def forward(self, x):
        if not self.k:
            self.k = int(self.p * x.shape[1])
        
        # Enforce k-sparsity
        topk, indices = torch.topk(x, self.k)
        res = Variable(torch.zeros(x.shape[0], x.shape[1]))
        res = res.scatter(1, indices, topk)
        
        return res
    
    
class KSparseDropout2d(nn.Module):
    def __init__(self, p=0.5, k=None, infer_ratio=1.0):
        super().__init__()
        self.p = p
        self.k = k
        
    def forward(self, x):
        if not self.k:
            self.k = int(self.p * x.shape[1])
        
        activation = x.sum(dim=2).sum(dim=2)
        topk, indices = torch.topk(activation, self.k, dim=1)
        
        for i, _ in enumerate(indices):
            for j, _ in enumerate(x[i, :, :, :]):
                if j not in indices[i]:
                    x[i, j, :, :] = 0
        
        return x

In [9]:
ks_dropout = KSparseDropout()

x = torch.rand(2, 10)

print(x)

x = ks_dropout(x)

print(x)

tensor([[0.7582, 0.8596, 0.7181, 0.4069, 0.5181, 0.3860, 0.8217, 0.4648, 0.6444,
         0.5185],
        [0.0300, 0.3980, 0.4545, 0.2809, 0.1544, 0.1968, 0.6695, 0.7141, 0.3331,
         0.6274]])
tensor([[0.7582, 0.8596, 0.7181, 0.0000, 0.0000, 0.0000, 0.8217, 0.0000, 0.6444,
         0.0000],
        [0.0000, 0.3980, 0.4545, 0.0000, 0.0000, 0.0000, 0.6695, 0.7141, 0.0000,
         0.6274]])
