In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import os, sys

module_path = os.path.abspath(os.path.join("../.."))
sys.path.append(module_path)

## Dataset and hyperparameters loading

In [3]:
from torchvision.transforms.v2 import Compose
from hyperparameters import load_hyperparameters_from_json

from SLTDataset import SLTDataset
from posecraft.Pose import Pose


DATASET = "GSL"
EXPERIMENT_ID = "frosty-haze-24"

dataset_path = f"/mnt/disk3Tb/slt-datasets/{DATASET}"
experiment_path = f"results/{DATASET}/{EXPERIMENT_ID}"
hp = load_hyperparameters_from_json(f"{experiment_path}/hp.json")
output_path = f"{experiment_path}/interp/avg"
os.makedirs(output_path, exist_ok=True)
transparent_plot = False
decoder_attn_weights_layer = 0

landmarks_mask = Pose.get_components_mask(hp["LANDMARKS_USED"])
transforms: Compose = Compose(hp["TRANSFORMS"])

train_dataset = SLTDataset(
    data_dir=dataset_path,
    split="train",
    input_mode=hp["INPUT_MODE"],
    output_mode=hp["OUTPUT_MODE"],
    transforms=transforms,
    max_tokens=hp["MAX_TOKENS"],
)
test_dataset = SLTDataset(
    data_dir=dataset_path,
    split="test",
    input_mode=hp["INPUT_MODE"],
    output_mode=hp["OUTPUT_MODE"],
    transforms=transforms,
    max_tokens=hp["MAX_TOKENS"],
)

Loaded metadata for dataset: The Greek Sign Language (GSL) Dataset
Loaded train annotations at /mnt/disk3Tb/slt-datasets/GSL/annotations.csv


Validating files: 100%|██████████| 8821/8821 [00:00<00:00, 230650.51it/s]


Dataset loaded correctly

Loaded metadata for dataset: The Greek Sign Language (GSL) Dataset
Loaded test annotations at /mnt/disk3Tb/slt-datasets/GSL/annotations.csv


Validating files: 100%|██████████| 881/881 [00:00<00:00, 195677.92it/s]

Dataset loaded correctly






In [4]:
import torch

device = torch.device(
    "mps"
    if torch.backends.mps.is_available()
    else ("cuda" if torch.cuda.is_available() else "cpu")
)

## Model

### Definition

In [5]:
import glob
from LightningKeypointsTransformer import LKeypointsTransformer

checkpoint_path = glob.glob(f"{experiment_path}/best*")[0]
try:
    l_model = LKeypointsTransformer.load_from_checkpoint(checkpoint_path)
    model = l_model.model
    translator = l_model.translator
except:
    from helpers import load_from_old_checkpoint

    model, translator = load_from_old_checkpoint(
        checkpoint_path, hp, device, landmarks_mask, train_dataset
    )

  checkpoint = torch.load(checkpoint_path, map_location=device)["state_dict"]


In [6]:
model = model.to(device)
model = model.eval()

In [7]:
BOS_IDX = train_dataset.tokenizer.cls_token_id
EOS_IDX = train_dataset.tokenizer.sep_token_id

### Interpretability

In [8]:
import pandas as pd
import numpy as np
import seaborn as sns
import ast
import matplotlib.pyplot as plt

from typing import List, Callable

from KeypointsTransformer import KeypointsTransformer
from Translator import Translator
from hyperparameters import HyperParameters

from interp.plot_functions import reorganize_list

In [9]:
def get_pose_text(dataset: SLTDataset, index: int, device: torch.device):
    src, _ = test_dataset[index]
    _, tgt = test_dataset.get_item_raw(index)
    src = src.unsqueeze(0)
    src = src.to(device)
    return src, tgt

In [10]:
translation_outputs = []
for i in range(len(test_dataset)):
    src, tgt = get_pose_text(test_dataset, i, device)
    translation = translator.translate(src, model, "greedy", test_dataset.tokenizer)[0]
    translation_outputs.append((i, tgt, translation))

In [11]:
df = pd.DataFrame(translation_outputs, columns=["id", "expected", "predicted"])
df.set_index("id", inplace=True)

# Calculate the percentage of each class
class_counts = df[
    "expected"
].value_counts()  # Count the frequency of each class in the 'expected' column
class_percentages = (class_counts / len(df)) * 100

