In [2]:
import os
import yaml
from torch.utils.data import DataLoader
import argparse

from GeospatialFM.data import get_datasets
from GeospatialFM.models import *
# from utils import load_config
from torchgeo.samplers import RandomGeoSampler
from matplotlib import pyplot as plt

from transformers import TrainingArguments, Trainer
from transformers import AdamW, get_linear_schedule_with_warmup
from GeospatialFM.utils import setup, get_eval_fn
from GeospatialFM.data import *
from GeospatialFM.models import *

from sklearn.linear_model import LogisticRegression
from tqdm import tqdm
import numpy as np
from torch.utils.data import ConcatDataset
import segmentation_models_pytorch as smp

%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device('cuda:0')

In [3]:
args = {'exp_name': None,
        'config_file': 'GeospatialFM/configs/crop.yaml',
        'opts': None, 
        'save_config': False}
args = argparse.Namespace(**args)
args.debug = True
args

Namespace(exp_name=None, config_file='GeospatialFM/configs/crop.yaml', opts=None, save_config=False, debug=True)

In [4]:
cfg, _ = setup(args)

In [5]:
cfg['MODEL']
optical_model = construct_model(cfg['MODEL'])

In [6]:
hasattr(optical_model, 'base_model')

True

In [7]:
cfg['SAR_MODEL']
sar_model = construct_model(cfg['SAR_MODEL'])

In [8]:
next(sar_model.parameters()).dtype

torch.float32

In [9]:
# crop = build_crop(cfg)

In [10]:
crop.encode_optical(torch.randn(1, 13, 256, 256)).shape

torch.Size([1, 2048])

In [11]:
train_ds, val_ds, test_ds = get_datasets(cfg['DATASET'])
dataloader = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=8)

In [12]:
crop.to(device)
for batch in dataloader:
    optical = batch['image'].to(device)
    radar = batch['radar'].to(device)
    label = batch['label'].to(device)
    
    logits = crop(optical, radar)

    

In [13]:
logits = crop(optical, radar)

In [15]:
logits[0].shape

torch.Size([128, 128])

In [None]:
training_args = TrainingArguments(**cfg['TRAINER'])
model = model.to(device)
compute_metrics = get_eval_fn(cfg['DATASET'])

In [17]:
len(train_ds), len(test_ds)

(39341, 125866)

In [18]:
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of parameters: {num_params}')

Number of parameters: 23578323


In [21]:
train_ds[0]['image'].shape

torch.Size([3, 224, 224])

In [36]:
train_ds[0]['label']

tensor(4)

In [14]:
ret = model(train_ds[0]['image'].unsqueeze(0).to(device))

In [17]:
train_ds = ConcatDataset([train_ds, val_ds])
train_dl = DataLoader(train_ds, batch_size=512, shuffle=True, num_workers=8)
test_dl = DataLoader(test_ds, batch_size=512, shuffle=False, num_workers=8)

In [16]:
ret.shape

torch.Size([1, 19])

In [18]:
trainer = Trainer(
    model=model,                # the instantiated 🤗 Transformers model to be trained
    args=training_args,                   # training arguments, defined above
    train_dataset=train_ds,    # training dataset
    eval_dataset=test_ds,      # evaluation dataset
    compute_metrics=compute_metrics,
)

trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mehzoahis[0m. Use [1m`wandb login --relogin`[0m to force relogin


KeyboardInterrupt: 

In [7]:
def extract_features(model, dataloader, device):
    x_all = []
    y_all = []

    for batch in tqdm(dataloader):
        images = batch["image"].to(device)
        labels = batch["label"].numpy()
        
        with torch.inference_mode():
            features = model(images).cpu().numpy()
        
        x_all.append(features)
        y_all.append(labels)

    x_all = np.concatenate(x_all, axis=0)
    y_all = np.concatenate(y_all, axis=0)

    return x_all, y_all

In [8]:
x_all, y_all = extract_features(model.base_model, train_dl, device)

100%|██████████| 43/43 [01:22<00:00,  1.92s/it]


In [9]:
x_test, y_test = extract_features(model.base_model, test_dl, device)

100%|██████████| 11/11 [00:30<00:00,  2.75s/it]


In [10]:
linear_model = LogisticRegression(C=50.0, max_iter=1000)
linear_model.fit(x_all, y_all)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


In [11]:
linear_model.score(x_test, y_test)

0.9377777777777778

In [4]:
model_cfg = cfg['MODEL']
model_cfg

{'architecture': 'vit_small_patch16_224', 'bands': 13, 'num_classes': 10, 'pretrained_ckpt': 'ViTSmall16_Weights.SENTINEL2_ALL_DINO', 'lp': False, 'head_extra_kwargs': {'use_bias': True}, 'load_pretrained_from': 'torchgeo'}

In [5]:
model = construct_model(model_cfg)

In [6]:
model

EncoderDecoder(
  (base_model): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(13, 384, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=384, out_features=384, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate

In [12]:
weights = tgm.get_weight(model_cfg['pretrained_ckpt'])
encoder = tgm.get_model(model_cfg['architecture'], weights=weights)

In [14]:
encoder.head.in_features

384

In [1]:
# get the last layer of the encoder

import timm

  from .autonotebook import tqdm as notebook_tqdm


In [17]:
timm.list_models('vit_small_patch16_224')

['vit_small_patch16_224']

In [18]:
model = timm.create_model('vit_small_patch16_224', pretrained=False)

In [23]:
# state_dict = torch.load('/data/common/huggingface_models/SSL4EO/B2_vits16_mae_ep99.pth')['state_dict']
# state_dict = {k.replace('module.encoder_q.', ''): v for k, v in state_dict.items() if not k.startswith('module.encoder_q.fc')}

# torch.save(state_dict, '/data/common/huggingface_models/SSL4EO/B2_vits16_mae.pth')

state_dict = torch.load('/data/common/huggingface_models/SSL4EO/B2_vits16_mae_ep99.pth')['model']
torch.save(state_dict, '/data/common/huggingface_models/SSL4EO/B2_vits16_mae.pth')

In [9]:
model.load_state_dict(state_dict)

RuntimeError: Error(s) in loading state_dict for ResNet:
	Missing key(s) in state_dict: "fc.weight", "fc.bias". 
	size mismatch for conv1.weight: copying a param with shape torch.Size([64, 2, 7, 7]) from checkpoint, the shape in current model is torch.Size([64, 3, 7, 7]).

In [21]:
for key in state_dict.keys():
    print(key)
    

cls_token
pos_embed
mask_token
decoder_pos_embed
patch_embed.proj.weight
patch_embed.proj.bias
blocks.0.norm1.weight
blocks.0.norm1.bias
blocks.0.attn.qkv.weight
blocks.0.attn.qkv.bias
blocks.0.attn.proj.weight
blocks.0.attn.proj.bias
blocks.0.norm2.weight
blocks.0.norm2.bias
blocks.0.mlp.fc1.weight
blocks.0.mlp.fc1.bias
blocks.0.mlp.fc2.weight
blocks.0.mlp.fc2.bias
blocks.1.norm1.weight
blocks.1.norm1.bias
blocks.1.attn.qkv.weight
blocks.1.attn.qkv.bias
blocks.1.attn.proj.weight
blocks.1.attn.proj.bias
blocks.1.norm2.weight
blocks.1.norm2.bias
blocks.1.mlp.fc1.weight
blocks.1.mlp.fc1.bias
blocks.1.mlp.fc2.weight
blocks.1.mlp.fc2.bias
blocks.2.norm1.weight
blocks.2.norm1.bias
blocks.2.attn.qkv.weight
blocks.2.attn.qkv.bias
blocks.2.attn.proj.weight
blocks.2.attn.proj.bias
blocks.2.norm2.weight
blocks.2.norm2.bias
blocks.2.mlp.fc1.weight
blocks.2.mlp.fc1.bias
blocks.2.mlp.fc2.weight
blocks.2.mlp.fc2.bias
blocks.3.norm1.weight
blocks.3.norm1.bias
blocks.3.attn.qkv.weight
blocks.3.attn.qk

In [22]:
for key in model.state_dict().keys():
    print(key)

cls_token
pos_embed
patch_embed.proj.weight
patch_embed.proj.bias
blocks.0.norm1.weight
blocks.0.norm1.bias
blocks.0.attn.qkv.weight
blocks.0.attn.qkv.bias
blocks.0.attn.proj.weight
blocks.0.attn.proj.bias
blocks.0.norm2.weight
blocks.0.norm2.bias
blocks.0.mlp.fc1.weight
blocks.0.mlp.fc1.bias
blocks.0.mlp.fc2.weight
blocks.0.mlp.fc2.bias
blocks.1.norm1.weight
blocks.1.norm1.bias
blocks.1.attn.qkv.weight
blocks.1.attn.qkv.bias
blocks.1.attn.proj.weight
blocks.1.attn.proj.bias
blocks.1.norm2.weight
blocks.1.norm2.bias
blocks.1.mlp.fc1.weight
blocks.1.mlp.fc1.bias
blocks.1.mlp.fc2.weight
blocks.1.mlp.fc2.bias
blocks.2.norm1.weight
blocks.2.norm1.bias
blocks.2.attn.qkv.weight
blocks.2.attn.qkv.bias
blocks.2.attn.proj.weight
blocks.2.attn.proj.bias
blocks.2.norm2.weight
blocks.2.norm2.bias
blocks.2.mlp.fc1.weight
blocks.2.mlp.fc1.bias
blocks.2.mlp.fc2.weight
blocks.2.mlp.fc2.bias
blocks.3.norm1.weight
blocks.3.norm1.bias
blocks.3.attn.qkv.weight
blocks.3.attn.qkv.bias
blocks.3.attn.proj.wei

In [29]:
# list all the parameter names in resnet

for (name, _), (weight_name, weight) in zip(resnet.named_parameters(), state_dict.items()):
    print(name, weight_name)


conv1.weight module.encoder_q.conv1.weight
bn1.weight module.encoder_q.bn1.weight
bn1.bias module.encoder_q.bn1.bias
layer1.0.conv1.weight module.encoder_q.bn1.running_mean
layer1.0.bn1.weight module.encoder_q.bn1.running_var
layer1.0.bn1.bias module.encoder_q.bn1.num_batches_tracked
layer1.0.conv2.weight module.encoder_q.layer1.0.conv1.weight
layer1.0.bn2.weight module.encoder_q.layer1.0.bn1.weight
layer1.0.bn2.bias module.encoder_q.layer1.0.bn1.bias
layer1.0.conv3.weight module.encoder_q.layer1.0.bn1.running_mean
layer1.0.bn3.weight module.encoder_q.layer1.0.bn1.running_var
layer1.0.bn3.bias module.encoder_q.layer1.0.bn1.num_batches_tracked
layer1.0.downsample.0.weight module.encoder_q.layer1.0.conv2.weight
layer1.0.downsample.1.weight module.encoder_q.layer1.0.bn2.weight
layer1.0.downsample.1.bias module.encoder_q.layer1.0.bn2.bias
layer1.1.conv1.weight module.encoder_q.layer1.0.bn2.running_mean
layer1.1.bn1.weight module.encoder_q.layer1.0.bn2.running_var
layer1.1.bn1.bias module.e

TypeError: 'dict_items' object is not subscriptable