# Flare Model Variations
***
*(by number of parameters)*

In [1]:
import torch
import torch.nn

import context
from models.flare import Flare, FlareLuma, FlareChroma
from models.transforms import InverseDCT
from models.arcnn import ARCNN

In [2]:
idct = InverseDCT()

y = torch.randn(1, 1, 128, 128)
dct_y = torch.randn(1, 64, 16, 16)
qt_y = torch.randn(1, 64)

In [6]:
def num_params(net: torch.nn.Module):
    total = 0
    for param in net.parameters():
        total += param.data.numel()
    return total

*Depthwise Separable Convolution shows weak results*

### Tiny: 1M Luma params, 1M Chroma params

In [7]:
luma = FlareLuma(idct, base_channels=16, blocks_per_stage=1, channel_multiplier=2, depthwise_separable=False)
chroma = FlareChroma(idct, base_channels=16, blocks_per_stage=1, channel_multiplier=2, depthwise_separable=False)

num_params(luma), num_params(chroma)

(1013624, 1023448)

In [8]:
%timeit luma(y, dct_y, qt_y)

16.3 ms ± 273 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [9]:
del luma
del chroma

### Small: 4M Luma params, 1M Chroma params

In [10]:
luma = FlareLuma(idct, base_channels=32, blocks_per_stage=1, channel_multiplier=2, depthwise_separable=False)
chroma = FlareChroma(idct, base_channels=32, blocks_per_stage=1, channel_multiplier=1.75, depthwise_separable=False)

num_params(luma), num_params(chroma)

(3973680, 2001036)

In [11]:
%timeit luma(y, dct_y, qt_y)

30.2 ms ± 177 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [12]:
del luma
del chroma

### Medium: 7M Luma params, 4M Chroma params

In [13]:
luma = FlareLuma(idct, base_channels=48, blocks_per_stage=1, channel_multiplier=2, depthwise_separable=False)
chroma = FlareChroma(idct, base_channels=32, blocks_per_stage=1, channel_multiplier=2, depthwise_separable=False)

num_params(luma), num_params(chroma)

(8880232, 3993264)

In [14]:
%timeit luma(y, dct_y, qt_y)

47.6 ms ± 868 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [15]:
del luma
del chroma

### Large: 15M Luma params, 8M Chroma params

In [16]:
luma = FlareLuma(idct, base_channels=64, blocks_per_stage=1, channel_multiplier=2.0, depthwise_separable=False)
chroma = FlareChroma(idct, base_channels=64, blocks_per_stage=1, channel_multiplier=1.75, depthwise_separable=False)

num_params(luma), num_params(chroma)

(15733280, 7901456)

In [17]:
%timeit luma(y, dct_y, qt_y)

68.9 ms ± 753 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [18]:
del luma
del chroma