In [16]:
import os
import math
import numpy as np

import torch
from torchvision import transforms
from torchvision.transforms import functional as TF
from torch.utils.data import DataLoader
from datasets import load_dataset, load_dataset_builder, get_dataset_infos
from tqdm import tqdm
from argparse import Namespace

from GeospatialFM.datasets.GFMBench.utils import get_dataset, get_metadata
from GeospatialFM.data_process import modal_specific_collate_fn, get_transform
from GeospatialFM.finetune.utils import get_task_model, get_loss_fn
from GeospatialFM.finetune.args import parse_args
from GeospatialFM.models import SpatialSpectralLowRankViTConfig, SpatialSpectralMAEViT
from GeospatialFM.finetune.linear_probe import compute_encoding

from functools import partial
from accelerate import Accelerator
from transformers import TrainingArguments
from safetensors import safe_open

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import rasterio
import pandas as pd
import json
from sklearn.decomposition import PCA
import cv2

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
def pca_vis_patch(patch, n_components=3):
    # perform PCA on patch
    pca = PCA(n_components=n_components)
    B, L, D = patch.shape
    patch_ = patch.reshape(B*L, D)
    try:
        patch_ = patch_.cpu().numpy()
    except:
        pass
    pca.fit(patch_)
    patch_pca = pca.transform(patch_)

    preprocessed_patches = patch_pca.reshape(B, int(L**0.5), int(L**0.5), 3)
    return preprocessed_patches

def norm_image(image):
    image = image - image.min()
    image = image / image.max()
    return image

In [3]:
ROOT_DIR="/home/haozhesi/Dropbox/GeospatialFM"

In [27]:
args = parse_args([
    "--data_dir", ROOT_DIR + "/data/geospatial",
    "--dataset_name", "eurosat",
    "--task_type", "classification", 
    "--scale", "2",
    "--modal", "multi",
    "--return_dict",
    "--per_device_train_batch_size", "64",
    "--gradient_accumulation_steps", "4", 
    "--num_train_epochs", "20",
    "--learning_rate", "3e-4",
    "--adam_weight_decay", "0.01",
    "--warmup_steps", "0",
    "--warmup_ratio", "0.2",
    "--report_to", "wandb",
    "--save_total_limit", "5",
    "--seed", "42",
    "--mixed_precision", "bf16",
    "--dataloader_num_workers", "32",
    "--dataloader_pin_memory",
    "--output_dir", ROOT_DIR + "/results/models",
    "--logging_dir", ROOT_DIR + "/results/logs", 
    "--wandb_dir", ROOT_DIR + "/results/",
    "--run_name", "LESSVIT_b2_d4_eurosat",
    "--lr_scheduler_type", "cosine",
    "--channel_embed_dims_per_head", "2",
    "--use_perception_field_mask",
    "--use_moe",
    "--num_experts", "3"
])

In [29]:
metadata = get_metadata(args.dataset_name)
args.crop_size = metadata["size"] if args.crop_size is None else args.crop_size

optical_mean, optical_std = metadata["s2c"]["mean"], metadata["s2c"]["std"]
radar_mean, radar_std = metadata["s1"]["mean"], metadata["s1"]["std"]

collate_fn = partial(modal_specific_collate_fn, modal=args.modal)

train_transform, eval_transform = get_transform(args.task_type, args.crop_size, args.scale, args.random_rotation, 
                                                optical_mean, optical_std, radar_mean, radar_std)
dataset = get_dataset(args, train_transform, eval_transform)

In [30]:
dataset

{'train': Dataset({
     features: ['optical', 'label', 'optical_channel_wv', 'spatial_resolution'],
     num_rows: 16200
 }),
 'val': Dataset({
     features: ['optical', 'label', 'optical_channel_wv', 'spatial_resolution'],
     num_rows: 5400
 }),
 'test': Dataset({
     features: ['optical', 'label', 'optical_channel_wv', 'spatial_resolution'],
     num_rows: 5400
 })}

In [10]:
json.loads(datasets['train'].info.description)



{'s2c': {'bands': ['B1',
   'B2',
   'B3',
   'B4',
   'B5',
   'B6',
   'B7',
   'B8',
   'B8A',
   'B9',
   'B10',
   'B11',
   'B12'],
  'channel_wv': [442.7,
   492.4,
   559.8,
   664.6,
   704.1,
   740.5,
   782.8,
   832.8,
   864.7,
   945.1,
   1373.5,
   1613.7,
   2202.4],
  'mean': [1354.40546513,
   1118.24399958,
   1042.92983953,
   947.62620298,
   1199.47283961,
   1999.79090914,
   2369.22292565,
   2296.82608323,
   732.08340178,
   12.11327804,
   1819.01027855,
   1118.92391149,
   2594.14080798],
  'std': [245.71762908,
   333.00778264,
   395.09249139,
   593.75055589,
   566.4170017,
   861.18399006,
   1086.63139075,
   1117.98170791,
   404.91978886,
   4.77584468,
   1002.58768311,
   761.30323499,
   1231.58581042]},
 's1': {'bands': None, 'channel_wv': None, 'mean': None, 'std': None},
 'size': 64,
 'num_classes': 10,
 'spatial_resolution': 10}