In [23]:
import os
import yaml
from torch.utils.data import DataLoader
import argparse

from GeospatialFM.data import get_datasets
from GeospatialFM.models import *
# from utils import load_config
from torchgeo.samplers import RandomGeoSampler
from matplotlib import pyplot as plt

from transformers import TrainingArguments, Trainer
from transformers import AdamW, get_linear_schedule_with_warmup
from GeospatialFM.utils import setup, get_eval_fn, get_data, init_distributed_device
from GeospatialFM.data import *
from GeospatialFM.models import *
from GeospatialFM.loss import *

from sklearn.linear_model import LogisticRegression
from tqdm import tqdm
import numpy as np
from torch.utils.data import ConcatDataset
import segmentation_models_pytorch as smp
from collections import OrderedDict

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [24]:
# exp_name = 'mae_cm_reconall_scratch_allTrain'
# exp_name = 'mae_unidecoder_scratch_allTrain_norwl'
exp_name = "ft"
device = torch.device('cuda:0')

In [28]:
args = {'exp_name': exp_name,
        'config_file': 'GeospatialFM/configs/finetune_vit.yaml',
        'opts': None, 
        'save_config': False,
        'finetune_modal': 'optical'}
args = argparse.Namespace(**args)
args.debug = True
args.finetune = True
device = init_distributed_device(args)
cfg, _ = setup(args)
args

Namespace(exp_name='ft', config_file='GeospatialFM/configs/finetune_vit.yaml', opts=None, save_config=False, finetune_modal='optical', debug=True, finetune=True, distributed=False, world_size=1, rank=0, local_rank=0, device='cuda:0')

In [31]:
cfg.MODEL

{'architecture': 'vit_base_patch16_224', 'cross_modal': True, 'unified_decoder': True, 'mask_ratio': 0.75, 'channel_mask_ratio': 0.0, 'freeze_encoder': False, 'use_siglip': False, 'OPTICAL': {'load_pretrained_from': None, 'pretrained_ckpt': None, 'freeze_encoder': False, 'channel_vit': False, 'kwargs': {'img_size': 224, 'patch_size': 16, 'in_chans': 13, 'embed_dim': 768, 'depth': 12, 'num_heads': 12, 'mlp_ratio': 4.0, 'qkv_bias': True, 'drop_path_rate': 0.0, 'drop_path_uniform': False, 'init_values': None, 'num_register_tokens': 0}, 'use_head': True, 'head_kwargs': {'head_type': 'linear', 'task_type': 'classification', 'use_bias': True, 'in_features': 768, 'num_classes': 17}}, 'RADAR': {'load_pretrained_from': None, 'pretrained_ckpt': None, 'freeze_encoder': False, 'channel_vit': False, 'kwargs': {'img_size': 224, 'patch_size': 16, 'in_chans': 2, 'embed_dim': 768, 'depth': 12, 'num_heads': 12, 'mlp_ratio': 4.0, 'qkv_bias': True, 'drop_path_rate': 0.0, 'drop_path_uniform': False, 'init_

In [14]:
model = construct_downstream_models(cfg.MODEL)['OPTICAL']

In [15]:
model

ViTModel(
  (encoder): ChannelViTEncoder(
    (patch_embed): PatchEmbedPerChannel(
      (channel_pool): AdaptiveMaxPool1d(output_size=1)
      (proj): Conv3d(1, 768, kernel_size=(1, 16, 16), stride=(1, 16, 16))
      (channel_embed): Embedding(13, 768)
    )
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='none')
       

In [16]:
# load pretrained mae vit-base model from timm
mae_state_dict = timm.create_model('vit_base_patch16_224.mae', pretrained=True).state_dict()

In [20]:
# delete the patch_embed layer weights/bias from the state dict
mae_state_dict['patch_embed.proj.weight'] = model.state_dict()['encoder.patch_embed.proj.weight']
mae_state_dict['patch_embed.proj.bias'] = model.state_dict()['encoder.patch_embed.proj.bias']

In [22]:
model.encoder.load_state_dict(mae_state_dict, strict=False)

_IncompatibleKeys(missing_keys=['patch_embed.channel_embed.weight'], unexpected_keys=[])

In [19]:
model.state_dict().keys()

odict_keys(['encoder.cls_token', 'encoder.pos_embed', 'encoder.patch_embed.proj.weight', 'encoder.patch_embed.proj.bias', 'encoder.patch_embed.channel_embed.weight', 'encoder.blocks.0.norm1.weight', 'encoder.blocks.0.norm1.bias', 'encoder.blocks.0.attn.qkv.weight', 'encoder.blocks.0.attn.qkv.bias', 'encoder.blocks.0.attn.proj.weight', 'encoder.blocks.0.attn.proj.bias', 'encoder.blocks.0.norm2.weight', 'encoder.blocks.0.norm2.bias', 'encoder.blocks.0.mlp.fc1.weight', 'encoder.blocks.0.mlp.fc1.bias', 'encoder.blocks.0.mlp.fc2.weight', 'encoder.blocks.0.mlp.fc2.bias', 'encoder.blocks.1.norm1.weight', 'encoder.blocks.1.norm1.bias', 'encoder.blocks.1.attn.qkv.weight', 'encoder.blocks.1.attn.qkv.bias', 'encoder.blocks.1.attn.proj.weight', 'encoder.blocks.1.attn.proj.bias', 'encoder.blocks.1.norm2.weight', 'encoder.blocks.1.norm2.bias', 'encoder.blocks.1.mlp.fc1.weight', 'encoder.blocks.1.mlp.fc1.bias', 'encoder.blocks.1.mlp.fc2.weight', 'encoder.blocks.1.mlp.fc2.bias', 'encoder.blocks.2.norm