In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import re
import json

In [2]:
RESULTS_ROOT = './results/models'
model_name = "LESSVIT"

In [3]:
def clear_ckpts(dataset):
    import shutil
    cnt = 0
    dataset_dir = f"{RESULTS_ROOT}/{dataset}"
    for dir in os.listdir(dataset_dir):
        if os.path.exists(os.path.join(dataset_dir, dir, "test_results.json")):
            # remove all ckpts dirs
            for sub_dir in os.listdir(os.path.join(dataset_dir, dir)):
                if "checkpoint" in sub_dir:
                    shutil.rmtree(os.path.join(dataset_dir, dir, sub_dir))
                    cnt += 1
    print(f"Cleared {cnt} ckpts for {dataset}")

In [16]:
def collect_results(model_name, dataset, metric_name, filter_key:dict=None):
    dataset_dir = f"{RESULTS_ROOT}/{dataset}"
    
    df_all = []
    for dir in os.listdir(dataset_dir):
        target_file = os.path.join(dataset_dir, dir, "test_results.json")
        try:
            with open(target_file, 'r') as f:
                log = json.load(f)
        except:
            continue
        lp = False
        model_config = dir.split("_")
        if len(model_config) == 7:
            _, embed_dims, depth, _, lr, scale, moe = tuple(model_config) 
        elif len(model_config) == 8:
            _, embed_dims, depth, _, lr, scale, moe, lp = tuple(model_config)
            if lp == "lp":
                lp = True
        else:
            raise ValueError(f"Invalid model config: {dir}")
        log['embed_dims'] = int(embed_dims.replace("b", ""))
        log['depth'] = int(depth.replace("d", ""))
        log['lr'] = float(lr.replace("lr", ""))
        log['scale'] = float(scale.replace("scale", ""))
        log['moe'] = int(moe.replace("moe", ""))
        log['lp'] = lp
        df_all.append(log)
            
    df_all = pd.DataFrame(df_all)
    if filter_key is not None:
        for key, value in filter_key.items():
            df_all = df_all.loc[df_all[key] == value]
    df_all = df_all.reset_index()
    df_all_stat = df_all.groupby(['embed_dims', 'depth', 'scale', 'moe', 'lp']).apply(lambda x: x.loc[x[metric_name].idxmax()])
    df_all_stat = df_all_stat.drop(columns=['embed_dims', 'depth', 'scale', 'moe', 'lp'])
    df_all_stat = df_all_stat.reset_index().drop(columns=['index'])
    df_all_stat.sort_values(by=[metric_name], ascending=False)
    return df_all, df_all_stat

In [56]:
dataset = 'eurosat'
metric_name = 'eval_accuracy'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"lp": True, "scale": 2.0, "moe": 0})
df_all_stat

Unnamed: 0,embed_dims,depth,scale,moe,lp,epoch,eval_accuracy,eval_loss,eval_runtime,eval_samples_per_second,eval_steps_per_second,lr
0,1,4,2.0,0,True,100.0,0.944444,0.173747,1.2154,4443.059,139.051,0.008
1,1,8,2.0,0,True,100.0,0.942222,0.191339,1.9857,2719.381,85.107,0.05
2,2,4,2.0,0,True,100.0,0.945556,0.172206,1.3675,3948.748,123.581,0.05
3,2,8,2.0,0,True,100.0,0.947222,0.205646,1.0972,4921.607,154.028,0.08


In [63]:
dataset = 'eurosat'
metric_name = 'eval_accuracy'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"lp": False, "scale": 2.0})
df_all_stat

Unnamed: 0,embed_dims,depth,scale,moe,lp,epoch,eval_accuracy,eval_loss,eval_runtime,eval_samples_per_second,eval_steps_per_second,lr
0,1,4,2.0,0,False,20.0,0.977593,0.105504,61.7935,87.388,0.696,5e-05
1,1,8,2.0,0,False,20.0,0.977407,0.099328,46.5186,116.083,0.924,5e-05


In [42]:
df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"lp": False})
df_all_stat

Unnamed: 0,embed_dims,depth,scale,moe,lp,epoch,eval_IoU,eval_loss,eval_runtime,eval_samples_per_second,eval_steps_per_second,lr
0,1,4,1.0,0,False,9.612245,0.530126,0.668704,50.6193,122.147,0.968,0.0001
1,1,8,1.0,0,False,9.612245,0.547406,0.644652,48.763,126.797,1.005,8e-05
2,2,4,1.0,0,False,9.612245,0.503099,0.70816,47.8818,129.13,1.023,8e-05
3,2,8,1.0,0,False,9.612245,0.533339,0.720019,51.1565,120.864,0.958,0.0001


