In [1]:
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
import os
import time

In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '4,5'
a = torch.ones((64,15,4,4,4)).cuda()

In [18]:
class DropBlock_3D(nn.Module):
    def __init__(self, keep_prob=0.9, block_size=7):
        super(DropBlock_3D, self).__init__()
        self.block_size = block_size
        self.keep_prob = keep_prob
        self.gamma = None
        self.kernel_size = (block_size, block_size, block_size)
        self.stride = (1, 1, 1)
        self.padding = (block_size//2, block_size//2, block_size//2)
    
    def calculate_gamma(self, x):
        return (1 - self.keep_prob) * x.shape[-1]**3/\
                (self.block_size**3 * (x.shape[-1] - self.block_size + 1)**3) 
    
    def forward(self, x):
        if not self.training:
            return x
        
        self.gamma = self.calculate_gamma(x)
        p = torch.ones_like(x) * (self.gamma)
        input = torch.bernoulli(p)
        mask = 1- torch.nn.functional.max_pool3d(input,
                                                  self.kernel_size,
                                                  self.stride,
                                                  self.padding)
        if self.block_size % 2 == 0:
            mask = mask[:, :, :-1, :-1, :-1]
        
        normalize_factor = mask.numel()/mask.shape[0]/(mask.sum((1,2,3,4)))
        normalize_factor= normalize_factor.view((normalize_factor.shape[0],1,1,1,1))
        result = x * normalize_factor *mask
        #result = mask * x * (mask.numel()/mask.sum())
        return result

In [19]:



# dp = DropBlock_3D(0.9, 1).cuda()
# start = time.time()
# b = dp(a)
# print time.time() - start
# start = time.time()


dp = DropBlock_3D(0.9, 2).cuda()
start = time.time()
b = dp(a)
print time.time() - start
start = time.time()

0.00173091888428


In [20]:
b

tensor([[[[[1.1650, 1.1650, 1.1650, 1.1650],
           [1.1650, 1.1650, 1.1650, 1.1650],
           [1.1650, 1.1650, 1.1650, 1.1650],
           [1.1650, 1.1650, 1.1650, 1.1650]],

          [[1.1650, 1.1650, 1.1650, 1.1650],
           [1.1650, 0.0000, 0.0000, 1.1650],
           [1.1650, 0.0000, 0.0000, 1.1650],
           [1.1650, 1.1650, 1.1650, 1.1650]],

          [[1.1650, 1.1650, 1.1650, 1.1650],
           [1.1650, 0.0000, 0.0000, 1.1650],
           [1.1650, 0.0000, 0.0000, 1.1650],
           [1.1650, 1.1650, 1.1650, 1.1650]],

          [[0.0000, 0.0000, 1.1650, 1.1650],
           [0.0000, 0.0000, 1.1650, 1.1650],
           [1.1650, 1.1650, 1.1650, 1.1650],
           [1.1650, 1.1650, 1.1650, 1.1650]]],


         [[[1.1650, 1.1650, 1.1650, 1.1650],
           [1.1650, 1.1650, 1.1650, 1.1650],
           [1.1650, 1.1650, 1.1650, 1.1650],
           [1.1650, 1.1650, 1.1650, 1.1650]],

          [[1.1650, 1.1650, 1.1650, 1.1650],
           [1.1650, 1.1650, 1.1650, 1.1650]

In [86]:
mask.sum((1,2,3,4))

tensor([882., 870., 857., 864., 858., 863., 854., 874., 871., 870., 848., 868.,
        847., 868., 854., 871., 872., 864., 876., 853., 858., 845., 866., 871.,
        855., 879., 855., 861., 866., 874., 875., 853., 880., 876., 871., 880.,
        876., 865., 876., 844., 884., 865., 860., 845., 874., 873., 876., 854.,
        879., 879., 864., 864., 842., 863., 865., 863., 862., 866., 863., 871.,
        859., 854., 869., 869.], device='cuda:0')

In [16]:
block_mask = F.max_pool2d(input=c[:, None, :, :],
                                  kernel_size=(2, 2),
                                  stride=(1, 1),
                                  padding=2 // 2)

In [18]:
block_mask.size()

torch.Size([64, 1, 5, 5])

In [20]:
block_mask = 1 - block_mask.squeeze(1)

In [23]:
a[:,0,:,:].size()

torch.Size([64, 4, 4])