In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import re
import json

In [2]:
RESULTS_ROOT = './results/models'
model_name = "LESSVIT"

In [3]:
def clear_ckpts(dataset):
    import shutil
    cnt = 0
    dataset_dir = f"{RESULTS_ROOT}/archived/{dataset}"
    for dir in os.listdir(dataset_dir):
        if os.path.exists(os.path.join(dataset_dir, dir, "test_results.json")):
            # remove all ckpts dirs
            for sub_dir in os.listdir(os.path.join(dataset_dir, dir)):
                if "checkpoint" in sub_dir:
                    shutil.rmtree(os.path.join(dataset_dir, dir, sub_dir))
                    cnt += 1
    print(f"Cleared {cnt} ckpts for {dataset}")

In [43]:
def collect_results(model_name, dataset, metric_name, filter_key:dict=None):
    dataset_dir = f"{RESULTS_ROOT}/{dataset}"
    
    df_all = []
    for dir in os.listdir(dataset_dir):
        if model_name not in dir:
            continue
        target_file = os.path.join(dataset_dir, dir, "test_results.json")
        try:
            with open(target_file, 'r') as f:
                log = json.load(f)
        except:
            continue
        lp = "lp" in dir
        modal = "optical"
        if "radar" in dir:
            modal = "radar"
        elif "multi" in dir:
            modal = "multi"
        model_config = dir.split("_")
            
        moe_idx = next((i for i, x in enumerate(model_config) if x.startswith("moe")), None)
        if moe_idx is not None:
            log['moe'] = int(model_config.pop(moe_idx).replace("moe", ""))
        else:
            log['moe'] = 0
            
        topk_idx = next((i for i, x in enumerate(model_config) if x.startswith("topk")), None)
        if topk_idx is not None:
            log['topk'] = int(model_config.pop(topk_idx).replace("topk", ""))
        else:
            log['topk'] = 3
        
        # find the config start with ckpt
        ckpt_idx = next((i for i, x in enumerate(model_config) if x.startswith("ckpt")), None)
        if ckpt_idx is not None:
            log['ckpt'] = int(model_config.pop(ckpt_idx).replace("ckpt", ""))
        else:
            log['ckpt'] = 24600
        
        try:
            model_config = model_config[:7]
            _, embed_dims, depth, rank, _, lr, scale = tuple(model_config) 
            log['embed_dims'] = int(embed_dims.replace("b", ""))
            log['depth'] = int(depth.replace("d", ""))
            log['lr'] = float(lr.replace("lr", ""))
            log['scale'] = float(scale.replace("scale", ""))
            log['lp'] = lp
            log['modal'] = modal
            log['rank'] = int(rank.replace("r", ""))
        except:
            model_config = model_config[:6]
            _, embed_dims, depth, _, lr, scale = tuple(model_config) 
        
            log['embed_dims'] = int(embed_dims.replace("b", ""))
            log['depth'] = int(depth.replace("d", ""))
            log['lr'] = float(lr.replace("lr", ""))
            log['scale'] = float(scale.replace("scale", ""))
            log['lp'] = lp
            log['modal'] = modal
            log['rank'] = 1
        df_all.append(log)
            
    df_all = pd.DataFrame(df_all)
    if filter_key is not None:
        for key, value in filter_key.items():
            df_all = df_all.loc[df_all[key] == value]
    df_all = df_all.reset_index()
    df_all_stat = df_all.groupby(['embed_dims', 'depth', 'scale', 'moe', 'lp', 'modal', 'rank']).apply(lambda x: x.loc[x[metric_name].idxmax()])
    df_all_stat = df_all_stat.drop(columns=['embed_dims', 'depth', 'scale', 'moe', 'lp', 'modal', 'rank'])
    df_all_stat = df_all_stat.reset_index().drop(columns=['index'])
    df_all_stat.sort_values(by=[metric_name], ascending=False)
    return df_all, df_all_stat

## Classification

In [51]:
dataset = 'eurosat'
metric_name = 'eval_accuracy'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"lp": False, "scale": 2.0, "ckpt": 24600})
df_all_stat

Unnamed: 0,embed_dims,depth,scale,moe,lp,modal,rank,epoch,eval_accuracy,eval_loss,eval_runtime,eval_samples_per_second,eval_steps_per_second,topk,ckpt,lr
0,2,4,2.0,0,False,optical,2,19.84252,0.970185,0.109059,201.7549,26.765,0.838,3,24600,3e-05


In [45]:
dataset = 'bigearthnet'
metric_name = 'eval_micro_mAP'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"lp": True, "ckpt": 73800, "modal": "optical"})
df_all_stat

FileNotFoundError: [Errno 2] No such file or directory: './results/models/bigearthnet'

In [46]:
# df_plot_stat = df_all_stat.copy()
# df_plot_stat.rename(columns={"modal": "Modal"}, inplace=True)
# df_plot_stat.rename(columns={"moe": "# Experts"}, inplace=True)
# df_plot_stat.rename(columns={"eval_micro_mAP": "mAP"}, inplace=True)
# # change multi to Optical+Radar
# df_plot_stat.loc[df_plot_stat["Modal"] == "multi", "Modal"] = "Optical+Radar"
# df_plot_stat.loc[df_plot_stat["Modal"] == "optical", "Modal"] = "Optical"

# # drop rows with # Experts = 0
# # df_plot_stat = df_plot_stat[df_plot_stat["# Experts"] != 0]
# # all the map x 100
# df_plot_stat["mAP"] = df_plot_stat["mAP"] * 100

