In [1]:
import os
import math
import numpy as np

import torch
from torchvision import transforms
from torchvision.transforms import functional as TF
from torch.utils.data import DataLoader
from datasets import load_dataset
from tqdm import tqdm
from argparse import Namespace

from GeospatialFM.datasets.GFMBench.utils import get_dataset, get_metadata
from GeospatialFM.data_process import apply_normalization, modal_specific_collate_fn, get_transform
from GeospatialFM.finetune.utils import get_task_model, get_loss_fn
# from GeospatialFM.models import SpatialSpectralLowRankViTConfig, SpatialSpectralMAEViT
# from GeospatialFM.models.low_rank_attention import get_perception_field_mask
# from GeospatialFM.models import PositionalChannelEmbedding
# from GeospatialFM.datasets import SSL4EODataset
# from GeospatialFM.scripts.trainer import MAETrainer

from functools import partial
from accelerate import Accelerator
from transformers import TrainingArguments

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import rasterio
import pandas as pd
import json

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
args = Namespace(
    dataset_name = "eurosat",
    task_type = "classification",
    scale = 2,
    data_dir = "/home/haozhesi/Dropbox/GeospatialFM/data/geospatial",
    crop_size = None,
    modal = "optical",
    return_dict = True,
    pretrained_model_path = "/home/haozhesi/Dropbox/GeospatialFM/results/models/LESSVIT_b2_d6/checkpoint-33000/model.safetensors",
    patch_size=16,
    embed_dim=768,
    channel_embed_dims_per_head=2,
    depth=12,
    num_heads=12,
    decoder_embed_dim=512,
    decoder_depth=6,
    decoder_num_heads=16,
    decoder_channel_embed_dims_per_head=2,
    use_perception_field_mask=True,
    attention_radius=320,
    norm_pix_loss=False,
    decoder_out_chans = 15,
    pos_chan_embed_residual=True,
)

In [3]:
metadata = get_metadata(args.dataset_name)
optical_mean, optical_std = metadata["s2c"]["mean"], metadata["s2c"]["std"]
radar_mean, radar_std = metadata["s1"]["mean"], metadata["s1"]["std"]
print(optical_mean, radar_mean)
print(optical_std, radar_std)
print(metadata["size"], metadata['num_classes'])

[1354.40546513, 1118.24399958, 1042.92983953, 947.62620298, 1199.47283961, 1999.79090914, 2369.22292565, 2296.82608323, 732.08340178, 12.11327804, 1819.01027855, 1118.92391149, 2594.14080798] None
[245.71762908, 333.00778264, 395.09249139, 593.75055589, 566.4170017, 861.18399006, 1086.63139075, 1117.98170791, 404.91978886, 4.77584468, 1002.58768311, 761.30323499, 1231.58581042] None
64 10


In [4]:
standard_transform = partial(apply_normalization, optical_mean=optical_mean, optical_std=optical_std, radar_mean=radar_mean, radar_std=radar_std, use_8bit=False)
collate_fn = partial(modal_specific_collate_fn, normalization=standard_transform, modal=args.modal)

In [5]:
crop_size = metadata["size"] if args.crop_size is None else args.crop_size
train_transform, eval_transform = get_transform(args.task_type, crop_size=crop_size, scale=args.scale)


In [6]:
dataset = get_dataset(args, train_transform, eval_transform)

In [7]:
# dataset["train"].set_format(type='torch', columns=['optical', 'label', 'optical_channel_wv', 'spatial_resolution'])
# dataset["train"].set_transform(train_transform)

In [8]:
# len(dataset["train"][0]['optical'][0][0])
dataset['val'][1]