# Calculate the accuracy percentage for each class
correct_predictions = df[df["expected"] == df["predicted"]]
correct_class_counts = correct_predictions["expected"].value_counts()
accuracy_percentages = (correct_class_counts / class_counts) * 100

result = pd.DataFrame(
    {
        "Class": class_counts.index,
        "Count": class_counts.values,
        "Percentage": class_percentages.values,
        "Accuracy Percentage": accuracy_percentages.reindex(
            class_counts.index, fill_value=0
        ).values,
    }
)

result["Accuracy Percentage"] = result["Accuracy Percentage"].fillna(0)
result = result.sort_values(by="Accuracy Percentage", ascending=True)

# Count the number of words in the 'Class' column
result["Word Count"] = result["Class"].apply(lambda x: len(x.split()))

# Create the index mapping
class_to_indices = {}
for i in range(len(test_dataset)):
    _, text = test_dataset.get_item_raw(i)
    if text in class_to_indices:
        class_to_indices[text].append(i)
    else:
        class_to_indices[text] = [i]

result.index = result["Class"].map(lambda cls: tuple(class_to_indices.get(cls, [])))
result.index.name = None

result.to_csv(f"{output_path}/class_percentages.csv")

In [12]:
# https://stackoverflow.com/a/70712921
# https://stackoverflow.com/a/76377041
from IPython.display import display, HTML

pd.set_option("display.max_rows", None)

display(HTML("<div style='height: 200px'>" + result.style.to_html() + "</div>"))

Unnamed: 0,Class,Count,Percentage,Accuracy Percentage,Word Count
"(625,)",NAI ΣΥΝ ΕΓΩ(1) ΖΑΛΙΖΟΜΑΙ,1,0.113507,0.0,4
"(336, 357, 378)",ΕΣΥ ΤΑΥΤΟΤΗΤΑ ΦΩΤΟΤΥΠΙΑ ΕΧΩ,3,0.340522,0.0,4
"(273, 297, 321)",ΚΑΛΟ ΕΓΩ(1) ΧΡΕΙΑΖΟΜΑΙ ΤΙΠΟΤΑ,3,0.340522,0.0,4
"(0, 15, 30)",ΓΕΙΑ ΕΓΩ(1) ΜΠΟΡΩ ΒΟΗΘΕΙΑ,3,0.340522,0.0,4
"(270, 294, 318)",ΠΑΡΑΒΟΛΟ ΕΓΩ(1) ΠΛΗΡΩΝΩ ΠΟΥ;,3,0.340522,0.0,4
"(398, 417, 435)",ΝΑΙ ΟΡΙΣΤΕ,3,0.340522,0.0,2
"(267, 291, 315)",ΕΣΥ ΠΡΕΠΕΙ ΠΛΗΡΩΝΩ ΓΙΑ ΟΝΟΜΑ ΜΕΤΑΦΟΡΑ ΕΣΥ ΑΥΤΟΚΙΝΗΤΟ ΠΟΣΟ; ΚΥΒΙΚΑ,3,0.340522,0.0,10
"(261, 285, 309)",ΕΓΩ(1) ΧΡΕΙΑΖΟΜΑΙ ΤΑΥΤΟΤΗΤΑ ΔΙΚΟ_ΣΟΥ ΣΥΝ 1 ΦΩΤΟΤΥΠΙΑ,3,0.340522,0.0,7
"(834, 857, 880)",ΕΣΥ ΚΑΛΟ ΚΑΛΟ ΣΥΝΕΧΕΙΑ,3,0.340522,0.0,4
"(48, 71, 94)",ΠΑΡΕΛΘΟΝ_ΠΡΟΣΦΑΤΟ,3,0.340522,0.0,1


In [13]:
result = pd.read_csv(f"{output_path}/class_percentages.csv")

# Group by 'Word Count' and calculate the number of classes and average accuracy percentage
stats = (
    result.groupby("Word Count")
    .agg(
        Class_Count=("Class", "count"),
        Average_Accuracy_Percentage=("Accuracy Percentage", "mean"),
    )
    .reset_index()
)
stats = stats.sort_values(by="Average_Accuracy_Percentage", ascending=True)

print(stats)

    Word Count  Class_Count  Average_Accuracy_Percentage
