# Testing splits from timm models in fastai
> Requieres `timm` from master

In [269]:
import timm
from timm.models.helpers import group_modules
from fastai.vision.all import *

In [270]:
arch = "resnet50"
m = timm.create_model(arch)
modules_names = group_modules(m, m.group_matcher(coarse=True))

In [271]:
def get_module_names(m):
    modules_names = group_modules(m, m.group_matcher(coarse=True))
    return list(modules_names.values())[:-1] #cut head

In [272]:
module_names = get_module_names(m)
module_names

[['conv1', 'bn1'],
 ['layer1.0.conv1',
  'layer1.0.bn1',
  'layer1.0.conv2',
  'layer1.0.bn2',
  'layer1.0.conv3',
  'layer1.0.bn3',
  'layer1.0.downsample.0',
  'layer1.0.downsample.1',
  'layer1.1.conv1',
  'layer1.1.bn1',
  'layer1.1.conv2',
  'layer1.1.bn2',
  'layer1.1.conv3',
  'layer1.1.bn3',
  'layer1.2.conv1',
  'layer1.2.bn1',
  'layer1.2.conv2',
  'layer1.2.bn2',
  'layer1.2.conv3',
  'layer1.2.bn3'],
 ['layer2.0.conv1',
  'layer2.0.bn1',
  'layer2.0.conv2',
  'layer2.0.bn2',
  'layer2.0.conv3',
  'layer2.0.bn3',
  'layer2.0.downsample.0',
  'layer2.0.downsample.1',
  'layer2.1.conv1',
  'layer2.1.bn1',
  'layer2.1.conv2',
  'layer2.1.bn2',
  'layer2.1.conv3',
  'layer2.1.bn3',
  'layer2.2.conv1',
  'layer2.2.bn1',
  'layer2.2.conv2',
  'layer2.2.bn2',
  'layer2.2.conv3',
  'layer2.2.bn3',
  'layer2.3.conv1',
  'layer2.3.bn1',
  'layer2.3.conv2',
  'layer2.3.bn2',
  'layer2.3.conv3',
  'layer2.3.bn3'],
 ['layer3.0.conv1',
  'layer3.0.bn1',
  'layer3.0.conv2',
  'layer3.0.bn2

In [273]:
def get_layers_from_names(module_names, m):
    layers = set()
    for l_name in L(module_names).concat():
        if "." not in l_name:
            layers.add(getattr(m, l_name))
        else:
            first_level_name = l_name.split(".")[0]
            layers.add(getattr(m, first_level_name))
    return L(layers)

In [274]:
get_layers_from_names(module_names, m)

(#6) [Sequential(
  (0): Bottleneck(
    (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (drop_block): Identity()
    (act2): ReLU(inplace=True)
    (aa): Identity()
    (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act3): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): Bottleneck(
    (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), b

## Now integrating with fastai:

In [336]:
timm.list_models("*regnet*")

['haloregnetz_b',
 'nf_regnet_b0',
 'nf_regnet_b1',
 'nf_regnet_b2',
 'nf_regnet_b3',
 'nf_regnet_b4',
 'nf_regnet_b5',
 'regnetv_040',
 'regnetv_064',
 'regnetx_002',
 'regnetx_004',
 'regnetx_006',
 'regnetx_008',
 'regnetx_016',
 'regnetx_032',
 'regnetx_040',
 'regnetx_064',
 'regnetx_080',
 'regnetx_120',
 'regnetx_160',
 'regnetx_320',
 'regnety_002',
 'regnety_004',
 'regnety_006',
 'regnety_008',
 'regnety_016',
 'regnety_032',
 'regnety_040',
 'regnety_040s_gn',
 'regnety_064',
 'regnety_080',
 'regnety_120',
 'regnety_160',
 'regnety_320',
 'regnetz_005',
 'regnetz_040',
 'regnetz_040h',
 'regnetz_b16',
 'regnetz_b16_evos',
 'regnetz_c16',
 'regnetz_c16_evos',
 'regnetz_d8',
 'regnetz_d8_evos',
 'regnetz_d32',
 'regnetz_e8']

In [306]:
r50 = create_timm_model("resnet50", 20, pretrained=False)

In [307]:
vit = create_timm_model("vit_base_patch16_224", 20, pretrained=False)

In [329]:
convnext = create_timm_model("convnext_base_in22k", 20, pretrained=False)

In [317]:
def timm_split(m, cut=-1):
    body, head = m[0].model, m[1]
    module_names = get_module_names(body)
    groups = L(module_names[0:cut], module_names[cut:]).map(partial(get_layers_from_names, m=body))
    return [g.map(params).concat() for g in groups] + [params(head)]

Here it's arbitrary the value of cut, but for resnets we recover the same defaults as fastai using -1

In [318]:
def timm_default_split(m): return timm_split(m)
def timm_resnet_split(m):  return timm_split(m, cut=-1)
def timm_vit_split(m):     return timm_split(m, cut=-3) # maybe more, don't know....

In [323]:
get_module_names(r50[0].model)

[['conv1', 'bn1'],
 ['layer1.0.conv1',
  'layer1.0.bn1',
  'layer1.0.conv2',
  'layer1.0.bn2',
  'layer1.0.conv3',
  'layer1.0.bn3',
  'layer1.0.downsample.0',
  'layer1.0.downsample.1',
  'layer1.1.conv1',
  'layer1.1.bn1',
  'layer1.1.conv2',
  'layer1.1.bn2',
  'layer1.1.conv3',
  'layer1.1.bn3',
  'layer1.2.conv1',
  'layer1.2.bn1',
  'layer1.2.conv2',
  'layer1.2.bn2',
  'layer1.2.conv3',
  'layer1.2.bn3'],
 ['layer2.0.conv1',
  'layer2.0.bn1',
  'layer2.0.conv2',
  'layer2.0.bn2',
  'layer2.0.conv3',
  'layer2.0.bn3',
  'layer2.0.downsample.0',
  'layer2.0.downsample.1',
  'layer2.1.conv1',
  'layer2.1.bn1',
  'layer2.1.conv2',
  'layer2.1.bn2',
  'layer2.1.conv3',
  'layer2.1.bn3',
  'layer2.2.conv1',
  'layer2.2.bn1',
  'layer2.2.conv2',
  'layer2.2.bn2',
  'layer2.2.conv3',
  'layer2.2.bn3',
  'layer2.3.conv1',
  'layer2.3.bn1',
  'layer2.3.conv2',
  'layer2.3.bn2',
  'layer2.3.conv3',
  'layer2.3.bn3'],
 ['layer3.0.conv1',
  'layer3.0.bn1',
  'layer3.0.conv2',
  'layer3.0.bn2

In [319]:
groups = timm_resnet_split(r50)
len(groups)

3

In [320]:
get_module_names(vit[0].model)

[['patch_embed.proj'],
 ['blocks.0.norm1',
  'blocks.0.attn.qkv',
  'blocks.0.attn.proj',
  'blocks.0.norm2',
  'blocks.0.mlp.fc1',
  'blocks.0.mlp.fc2'],
 ['blocks.1.norm1',
  'blocks.1.attn.qkv',
  'blocks.1.attn.proj',
  'blocks.1.norm2',
  'blocks.1.mlp.fc1',
  'blocks.1.mlp.fc2'],
 ['blocks.2.norm1',
  'blocks.2.attn.qkv',
  'blocks.2.attn.proj',
  'blocks.2.norm2',
  'blocks.2.mlp.fc1',
  'blocks.2.mlp.fc2'],
 ['blocks.3.norm1',
  'blocks.3.attn.qkv',
  'blocks.3.attn.proj',
  'blocks.3.norm2',
  'blocks.3.mlp.fc1',
  'blocks.3.mlp.fc2'],
 ['blocks.4.norm1',
  'blocks.4.attn.qkv',
  'blocks.4.attn.proj',
  'blocks.4.norm2',
  'blocks.4.mlp.fc1',
  'blocks.4.mlp.fc2'],
 ['blocks.5.norm1',
  'blocks.5.attn.qkv',
  'blocks.5.attn.proj',
  'blocks.5.norm2',
  'blocks.5.mlp.fc1',
  'blocks.5.mlp.fc2'],
 ['blocks.6.norm1',
  'blocks.6.attn.qkv',
  'blocks.6.attn.proj',
  'blocks.6.norm2',
  'blocks.6.mlp.fc1',
  'blocks.6.mlp.fc2'],
 ['blocks.7.norm1',
  'blocks.7.attn.qkv',
  'blocks.

In [321]:
groups = timm_vit_split(vit)
len(groups)

3

In [322]:
timm.list_models("*vit*")

['convit_base',
 'convit_small',
 'convit_tiny',
 'crossvit_9_240',
 'crossvit_9_dagger_240',
 'crossvit_15_240',
 'crossvit_15_dagger_240',
 'crossvit_15_dagger_408',
 'crossvit_18_240',
 'crossvit_18_dagger_240',
 'crossvit_18_dagger_408',
 'crossvit_base_240',
 'crossvit_small_240',
 'crossvit_tiny_240',
 'levit_128',
 'levit_128s',
 'levit_192',
 'levit_256',
 'levit_256d',
 'levit_384',
 'mobilevit_s',
 'mobilevit_xs',
 'mobilevit_xxs',
 'semobilevit_s',
 'vit_base_patch8_224',
 'vit_base_patch8_224_dino',
 'vit_base_patch8_224_in21k',
 'vit_base_patch16_18x2_224',
 'vit_base_patch16_224',
 'vit_base_patch16_224_dino',
 'vit_base_patch16_224_in21k',
 'vit_base_patch16_224_miil',
 'vit_base_patch16_224_miil_in21k',
 'vit_base_patch16_224_sam',
 'vit_base_patch16_384',
 'vit_base_patch16_plus_240',
 'vit_base_patch16_rpn_224',
 'vit_base_patch32_224',
 'vit_base_patch32_224_in21k',
 'vit_base_patch32_224_sam',
 'vit_base_patch32_384',
 'vit_base_patch32_plus_256',
 'vit_base_r26_s32

In [327]:
for arch in timm.list_models("*res*"):
    print(f"Testing {arch}:")
    timm_model = create_timm_model(arch, 20, pretrained=False)
    if "vit" in arch:
        groups = timm_vit_split(timm_model)
    else:
        groups = timm_resnet_split(timm_model)

Testing bat_resnext26ts:
Testing cspresnet50:
Testing cspresnet50d:
Testing cspresnet50w:
Testing cspresnext50:
Testing cspresnext50_iabn:
Testing dla60_res2net:
Testing dla60_res2next:
Testing eca_resnet33ts:
Testing eca_resnext26ts:
Testing ecaresnet26t:
Testing ecaresnet50d:
Testing ecaresnet50d_pruned:
Testing ecaresnet50t:
Testing ecaresnet101d:
Testing ecaresnet101d_pruned:
Testing ecaresnet200d:
Testing ecaresnet269d:
Testing ecaresnetlight:
Testing ecaresnext26t_32x4d:
Testing ecaresnext50t_32x4d:
Testing ens_adv_inception_resnet_v2:
Testing gcresnet33ts:
Testing gcresnet50t:
Testing gcresnext26ts:
Testing gcresnext50ts:
Testing gluon_resnet18_v1b:
Testing gluon_resnet34_v1b:
Testing gluon_resnet50_v1b:
Testing gluon_resnet50_v1c:
Testing gluon_resnet50_v1d:
Testing gluon_resnet50_v1s:
Testing gluon_resnet101_v1b:
Testing gluon_resnet101_v1c:
Testing gluon_resnet101_v1d:
Testing gluon_resnet101_v1s:
Testing gluon_resnet152_v1b:
Testing gluon_resnet152_v1c:
Testing gluon_resnet1