# # set the figure size
# plt.figure(figsize=(10, 4))
# sns.set_theme()

# # change modal to Modal
# sns.lineplot(data=df_plot_stat, x="# Experts", y="mAP", hue="Modal", marker="o")
# # change x tick labels
# plt.xticks(df_plot_stat["# Experts"].unique(), df_plot_stat["# Experts"].unique())
# plt.xlabel("# Experts", fontweight="bold")
# plt.ylabel("mAP", fontweight="bold")
# # bold the title
# plt.title("Linear Probing With Mixture of Experts (MoE)", fontweight="bold")
# # # add an horizontal line at y = 82.87, the color is the same as Optical+Radar
# plt.show()
# # save as pdf
# # plt.savefig("moe_mAP.pdf", bbox_inches="tight", dpi=300, pad_inches=0)

In [47]:
dataset = 'so2sat'
metric_name = 'eval_accuracy'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"lp": False, "scale": 4.0, "ckpt": 94200})
df_all_stat

2,embed_dims,depth,scale,moe,lp,modal,rank,epoch,eval_accuracy,eval_loss,eval_runtime,eval_samples_per_second,eval_steps_per_second,topk,ckpt,lr
0,2,8,4.0,0,False,optical,1,20.0,0.632496,1.405526,1142.4258,42.285,1.322,3,94200,1e-05


## Segmentation

In [48]:
dataset = 'segmunich'
metric_name = 'eval_IoU'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"ckpt": 94200})
df_all_stat

5,embed_dims,depth,scale,moe,lp,modal,rank,epoch,eval_IoU,eval_loss,eval_runtime,eval_samples_per_second,eval_steps_per_second,topk,ckpt,lr
0,2,8,1.0,0,False,optical,1,9.954914,0.422907,0.687459,499.0044,19.731,0.617,3,94200,0.0003


In [49]:
dataset = 'dfc2020'
metric_name = 'eval_IoU'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"ckpt": 24600})
df_all_stat

Unnamed: 0,embed_dims,depth,scale,moe,lp,modal,rank,epoch,eval_IoU,eval_loss,eval_runtime,eval_samples_per_second,eval_steps_per_second,topk,ckpt,lr
0,2,4,1.0,0,False,optical,2,9.984592,0.386357,0.864857,210.5819,42.14,1.32,3,24600,3e-05
1,2,4,1.0,0,False,optical,3,9.984592,0.390126,0.922564,203.5464,43.597,1.366,3,24600,3e-05
2,2,4,1.0,0,False,optical,5,9.984592,0.424618,0.798978,198.0917,44.797,1.403,3,24600,3e-05
3,2,4,1.0,0,False,optical,8,9.984592,0.473565,0.638685,197.644,44.899,1.407,3,24600,3e-05


In [50]:
dataset = 'marida'
metric_name = 'eval_IoU'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"ckpt": 94200})
df_all_stat

7,embed_dims,depth,scale,moe,lp,modal,rank,epoch,eval_IoU,eval_loss,eval_runtime,eval_samples_per_second,eval_steps_per_second,topk,ckpt,lr
0,2,8,1.0,0,False,optical,1,9.795918,0.556429,0.698661,134.5889,45.94,1.441,3,94200,8e-05


In [31]:
dataset = 'landsat'
metric_name = 'eval_IoU'

df_all, df_all_stat = collect_results(model_name, dataset, metric_name, filter_key={"ckpt": 73800})
df_all_stat

KeyError: 'ckpt'

In [13]:
df_all

Unnamed: 0,index,epoch,eval_IoU,eval_loss,eval_runtime,eval_samples_per_second,eval_steps_per_second,moe,topk,ckpt,embed_dims,depth,lr,scale,lp,modal
0,0,9.989339,0.232021,1.058459,155.5005,20.624,0.65,0,3,73800,2,4,0.0005,1.0,False,optical
1,1,9.989339,0.240693,1.038104,160.0591,20.036,0.631,0,3,73800,2,4,0.0001,1.0,False,optical
2,2,9.989339,0.113345,1.672956,180.3185,17.785,0.56,0,3,73800,2,4,0.0008,1.0,False,optical
3,3,9.836713,0.244034,1.017071,152.9754,20.964,0.66,0,3,73800,2,4,0.0003,1.0,False,optical
4,4,9.989339,0.237204,1.056266,195.0529,16.442,0.518,0,3,73800,2,4,8e-05,1.0,False,optical
5,5,9.836713,0.228334,1.091147,218.4693,14.679,0.462,0,3,73800,2,4,5e-05,1.0,False,optical
6,6,9.989339,0.229531,1.086186,164.8778,19.451,0.613,0,3,73800,2,4,5e-05,1.0,False,optical
7,7,9.836713,0.209672,1.146307,147.1823,21.789,0.686,0,3,73800,2,4,3e-05,1.0,False,optical
8,8,9.836713,0.234644,1.049432,188.5231,17.011,0.536,0,3,73800,2,4,0.0005,1.0,False,optical
9,9,9.989339,0.213849,1.132244,165.4537,19.383,0.61,0,3,73800,2,4,3e-05,1.0,False,optical


In [13]:
# clear_ckpts("bigearthnet")
# clear_ckpts("segmunich") 
clear_ckpts("dfc2020")
clear_ckpts("eurosat")
# clear_ckpts("marida")
# clear_ckpts("so2sat")
# clear_ckpts("landsat")

Cleared 0 ckpts for dfc2020
Cleared 8 ckpts for eurosat