12          17            1                     0.000000
10          11            1                    33.333333
11          12            1                    33.333333
0            1           13                    53.620401
8            9            6                    61.111111
6            7           21                    61.904762
1            2           24                    62.908497
3            4           49                    66.326531
9           10            6                    66.666667
5            6           40                    67.500000
4            5           37                    72.522523
2            3           38                    79.824561
7            8           10                    80.000000
13          18            1                   100.000000


In [14]:
def combine_tuples(string):
    # Replace semicolons with commas to create a list of tuples separated by commas
    string = string.replace("; ", ", ")
    # Use eval to convert the string into a list of tuples
    tuple_list = ast.literal_eval(f"[{string}]")
    # Flatten all tuples into one
    combined_tuple = tuple(item for tup in tuple_list for item in tup)
    return combined_tuple


result = pd.read_csv(f"{output_path}/class_percentages.csv", index_col=0)

# Ensure that the index is in string format
result.index = result.index.astype(str)

# Group by 'Word Count' and calculate the number of classes and the average accuracy percentage
# Concatenate the group indices into a single string separated by a delimiter (e.g., ";")
stats = (
    result.groupby("Word Count")
    .agg(
        Class_Count=("Class", "count"),
        Average_Accuracy_Percentage=("Accuracy Percentage", "mean"),
        Index_Concat=(("Class", lambda x: "; ".join(x.index))),
    )
    .reset_index()
)

stats["Index_Concat"] = stats["Index_Concat"].apply(combine_tuples)

# Set the 'Index_Concat' column as the new index
stats = stats.set_index("Index_Concat")
stats.index.name = None

# Sort by 'Average_Accuracy_Percentage' if desired
stats = stats.sort_values(by="Word Count", ascending=False)

print(stats.index[0], type(stats.index[0]))

display(HTML("<div style='height: 200px'>" + stats.style.to_html() + "</div>"))

(506, 556) <class 'tuple'>


Unnamed: 0,Word Count,Class_Count,Average_Accuracy_Percentage
"(506, 556)",18,1,100.0
"(531,)",17,1,0.0
"(741, 768, 795)",12,1,33.333333
"(448, 467, 486)",11,1,33.333333
"(267, 291, 315, 10, 505, 530, 555, 25, 40, 454, 473, 492, 504, 529, 554)",10,6,66.666667
"(545, 570, 520, 749, 776, 803, 512, 537, 562, 445, 464, 483, 214, 227, 240)",9,6,61.111111
"(329, 510, 328, 349, 370, 265, 289, 313, 535, 560, 350, 371, 342, 363, 384, 634, 656, 678, 824, 847, 870, 453, 472, 491)",8,10,80.0
"(261, 285, 309, 163, 181, 199, 739, 766, 793, 459, 478, 497, 789, 614, 388, 407, 425, 401, 420, 438, 581, 595, 609, 822, 845, 868, 259, 283, 307, 444, 463, 482, 516, 541, 566, 586, 600, 735, 762, 733, 760, 787, 737, 764, 791, 56, 79, 102, 117, 132, 147, 123, 138, 153, 124, 139, 154)",7,21,61.904762
"(107, 686, 702, 718, 416, 229, 84, 145, 346, 61, 732, 759, 786, 220, 233, 246, 745, 772, 799, 54, 77, 100, 118, 133, 148, 254, 278, 302, 501, 526, 551, 126, 141, 156, 514, 539, 564, 165, 183, 201, 828, 851, 874, 218, 231, 244, 332, 353, 374, 330, 351, 372, 115, 130, 216, 242, 58, 81, 104, 293, 317, 325, 367, 684, 700, 716, 633, 655, 677, 393, 412, 430, 814, 837, 860, 817, 840, 863, 743, 770, 797, 820, 843, 866, 826, 849, 872, 753, 780, 807, 456, 475, 494, 455, 474, 493, 169, 187, 205, 217, 230, 243)",6,40,67.5
"(397, 434, 122, 137, 152, 636, 658, 680, 812, 835, 858, 613, 129, 144, 443, 683, 269, 249, 65, 88, 111, 450, 469, 488, 585, 599, 628, 650, 672, 334, 355, 376, 578, 592, 606, 221, 234, 247, 258, 282, 306, 400, 419, 437, 262, 286, 310, 813, 836, 859, 114, 159, 177, 195, 213, 226, 239, 462, 481, 575, 589, 603, 699, 715, 731, 758, 785, 519, 544, 569, 518, 543, 568, 223, 236, 584, 598, 612, 685, 701, 717, 818, 841, 864, 832, 855, 878, 116, 131, 146, 3, 18, 33, 12, 27, 42, 63, 86, 109, 219, 232, 245, 160, 178, 196, 162, 180, 198, 502, 527, 552, 167, 185, 203, 168, 186, 204)",5,37,72.522523