{'optical': tensor([[[1179.0000, 1179.0217, 1179.0686,  ..., 1200.2825, 1199.2920,
           1198.8234],
          [1179.0000, 1179.0219, 1179.0688,  ..., 1200.2639, 1199.2874,
           1198.8254],
          [1179.0000, 1179.0219, 1179.0687,  ..., 1200.2242, 1199.2775,
           1198.8296],
          ...,
          [1220.5801, 1220.6960, 1220.9436,  ..., 1216.6517, 1216.1025,
           1215.8430],
          [1221.6642, 1221.7920, 1222.0649,  ..., 1216.7422, 1216.1633,
           1215.8899],
          [1222.1764, 1222.3098, 1222.5947,  ..., 1216.7845, 1216.1918,
           1215.9117]],
 
         [[ 866.0000,  865.9124,  865.7252,  ...,  895.4809,  876.9927,
            868.2646],
          [ 866.1971,  866.1090,  865.9208,  ...,  895.3577,  876.8964,
            868.1810],
          [ 866.6183,  866.5292,  866.3388,  ...,  895.0944,  876.6904,
            868.0020],
          ...,
          [ 894.3511,  894.6378,  895.2505,  ...,  919.2675,  904.9047,
            898.1255],
      

In [23]:
dataloader = DataLoader(dataset["val"], batch_size=4, collate_fn=collate_fn)
for i, batch in enumerate(dataloader):
    # print(batch['optical'].shape)
    for key, value in batch.items():
        print(key, value.shape)
    break
labels = batch.pop("labels")

spatial_resolution ()
labels torch.Size([4])
optical torch.Size([4, 13, 128, 128])
optical_channel_wv torch.Size([1, 13])


In [24]:
model = get_task_model(args, metadata["num_classes"], metadata["size"])
if args.pretrained_model_path:
        from safetensors import safe_open
        with safe_open(args.pretrained_model_path, framework="pt", device="cpu") as f:
            # Load only encoder weights
            for key in f.keys():
                if key.startswith("encoder."):
                    # Get the corresponding key in target model
                    param = f.get_tensor(key)
                    model.state_dict()[key].copy_(param)

In [25]:
labels, batch

(tensor([3, 9, 9, 0]),
 {'spatial_resolution': array(5.),
  'optical': tensor([[[[0.6135, 0.6137, 0.6140,  ..., 0.6462, 0.6446, 0.6438],
            [0.6135, 0.6137, 0.6140,  ..., 0.6462, 0.6446, 0.6438],
            [0.6135, 0.6137, 0.6140,  ..., 0.6464, 0.6447, 0.6439],
            ...,
            [0.5748, 0.5749, 0.5749,  ..., 0.5578, 0.5596, 0.5604],
            [0.5787, 0.5788, 0.5789,  ..., 0.5567, 0.5584, 0.5592],
            [0.5806, 0.5806, 0.5807,  ..., 0.5562, 0.5579, 0.5587]],
  
           [[0.5081, 0.5078, 0.5071,  ..., 0.4602, 0.4941, 0.5101],
            [0.5079, 0.5076, 0.5069,  ..., 0.4607, 0.4945, 0.5104],
            [0.5075, 0.5072, 0.5065,  ..., 0.4617, 0.4953, 0.5111],
            ...,
            [0.5898, 0.5890, 0.5873,  ..., 0.5208, 0.5304, 0.5350],
            [0.5272, 0.5226, 0.5126,  ..., 0.5287, 0.5325, 0.5343],
            [0.4977, 0.4912, 0.4774,  ..., 0.5324, 0.5335, 0.5340]],
  
           [[0.4823, 0.4823, 0.4823,  ..., 0.3961, 0.4425, 0.4645],
     

In [26]:
outputs = model(**batch)

In [27]:
outputs

{'logits': tensor([[ 2.1950,  2.2815,  2.1812,  0.4910,  0.4998, -0.5561, -0.6597, -0.0917,
           1.2278,  0.2942],
         [ 0.4912,  2.7523, -0.9729,  1.8996,  2.9218,  0.2955, -0.5348, -0.0702,
           0.9565, -0.9569],
         [ 0.1277,  1.4904, -3.4178,  0.4776,  0.1071,  0.5749, -0.0490, -0.7131,
          -0.8056, -0.8754],
         [ 1.6149,  0.4109, -1.0782, -0.2697, -1.3952,  2.5292,  1.7501, -0.7728,
           0.3800, -1.5824]], grad_fn=<AddmmBackward0>)}

In [28]:
loss_fn = get_loss_fn(args.task_type)

In [29]:
loss_fn(outputs, labels, None)


tensor(3.2763, grad_fn=<NllLossBackward0>)