This notebook aims to do some experiments to show the impacts of **margin** from *pitome*.

**Margin** *m* in Energy score for each node(token) works as a dynamic thresholding value that determines whether two tokens belong to the same object.

In [1]:
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
# change working directory otherwise accessing data folder fails
os.chdir('..')
from evaluate import EvaluateArgs, evaluate
from tome_sam.utils.tome_presets import SAMToMeSetting, BSMToMe, ToMeConfig, PiToMe
from flops import get_flops
from tome_sam.build_tome_sam import SAM_CONFIGS
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import product

In [2]:
EXPERIMENT_RESULTS = pd.DataFrame({
    "model": pd.Series(dtype="str"),
    "dataset": pd.Series(dtype="str"),
    "image_size": pd.Series(dtype="object"),
    "tome_settings": pd.Series(dtype="object"),
    "mask_iou": pd.Series(dtype="float"),
    "boundary_iou": pd.Series(dtype="float"),
    "FLOPS": pd.Series(dtype="float"),
    "im/s": pd.Series(dtype="float"),
    "Experiment": pd.Series(dtype="str")
})

def update_experiment_results(*, result_df: pd.DataFrame, evaluate_args: EvaluateArgs, eval_results, flops: float, experiment: str):
    new_row = { "model": evaluate_args.model_type,
                "dataset": evaluate_args.dataset,
                "image_size": evaluate_args.input_size,
                "tome_settings": evaluate_args.tome_setting,
                "mask_iou": eval_results["mask_iou"],
                "boundary_iou": eval_results["boundary_iou"],
                "FLOPS": flops["flops/img(image_encoder)"],
                "im/s": eval_results["im/s"],
                "Experiment": experiment
              }
    return pd.concat([result_df, pd.DataFrame([new_row])], ignore_index=True)



def plot_heatmap(param1, param2, target_values, xlabel="param1", ylabel="param2", target_label="target_label",
                 cmap="RdYlGn", annot=True, fmt=".2f"):
    plt.figure(figsize=(8, 6))
    ax = sns.heatmap(
        target_values,
        xticklabels=param1,
        yticklabels=param2,
        cmap=cmap,
        annot=annot,
        fmt=fmt,
        cbar_kws={'label': target_label},
    )
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title('Heatmap')
    plt.show()


def plot_correlation_matrix(data, title="Correlation Matrix", cmap="RdYlGn", annot=True, fmt=".2f"):
    correlation_matrix = data.corr()
    plt.figure(figsize=(8, 6))
    sns.heatmap(
        correlation_matrix,
        annot=annot,
        cmap=cmap,
        fmt=fmt,
        cbar_kws={'label': 'Correlation Coefficient'},
    )
    plt.title(title)
    plt.show()


def plot_line_graph(x_values, y_values, title, x_label, y_label):
    plt.figure(figsize=(8, 6))
    plt.plot(x_values, y_values, marker='o', linestyle='-', color='blue', label='Line') 
    plt.title(title)
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.legend()
    plt.tight_layout()
    plt.show()

#### Pure SAM

In [3]:
%%capture pure_sam
evaluate_args = EvaluateArgs(
    dataset="dis",
    output="",
    model_type="vit_b",
    checkpoint="checkpoints/sam_vit_b_01ec64.pth",
    device="mps",
    seed=0,
    input_size=[1024, 1024],
    batch_size=1,
    multiple_masks=False,
    tome_setting = None,
)

eval_results = evaluate(evaluate_args)
flops_per_image = get_flops(evaluate_args)

EXPERIMENT_RESULTS = update_experiment_results(result_df=EXPERIMENT_RESULTS, 
                                               evaluate_args=evaluate_args,
                                               eval_results=eval_results, 
                                               flops=flops_per_image,
                                               experiment='baseline')

#### Global Attention 

Only apply pitome with different margin r and  to ViT blocks which do global attention inside SAM image encoder.

In [10]:
vit_b_config = SAM_CONFIGS["vit-b"]
print(f"{vit_b_config.model_type} has in total {vit_b_config.depth} vit layers.")
print(f"Among all layers, it only calculates GLOBAL ATTENTION in {vit_b_config.global_attn_indexes}.")

vit-b has in total 12 vit layers.
Among all layers, it only calculates GLOBAL ATTENTION in [2, 5, 8, 11].


Apply the same combinations of reduce rate **r** and **margin** onto all global attention layer

In [8]:
%%capture global_attn_1

r_values = list(np.arange(0.1, 0.6, 0.1))  # 0.1 <= r <= 0.5, step size 0.1
margin_values = list(np.arange(0.0, 1.1, 0.1))  # 0.0 <= margin <= 1.0, step size
alpha_value = 1.0 # alpha has a fixed value

