### Level switching

Flex modules behave in much the same way as regular modules, but with the added functionality to change the level. So for a regular Conv2d, you might do this to create and run it:

In [None]:
from torch import nn
import torch

# creates a conv2d which takes 3 channels as input and gives 20 channels output
conv = nn.Conv2d(3, 20, kernel_size=3)
# creates a batch of 10 images with 3 channels of 100 by 100 pixels
x = torch.rand((10, 3, 100, 100))
y = conv(x)
print(y.shape)

To make and run a flexible conv2d:

In [None]:
import flex_modules as fm

# This creates a flexible conv2d with 3 levels, where the
# 1st, 2nd, and 3rd take 1, 2, and 3 channels of input and
# give 10, 15, and 20 channels output.
conv = fm.Conv2d([1, 2, 3], [10, 15, 20], kernel_size=3)

# set the conv to the maximum level, which takes 3 channels
conv.set_level_use(conv.max_level())
x = torch.rand((10, 3, 100, 100))
y = conv(x)
print(y.shape)


In [None]:
# We can set it to a lower level, which only takes 2 channels of input, but
# if we try to pass the same input of 3 channels to it, it will error.

conv.set_level_use(1)
try:
    y = conv(x)
except Exception as e:
    print(e)

# This is why the first layer of any flexible model should always
# have the same input dimensions, but output dimensions can differ.
first_conv = fm.Conv2d([3, 3, 3], [10, 15, 20], kernel_size=3)

for i in range(0, first_conv.max_level() + 1):
    first_conv.set_level_use(i)
    y = first_conv(x)
    print(first_conv.current_level())
    print(y.shape)


This same system is also applicable to flexible models:

In [None]:
from networks import flexvit
import utils

# make flexible visual transformer with default config of 2 levels
model = flexvit.ViTConfig().make_model().to(utils.get_device())
x = torch.rand(10, 3, 224, 224).to(utils.get_device())

for i in range(0, model.max_level() + 1):
    model.set_level_use(i)
    y = model(x)
    print(model.current_level())
    print(y.shape)

### Copying

Being able to copy between regular layers and flexible layers can be very useful.

In [None]:
def randomize_nets(*nets: nn.Module):
    for net in nets:
        for p in net.parameters():
            p.data[:] = torch.rand(*p.shape)

reg = nn.Conv2d(10, 20, kernel_size=3)
flex = fm.Conv2d([5, 7, 10], [10, 15, 20], kernel_size=3)
x = torch.rand(10, 10, 20, 20)

flex.set_level_use(2)

# copying from regular to flexible
randomize_nets(flex, reg)
flex.load_from_base(reg)
assert(torch.isclose(flex(x), reg(x)).all())

# copying from flexible to regular
randomize_nets(flex, reg)
flex.copy_to_base(reg)
assert(torch.isclose(flex(x), reg(x)).all())

# creating a new base copy
randomize_nets(flex, reg)
reg2 = flex.make_base_copy()
assert(torch.isclose(flex(x), reg2(x)).all())

To copy entire networks is also possible with utils.flexible_model_copy()

In [None]:
from networks import vit

reg = vit.ViTConfig().make_model().to(utils.get_device())
flex = flexvit.ViTConfig().make_model().to(utils.get_device())
x = torch.rand(10, 3, 224, 224).to(utils.get_device())

# copying from regular to flexible
randomize_nets(flex, reg)
utils.flexible_model_copy(reg, flex)
assert(torch.isclose(flex(x), reg(x)).all())

# copying from flexible to regular
randomize_nets(reg, flex)
utils.flexible_model_copy(flex, reg)
assert(torch.isclose(flex(x), reg(x)).all())
