# Demo of a Model Transformation API

Here is a small demo of how the existing Drop-out/Connect representations might be implemented in a more extensible manner.

In [1]:
%load_ext autoreload
%autoreload 2

Consider the following nested model.
Here, we would expect Dropout layers to be inserted before layers `1` and `3.0`.

In [2]:
import torch

import probly.representation.dropout as probly_dropout
import probly.traverse_representation.dropout as probly_traverse_dropout


def showModel(m):
    print("\n".join(map(str, m.named_children())))


s = torch.nn.Sequential(
    torch.nn.Sequential(
        torch.nn.Linear(10, 10),
        torch.nn.ReLU(),
    ),
    torch.nn.Linear(10, 10),
    torch.nn.ReLU(),
    torch.nn.Sequential(
        torch.nn.Linear(10, 10),
        torch.nn.ReLU(),
    ),
)

showModel(s)
s.eval()
print(s(torch.tensor([1.0] * 10)))

('0', Sequential(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): ReLU()
))
('1', Linear(in_features=10, out_features=10, bias=True))
('2', ReLU())
('3', Sequential(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): ReLU()
))
tensor([0.0951, 0.2140, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4059,
        0.0000], grad_fn=<ReluBackward0>)


## Current Implementation

Using the current implementation, we actually have a bug, because the recursive information does not properly propagate the `first_layer` down the call stack.
More specifically, no `Dropout` is inserted before layer `1`.

In [18]:
s2 = probly_dropout.Dropout(s, p=0.4)

showModel(s2)
s2.eval()
print(s2.predict_representation(torch.tensor([[1.0] * 10]), 2))

('model', Sequential(
  (0): Sequential(
    (0): Linear(in_features=10, out_features=10, bias=True)
    (1): ReLU()
  )
  (1): Linear(in_features=10, out_features=10, bias=True)
  (2): ReLU()
  (3): Sequential(
    (0): Sequential(
      (0): Dropout(p=0.4, inplace=False)
      (1): Linear(in_features=10, out_features=10, bias=True)
    )
    (1): ReLU()
  )
))
tensor([[[0.0954, 0.1076, 0.0954, 0.1024, 0.0982, 0.0954, 0.0954, 0.0954,
          0.1194, 0.0954],
         [0.1091, 0.1149, 0.0893, 0.0893, 0.0893, 0.0893, 0.0893, 0.0893,
          0.1508, 0.0893]]], grad_fn=<StackBackward0>)


## New Traverser-based Implementation
 
The same API can be realized using the extensible traverser approach.
The new implementation fixes the bug, flattens `Sequential` layers and is much more flexible in general (as we will show next).

In [15]:
s3 = probly_traverse_dropout.Dropout(s, p=0.4)

showModel(s3)
s3.eval()
print(s3.predict_representation(torch.tensor([[1.0] * 10]), 2))

('model', Sequential(
  (0_0): Linear(in_features=10, out_features=10, bias=True)
  (0_1): ReLU()
  (1_0): Dropout(p=0.4, inplace=False)
  (1_1): Linear(in_features=10, out_features=10, bias=True)
  (2): ReLU()
  (3_0_0): Dropout(p=0.4, inplace=False)
  (3_0_1): Linear(in_features=10, out_features=10, bias=True)
  (3_1): ReLU()
))
tensor([[[0.1093, 0.0939, 0.0939, 0.0939, 0.0939, 0.1055, 0.0939, 0.0939,
          0.1279, 0.0939],
         [0.0977, 0.1066, 0.0958, 0.0966, 0.0958, 0.0958, 0.0958, 0.0958,
          0.1241, 0.0958]]], grad_fn=<StackBackward0>)


To demonstrate the extensibility of the new approach, let's implement our own `Linear` layer, which we would like to be considered during rewriting.

In [19]:
class MyLinear(torch.nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.K = torch.nn.Parameter(
            torch.empty(out_features, in_features),
        )

    def forward(self, x):
        return x @ self.K


myS = torch.nn.Sequential(
    torch.nn.Sequential(
        torch.nn.Linear(10, 10),
        torch.nn.ReLU(),
    ),
    torch.nn.Linear(10, 10),
    torch.nn.ReLU(),
    torch.nn.Sequential(
        MyLinear(10, 10),
        torch.nn.ReLU(),
    ),
)


showModel(myS)

('0', Sequential(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): ReLU()
))
('1', Linear(in_features=10, out_features=10, bias=True))
('2', ReLU())
('3', Sequential(
  (0): MyLinear()
  (1): ReLU()
))


Using the current implementation, the custom linear layer is ignored:

In [21]:
myS2 = probly_dropout.Dropout(myS, p=0.4)

showModel(myS2)

('model', Sequential(
  (0): Sequential(
    (0): Linear(in_features=10, out_features=10, bias=True)
    (1): ReLU()
  )
  (1): Linear(in_features=10, out_features=10, bias=True)
  (2): ReLU()
  (3): Sequential(
    (0): MyLinear()
    (1): ReLU()
  )
))


Using the extensible approach, we can dynamically extend the implementation in user-space:

In [24]:
probly_traverse_dropout.register(MyLinear)  # Registers the new layer type

myS3 = probly_traverse_dropout.Dropout(myS, p=0.4)

showModel(myS3)

('model', Sequential(
  (0_0): Linear(in_features=10, out_features=10, bias=True)
  (0_1): ReLU()
  (1_0): Dropout(p=0.4, inplace=False)
  (1_1): Linear(in_features=10, out_features=10, bias=True)
  (2): ReLU()
  (3_0_0): Dropout(p=0.4, inplace=False)
  (3_0_1): MyLinear()
  (3_1): ReLU()
))


## Next Steps and Discussion

This is just a very basic demo of how an extensible API could look like.
The underlying traverser system is significantly more powerful.

Right now it uses a combination of different extensible recursive traversers for Python datastructures, Torch modules and problem specific replacement code.

It was designed with the goal of enabling a cross-framework implementation, i.e., the Same Dropout class could be adapted for JAX or TensorFlow.

A potential future API could then look something like this:
```python
my_torch = MyTorch()
my_jax = MyJAX()

import future.representation.dropout as dropout

my_torch_dropout = dropout.Dropout(my_torch) # Works
my_jax_dropout = dropout.Dropout(my_jax) # => Error: Unknown model type!

import future.representation.extension.jax

my_jax_dropout = dropout.Dropout(my_jax) # Works
```

The advantage of such an approach would be, that extensions of our package to entirely different settings could be provided by the community as external plugins.

The main risk with such an approach is reduced maintainablity due to complexity.
While some complexity cannot be avoided when going down this path, it is possible to hide this complexity behind a well-designed Traverser API which (ideally) would require little to no maintenance after being completed and tested.
How such an API should look like is still to be discussed.

One first alpha attempt can be found in the `traverse` (generic datastructure traversal) and `nn_traverse` (extensions for traversal of neural networks).
While not perfect, those modules showcase how extensibility could be achieved.