# Visualize lambda weights

Visualize the calculated lambda weights ot understand whether the weight network always favors the original rotation or not.
Reproduce and extend the image in the Appendix D.5, p23.

In [1]:
%cd DL2-2024/

/teamspace/studios/this_studio/DL2-2024


In [2]:
%load_ext autoreload
%autoreload 2

import wandb
run = wandb.init()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33madamdivak[0m ([33mCV2-project[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
wandb_models = [
    {
        "model_name": "Lambda equitune (equivariant)",
        "model_link": "dl2-2024/dl-2024/Weighting_model:v39",
        "method": "equitune",
        "dataset_name": "CIFAR100"
    },
    {
        "model_name": "Lambda equiattention",
        "model_link": "dl2-2024/dl-2024/Weighting_model:v25",
        "method": "attention",
        "dataset_name": "CIFAR100"
    },
    {
        "model_name": "Lambda equitune (equivariant)",
        "model_link": "dl2-2024/dl-2024/Weighting_model:v37",
        "method": "equitune",
        "dataset_name": "ISIC2018"
    },
    {
        "model_name": "Lambda equiattention",
        "model_link": "dl2-2024/dl-2024/Weighting_model:v42",
        "method": "attention",
        "dataset_name": "ISIC2018"
    },
]

from EquiCLIP.visualize_lambda import main as visualize_lambda_main

for details in wandb_models:
    artifact = run.use_artifact(details["model_link"], type='model')
    artifact_dir = artifact.download()
    visualize_lambda_main([
        "--dataset_name", details["dataset_name"],
        "--method", details["method"],
        "--group_name", "rot90",
        "--data_transformations", "rot90",
        "--model_file", artifact.file(),
        "--model_display_name", details["model_name"],
        "--output_filename_suffix", details["model_name"]
    ])


Torch version: 2.0.1


[34m[1mwandb[0m:   1 of 1 files downloaded.  
Global seed set to 0


Namespace(seed=0, device='cuda:0', img_num=0, num_prefinetunes=10, data_transformations='rot90', group_name='rot90', method='equitune', model_name='RN50', dataset_name='CIFAR100', verbose=True, softmax=False, use_underscore=False, load=False, full_finetune=False, model_file='./artifacts/Weighting_model:v39/CIFAR100_RN50_aug_rot90_eq_rot90_steps_20.pt', output_filename_suffix='Lambda equitune (equivariant)', model_display_name='Lambda equitune (equivariant)', undersample=False, oversample=False, kaggle=False)
Model parameters: 102,007,137
Input resolution: 224
Context length: 77
Vocab size: 49408
Files already downloaded and verified
Files already downloaded and verified
loaded zeroshot weights!
Loading model from ./artifacts/Weighting_model:v39/CIFAR100_RN50_aug_rot90_eq_rot90_steps_20.pt


100%|██████████| 1250/1250 [00:43<00:00, 28.52it/s]
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Global seed set to 0


Namespace(seed=0, device='cuda:0', img_num=0, num_prefinetunes=10, data_transformations='rot90', group_name='rot90', method='attention', model_name='RN50', dataset_name='CIFAR100', verbose=True, softmax=False, use_underscore=False, load=False, full_finetune=False, model_file='./artifacts/Weighting_model:v25/CIFAR100_RN50_aug_rot90_eq_rot90_steps_20.pt', output_filename_suffix='Lambda equiattention', model_display_name='Lambda equiattention', undersample=False, oversample=False, kaggle=False)


In [7]:
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px
from pathlib import Path
pd.options.plotting.backend = "plotly"

def plot_all_weights(full_df, output_dir):
    # Normalize weights
    # The raw values don't matter, as in the end the lambda * feature values are divided by their sum, 
    # so essentially lambda values are normalized
    group_columns = ["0", "90", "180", "270"]
    other_columns = set(full_df.columns) - set(group_columns)
    full_df[group_columns] = full_df[group_columns].div(full_df[group_columns].sum(axis=1), axis=0)
    
    dataset_names = full_df["dataset_name"].unique()
    # We could do it more elegantly, but originally I wrote the code for a single dataset and I want
    # to make the smallest changes now..
    for dataset_name in dataset_names:
        df = full_df[full_df["dataset_name"] == dataset_name]

        df_statistics = df.groupby("model_display_name")[group_columns].agg(["mean", "std"])
        df_statistics = df_statistics.stack(level=0).reset_index().rename({"level_1": "group_transformation"}, axis=1)
        fig = px.bar(
            df_statistics, 
            x="group_transformation",
            y="mean", 
            # error_y="std", 
            facet_col="model_display_name",
            title=f"Normalized λ weight values for each input of {dataset_name}",
            labels={
                "model_display_name": "Model", 
                # "value": "Lambda weights mean±std", 
                "group_transformation": "Group transformation (rotation, deg)",
                "mean": "Mean λ weight"
            },
        )
        fig.write_image(f"{output_dir}/lamba_weight_means_{dataset_name}.svg")
        display(fig)

        fig = df[group_columns + ["model_display_name"]].plot(
            kind='box', 
            title=f"Normalized λ weight values for each input of {dataset_name}",
            labels={
                "model_display_name": "Model", 
                "value": "Normalized λ weights", 
                "variable": "Group transformation (rotation, deg)"},
            facet_col="model_display_name"
        )
        fig.write_image(f"{output_dir}/lamba_weight_box_{dataset_name}.svg")
        display(fig)

        fig = df[group_columns + ["model_display_name"]].plot(
            kind='histogram', 
            title=f"Histogram of λ weight values for each input of {dataset_name}",
            labels={
                "model_display_name": "Model", 
                "value": "λ weights", 
                "variable": "Group"},
            facet_col="model_display_name",
            facet_row="variable"
        )
        fig.write_image(f"{output_dir}/lamba_weight_histogram_{dataset_name}.svg")
        display(fig)

        def get_nonstandard_rotation_has_highest_weight(df):
            df_nonstandard_rotation_has_highest_weight = df[df["0"] < df[["90", "180", "270"]].max(axis=1)]
            ratio_nonstandard_rotation_has_highest_weight = df_nonstandard_rotation_has_highest_weight.shape[0] / df.shape[0]
            return f"For {ratio_nonstandard_rotation_has_highest_weight * 100 :.2f}% of samples the highest lambda weight is not for the original rotation"

        nonstandard_rotation_has_highest_weight = df.groupby("model_display_name").apply(get_nonstandard_rotation_has_highest_weight)
        print(nonstandard_rotation_has_highest_weight)

output_dir = Path("results/lambda_weights")
all_dfs = []
for df_path in output_dir.glob("*.csv"):
    df = pd.read_csv(df_path, index_col=0)
    all_dfs.append(df)
df = pd.concat(all_dfs)

df_statistics = plot_all_weights(df, output_dir / "../plots")
df_statistics
#df





model_display_name
Lambda equiattention             For 79.79% of samples the highest lambda weigh...
Lambda equitune (equivariant)    For 82.90% of samples the highest lambda weigh...
dtype: object








model_display_name
Lambda equiattention             For 30.54% of samples the highest lambda weigh...
Lambda equitune (equivariant)    For 27.75% of samples the highest lambda weigh...
dtype: object






# 