In [3]:
import kornia
import torch
import torch.nn as nn
import kornia.contrib as K
 
img = torch.rand(1, 3, 256, 256)
mvit = K.MobileViT(mode='xxs')
out = mvit(img)

In [9]:
mvit

MobileViT(
  (conv1): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): SiLU()
  )
  (mv2): ModuleList(
    (0): MV2Block(
      (conv): Sequential(
        (0): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU()
        (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): SiLU()
        (6): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (7): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): MV2Block(
      (conv): Sequential(
        (0): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [10]:

# Access the first convolutional layer
first_conv_layer = mvit.conv1[0]

# Modify the first conv layer to have 12 input channels instead of 3
# Keep other parameters the same
new_first_conv_layer = nn.Conv2d(12, first_conv_layer.out_channels, 
                                 kernel_size=first_conv_layer.kernel_size, 
                                 stride=first_conv_layer.stride, 
                                 padding=first_conv_layer.padding, 
                                 bias=first_conv_layer.bias is not None)

# Replace the original first conv layer with the new one
mvit.conv1[0] = new_first_conv_layer

In [12]:
# Define your model (up to the layer before nn.Linear)
model = nn.Sequential(
    mvit,  # Your mvit model
    nn.AvgPool2d(128 // 32, 1),
    nn.Flatten()
)


# Create a dummy input tensor
dummy_input = torch.randn(1, 12, 128, 128)  # [batch_size, channels, height, width]

# Forward the dummy input through the model
output = model(dummy_input)
print(out.shape)

torch.Size([1, 320, 8, 8])