In [61]:
dataset = 'bigearthnet'
metric_name = 'eval_micro_mAP'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"lp": False})
df_all_stat

Unnamed: 0,embed_dims,depth,scale,moe,lp,epoch,eval_loss,eval_micro_mAP,eval_runtime,eval_samples_per_second,eval_steps_per_second,lr
0,2,4,1.0,0,False,9.909953,0.16715,0.856657,915.419,137.496,1.075,8e-05
1,2,8,1.0,0,False,9.909953,0.168938,0.857083,1063.7797,118.32,0.925,0.0003


In [62]:
dataset = 'bigearthnet'
metric_name = 'eval_micro_mAP'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"lp": True})
df_all_stat

Unnamed: 0,embed_dims,depth,scale,moe,lp,epoch,eval_loss,eval_micro_mAP,eval_runtime,eval_samples_per_second,eval_steps_per_second,lr
0,1,4,1.0,0,True,100.0,0.184154,0.822091,16.1357,7800.477,243.808,0.05
1,1,8,1.0,0,True,100.0,0.183818,0.82234,27.3648,4599.56,143.761,0.05
2,2,4,1.0,0,True,100.0,0.185422,0.820152,19.2884,6525.487,203.957,0.08
3,2,8,1.0,0,True,100.0,0.186232,0.818101,16.0193,7857.131,245.578,0.08


In [38]:
dataset = 'segmunich'
metric_name = 'eval_IoU'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name)
df_all_stat

Unnamed: 0,embed_dims,depth,scale,moe,lp,epoch,eval_IoU,eval_loss,eval_runtime,eval_samples_per_second,eval_steps_per_second,lr
0,1,4,1.0,0,False,9.994595,0.409236,0.709748,152.4886,64.569,0.505,0.0001
1,1,8,1.0,0,False,9.994595,0.412143,0.721112,489.7486,20.104,0.157,8e-05
2,2,4,1.0,0,False,9.994595,0.410334,0.724537,160.3658,61.397,0.48,0.0001
3,2,8,1.0,0,False,9.994595,0.410663,0.718143,142.164,69.258,0.542,0.0001


In [25]:
dataset = 'dfc2020'
metric_name = 'eval_IoU'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name)
df_all_stat

Unnamed: 0,embed_dims,depth,scale,moe,lp,epoch,eval_IoU,eval_loss,eval_runtime,eval_samples_per_second,eval_steps_per_second,lr
0,1,4,1.0,0,False,9.941538,0.448969,0.628994,85.5797,103.693,0.818,8e-05
1,1,8,1.0,0,False,9.941538,0.485194,0.597362,84.8937,104.531,0.825,3e-05
2,2,4,1.0,0,False,9.941538,0.487938,0.594833,80.0677,110.831,0.874,3e-05
3,2,8,1.0,0,False,9.941538,0.476684,0.521245,80.8232,109.795,0.866,0.0008


In [26]:
dataset = 'so2sat'
metric_name = 'eval_accuracy'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"lp": True, "scale": 4.0})
df_all_stat

Unnamed: 0,embed_dims,depth,scale,moe,lp,epoch,eval_accuracy,eval_loss,eval_runtime,eval_samples_per_second,eval_steps_per_second,lr
0,1,4,4.0,0,True,100.0,0.617219,1.363282,6.5719,7350.488,229.765,0.005
1,1,8,4.0,0,True,100.0,0.608173,1.406054,6.1894,7804.733,243.964,0.005
2,2,4,4.0,0,True,100.0,0.635539,1.27124,6.3292,7632.399,238.577,0.005
3,2,8,4.0,0,True,100.0,0.625789,1.306871,8.0083,6032.11,188.554,0.005


In [40]:
dataset = 'marida'
metric_name = 'eval_IoU'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name)
df_all_stat

Unnamed: 0,embed_dims,depth,scale,moe,lp,epoch,eval_IoU,eval_loss,eval_runtime,eval_samples_per_second,eval_steps_per_second,lr
0,1,4,1.0,0,False,9.612245,0.530126,0.668704,50.6193,122.147,0.968,0.0001
1,1,8,1.0,0,False,9.612245,0.547406,0.644652,48.763,126.797,1.005,8e-05
2,2,4,1.0,0,False,9.612245,0.503099,0.70816,47.8818,129.13,1.023,8e-05
3,2,8,1.0,0,False,9.612245,0.533339,0.720019,51.1565,120.864,0.958,0.0001


In [41]:
# clear_ckpts("bigearthnet")
# clear_ckpts("segmunich") 
# clear_ckpts("dfc2020")
# clear_ckpts("eurosat")
# clear_ckpts("marida")

Cleared 32 ckpts for marida
