# Weight Selection

In [None]:
import torch
import timm

In [None]:
def uniform_element_selection(wt, s_shape):
    assert wt.dim() == len(s_shape), "Tensors have different number of dimensions"
    ws = wt.clone()
    for dim in range(wt.dim()):
        assert wt.shape[dim] >= s_shape[dim], "Teacher's dimension should not be smaller than student's dimension"  # determine whether teacher is larger than student on this dimension
        if wt.shape[dim] % s_shape[dim] == 0:
            step = wt.shape[dim] // s_shape[dim]
            indices = torch.arange(s_shape[dim]) * step
        else:
            indices = torch.round(torch.linspace(0, wt.shape[dim], s_shape[dim]))
        ws = torch.index_select(ws, dim, indices)
    assert ws.shape == s_shape
    return ws

In [None]:
# ViT-T weight selection from ImageNet-21K pretrained ViT-S
teacher = timm.create_model('vit_small_patch16_224_in21k', pretrained=True)
teacher_weights = teacher.state_dict()
from models.vision_transformer import vit_tiny
student = vit_tiny()
student_weights = student.state_dict()
# ConvNeXt-F weight seletion from ImageNet-21K pretrained ConvNeXt-T
# Uncomment below for ConvNeXt
# teacher = timm.create_model('convnext_tiny_in22k', pretrained=True)
# teacher_weights = teacher.state_dict()
# from models.convnext import convnext_femto
# student = convnext_femto()
# student_weights = student.state_dict()

In [None]:
weight_selection = {}
for key in student_weights.keys():
    # We don't perform weight selection on classification head by default. Remove this constraint if target dataset is the same as teacher's.
    if "head" in key:
        continue
    # First-N layer selection is implicitly applied here
    weight_selection[key] = uniform_element_selection(teacher_weights[key], student_weights[key].shape)

In [None]:
torch.save({'model':weight_selection}, "weight_selection.pth")

# Weight Interpolation/Pooling

In [1]:
import torch
import torch.nn.functional as F
import timm
from models.vision_transformer import vit_tiny

In [5]:
def weight_interpolate(weight, shape, mode):
    s_len = len(shape)
    if s_len == 4:
        weight = weight.permute(3, 2, 1, 0)
        shape = shape[::-1]
    else:
        shape = torch.Size([1, shape[0]]) if s_len == 1 else shape
        while len(weight.shape) < 4:
            weight = weight.unsqueeze(0)

    ws = F.interpolate(weight, size=shape[-2:], mode=mode, align_corners=None)

    if s_len == 4:
        ws = ws.permute(3, 2, 1, 0)
    else:
        while len(ws.shape) > s_len:
            ws = ws.squeeze(0)    
    return ws

def weight_pooling(weight, shape):
    s_len = len(shape)
    if s_len == 4:
        weight = weight.permute(3, 2, 1, 0)
        shape = shape[::-1]
    else:
        shape = torch.Size([1, shape[0]]) if s_len == 1 else shape
        while len(weight.shape) < 4:
            weight = weight.unsqueeze(0)

    ws = F.interpolate(weight, size=shape[-2:], mode='bilinear', align_corners=None)
    ws = F.max_pool1d(ws, ws.shape[-1], stride=1).squeeze(-1)

    if s_len == 4:
        ws = ws.permute(3, 2, 1, 0)
    else:
        while len(ws.shape) > s_len:
            ws = ws.squeeze(0)    
    return ws

In [6]:
teacher = timm.create_model('vit_small_patch16_224_in21k', pretrained=True)
teacher_weights = teacher.state_dict()
student = vit_tiny()
student_weights = student.state_dict()

In [7]:
mode = 'nearest'
weight_selection = {}
for key in student_weights.keys():
    # We don't perform weight selection on classification head by default. Remove this constraint if target dataset is the same as teacher's.
    if "head" in key:
        continue
    # First-N layer selection is implicitly applied here
    print(key)
    print(teacher_weights[key].shape, student_weights[key].shape)
    weight_selection[key] = weight_interpolate(teacher_weights[key], student_weights[key].shape, mode=mode)
    print('Mean: ', teacher_weights[key].mean(), weight_selection[key].mean())
    print('Std: ', teacher_weights[key].std(), weight_selection[key].std())
    print('Max: ', teacher_weights[key].max(), weight_selection[key].max())
    print('Min: ', teacher_weights[key].min(), weight_selection[key].min())
    print('')

cls_token
torch.Size([1, 1, 384]) torch.Size([1, 1, 192])
Mean:  tensor(-0.0312) tensor(-0.0271)
Std:  tensor(0.6031) tensor(0.7755)
Max:  tensor(3.5577) tensor(3.5577)
Min:  tensor(-6.5319) tensor(-6.5319)

pos_embed
torch.Size([1, 197, 384]) torch.Size([1, 197, 192])
Mean:  tensor(-0.0014) tensor(-0.0015)
Std:  tensor(0.4151) tensor(0.4453)
Max:  tensor(4.7052) tensor(4.7052)
Min:  tensor(-6.5487) tensor(-6.5487)

patch_embed.proj.weight
torch.Size([384, 3, 16, 16]) torch.Size([192, 3, 16, 16])
Mean:  tensor(3.5103e-05) tensor(5.8455e-05)
Std:  tensor(0.0465) tensor(0.0459)
Max:  tensor(0.5049) tensor(0.5049)
Min:  tensor(-0.4606) tensor(-0.4606)

patch_embed.proj.bias
torch.Size([384]) torch.Size([192])
Mean:  tensor(-0.0030) tensor(-0.0198)
Std:  tensor(0.2306) tensor(0.2796)
Max:  tensor(0.9889) tensor(0.8221)
Min:  tensor(-2.3685) tensor(-2.3685)

blocks.0.norm1.weight
torch.Size([384]) torch.Size([192])
Mean:  tensor(0.4652) tensor(0.4612)
Std:  tensor(0.2491) tensor(0.2460)
Max

In [None]:
wt = torch.rand((384, 3, 16, 16))
shape = torch.Size([192, 3, 16, 16])

In [None]:
wt = torch.rand((1, 1, 384))
shape = torch.Size([1, 1, 192])

In [None]:
wt = torch.rand((1152, 384))
shape = torch.Size([576, 192])

In [None]:
wt = torch.rand((384))
shape = torch.Size([192])

In [None]:
ws = weight_interpolate(wt, shape)
ws.shape