In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import re
import json

In [2]:
RESULTS_ROOT = './results/models'
model_name = "LESSVIT"

In [3]:
def clear_ckpts(dataset):
    import shutil
    cnt = 0
    dataset_dir = f"{RESULTS_ROOT}/{dataset}"
    for dir in os.listdir(dataset_dir):
        if os.path.exists(os.path.join(dataset_dir, dir, "test_results.json")):
            # remove all ckpts dirs
            for sub_dir in os.listdir(os.path.join(dataset_dir, dir)):
                if "checkpoint" in sub_dir:
                    shutil.rmtree(os.path.join(dataset_dir, dir, sub_dir))
                    cnt += 1
    print(f"Cleared {cnt} ckpts for {dataset}")

In [4]:
def collect_results(model_name, dataset, metric_name, filter_key:dict=None):
    dataset_dir = f"{RESULTS_ROOT}/{dataset}"
    
    df_all = []
    for dir in os.listdir(dataset_dir):
        target_file = os.path.join(dataset_dir, dir, "test_results.json")
        try:
            with open(target_file, 'r') as f:
                log = json.load(f)
        except:
            continue
        lp = False
        model_config = dir.split("_")
        if model_config[-1].startswith("ckpt"):
            log['ckpt'] = int(model_config.pop().replace("ckpt", ""))
        else:
            log['ckpt'] = 24600
        if len(model_config) == 7:
            _, embed_dims, depth, _, lr, scale, moe = tuple(model_config) 
        elif len(model_config) == 8:
            _, embed_dims, depth, _, lr, scale, moe, lp = tuple(model_config)
            if lp == "lp":
                lp = True
        else:
            raise ValueError(f"Invalid model config: {dir}")
        log['embed_dims'] = int(embed_dims.replace("b", ""))
        log['depth'] = int(depth.replace("d", ""))
        log['lr'] = float(lr.replace("lr", ""))
        log['scale'] = float(scale.replace("scale", ""))
        log['moe'] = int(moe.replace("moe", ""))
        log['lp'] = lp
        df_all.append(log)
            
    df_all = pd.DataFrame(df_all)
    if filter_key is not None:
        for key, value in filter_key.items():
            df_all = df_all.loc[df_all[key] == value]
    df_all = df_all.reset_index()
    df_all_stat = df_all.groupby(['embed_dims', 'depth', 'scale', 'moe', 'lp']).apply(lambda x: x.loc[x[metric_name].idxmax()])
    df_all_stat = df_all_stat.drop(columns=['embed_dims', 'depth', 'scale', 'moe', 'lp'])
    df_all_stat = df_all_stat.reset_index().drop(columns=['index'])
    df_all_stat.sort_values(by=[metric_name], ascending=False)
    return df_all, df_all_stat

## Classification

In [5]:
dataset = 'eurosat'
metric_name = 'eval_accuracy'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"lp": False, "scale": 2.0, "ckpt": 73800})
df_all_stat

Unnamed: 0,embed_dims,depth,scale,moe,lp,epoch,eval_accuracy,eval_loss,eval_runtime,eval_samples_per_second,eval_steps_per_second,ckpt,lr
0,2,4,2.0,0,False,20.0,0.981111,0.101385,52.0294,103.788,0.826,73800,8e-05
1,2,4,2.0,3,False,20.0,0.979815,0.101409,53.6486,100.655,0.802,73800,5e-05


In [6]:
dataset = 'bigearthnet'
metric_name = 'eval_micro_mAP'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"lp": False, "ckpt": 73800})
df_all_stat

Unnamed: 0,embed_dims,depth,scale,moe,lp,epoch,eval_loss,eval_micro_mAP,eval_runtime,eval_samples_per_second,eval_steps_per_second,ckpt,lr
0,2,4,1.0,0,False,9.909953,0.164787,0.860342,1003.9924,125.365,0.98,73800,8e-05
1,2,4,1.0,3,False,9.909953,0.163694,0.862077,963.0682,130.693,1.022,73800,0.0005


In [23]:
dataset = 'so2sat'
metric_name = 'eval_accuracy'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"lp": False, "scale": 4.0, "ckpt": 73800})
df_all_stat

Unnamed: 0,embed_dims,depth,scale,moe,lp,epoch,eval_accuracy,eval_loss,eval_runtime,eval_samples_per_second,eval_steps_per_second,ckpt,lr
0,2,4,4.0,0,False,20.0,0.635891,2.228079,361.368,133.678,1.046,73800,5e-05


## Segmentation

In [8]:
dataset = 'segmunich'
metric_name = 'eval_IoU'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"ckpt": 73800})
df_all_stat

6,embed_dims,depth,scale,moe,lp,epoch,eval_IoU,eval_loss,eval_runtime,eval_samples_per_second,eval_steps_per_second,ckpt,lr
0,2,4,1.0,0,False,9.994595,0.419522,0.694597,120.4717,81.729,0.639,73800,0.0003


In [9]:
dataset = 'dfc2020'
metric_name = 'eval_IoU'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"ckpt": 73800})
df_all_stat

2,embed_dims,depth,scale,moe,lp,epoch,eval_IoU,eval_loss,eval_runtime,eval_samples_per_second,eval_steps_per_second,ckpt,lr
0,2,4,1.0,0,False,9.941538,0.441486,0.781353,76.1034,116.604,0.92,73800,1e-05


In [11]:
dataset = 'marida'
metric_name = 'eval_IoU'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"ckpt": 73800})
df_all_stat

6,embed_dims,depth,scale,moe,lp,epoch,eval_IoU,eval_loss,eval_runtime,eval_samples_per_second,eval_steps_per_second,ckpt,lr
0,2,4,1.0,0,False,9.612245,0.531024,0.738966,53.3294,115.94,0.919,73800,5e-05


In [11]:
dataset = 'landsat'
metric_name = 'eval_IoU'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name)
df_all_stat

2,embed_dims,depth,scale,moe,lp,epoch,eval_IoU,eval_loss,eval_runtime,eval_samples_per_second,eval_steps_per_second,lr
0,2,4,1.0,0,False,9.861314,0.241159,1.055351,67.7246,55.371,0.443,0.0001


In [12]:
# clear_ckpts("bigearthnet")
# clear_ckpts("segmunich") 
# clear_ckpts("dfc2020")
# clear_ckpts("eurosat")
# clear_ckpts("marida")
clear_ckpts("so2sat")

Cleared 32 ckpts for so2sat
