# Ranking des concepts par LIG-score

---
Goal of the notebook: Ranker des concepts par LIG-score.

Inputs of the notebook:
-.
Output of the notebook:

-.
Takeaways: 
- .
- .

In [1]:
# !pip install -r requirements.txt

In [2]:
import os
import sys
sys.path.append('../../run_experiments/')
sys.path.append('../../run_experiments/scripts')
sys.path.append('../../run_experiments/models')
sys.path.append('../../run_experiments/data')

import torch
import numpy as np
import pandas as pd

# import fonction for getting PLM and tokenizer
from models.utils import load_model_and_tokenizer

# library for managing memory RAM
import gc

import pickle
import json


In [3]:
torch.cuda.empty_cache()
gc.collect()

44

# 0. autoreload

In [4]:
#code for autoreload script associated with jupyter notebook
%load_ext autoreload
%autoreload 2

## 1.SETUP ENVIRONMENT VARIABLES

In [5]:
# import config
from load_config import load_config

model_name = 'gemma'    # 'bert-base-uncased' ou 'deberta-large' or 'gemma'
dataset    = 'movies'               # 'movies' / 'agnews' / 'dbpedia' / 'medical'/ 'ledgar'/ n24news
annotation = 'C3M'       # 'C3M' ou 'our_annotation' ou 'combined_annotation'
config = load_config(model_name, dataset)
config.annotation = annotation

# additionnal parameter for this notebook
agg_mode = "abs"
agg_scope = "all"

# 2. Import Dataframe et dataloader

In [6]:
# charger les données de train (facultatif : validation et de test augmented)
df_aug_train = pd.read_csv(f"{config.SAVE_PATH_CONCEPTS}/df_with_topics_v4.csv")

#### Black Box Model

In [7]:
from models.BaselineModel import BaselineModel

# import the PLM model and tokenizer and bottleneck layer
embedder_model, embedder_tokenizer, _, classifier = load_model_and_tokenizer(config, n_concepts = 4)

black_box_model = BaselineModel(embedder_model, classifier, None, None, None, config, save_path = config.SAVE_PATH)
black_box_model.load_model()

# Fonction pour vérifier si les paramètres du modèle sont en float64
def check_parameters_dtype(model):
    for param in model.parameters():
        if param.dtype != torch.float64:
            return False
    return True
    
# Fonction pour convertir les paramètres du modèle en float64
def convert_parameters_to_float64(model):
    for param in model.parameters():
        param.data = param.data.double()

# Supposons que `black_box_model` est l'instance du modèle
convert_parameters_to_float64(black_box_model)

# Vérifier si tous les paramètres sont maintenant en float64
is_float64 = check_parameters_dtype(black_box_model)

print(f"Tous les paramètres sont maintenant en float64 : {is_float64}")



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  self.classifier.load_state_dict(torch.load(f"{self.save_path}blue_checkpoints/{self.config.model_name}/BaselineModel/{self.embedder_model_name}_classifier_state_dict.pth"))
  self.embedder_model.load_state_dict(torch.load(f"{self.save_path}blue_checkpoints/{self.config.model_name}/BaselineModel/{self.embedder_model_name}_state_dict.pth"))


Aucune performance enregistrée trouvée.
Tous les paramètres sont maintenant en float64 : True


#### define function to wrap LIG

In [17]:
###############################################
# Fonction 0 : Wrapper le modèle black box pour captum
##############################################
# !!! attention black_box_model utiliser comme variable local ci-dessous !!!

from captum.attr import LayerIntegratedGradients

if config.model_name == 'gemma':
    def forward_LIG_black_box(input_ids, attention_mask=None):
        outputs = black_box_model.embedder_model(input_ids=input_ids, attention_mask=attention_mask)
        pooled = outputs[0][:, -1, :]
        logits = black_box_model.classifier(pooled)
    lig = LayerIntegratedGradients(
        forward_LIG_black_box,
        layer=black_box_model.embedder_model.layers[-1],
        # attribute_to_layer_input=True 
    )
else:   
    def forward_LIG_black_box(input_ids, attention_mask):
        outputs = black_box_model.embedder_model(input_ids=input_ids.to(config.device), attention_mask=attention_mask.to(config.device))
        pooled_output = outputs.last_hidden_state[:, 0, :]
        logits = black_box_model.classifier(pooled_output)
        return logits
    lig = LayerIntegratedGradients(
        forward_LIG_black_box,
        layer = black_box_model.embedder_model.encoder.layer[-1]
    )

# Automation

