In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os, sys

module_path = os.path.abspath(os.path.join("../.."))
sys.path.append(module_path)

## Dataset and hyperparameters loading

In [None]:
import torch
from torchvision.transforms.v2 import Compose

from SLTDataset import SLTDataset
from posecraft.Pose import Pose
from helpers import create_src_mask
from hyperparameters import load_hyperparameters_from_json


DATASET = "GSL"
EXPERIMENT_ID = ""
if EXPERIMENT_ID == "":
    raise ValueError(
        f"Set EXPERIMENT_ID to a valid experiment id inside results/{DATASET} folder"
    )

dataset_path = f"/mnt/disk3Tb/slt-datasets/{DATASET}"
experiment_path = f"results/{DATASET}/{EXPERIMENT_ID}"
hp = load_hyperparameters_from_json(f"{experiment_path}/hp.json")
output_path = f"{experiment_path}/interp/avg"
os.makedirs(output_path, exist_ok=True)
transparent_plot = False

landmarks_mask = Pose.get_components_mask(hp["LANDMARKS_USED"])
transforms: Compose = Compose(hp["TRANSFORMS"])

test_dataset = SLTDataset(
    data_dir=dataset_path,
    split="test",
    input_mode=hp["INPUT_MODE"],
    output_mode=hp["OUTPUT_MODE"],
    transforms=transforms,
    max_tokens=hp["MAX_TOKENS"],
)

In [None]:
import pandas as pd

results = pd.read_csv(f"{experiment_path}/translations.csv")
results["length"] = results["y"].apply(lambda x: len(x.split()))
correct = results[results["trans_greedy"] == results["y"]]
correct

In [None]:
len(correct)

In [None]:
lenghts = sorted(list(correct["length"].unique() + 1))
print(lenghts)

# Interpretability

In [None]:
import pickle
from torch import Tensor
import seaborn as sns
import matplotlib.pyplot as plt


# each dict contains for each sample, a dict that contains for each layer, a list of attention weights per call (one for the encoder, one per word for the decoder)
encoder_sa: dict[int, dict[int, list[Tensor]]] = pickle.load(
    open(f"{experiment_path}/interp/encoder_sa.pkl", "rb")
)
decoder_sa: dict[int, dict[int, list[Tensor]]] = pickle.load(
    open(f"{experiment_path}/interp/decoder_sa.pkl", "rb")
)
decoder_ca: dict[int, dict[int, list[Tensor]]] = pickle.load(
    open(f"{experiment_path}/interp/decoder_ca.pkl", "rb")
)
intermediate_outputs: dict[int, dict[str, list[Tensor]]] = pickle.load(
    open(f"{experiment_path}/interp/intermediate_outputs.pkl", "rb")
)

In [None]:
from torch.nn.functional import interpolate


def interpolate_all(
    attn_weights: dict[int, dict[int, list[Tensor]]],
    dataset: SLTDataset,
    out_size: int = 100,
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
) -> dict[int, dict[int, list[Tensor]]]:
    interpolated_weights: dict[int, dict[int, list[Tensor]]] = {}
    for sample_id, sample in attn_weights.items():
        interpolated_weights[sample_id] = {}
        for layer, cas in sample.items():
            src, tgt = dataset[sample_id]
            src_mask, src_padding_mask = create_src_mask(src.to(device), device)
            weights = [
                torch.stack([c.to(device)[:, ~src_padding_mask] for c in t])
                for t in cas
            ]
            interpolated_weights[sample_id][layer] = [
                interpolate(aw, size=(out_size), mode="linear").squeeze(0)
                for aw in weights
            ]
    return interpolated_weights

## Decoder CA average

Get the interpolated version of the attention weights to 100 frames

In [None]:
int_decoder_ca = interpolate_all(decoder_ca, test_dataset, out_size=100)

In [None]:
from interp.plot_functions import preprocess_attn_weights

LAYER = 0
TAKE_LAST = True
processed_decoder_ca = [
    preprocess_attn_weights(
        [s.unsqueeze(0) for s in sample[LAYER]], take_last=TAKE_LAST
    )
    for sample_id, sample in int_decoder_ca.items()
    if sample_id in correct.index
]

In [None]:
mean_decoder_ca_per_lengths = []
for token_length in lenghts:
    processed_decoder_ca_per_length = [
        p for p in processed_decoder_ca if p.shape[0] == token_length
    ]
    mean_decoder_ca_per_lengths.append(
        torch.stack(processed_decoder_ca_per_length).mean(dim=0)
    )

In [None]:
mean_decoder_ca_per_lengths_to_plot = mean_decoder_ca_per_lengths[:5]

