In [1]:
import torch.nn as nn
from rshf.satmae import SatMAE_Pre_MS
from torchsummary import summary
import torch
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class DoubleConv(nn.Module):
    def __init__(self, in_channel, out_channel, mid_channels=None, bias=False):
        super().__init__()

        # If an intermediate channel count isn't provided, match the final width.
        if not mid_channels:
            mid_channels = out_channel

        # Two consecutive Conv-BN-ReLU blocks with 3x3 kernels and padding=1 to keep H, W unchanged.
        self.doubleconv = nn.Sequential(
            # First 3x3 convolution: C_in -> mid_channels
            nn.Conv2d(
                in_channels=in_channel,
                out_channels=mid_channels,
                kernel_size=3,
                padding=1,          # keep spatial dimensions (same-conv for 3x3)
                bias=bias
            ),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),

            # Second 3x3 convolution: mid_channels -> C_out
            nn.Conv2d(
                in_channels=mid_channels,
                out_channels=out_channel,
                kernel_size=3,
                padding=1,          # keep spatial dimensions (same-conv for 3x3)
                bias=bias
            ),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.doubleconv(x)

In [3]:
class TransformerSkipAdapter(nn.Module):
    def __init__(self, in_channel, out_channel, scale_factor):
        super().__init__()
        # Upsample the skip feature to match decoder stage resolution
        self.up = nn.Upsample(mode='bilinear', scale_factor=scale_factor, align_corners=True)
        # Refine feature map after resizing
        self.dc = DoubleConv(in_channel, out_channel)

    def forward(self, skip):
        skip = self.up(skip)      # Resize to match decoder resolution
        skip = self.dc(skip)      # Apply DoubleConv for refinement
        return skip


In [4]:
class Decoder(nn.Module):
    def __init__(self, in_channel, out_channel, scale_factor):
        super().__init__()
        # Reduce channels and refine features at current resolution
        self.dc = DoubleConv(in_channel=in_channel, out_channel=out_channel)
        # Upsample to next resolution for decoding
        self.up = nn.Upsample(mode='bilinear', scale_factor=scale_factor, align_corners=True)

    def forward(self, img_patch):
        # Example: input = (B, 768, 14, 14)
        img_patch = self.dc(img_patch)
        # After DoubleConv: (B, 512, 14, 14)

        img_patch = self.up(img_patch)
        # After Upsample: (B, 512, 28, 28)

        return img_patch


