In [12]:
import torch
from torchvision import models

load RN50 for simple SAR-optical fusion

In [13]:
class LateFusionModel(torch.nn.Module):
    def __init__(self,da=False):
        super().__init__()
        self.net1 = models.resnet50(pretrained=False)
        self.net1.conv1 = torch.nn.Conv2d(2, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.net1.fc = torch.nn.Identity()
        self.net2 = models.resnet50(pretrained=False)
        self.net2.conv1 = torch.nn.Conv2d(13, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.net2.fc = torch.nn.Identity() 
        self.ffc = torch.nn.Linear(4096,19)    
        self.da = da
        if self.da:
            from models.dat.dat_blocks import DAttentionBaseline

            self.da1_l3 = DAttentionBaseline(
                q_size=(14,14), kv_size=(14,14), n_heads=8, n_head_channels=128, n_groups=4,
                attn_drop=0, proj_drop=0, stride=2, 
                offset_range_factor=-1, use_pe=True, dwc_pe=False,
                no_off=False, fixed_pe=False, ksize=5, log_cpb=False
            )

            self.da1_l4 = DAttentionBaseline(
                q_size=(7,7), kv_size=(7,7), n_heads=16, n_head_channels=128, n_groups=8,
                attn_drop=0, proj_drop=0, stride=1, 
                offset_range_factor=-1, use_pe=True, dwc_pe=False,
                no_off=False, fixed_pe=False, ksize=3, log_cpb=False
            )

            self.da2_l3 = DAttentionBaseline(
                q_size=(14,14), kv_size=(14,14), n_heads=8, n_head_channels=128, n_groups=4,
                attn_drop=0, proj_drop=0, stride=2, 
                offset_range_factor=-1, use_pe=True, dwc_pe=False,
                no_off=False, fixed_pe=False, ksize=5, log_cpb=False
            )

            self.da2_l4 = DAttentionBaseline(
                q_size=(7,7), kv_size=(7,7), n_heads=16, n_head_channels=128, n_groups=8,
                attn_drop=0, proj_drop=0, stride=1, 
                offset_range_factor=-1, use_pe=True, dwc_pe=False,
                no_off=False, fixed_pe=False, ksize=3, log_cpb=False
            )


    def forward_backbone(self, x, backbone, da_l3, da_l4):
        x = backbone.conv1(x)
        x = backbone.bn1(x)
        x = backbone.relu(x)
        x = backbone.maxpool(x)

        x = backbone.layer1(x)
        x = backbone.layer2(x)
        x = backbone.layer3(x)
        if self.da:
            x1,_,_ = da_l3(x)
            x = x + x1
        x = backbone.layer4(x)
        if self.da:
            x2,_,_ = da_l4(x)
            x = x + x2
        x = backbone.avgpool(x)
        x = torch.flatten(x, 1)
        x = backbone.fc(x)
        return x


    def forward(self,s1,s2):
        #z1 = self.net1(s1)
        #z2 = self.net2(s2)
        z1 = self.forward_backbone(s1, self.net1, self.da1_l3, self.da1_l4)
        z2 = self.forward_backbone(s2, self.net2, self.da2_l3, self.da2_l4)
        z12 = torch.cat((z1,z2),-1)
        return self.ffc(z12)

In [4]:
net = LateFusionModel(da=True)

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
pretrained = 'utils/rn50_rda_ssl4eo-s12_joint_decur_ep100.pth'
state_dict = torch.load(pretrained)

state_dict = {k.replace("module.backbone_1", "net1"): v for k,v in state_dict.items()}
state_dict = {k.replace("module.backbone_2", "net2"): v for k,v in state_dict.items()}
state_dict = {k.replace("module.", ""): v for k,v in state_dict.items()}

msg = net.load_state_dict(state_dict, strict=False)
assert set(msg.missing_keys) == {"ffc.weight", "ffc.bias"}

Load ViTS16 for multispectral

In [21]:
pretrained = 'utils/vits16_ssl4eo-s12_ms_decur_ep100.pth'
state_dict = torch.load(pretrained)

import timm
vit = timm.create_model('vit_small_patch16_224', pretrained=False)
vit.patch_embed.proj = torch.nn.Conv2d(13, 384, kernel_size=(16, 16), stride=(16, 16))

msg = vit.load_state_dict(state_dict, strict=False)
assert set(msg.missing_keys) == {"head.weight", "head.bias"}