fig, axs = plt.subplots(1, len(mean_decoder_ca_per_lengths_to_plot), figsize=(30, 5))
for i, mean_decoder_ca in enumerate(mean_decoder_ca_per_lengths_to_plot):
    sns.heatmap(mean_decoder_ca.cpu(), ax=axs[i], cbar=False)
    axs[i].set_title(f"Sentence Length {lenghts[i] - 1}")
    # set xticks labels to 0, 25, 50, 75, 100
    axs[i].set_xticks([0, 25, 50, 75, 100])
    axs[i].set_xticklabels([0, 25, 50, 75, 100], rotation=0)
    axs[i].title.set_fontsize(20)
fig = axs[0].set_ylabel("Predicted token", fontsize=20)

## Intermediate outputs

In [None]:
def reorganize_list(input_list, N):
    grouped_list = []
    for i in range(N):
        grouped_list.extend(input_list[i::N])
    return grouped_list


attn_diffs: list[Tensor] = []
for sample in intermediate_outputs:
    src, tgt = test_dataset[sample]
    translation = [t for t in tgt if t != test_dataset.tokenizer.pad_token_id]
    diff_sa_mha_block = [
        t1 - t2
        for t1, t2 in zip(
            intermediate_outputs[sample]["sa_block"],
            intermediate_outputs[sample]["mha_block"],
        )
    ]
    attn_diffs.append(
        torch.Tensor(
            [float(diff_sa_mha_block[i].mean()) for i in range(len(diff_sa_mha_block))]
        )
    )

In [None]:
mean_differences = []
for l in lenghts:
    attn_diffs_per_length = [
        p for i, p in enumerate(attn_diffs) if len(p) == l and i in correct.index
    ]
    if len(attn_diffs_per_length) > 0:
        mean_differences.append(torch.stack(attn_diffs_per_length).mean(dim=0))

In [None]:
import matplotlib.pyplot as plt


def is_monotonically_increasing(tensor):
    differences = tensor[1:] - tensor[:-1]
    return torch.all(differences >= 0).item()


mean_differences_to_plot = mean_differences[:10]
fig, axs = plt.subplots(1, len(mean_differences_to_plot), figsize=(40, 5))
for i, mean_diff in enumerate(mean_differences_to_plot):
    sns.heatmap(
        mean_diff.unsqueeze(1).cpu(),
        ax=axs[i],
        vmin=-0.15,
        vmax=0.15,
        cmap="coolwarm",
        cbar=(i == len(mean_differences_to_plot) - 1),
    )
    # axs[i].set_title(f"Sentence Length: {lenghts[i] - 1}")
    axs[i].set_xticks([])
    axs[i].title.set_fontsize(20)
    axs[i].yaxis.set_tick_params(labelsize=15)
fig = axs[0].set_ylabel("Predicted token", fontsize=20)

## Encoder SA average

In [None]:
int_encoder_sa = interpolate_all(encoder_sa, test_dataset, out_size=100)

In [None]:
LAYER = 0

mean_encoder_sa_per_lengths = []
for token_length in lenghts:
    processed_encoder_ca_per_length: list[Tensor] = []
    for idx in range(len(test_dataset)):
        src, tgt = test_dataset[idx]
        if (
            len([t for t in tgt if t != test_dataset.tokenizer.pad_token_id]) - 1
        ) == token_length:
            processed_encoder_ca_per_length.append(int_encoder_sa[idx][LAYER][0])
    if len(processed_encoder_ca_per_length) != 0:
        mean_encoder_sa_per_lengths.append(
            torch.stack(processed_encoder_ca_per_length).mean(dim=0)
        )

In [None]:
mean_encoder_sa_per_lengths_to_plot = mean_encoder_sa_per_lengths[:5]

fig, axs = plt.subplots(1, len(mean_encoder_sa_per_lengths_to_plot), figsize=(30, 5))
for i, mean_encoder_ca in enumerate(mean_encoder_sa_per_lengths_to_plot):
    sns.heatmap(mean_encoder_ca.cpu(), ax=axs[i], cbar=False)
    axs[i].set_title(f"Sentence Length: {lenghts[i] + 4}")
    # # set xticks labels to rotation=0
    # axs[i].set_xticks([i for i in range(0, 101, 5)])
    # axs[i].set_xticklabels([i for i in range(0, 101, 5)], rotation=0)
    axs[i].title.set_fontsize(20)

In [None]:
LAYER = 0

int_encoder_sa_list = [
    attn_weights[LAYER][0] for attn_weights in int_encoder_sa.values()
]
mean_encoder_ca = torch.stack(int_encoder_sa_list).mean(dim=0)
sns.heatmap(mean_encoder_ca.cpu(), cbar=False)