In [1]:
from src.models.utae_paps_models.utae import UTAE
import torchvision.models as models

resnet = models.resnet34(pretrained=True)

utae = UTAE(
            input_dim=7,
            encoder_widths=[64, 64, 64, 128],
            decoder_widths=[32, 32, 64, 128],
            out_conv=[32, 1],
            str_conv_k=4,
            str_conv_s=2,
            str_conv_p=1,
            agg_mode="att_group",
            encoder_norm="group",
            n_head=16,
            d_model=256,
            d_k=4,
            encoder=False,
            return_maps=False,
            pad_value=0,
            padding_mode="reflect",
        )

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
utae

UTAE(
  (in_conv): ConvBlock(
    (conv): ConvLayer(
      (conv): Sequential(
        (0): Conv2d(7, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
        (1): GroupNorm(4, 64, eps=1e-05, affine=True)
        (2): ReLU()
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
        (4): GroupNorm(4, 64, eps=1e-05, affine=True)
        (5): ReLU()
      )
    )
  )
  (down_blocks): ModuleList(
    (0-1): 2 x DownConvBlock(
      (down): ConvLayer(
        (conv): Sequential(
          (0): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
          (1): GroupNorm(4, 64, eps=1e-05, affine=True)
          (2): ReLU()
        )
      )
      (conv1): ConvLayer(
        (conv): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
          (1): GroupNorm(4, 64, eps=1e-05, affine=True)
          (2): ReLU()
        

In [3]:
resnet

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [4]:
print("ResNet layer1 conv1 shape:", resnet.layer1[0].conv1.weight.shape)
print("UTAE down_blocks[0] conv1 shape:", utae.down_blocks[0].conv1.conv[0].weight.shape)


ResNet layer1 conv1 shape: torch.Size([64, 64, 3, 3])
UTAE down_blocks[0] conv1 shape: torch.Size([64, 64, 3, 3])


In [5]:
import torch
import torch.nn as nn

def load_resnet_weights_to_utae(resnet, utae_model):
    loaded_layers = 0
    loaded_weights = []

    # Mapping resnet layers to UTAE down_blocks
    resnet_layer1 = resnet.layer1  # Equivalent to down_blocks[0] and down_blocks[1] in UTAE
    resnet_layer2 = resnet.layer2  # Equivalent to down_blocks[2] in UTAE

    # Transfer weights from ResNet's layer1 to UTAE down_blocks[0] and down_blocks[1]
    for i in range(2):  # down_blocks[0] and down_blocks[1]
        res_block = resnet_layer1[i]  # ResNet BasicBlock
        utae_block = utae_model.down_blocks[i]  # UTAE DownConvBlock

        # Skip the 'down' ConvLayer in UTAE since it's not equivalent
        # Load weights for conv1
        if res_block.conv1.weight.shape == utae_block.conv1.conv[0].weight.shape:
            utae_block.conv1.conv[0].weight.data = res_block.conv1.weight.data
            print(f"ResNet layer: layer1[{i}].conv1 --> UTAE layer: down_blocks[{i}].conv1")
            loaded_weights.append(f'down_blocks[{i}].conv1')
            loaded_layers += 1
            if hasattr(res_block.conv1, "bias") and res_block.conv1.bias is not None:
                utae_block.conv1.conv[0].bias.data = res_block.conv1.bias.data

        # Load weights for conv2
        if res_block.conv2.weight.shape == utae_block.conv2.conv[0].weight.shape:
            utae_block.conv2.conv[0].weight.data = res_block.conv2.weight.data
            print(f"ResNet layer: layer1[{i}].conv2 --> UTAE layer: down_blocks[{i}].conv2")
            loaded_weights.append(f'down_blocks[{i}].conv2')
            loaded_layers += 1
            if hasattr(res_block.conv2, "bias") and res_block.conv2.bias is not None:
                utae_block.conv2.conv[0].bias.data = res_block.conv2.bias.data

    # Transfer weights from ResNet's layer2 to UTAE down_blocks[2]
    res_block = resnet_layer2[0]  # First block in layer2
    utae_block = utae_model.down_blocks[2]  # UTAE DownConvBlock for layer2 equivalent

    # Skip the 'down' ConvLayer in UTAE since it has different stride and kernel size
    # Load weights for conv1 (64 -> 128)
    if res_block.conv1.weight.shape == utae_block.conv1.conv[0].weight.shape:
        utae_block.conv1.conv[0].weight.data = res_block.conv1.weight.data
        print(f"ResNet layer: layer2[0].conv1 --> UTAE layer: down_blocks[2].conv1")
        loaded_weights.append(f'down_blocks[2].conv1')
        loaded_layers += 1
        if hasattr(res_block.conv1, "bias") and res_block.conv1.bias is not None:
            utae_block.conv1.conv[0].bias.data = res_block.conv1.bias.data

    # Load weights for conv2 (128 -> 128)
    if res_block.conv2.weight.shape == utae_block.conv2.conv[0].weight.shape:
        utae_block.conv2.conv[0].weight.data = res_block.conv2.weight.data
        print(f"ResNet layer: layer2[0].conv2 --> UTAE layer: down_blocks[2].conv2")
        loaded_weights.append(f'down_blocks[2].conv2')
        loaded_layers += 1
        if hasattr(res_block.conv2, "bias") and res_block.conv2.bias is not None:
            utae_block.conv2.conv[0].bias.data = res_block.conv2.bias.data

    print(f"Total layers loaded: {loaded_layers}")
    print("Loaded weights:")
    for weight in loaded_weights:
        print(weight)

# Example usage:
# resnet = torchvision.models.resnet34(pretrained=True)
# load_resnet_weights_to_utae(resnet, utae_model)


In [11]:
# Initialize UTAE model
utae_model = UTAE(input_dim=7)

# Load pretrained ResNet
resnet = models.resnet34(pretrained=True)

# load weights:
load_resnet_weights_to_utae(resnet, utae)

ResNet layer: layer1[0].conv1 --> UTAE layer: down_blocks[0].conv1
ResNet layer: layer1[0].conv2 --> UTAE layer: down_blocks[0].conv2
ResNet layer: layer1[1].conv1 --> UTAE layer: down_blocks[1].conv1
ResNet layer: layer1[1].conv2 --> UTAE layer: down_blocks[1].conv2
ResNet layer: layer2[0].conv1 --> UTAE layer: down_blocks[2].conv1
ResNet layer: layer2[0].conv2 --> UTAE layer: down_blocks[2].conv2
Total layers loaded: 6
Loaded weights:
down_blocks[0].conv1
down_blocks[0].conv2
down_blocks[1].conv1
down_blocks[1].conv2
down_blocks[2].conv1
down_blocks[2].conv2


In [9]:
import torch
random_input = torch.randn(1, 5, 7, 128, 128)  
batch_positions = torch.arange(5).unsqueeze(0).repeat(1, 1) 

utae_model.eval()  
with torch.no_grad(): 
    output = utae_model(random_input, batch_positions)

print("Output shape:", output.shape)

Output shape: torch.Size([1, 1, 128, 128])


In [13]:
def calculate_loaded_percentage_in_encoder(resnet, utae_model):
    total_params = 0
    loaded_params = 0

    # Mapping resnet layers to UTAE down_blocks
    resnet_layer1 = resnet.layer1  # Equivalent to down_blocks[0] and down_blocks[1] in UTAE
    resnet_layer2 = resnet.layer2  # Equivalent to down_blocks[2] in UTAE

    # For down_blocks[0] and down_blocks[1], matching ResNet's layer1
    for i in range(2):  # down_blocks[0] and down_blocks[1]
        res_block = resnet_layer1[i]  # ResNet BasicBlock
        utae_block = utae_model.down_blocks[i]  # UTAE DownConvBlock

        # Total parameters in UTAE conv1 and conv2
        total_params += utae_block.conv1.conv[0].weight.numel()
        total_params += utae_block.conv2.conv[0].weight.numel()

        # Check if ResNet weights can be loaded
        if res_block.conv1.weight.shape == utae_block.conv1.conv[0].weight.shape:
            loaded_params += res_block.conv1.weight.numel()
        if res_block.conv2.weight.shape == utae_block.conv2.conv[0].weight.shape:
            loaded_params += res_block.conv2.weight.numel()

    # For down_blocks[2], matching ResNet's layer2
    res_block = resnet_layer2[0]  # First block in layer2
    utae_block = utae_model.down_blocks[2]  # UTAE DownConvBlock for layer2 equivalent

    # Total parameters in UTAE conv1 and conv2
    total_params += utae_block.conv1.conv[0].weight.numel()
    total_params += utae_block.conv2.conv[0].weight.numel()

    # Check if ResNet weights can be loaded
    if res_block.conv1.weight.shape == utae_block.conv1.conv[0].weight.shape:
        loaded_params += res_block.conv1.weight.numel()
    if res_block.conv2.weight.shape == utae_block.conv2.conv[0].weight.shape:
        loaded_params += res_block.conv2.weight.numel()

    # Compute the percentage of weights loaded
    percentage_loaded = (loaded_params / total_params) * 100
    print(f"Total weights in encoder: {total_params}")
    print(f"Loaded weights in encoder: {loaded_params}")
    print(f"Percentage of weights loaded in encoder: {percentage_loaded:.2f}%")

# Example usage:
# resnet = torchvision.models.resnet34(pretrained=True)
# calculate_loaded_percentage_in_encoder(resnet, utae_model)


In [15]:
calculate_loaded_percentage_in_encoder(resnet, utae_model)

Total weights in encoder: 368640
Loaded weights in encoder: 368640
Percentage of weights loaded in encoder: 100.00%
