In [1]:
import torch
import torch.nn as nn

class MaskedLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super(MaskedLinear, self).__init__(in_features, out_features, bias)
        self.register_buffer('weight_mask', torch.ones(out_features, in_features))

    def forward(self, input):
        masked_weight = self.weight * self.weight_mask
        return nn.functional.linear(input, masked_weight, self.bias)

# Create a simple model with a masked linear layer
model = nn.Sequential(
    MaskedLinear(3, 4),
    nn.ReLU(),
    nn.Linear(4, 2)
)

# Set a specific weight mask for the masked linear layer
new_mask = torch.tensor([
    [1, 0, 1],
    [1, 1, 0],
    [0, 1, 1],
    [1, 1, 1]
])
model[0].weight_mask.copy_(new_mask)

tensor([[1., 0., 1.],
        [1., 1., 0.],
        [0., 1., 1.],
        [1., 1., 1.]])

In [7]:
model[0].weight_mask

tensor([[1., 0., 1.],
        [1., 1., 0.],
        [0., 1., 1.],
        [1., 1., 1.]])