In [None]:
# default_exp adaptive.layerdrop
%load_ext autoreload
%autoreload 2

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export 
import torch
from torch import nn

In [None]:
#export 
class LayerDrop(nn.Module):
    """
    Implements Reducing Transformer Depth on Demand with Structured Dropout (https://arxiv.org/abs/1909.11556)
    
    Arguments:
        module_list (nn.ModuleList): List from which layers are to dropped.
        layers_to_drop (int): number of layers to drop
    """
    def __init__(self, module_list, layers_to_drop):
        super(LayerDrop, self).__init__()
        self.module_list = module_list
        self.layers_to_drop = layers_to_drop
        self.length = len(module_list)
    
    def forward(self, feats, mask=None):
        x = torch.randint(0, self.length, (self.layers_to_drop,))
        for index, layer in enumerate(self.module_list):
            if index not in x:
                if not mask:
                    feats = layer(feats)
                else:
                    feats = layer(feats, mask)
        return feats

In [None]:
net = nn.ModuleList([nn.Linear(2, 2) for i in range(3)])

In [None]:
layerdrop = LayerDrop(net,2)

In [None]:
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)

In [None]:
list(layerdrop.module_list.parameters())

[Parameter containing:
 tensor([[-0.4580, -0.3658],
         [ 0.6844, -0.1322]], requires_grad=True), Parameter containing:
 tensor([0.1487, 0.4848], requires_grad=True), Parameter containing:
 tensor([[ 0.3933,  0.3460],
         [-0.1324,  0.3969]], requires_grad=True), Parameter containing:
 tensor([-0.3601,  0.0282], requires_grad=True), Parameter containing:
 tensor([[-0.2810, -0.1985],
         [ 0.2533, -0.6242]], requires_grad=True), Parameter containing:
 tensor([ 0.6768, -0.2451], requires_grad=True)]

In [None]:
loss = layerdrop(torch.rand(10,2)).sum()

tensor([0, 2])


In [None]:
optimizer.zero_grad()

In [None]:
loss.backward()

In [None]:
optimizer.step()

In [None]:
list(layerdrop.module_list.parameters())

[Parameter containing:
 tensor([[-0.4580, -0.3658],
         [ 0.6844, -0.1322]], requires_grad=True), Parameter containing:
 tensor([0.1487, 0.4848], requires_grad=True), Parameter containing:
 tensor([[ 0.3432,  0.3078],
         [-0.1824,  0.3587]], requires_grad=True), Parameter containing:
 tensor([-0.4601, -0.0718], requires_grad=True), Parameter containing:
 tensor([[-0.2810, -0.1985],
         [ 0.2533, -0.6242]], requires_grad=True), Parameter containing:
 tensor([ 0.6768, -0.2451], requires_grad=True)]