In [18]:
import json
from LIG_ranking import postprocess_cosine, compute_cosine_similarities
if config.model_name == 'gemma':
    from LIG_ranking import compute_attributions_on_gemma as compute_attributions
else:
    from LIG_ranking import compute_attributions

In [22]:
def main(mode="clip", agg_scope="present"):
    BATCH_SIZE = 1
    # Chemin pour sauvegarder le DataFrame des attributions
    path_attr = f"{config.SAVE_PATH}/blue_checkpoints/{config.model_name}/cavs/{config.cavs_type}/attributions_df_{config.cavs_type}_{config.annotation}.pkl"
    os.makedirs(os.path.dirname(path_attr), exist_ok=True)
    
    # Si le fichier existe déjà, on le charge, sinon on calcule et on sauvegarde
    if os.path.exists(path_attr):
        print("attributions_df déjà présent, chargement...")
        with open(path_attr, "rb") as f:
            attributions_df = pickle.load(f)
    else:
        attributions_df = compute_attributions(df_aug_train, BATCH_SIZE, embedder_tokenizer, lig, config.device)
        with open(path_attr, "wb") as f:
            pickle.dump(attributions_df, f)
        print("attributions_df sauvegardé à", path_attr)
    
    # Chargement des vecteurs CAV à partir d'un fichier JSON et conversion sur GPU
    file_path = f"{config.SAVE_PATH}/blue_checkpoints/{config.model_name}/cavs/{config.cavs_type}/cavs_{config.cavs_type}_{config.annotation}.json"
    if os.path.exists(file_path):
        with open(file_path, 'r') as f:
            cavs_vectors = json.load(f)
        cavs = {k: torch.tensor(v, dtype=torch.float32).to(config.device) for k, v in cavs_vectors.items()}
        print("cavs chargés à partir de", file_path)
    else:
        print("Fichier cavs introuvable :", file_path)
        return  # Arrêter l'exécution si le fichier n'existe pas

    # Chemin pour sauvegarder le DataFrame mis à jour avec les similarités cosinus
    path_cosine_df = f"{config.SAVE_PATH}/blue_checkpoints/{config.model_name}/cavs/{config.cavs_type}/df_aug_train_updated_{config.annotation}.pkl"
    if os.path.exists(path_cosine_df):
        print("DataFrame mis à jour déjà présent, chargement...")
        with open(path_cosine_df, "rb") as f:
            df_aug_train_updated = pickle.load(f)
    else:
        df_aug_train_updated = compute_cosine_similarities(attributions_df, df_aug_train, cavs, config.device)
        with open(path_cosine_df, "wb") as f:
            pickle.dump(df_aug_train_updated, f)
        print("df_aug_train_updated sauvegardé à", path_cosine_df)

    # Chemin pour sauvegarder le tri des concepts par moyenne des similarités cosinus
    file_path_2 = (f"{config.SAVE_PATH}/blue_checkpoints/{config.model_name}/"
                    f"cavs/{config.cavs_type}/sorted_macro_concepts_cosine_sm_{config.cavs_type}_{config.annotation}_{mode}_{agg_scope}.json"
                    )
    if os.path.exists(file_path_2):
        with open(file_path_2, "r") as f:
            sorted_concepts = json.load(f)
        print("Fichier sorted_macro_concepts_cosine_sm déjà présent, chargement...")
    else:
        df_aug_train_updated, sorted_concepts = postprocess_cosine(
            df_aug_train_updated, list(cavs.keys()), mode=mode, agg_scope=agg_scope
        )
    with open(file_path_2, "w") as f:
        json.dump(sorted_concepts, f, indent=4)
    print("Fichier sorted_macro_concepts_cosine_sm sauvegardé à", file_path_2)

    # Libération finale de la mémoire
    del attributions_df, cavs, cavs_vectors
    gc.collect()


In [23]:
main(mode="abs", agg_scope="all")

attributions_df déjà présent, chargement...
cavs chargés à partir de /home/bhan/Yann_CBM/Launch/dbfs/results_movies//blue_checkpoints/gemma/cavs/mean/cavs_mean_C3M.json
DataFrame mis à jour déjà présent, chargement...
Fichier sorted_macro_concepts_cosine_sm déjà présent, chargement...
Fichier sorted_macro_concepts_cosine_sm sauvegardé à /home/bhan/Yann_CBM/Launch/dbfs/results_movies//blue_checkpoints/gemma/cavs/mean/sorted_macro_concepts_cosine_sm_mean_C3M_abs_all.json
