Here we demonstrate the workflow of adding sparsity into existing models. As an example we take a
single encoder layer of BERT.

In [1]:
import torch
input_shape = (8, 128, 768) # batch, sequence, features
model = torch.hub.load('huggingface/pytorch-transformers',
    'model', 'bert-base-uncased').encoder.layer[0]
input = torch.rand(input_shape)
output = model(input)
print(output[0].shape)

Using cache found in /users/aivanov/.cache/torch/hub/huggingface_pytorch-transformers_main
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([8, 128, 768])


We target all linear layers in this model, including feedforward and attention projection layers.
A linear layer computes $y = xA^T + b$ and is defined in the `torch.nn.Linear` module.
In particular, we are going to sparsify tensors $A$ by magnitude pruning of $90\%$ of their values and storing them in the CSR format.
In the following snippet we collect the six weight tensors from linear layers, and assign sparsifiers to them. This yields the fully qualified names assigned by PyTorch to each of these tensors.

In [2]:
import sten
weights_to_sparsify = []
sb = sten.SparsityBuilder()
for module_name, module in model.named_modules():
    if isinstance(module, torch.nn.modules.linear.Linear):
        weight = module_name + ".weight"
        weights_to_sparsify.append(weight)
        sb.set_weight(
            name=weight,
            initial_sparsifier=sten.ScalarFractionSparsifier(0.9),
            inline_sparsifier=sten.KeepAll(),
            tmp_format=torch.Tensor,
            external_sparsifier=sten.KeepAll(),
            out_format=sten.CsrTensor,
        )
print(weights_to_sparsify)



['attention.self.query.weight', 'attention.self.key.weight', 'attention.self.value.weight', 'attention.output.dense.weight', 'intermediate.dense.weight', 'output.dense.weight']


Next, we repeat the same process for intermediate tensors.
In this example, we target only the output of the GELU activation.
However, it is challenging to refer to this intermediate tensor, as we treat the module as a black box that we do not modify, and internal operators may have varying or no name, depending on the implementation.
Examining the layer modules shows the model structure:

In [3]:
print(model)

BertLayer(
  (attention): BertAttention(
    (self): BertSelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (value): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (output): BertSelfOutput(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (intermediate): BertIntermediate(
    (dense): Linear(in_features=768, out_features=3072, bias=True)
    (intermediate_act_fn): GELUActivation()
  )
  (output): BertOutput(
    (dense): Linear(in_features=3072, out_features=768, bias=True)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)


From this we see that the `model.intermediate` submodule contains the GELU activation, but we still do not know the name of the output intermediate tensor.
We use the `torch.fx` tracer to assign deterministic names to the intermediate tensors. The result of running this command shows that the output of `<built-in function gelu>` (accessible as `torch.nn.functional.gelu`) is assigned to the tensor with the name `gelu` inside the `model.intermediate` module.

In [4]:
torch.fx.symbolic_trace(model.intermediate).graph.print_tabular()

opcode         name           target                    args              kwargs
-------------  -------------  ------------------------  ----------------  --------
placeholder    hidden_states  hidden_states             ()                {}
call_module    dense          dense                     (hidden_states,)  {}
call_function  gelu           <built-in function gelu>  (dense,)          {}
output         output         output                    (gelu,)           {}


We now assign a random fraction sparsifier with $90\%$ zeroing probability to the GELU output intermediate tensor.
The sparsifier stores the tensor in COO format.

In [5]:
sb.set_interm(
    name="intermediate.gelu",
    inline_sparsifier=sten.RandomFractionSparsifier(0.9),
    tmp_format=sten.CooTensor,
    external_sparsifier=sten.KeepAll(),
    out_format=sten.CooTensor,
)

Finally, we create a new sparse model from the original dense model and run it with the same arguments as before:

In [6]:
sparse_model = sb.get_sparse_model(model)
output = sparse_model(input)
print(output[0].shape)

  CsrTensor(tensor.to_sparse_csr()),


torch.Size([8, 128, 768])