In [5]:
class outConv(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        # 1×1 conv: channel-wise linear projection, preserves (H, W)
        self.outconv = nn.Conv2d(in_channel, out_channel, kernel_size = 1)

    def forward(self, image):
        return self.outconv(image)

In [9]:
model = SatMAE_Pre_MS.from_pretrained('MVRL/satmae-vitbase-multispec-pretrain')
model.in_c = 10
summary(model = model, input_size=(10,96,96), batch_size=1, device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [1, 768, 12, 12]         197,376
          Identity-2              [1, 144, 768]               0
        PatchEmbed-3              [1, 144, 768]               0
            Conv2d-4           [1, 768, 12, 12]         197,376
          Identity-5              [1, 144, 768]               0
        PatchEmbed-6              [1, 144, 768]               0
            Conv2d-7           [1, 768, 12, 12]          99,072
          Identity-8              [1, 144, 768]               0
        PatchEmbed-9              [1, 144, 768]               0
        LayerNorm-10              [1, 109, 768]           1,536
           Linear-11             [1, 109, 2304]       1,771,776
         Identity-12           [1, 16, 109, 48]               0
         Identity-13           [1, 16, 109, 48]               0
         Identity-14              [1, 1

In [118]:
features = {}
layers = [2,6,9,12]
for l in layers: 
    model.blocks[l - 1].register_forward_hook(lambda m, inp, out, i=l: features.setdefault(i, out))
    

In [119]:
random_inp = torch.rand(1,10,96,96) 

In [120]:
latent, mask, ids = model.forward_encoder(random_inp, mask_ratio=0)

In [121]:
features[12].shape

torch.Size([1, 433, 768])

In [65]:
class Satmae(nn.Module): 
    def __init__(
            self,
            model,
            num_classes = 5,
            TSA_scale_factor_list = [8,4,2],
            decoder_scale_factor = 2,
            decoder_in_channel_list = [384, 768, 768 * 3],
            decoder_out_channel_list = [128, 256, 512], 
            skip_blocks = [4,7,10],
            TSA_out_channels = [64, 128, 256], 
            TSA_in_channels = 768 * 3,
            DoubleConv_out_channels = 32,
            useMidChannel = True
    ):
        super().__init__()
        # model 
        self.pretrainedSatmae = model
        
        self.skip_blocks = skip_blocks
                
        self.decoder_scale_factor = decoder_scale_factor
        
        self.TSA_scale_factors = TSA_scale_factor_list
        
        self.decoder_in_c =  decoder_in_channel_list
        
        self.decoder_out_c = decoder_out_channel_list
        
        self.TSA_out_c = TSA_out_channels
        
        self.TSA_in_c = TSA_in_channels 
        
        self.dc_in_c = self.TSA_out_c[0] + self.decoder_out_c[0]
        
        self.dc_out_c = DoubleConv_out_channels
        
        if useMidChannel:
            self.dc_mid_channel = (self.dc_in_c + self.dc_out_c) // 2
        
        
        # decoder section 
        # decoder first block 
        
        for param in self.pretrainedSatmae.parameters(): 
            param.requires_grad = False
        
        
        # all the 3 decoder blocks
        for i in range(1, 4): 
            setattr(
                self, f"decoder{i}",
                Decoder(
                    in_channel= self.decoder_in_c[i-1],
                    out_channel=self.decoder_out_c[i-1],
                    scale_factor=self.decoder_scale_factor,
                )
            )
            setattr(
                self, f"skipConnection{i}", 
                TransformerSkipAdapter(
                    in_channel=self.TSA_in_c,
                    out_channel=self.TSA_out_c[i-1],
                    scale_factor=self.TSA_scale_factors[i-1],
                    )
            )
            
        self.dc = DoubleConv(in_channel=self.dc_in_c, out_channel=self.dc_out_c, mid_channels=self.dc_mid_channel)
        
        self.outConv = outConv(in_channel= self.dc_out_c, out_channel=num_classes)
            
                
        
        
    
    def forward(self, img):
        features = {}
        for l in self.skip_blocks:
            self.pretrainedSatmae.blocks[l - 1].register_forward_hook(
                lambda m, inp, out, i=l: features.setdefault(i, out)
            )


        enc_output = self.pretrainedSatmae.forward_encoder(img, mask_ratio=0.0)[0]
        enc_output = self.parse_enc_output(enc_output)
        # gives you 433 output, one is cls 
        # remove 433 to 432 
        # 432 = 12 x 12 x 3

        
        # decoder block 3 + TSA block 3 + concat 
        enc_output = self.decoder3(enc_output)
    
        skipFeat = features[self.skip_blocks[2]]
        skipFeat = self.parse_enc_output(skipFeat)
        skipFeat = self.skipConnection3(skipFeat)
        
        enc_output = torch.cat((enc_output, skipFeat), dim= 1)
        
        # decoder block 2 + TSA block 2 + concat
        enc_output = self.decoder2(enc_output)
        skipFeat = features[self.skip_blocks[1]]
        skipFeat = self.parse_enc_output(skipFeat)
        skipFeat = self.skipConnection2(skipFeat)
        
        enc_output = torch.cat((enc_output, skipFeat), dim= 1)
        
        # decoder block 1 + TSA block 1 + concat
        enc_output = self.decoder1(enc_output)
        skipFeat = features[self.skip_blocks[0]]
        skipFeat = self.parse_enc_output(skipFeat)
        skipFeat = self.skipConnection1(skipFeat)
        
        enc_output = torch.cat((enc_output, skipFeat), dim= 1)

        # all decoders are done 
        enc_output = self.dc(enc_output)
        
        # segmentation map (or output)
        output = self.outConv(enc_output)
        
        return output
        
    
    def parse_enc_output(self, feat):
        
        # dropping cls tag
        B, N, C = feat.shape
        
        feat = feat[:, 1:, :] # (B, 432, 768)
        
        streams= 3
        
        h = w = int(math.sqrt((N-1) // streams))
        
        feat = feat.view(B, streams, h, w, C) 
        
        feat = feat.permute(0,1,4,2,3)
        
        feat = feat.reshape(B, C * streams, h, w)
        
        return feat



    def StartFineTuning(self, blocks_to_unfreeze=1):
        total = len(self.pretrainedSatmae.blocks)
        for idx in range(total - blocks_to_unfreeze, total):
            for p in self.pretrainedSatmae.blocks[idx].parameters():
                p.requires_grad = True
            
        

In [66]:
satmae_model = Satmae(model)

In [67]:
satmae_model

Satmae(
  (pretrainedSatmae): MaskedAutoencoderGroupChannelViT(
    (patch_embed): ModuleList(
      (0-1): 2 x PatchEmbed(
        (proj): Conv2d(4, 768, kernel_size=(8, 8), stride=(8, 8))
        (norm): Identity()
      )
      (2): PatchEmbed(
        (proj): Conv2d(2, 768, kernel_size=(8, 8), stride=(8, 8))
        (norm): Identity()
      )
    )
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (norm): Identity()
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): M

In [68]:
summary(satmae_model, input_size=(10,96,96))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 12, 12]         197,376
          Identity-2             [-1, 144, 768]               0
        PatchEmbed-3             [-1, 144, 768]               0
            Conv2d-4          [-1, 768, 12, 12]         197,376
          Identity-5             [-1, 144, 768]               0
        PatchEmbed-6             [-1, 144, 768]               0
            Conv2d-7          [-1, 768, 12, 12]          99,072
          Identity-8             [-1, 144, 768]               0
        PatchEmbed-9             [-1, 144, 768]               0
        LayerNorm-10             [-1, 433, 768]           1,536
           Linear-11            [-1, 433, 2304]       1,771,776
         Identity-12          [-1, 16, 433, 48]               0
         Identity-13          [-1, 16, 433, 48]               0
         Identity-14             [-1, 4