In [1]:
import os
import yaml
from torch.utils.data import DataLoader
import argparse

from GeospatialFM.data import get_datasets
from GeospatialFM.models import *
# from utils import load_config
from torchgeo.samplers import RandomGeoSampler
from matplotlib import pyplot as plt

from transformers import TrainingArguments, Trainer
from transformers import AdamW, get_linear_schedule_with_warmup
from GeospatialFM.utils import setup, get_eval_fn
from GeospatialFM.data import *

from sklearn.linear_model import LogisticRegression
from tqdm import tqdm
import numpy as np
from torch.utils.data import ConcatDataset

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda:0')

In [2]:
args = {'exp_name': None,
        'config_file': 'GeospatialFM/configs/eurosat.yaml',
        'opts': None, 
        'output_dir': './results/configs/eurosat', 
        'save_config': False}
args = argparse.Namespace(**args)
args.debug = True
args

Namespace(exp_name=None, config_file='GeospatialFM/configs/eurosat.yaml', opts=None, output_dir='./results/configs/eurosat', save_config=False, debug=True)

In [3]:
cfg, _ = setup(args)

In [13]:
cfg['MODEL']

{'architecture': 'vit_small_patch16_224', 'image_size': 224, 'patch_size': 16, 'bands': 13, 'num_classes': 10, 'load_pretrained': '', 'lp': False, 'load_pretrained_from': 'torchgeo', 'pretrained_ckpt': 'ViTSmall16_Weights.SENTINEL2_ALL_DINO'}

In [None]:
training_args = TrainingArguments(**cfg['TRAINER'])
model = construct_model(cfg['MODEL'])
model = model.to(device)
train_ds, val_ds, test_ds = get_datasets(cfg['DATASET'])
compute_metrics = get_eval_fn(cfg['DATASET'])

In [6]:
train_ds = ConcatDataset([train_ds, val_ds])
train_dl = DataLoader(train_ds, batch_size=512, shuffle=True, num_workers=8)
test_dl = DataLoader(test_ds, batch_size=512, shuffle=False, num_workers=8)

In [7]:
def extract_features(model, dataloader, device):
    x_all = []
    y_all = []

    for batch in tqdm(dataloader):
        images = batch["image"].to(device)
        labels = batch["label"].numpy()
        
        with torch.inference_mode():
            features = model(images).cpu().numpy()
        
        x_all.append(features)
        y_all.append(labels)

    x_all = np.concatenate(x_all, axis=0)
    y_all = np.concatenate(y_all, axis=0)

    return x_all, y_all

In [8]:
x_all, y_all = extract_features(model.base_model, train_dl, device)

100%|██████████| 43/43 [01:22<00:00,  1.92s/it]


In [9]:
x_test, y_test = extract_features(model.base_model, test_dl, device)

100%|██████████| 11/11 [00:30<00:00,  2.75s/it]


In [10]:
linear_model = LogisticRegression(C=50.0, max_iter=1000)
linear_model.fit(x_all, y_all)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


In [11]:
linear_model.score(x_test, y_test)

0.9377777777777778

In [4]:
model_cfg = cfg['MODEL']
model_cfg

{'architecture': 'vit_small_patch16_224', 'bands': 13, 'num_classes': 10, 'pretrained_ckpt': 'ViTSmall16_Weights.SENTINEL2_ALL_DINO', 'lp': False, 'head_extra_kwargs': {'use_bias': True}, 'load_pretrained_from': 'torchgeo'}

In [5]:
model = construct_model(model_cfg)

In [6]:
model

EncoderDecoder(
  (base_model): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(13, 384, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=384, out_features=384, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate

In [12]:
weights = tgm.get_weight(model_cfg['pretrained_ckpt'])
encoder = tgm.get_model(model_cfg['architecture'], weights=weights)

In [14]:
encoder.head.in_features

384

In [None]:
# get the last layer of the encoder
