### Level switching

Flex modules behave in much the same way as regular modules, but with the added functionality to change the level. So for a regular `Conv2d`, you might do this to create and run it:

In [1]:
from torch import nn
import torch

# creates a conv2d which takes 3 channels as input and gives 20 channels output
conv = nn.Conv2d(3, 20, kernel_size=3)
# creates a batch of 10 images with 3 channels of 100 by 100 pixels
x = torch.rand((10, 3, 100, 100))
y = conv(x)
print(y.shape)

torch.Size([10, 20, 98, 98])


To make and run a flexible `Conv2d`:

In [2]:
import flex_modules as fm

# This creates a flexible conv2d with 3 levels, where the
# 1st, 2nd, and 3rd take 1, 2, and 3 channels of input and
# give 10, 15, and 20 channels output.
conv = fm.Conv2d([1, 2, 3], [10, 15, 20], kernel_size=3)

# set the conv to the maximum level, which takes 3 channels
conv.set_level_use(conv.max_level())
x = torch.rand((10, 3, 100, 100))
y = conv(x)
print(y.shape)


torch.Size([10, 20, 98, 98])


In [3]:
# We can set it to a lower level, which only takes 2 channels of input, but
# if we try to pass the same input of 3 channels to it, it will error.

conv.set_level_use(1)
try:
    y = conv(x)
except Exception as e:
    print(e)

# This is why the first layer of any flexible model should always
# have the same input dimensions, but output dimensions can differ.
first_conv = fm.Conv2d([3, 3, 3], [10, 15, 20], kernel_size=3)

for i in range(0, first_conv.max_level() + 1):
    first_conv.set_level_use(i)
    y = first_conv(x)
    print(first_conv.current_level())
    print(y.shape)


Given groups=1, weight of size [20, 3, 3, 3], expected input[10, 4, 100, 100] to have 3 channels, but got 4 channels instead
0
torch.Size([10, 10, 98, 98])
1
torch.Size([10, 15, 98, 98])
2
torch.Size([10, 20, 98, 98])


This same system is also applicable to flexible models, because flexible models actually also implement the `fm.Module` interface.

In [4]:
from networks import flexvit
import utils

# make flexible visual transformer with default config of 2 levels
FLEXVIT_CONFIG = flexvit.ViTConfig()

model = FLEXVIT_CONFIG.make_model().to(utils.get_device())
x = torch.rand(10, 3, 224, 224).to(utils.get_device())

for i in range(0, model.max_level() + 1):
    model.set_level_use(i)
    y = model(x)
    print(model.current_level())
    print(y.shape)

0
torch.Size([10, 1000])
1
torch.Size([10, 1000])
2
torch.Size([10, 1000])


### Copying

Being able to copy between regular layers and flexible layers can be very useful.

In [5]:
def randomize_nets(*nets: nn.Module):
    for net in nets:
        for p in net.parameters():
            p.data[:] = torch.rand(*p.shape)

reg = nn.Conv2d(10, 20, kernel_size=3)
flex = fm.Conv2d([5, 7, 10], [10, 15, 20], kernel_size=3)
x = torch.rand(10, 10, 20, 20)

flex.set_level_use(2)

# copying from regular to flexible
randomize_nets(flex, reg)
flex.load_from_base(reg)
assert(torch.isclose(flex(x), reg(x)).all())

# copying from flexible to regular
randomize_nets(flex, reg)
flex.copy_to_base(reg)
assert(torch.isclose(flex(x), reg(x)).all())

# creating a new base copy
randomize_nets(flex, reg)
reg2 = flex.make_base_copy()
assert(torch.isclose(flex(x), reg2(x)).all())

Copying networks and modules can also be done using `utils.flexible_model_copy`

In [6]:
from networks import vit

reg = vit.ViTConfig().make_model().to(utils.get_device())
flex = flexvit.ViTConfig().make_model().to(utils.get_device())
x = torch.rand(10, 3, 224, 224).to(utils.get_device())

# copying from regular to flexible
randomize_nets(flex, reg)
utils.flexible_model_copy(reg, flex)
assert(torch.isclose(flex(x), reg(x)).all())

# copying from flexible to regular
randomize_nets(reg, flex)
utils.flexible_model_copy(flex, reg)
assert(torch.isclose(flex(x), reg(x)).all())


### Level Deltas
To quickly load weights from different levels in deployment, the level delta system is created.

In [7]:
model = FLEXVIT_CONFIG.make_model().to(utils.get_device())
model.set_level_use(1)

reg_model = model.make_base_copy()
print(utils.count_parameters(reg_model))

model.set_level_use(2)
down_delta, up_delta = model.export_level_delta()

# Now we apply the updelta from level 2 to a level 1 model, which makes it level 2
up_delta.apply(reg_model)
print(utils.count_parameters(reg_model))

38837736
86567656


In [8]:
model.set_level_use(1)
down_delta, up_delta = model.export_level_delta()

print(utils.count_parameters(reg_model))

# Similarly we can apply the downdelta from level 1 to a level 2 model to bring
# it back to level 1
down_delta.apply(reg_model)
print(utils.count_parameters(reg_model))

86567656
38837736


### Delta managers

To make the managing of deltas simpler the delta manager is created.

In [9]:
import networks.level_delta_utils as delta

model.set_level_use(0)
manager = delta.InMemoryDeltaManager(model)
reg_model = manager.managed_model()

print(utils.count_parameters(reg_model))

# We can now set the level of the regular model
manager.move_to(2)
print(utils.count_parameters(reg_model))

manager.move_to(1)
print(utils.count_parameters(reg_model))

manager.move_to(0)
print(utils.count_parameters(reg_model))

manager.move_to(2)
print(utils.count_parameters(reg_model))

22050664
86567656
38837736
22050664
86567656


In [11]:
import os

# Instead of having these deltas in memory all the time, there is also delta files


DELTA_FILENAME = "vit.delta"

# first create this delta file
with open(DELTA_FILENAME, "wb") as file:
    delta.FileDeltaManager.make_delta_file(file, model, starting_level=0)

reg_config = FLEXVIT_CONFIG.create_base_config(0).no_prebuilt()
with delta.file_delta_manager(DELTA_FILENAME, reg_config) as manager:
    reg_model = manager.move_to(0)

    print(utils.count_parameters(reg_model))

    # We can now set the level of the regular model
    reg_model = manager.move_to(2)
    print(utils.count_parameters(reg_model))

    reg_model = manager.move_to(1)
    print(utils.count_parameters(reg_model))

    reg_model = manager.move_to(0)
    print(utils.count_parameters(reg_model))

    reg_model = manager.move_to(2)
    print(utils.count_parameters(reg_model))

print()

# There is some overhead in using delta files, but not a lot
print(f"Delta file size: {os.path.getsize(DELTA_FILENAME) / 1000000: .2f} MB")
print(
    f"Flexible model size: {utils.model_size_in_mb(model) / 1000000: .2f} MB")
print(
    f"Regular model size: {utils.model_size_in_mb(reg_model) / 1000000: .2f} MB")

if os.path.exists(DELTA_FILENAME):
    os.remove(DELTA_FILENAME)

22050664
86567656
38837736
22050664
86567656

Delta file size:  354.03 MB
Flexible model size:  350.14 MB
Regular model size:  346.33 MB
