In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import re
import json
import numpy as np

In [2]:
RESULTS_ROOT = '/data/common/geospatial_fm/models'
model_name = "LESSVIT"

In [3]:
def clear_ckpts(dataset):
    import shutil
    cnt = 0
    dataset_dir = f"{RESULTS_ROOT}/{dataset}"
    for dir in os.listdir(dataset_dir):
        if os.path.exists(os.path.join(dataset_dir, dir, "test_results.json")):
            # remove all ckpts dirs
            for sub_dir in os.listdir(os.path.join(dataset_dir, dir)):
                if "checkpoint" in sub_dir:
                    shutil.rmtree(os.path.join(dataset_dir, dir, sub_dir))
                    cnt += 1
    print(f"Cleared {cnt} ckpts for {dataset}")

In [4]:
def collect_results(model_name, dataset, metric_name, filter_key:dict=None):
    dataset_dir = f"{RESULTS_ROOT}/{dataset}"
    
    df_all = []
    for dir in os.listdir(dataset_dir):
        if model_name not in dir:
            continue
        target_file = os.path.join(dataset_dir, dir, "test_results.json")
        try:
            with open(target_file, 'r') as f:
                log = json.load(f)
        except:
            continue
        lp = "lp" in dir
        modal = "optical"
        if "radar" in dir:
            modal = "radar"
        elif "multi" in dir:
            modal = "multi"
        # replace the first - with _
        if "LESSVIT-S" in dir: dir = dir.replace("LESSVIT-S", "LESSVIT_s")
        model_config = dir.split("_")
            
        moe_idx = next((i for i, x in enumerate(model_config) if x.startswith("moe")), None)
        if moe_idx is not None:
            log['moe'] = int(model_config.pop(moe_idx).replace("moe", ""))
        else:
            log['moe'] = 0
            
        topk_idx = next((i for i, x in enumerate(model_config) if x.startswith("topk")), None)
        if topk_idx is not None:
            log['topk'] = int(model_config.pop(topk_idx).replace("topk", ""))
        else:
            log['topk'] = 3
        
        # find the config start with ckpt
        ckpt_idx = next((i for i, x in enumerate(model_config) if x.startswith("ckpt")), None)
        if ckpt_idx is not None:
            log['ckpt'] = int(model_config.pop(ckpt_idx).replace("ckpt", ""))
        else:
            log['ckpt'] = 24600
        
        try:
            model_config = model_config[:7]
            _, embed_dims, depth, rank, _, lr, scale = tuple(model_config) 
            log['embed_dims'] = int(embed_dims.replace("s", "")) # TODO: choose between s and b
            log['depth'] = int(depth.replace("d", ""))
            log['lr'] = float(lr.replace("lr", ""))
            log['scale'] = float(scale.replace("scale", ""))
            log['lp'] = lp
            log['modal'] = modal
            log['rank'] = int(rank.replace("r", ""))
        except:
            model_config = model_config[:6]
            _, embed_dims, depth, _, lr, scale = tuple(model_config) 
        
            log['embed_dims'] = int(embed_dims.replace("b", ""))
            log['depth'] = int(depth.replace("d", ""))
            log['lr'] = float(lr.replace("lr", ""))
            log['scale'] = float(scale.replace("scale", ""))
            log['lp'] = lp
            log['modal'] = modal
            log['rank'] = 1
        df_all.append(log)
            
    df_all = pd.DataFrame(df_all)
    if filter_key is not None:
        for key, value in filter_key.items():
            df_all = df_all.loc[df_all[key] == value]
    df_all = df_all.reset_index()
    df_all_stat = df_all.groupby(['embed_dims', 'depth', 'scale', 'moe', 'lp', 'modal', 'rank']).apply(lambda x: x.loc[x[metric_name].idxmax()])
    df_all_stat = df_all_stat.drop(columns=['embed_dims', 'depth', 'scale', 'moe', 'lp', 'modal', 'rank'])
    df_all_stat = df_all_stat.reset_index().drop(columns=['index'])
    df_all_stat.sort_values(by=[metric_name], ascending=False)
    return df_all, df_all_stat

## Classification

In [None]:
dataset = 'eurosat'
metric_name = 'eval_accuracy'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"lp": False, "scale": 2.0, "ckpt": 24600})
df_all_stat

In [None]:
dataset = 'bigearthnet'
metric_name = 'eval_micro_mAP'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"lp": False, "ckpt": 24600, "modal": "optical"})
df_all_stat

In [None]:
dataset = 'so2sat'
metric_name = 'eval_accuracy'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"lp": False, "scale": 4.0, "ckpt": 94200})
df_all_stat

## Segmentation

In [None]:
dataset = 'segmunich'
metric_name = 'eval_IoU'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"ckpt": 94200})
df_all_stat

In [None]:
dataset = 'dfc2020'
metric_name = 'eval_IoU'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"ckpt": 24600})
df_all_stat

In [None]:
dataset = 'marida'
metric_name = 'eval_IoU'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"ckpt": 94200})
df_all_stat

In [None]:
dataset = 'landsat'
metric_name = 'eval_IoU'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"ckpt": 73800})
df_all_stat

In [None]:
# clear_ckpts("bigearthnet")
# clear_ckpts("segmunich") 
clear_ckpts("dfc2020")
clear_ckpts("eurosat")
# clear_ckpts("marida")
# clear_ckpts("so2sat")
# clear_ckpts("landsat")