for r, margin in product(r_values, margin_values):
    test_pitome_settings: SAMToMeSetting = {
        global_layer_idx: ToMeConfig(
            mode='pitome',
            params=PiToMe(r=r, margin=margin, alpha=alpha_value)
        )
        for global_layer_idx in vit_b_config.global_attn_indexes
    }
    evaluate_args = EvaluateArgs(
    dataset="dis",
    output="",
    model_type="vit_b",
    checkpoint="checkpoints/sam_vit_b_01ec64.pth",
    device="mps",
    seed=0,
    input_size=[1024, 1024],
    batch_size=1,
    multiple_masks=False,
    tome_setting = test_pitome_settings,
    )
    
    eval_results = evaluate(evaluate_args)
    flops_per_image = get_flops(evaluate_args)
    EXPERIMENT_RESULTS = update_experiment_results(result_df=EXPERIMENT_RESULTS, 
                                                   evaluate_args=evaluate_args,
                                                   eval_results=eval_results,
                                                   flops=flops_per_image,
                                                   experiment='same r and margin for all global attn')

--- Create valid dataloader with dataset dis ---
------------------------------ valid --------------------------------
--->>> dataset:  DIS5K-VD <<<---
-im- DIS5K-VD ./data/DIS5K/DIS-VD/im :  470
-gt- DIS5K-VD ./data/DIS5K/DIS-VD/gt :  470
--- Valid dataloader with dataset dis created ---
--- Create SAM vit_b with token merging in layers {2: ToMeConfig(mode='pitome', params=PiToMe(r=0.1, margin=0.0, alpha=1.0)), 5: ToMeConfig(mode='pitome', params=PiToMe(r=0.1, margin=0.0, alpha=1.0)), 8: ToMeConfig(mode='pitome', params=PiToMe(r=0.1, margin=0.0, alpha=1.0)), 11: ToMeConfig(mode='pitome', params=PiToMe(r=0.1, margin=0.0, alpha=1.0))} ---
--- Start evaluation ---
valid dataloader length: 470
  [  0/470]  eta: 0:09:04  mask_iou: 0.7477 (0.7477)  boundary_iou: 0.5125 (0.5125)  time: 1.1596  data: 0.0531
  [200/470]  eta: 0:05:54  mask_iou: 0.5691 (0.4796)  boundary_iou: 0.4713 (0.4090)  time: 1.1937  data: 0.0857
  [400/470]  eta: 0:01:32  mask_iou: 0.4486 (0.4856)  boundary_iou: 0.4057 (

                                                                                                                                                                  

--- Create valid dataloader with dataset dis ---
------------------------------ valid --------------------------------
--->>> dataset:  DIS5K-VD <<<---
-im- DIS5K-VD ./data/DIS5K/DIS-VD/im :  470
-gt- DIS5K-VD ./data/DIS5K/DIS-VD/gt :  470
--- Valid dataloader with dataset dis created ---
--- Create SAM vit_b with token merging in layers {2: ToMeConfig(mode='pitome', params=PiToMe(r=0.1, margin=0.1, alpha=1.0)), 5: ToMeConfig(mode='pitome', params=PiToMe(r=0.1, margin=0.1, alpha=1.0)), 8: ToMeConfig(mode='pitome', params=PiToMe(r=0.1, margin=0.1, alpha=1.0)), 11: ToMeConfig(mode='pitome', params=PiToMe(r=0.1, margin=0.1, alpha=1.0))} ---


  state_dict = torch.load(f)


--- Start evaluation ---
valid dataloader length: 470
  [  0/470]  eta: 0:10:13  mask_iou: 0.7659 (0.7659)  boundary_iou: 0.5468 (0.5468)  time: 1.3062  data: 0.0634
  [200/470]  eta: 0:05:39  mask_iou: 0.5793 (0.4796)  boundary_iou: 0.4831 (0.4093)  time: 1.1572  data: 0.0803


KeyboardInterrupt: 

In the pitome paper, it also suggests to *dynamically* adjust **margin** according to the layer order. $m = 0.9 − 0.9 × l_i/l$, where $l_i$ is the current layer index and $l$ is the total number of encoder layers, indicating an increasing margin as tokens move to deeper layers.

In [12]:
# %%capture global_dynamic_attn

r_values = list(np.arange(0.1, 0.6, 0.1))  # 0.1 <= r <= 0.5, step size 0.1
alpha_value = 1.0 # alpha has a fixed value

def get_dynamic_margin(current_layer_idx, total_layer_idx) -> float:
    assert current_layer_idx <= total_layer_idx
    return 0.9 - 0.9 * (current_layer_idx/total_layer_idx)
    
for r in r_values:
    test_pitome_settings: SAMToMeSetting = {
        global_layer_idx: ToMeConfig(
            mode='pitome',
            params=PiToMe(r=r, margin=get_dynamic_margin(global_layer_idx, vit_b_config.depth), alpha=alpha_value)
        )
        for global_layer_idx in vit_b_config.global_attn_indexes
    }
    evaluate_args = EvaluateArgs(
    dataset="dis",
    output="",
    model_type="vit_b",
    checkpoint="checkpoints/sam_vit_b_01ec64.pth",
    device="mps",
    seed=0,
    input_size=[1024, 1024],
    batch_size=1,
    multiple_masks=False,
    tome_setting = test_pitome_settings,
    )
    
    eval_results = evaluate(evaluate_args)
    flops_per_image = get_flops(evaluate_args)
    EXPERIMENT_RESULTS = update_experiment_results(result_df=EXPERIMENT_RESULTS, 
                                                   evaluate_args=evaluate_args,
                                                   eval_results=eval_results, 
                                                   flops=flops_per_image,
                                                   experiment='same r but dynamic margin for all global attn')

--- Create valid dataloader with dataset dis ---
------------------------------ valid --------------------------------
--->>> dataset:  DIS5K-VD <<<---
-im- DIS5K-VD ./data/DIS5K/DIS-VD/im :  470
-gt- DIS5K-VD ./data/DIS5K/DIS-VD/gt :  470
--- Valid dataloader with dataset dis created ---
--- Create SAM vit_b with token merging in layers {2: ToMeConfig(mode='pitome', params=PiToMe(r=0.1, margin=0.75, alpha=1.0)), 5: ToMeConfig(mode='pitome', params=PiToMe(r=0.1, margin=0.525, alpha=1.0)), 8: ToMeConfig(mode='pitome', params=PiToMe(r=0.1, margin=0.30000000000000004, alpha=1.0)), 11: ToMeConfig(mode='pitome', params=PiToMe(r=0.1, margin=0.07500000000000007, alpha=1.0))} ---


  state_dict = torch.load(f)


--- Start evaluation ---
valid dataloader length: 470


  dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)


  [  0/470]  eta: 0:26:35  mask_iou: 0.7843 (0.7843)  boundary_iou: 0.5698 (0.5698)  im/s: 0.3100 (0.3100)  time: 3.3937  data: 0.0566
  [200/470]  eta: 0:05:27  mask_iou: 0.5834 (0.4828)  boundary_iou: 0.4677 (0.4106)  im/s: 1.1200 (1.0839)  time: 1.1553  data: 0.0868
  [400/470]  eta: 0:01:27  mask_iou: 0.5295 (0.4879)  boundary_iou: 0.3887 (0.4129)  im/s: 1.1200 (1.0750)  time: 1.2319  data: 0.1019
  [469/470]  eta: 0:00:01  mask_iou: 0.5906 (0.4875)  boundary_iou: 0.4451 (0.4128)  im/s: 1.1200 (1.0754)  time: 1.0970  data: 0.0658
 Total time: 0:09:46 (1.2477 s / it)