In [15]:
def get_decoder_cross_attn_output_weights_list(
    model: KeypointsTransformer,
    src: torch.Tensor,
    translator: Translator,
    hp: HyperParameters,
    BOS_IDX: int,
    EOS_IDX: int,
):
    attn_output_weights_list = []

    def attention_hook(module, input, output):  # input: (query, key, value)
        _, attn_output_weights = output  # output: (attn_output, attn_output_weights)
        # print(output[0].shape, output[1].shape)
        attn_output_weights_list.append(attn_output_weights[0].cpu().detach().numpy())

    hook_handles = []
    for layer in range(hp["NUM_DECODER_LAYERS"]):
        multihead_attn_module = model.transformer.decoder.layers[layer].multihead_attn
        hook_handles.append(multihead_attn_module.register_forward_hook(attention_hook))

    # Inference
    translator.greedy_decode(src, model, BOS_IDX, EOS_IDX)

    for handle in hook_handles:
        handle.remove()

    return attn_output_weights_list

In [16]:
def get_attn_weights(
    attn_output_weights: List[torch.Tensor],
    hp: HyperParameters,
    translation: List[str],
    layer: int,
    norm_func: Callable,
):
    attn_output_weights = reorganize_list(attn_output_weights, hp["NUM_DECODER_LAYERS"])
    lower = (len(translation) - 1) * layer
    upper = lower + (len(translation) - 1)
    attn_output_weights = attn_output_weights[lower:upper]
    attn_weights = np.zeros_like(attn_output_weights[-1])
    for i, attn_output_weights in enumerate(attn_output_weights):
        attn_weights[i, :] = norm_func(attn_output_weights[i, :])
    return attn_weights

In [17]:
norm_min_max_lambda = lambda t: (t - t.min()) / (t.max() - t.min())

In [19]:
for i in range(len(stats)):
    wc = stats["Word Count"].iloc[i]
    print(f"Processing poses with {wc} words...")
    weights = []
    poses = stats.index[i]
    for j in poses:
        src, tgt = get_pose_text(test_dataset, j, device)

        translation = translator.translate(
            src, model, "greedy", train_dataset.tokenizer
        )[0]

        if translation != tgt:
            continue

        attn_output_weights_list = get_decoder_cross_attn_output_weights_list(
            model, src, translator, hp, BOS_IDX, EOS_IDX
        )

        translation = ("BOS " + translation + " EOS").split()
        attn_weights = get_attn_weights(
            attn_output_weights_list,
            hp,
            translation,
            decoder_attn_weights_layer,
            norm_min_max_lambda,
        )

        weights.append(attn_weights)

    if len(weights) == 0:
        print(f"No match found between tgt and translation for poses with {wc} words")
        continue

    weights = np.stack(weights, axis=0)
    weights_avg = np.mean(weights, axis=0)

    sns.heatmap(
        weights_avg,
        xticklabels=np.arange(hp["MAX_FRAMES"]),
        cbar=False,
    )

    plt.savefig(
        f"{output_path}/attn_weights_heatmap_decoder_layer{decoder_attn_weights_layer}_wc{wc}.jpg",
        dpi=150,
        bbox_inches="tight",
        transparent=transparent_plot,
    )

    plt.close()

Processing poses with 18 words...
Processing poses with 17 words...
No match found between tgt and translation for poses with 17 words
Processing poses with 12 words...
Processing poses with 11 words...
Processing poses with 10 words...
Processing poses with 9 words...
Processing poses with 8 words...
Processing poses with 7 words...
Processing poses with 6 words...
Processing poses with 5 words...
Processing poses with 4 words...
Processing poses with 3 words...
Processing poses with 2 words...
Processing poses with 1 words...