Averaged stats: mask_iou: 0.5906 (0.4875)  boundary_iou: 0.4451 (0.4128)  im/s: 1.1200 (1.0754)
--- Create valid dataloader with dataset dis ---
------------------------------ valid --------------------------------
--->>> dataset:  DIS5K-VD <<<---
-im- DIS5K-VD ./data/DIS5K/DIS-VD/im :  470
-gt- DIS5K-VD ./data/DIS5K/DIS-VD/gt :  470
--- Valid dataloader with dataset dis created ---
--- Create SAM vit_b with token mergi

                                                                                                                                                                                

--- Create valid dataloader with dataset dis ---
------------------------------ valid --------------------------------
--->>> dataset:  DIS5K-VD <<<---
-im- DIS5K-VD ./data/DIS5K/DIS-VD/im :  470
-gt- DIS5K-VD ./data/DIS5K/DIS-VD/gt :  470
--- Valid dataloader with dataset dis created ---
--- Create SAM vit_b with token merging in layers {2: ToMeConfig(mode='pitome', params=PiToMe(r=0.2, margin=0.75, alpha=1.0)), 5: ToMeConfig(mode='pitome', params=PiToMe(r=0.2, margin=0.525, alpha=1.0)), 8: ToMeConfig(mode='pitome', params=PiToMe(r=0.2, margin=0.30000000000000004, alpha=1.0)), 11: ToMeConfig(mode='pitome', params=PiToMe(r=0.2, margin=0.07500000000000007, alpha=1.0))} ---


  state_dict = torch.load(f)


--- Start evaluation ---
valid dataloader length: 470
  [  0/470]  eta: 0:21:39  mask_iou: 0.7567 (0.7567)  boundary_iou: 0.5126 (0.5126)  im/s: 0.3800 (0.3800)  time: 2.7646  data: 0.0561


KeyboardInterrupt: 