In [1]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2

In [2]:
import os, sys

module_path = os.path.abspath(os.path.join("../.."))
sys.path.append(module_path)

## Dataset and hyperparameters loading

In [3]:
from torchvision.transforms.v2 import Compose
from hyperparameters import load_hyperparameters_from_json

from SLTDataset import SLTDataset
from posecraft.Pose import Pose


DATASET = "GSL"
EXPERIMENT_ID = "glorious-grass-31"
SAMPLE_IDX = 0

dataset_path = f"/mnt/disk3Tb/slt-datasets/{DATASET}"
experiment_path = f"results/{DATASET}/{EXPERIMENT_ID}"
hp = load_hyperparameters_from_json(f"{experiment_path}/hp.json")
output_path = f"{experiment_path}/interp/{SAMPLE_IDX}/"
os.makedirs(output_path, exist_ok=True)

landmarks_mask = Pose.get_components_mask(hp["LANDMARKS_USED"])
transforms: Compose = Compose(hp["TRANSFORMS"])

train_dataset = SLTDataset(
    data_dir=dataset_path,
    split="train",
    input_mode=hp["INPUT_MODE"],
    output_mode=hp["OUTPUT_MODE"],
    transforms=transforms,
    max_tokens=hp["MAX_TOKENS"],
)
test_dataset = SLTDataset(
    data_dir=dataset_path,
    split="test",
    input_mode=hp["INPUT_MODE"],
    output_mode=hp["OUTPUT_MODE"],
    transforms=transforms,
    max_tokens=hp["MAX_TOKENS"],
)

Loaded metadata for dataset: The Greek Sign Language (GSL) Dataset
Loaded train annotations at /mnt/disk3Tb/slt-datasets/GSL/annotations.csv


Validating files: 100%|██████████| 8821/8821 [00:00<00:00, 223637.73it/s]

Dataset loaded correctly

Loaded metadata for dataset: The Greek Sign Language (GSL) Dataset
Loaded test annotations at /mnt/disk3Tb/slt-datasets/GSL/annotations.csv



Validating files: 100%|██████████| 881/881 [00:00<00:00, 192327.17it/s]

Dataset loaded correctly






### Display sample

In [4]:
from IPython.display import HTML

# avoid using the last transform as it flattens the keypoints
visual_transforms: Compose = Compose(hp["TRANSFORMS"][:-1])
anim = test_dataset.visualize_pose(SAMPLE_IDX, transforms=visual_transforms)
HTML(anim.to_jshtml())

<IPython.core.display.Javascript object>

In [5]:
anim.save(f"{output_path}/sample.mp4", writer="ffmpeg")

In [6]:
import torch

device = torch.device(
    "mps"
    if torch.backends.mps.is_available()
    else ("cuda" if torch.cuda.is_available() else "cpu")
)

src, tgt = test_dataset[SAMPLE_IDX]
src = src.unsqueeze(0)
tgt = tgt.unsqueeze(0)
src = src.to(device)
tgt = tgt.to(device)

## Model

### Definition

In [7]:
import glob
from LightningKeypointsTransformer import LKeypointsTransformer


checkpoint_path = glob.glob(f"{experiment_path}/best*")[0]
l_model = LKeypointsTransformer.load_from_checkpoint(checkpoint_path)
model = l_model.model
translator = l_model.translator

/mnt/disk3Tb/miniconda3-ostanchi/envs/captum/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:208: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.


In [8]:
BOS_IDX = train_dataset.tokenizer.cls_token_id
EOS_IDX = train_dataset.tokenizer.sep_token_id

### Interpretability

In [9]:
idx2word = translator.translate(src, model, "greedy", train_dataset.tokenizer)
idx2word = ["BOS " + s + " EOS" for s in idx2word]
idx2word = idx2word[0].split()
idx2word

['BOS', 'ΓΕΙΑ', 'ΕΓΩ(1)', 'ΜΠΟΡΩ', 'ΒΟΗΘΕΙΑ', 'ΠΩΣ;', 'EOS']

#### Encoder Self-Attention

In [10]:
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

attn_output_weights_list = []


def attention_hook(module, input, output):  # input: (query, key, value)
    _, attn_output_weights = output  # output: (attn_output, attn_output_weights)
    attn_output_weights_list.append(attn_output_weights[0].detach().cpu().numpy())


hook_handles = []
for layer in range(hp["NUM_ENCODER_LAYERS"]):
    self_attn_module = model.transformer.encoder.layers[layer].self_attn
    hook_handles.append(self_attn_module.register_forward_hook(attention_hook))

# Inference
translator.greedy_decode(src, model, BOS_IDX, EOS_IDX)

for handle in hook_handles:
    handle.remove()

fig, axes = plt.subplots(1, hp["NUM_ENCODER_LAYERS"], figsize=(10, 5), sharey=True)
for layer, attn_weights in enumerate(attn_output_weights_list):
    ax = (
        axes[layer] if hp["NUM_ENCODER_LAYERS"] > 1 else axes
    )  # Handle case with only one layer
    sent = np.arange(hp["MAX_FRAMES"])
    sns.heatmap(
        attn_weights,
        ax=ax,
        xticklabels=sent,
        yticklabels=sent,
        square=True,
        cbar=False,
    )  # vmin=0.0, vmax=1.0)
    ax.set_title(f"Layer {layer+1}")

plt.savefig(f"{output_path}/attn_self_heatmaps_encoder_layers.png")
plt.show()

<IPython.core.display.Javascript object>

#### Decoder Self-Attention

In [11]:
def reorganize_list(input_list, N):
    grouped_list = []
    for i in range(N):
        grouped_list.extend(input_list[i::N])
    return grouped_list

In [12]:
import seaborn as sns

attn_output_weights_list = []


def attention_hook(module, input, output):  # input: (query, key, value)
    _, attn_output_weights = output  # output: (attn_output, attn_output_weights)
    print(output[0].shape, output[1].shape)
    attn_output_weights_list.append(attn_output_weights[0].detach().cpu().numpy())


hook_handles = []
for layer in range(hp["NUM_DECODER_LAYERS"]):
    self_attn_module = model.transformer.decoder.layers[layer].self_attn
    hook_handles.append(self_attn_module.register_forward_hook(attention_hook))

# Inference
translator.greedy_decode(src, model, BOS_IDX, EOS_IDX)

for handle in hook_handles:
    handle.remove()

rows = len(idx2word) - 1  # from BOS to EOS-1
fig, axes = plt.subplots(rows, hp["NUM_DECODER_LAYERS"], figsize=(20, 20), sharey=True)
attn_output_weights_list = reorganize_list(
    attn_output_weights_list, hp["NUM_DECODER_LAYERS"]
)
for layer, attn_weights in enumerate(attn_output_weights_list):
    i, j = divmod(layer, rows)
    print(i, j, layer)
    ax = axes[j, i]
    tgt_sent = idx2word[1 : attn_weights.shape[0] + 1]
    sns.heatmap(
        attn_weights,
        ax=ax,
        xticklabels=tgt_sent,
        yticklabels=tgt_sent,
        square=True,
        cbar=False,
    )  # vmin=0.0, vmax=1.0)
    ax.set_yticklabels(tgt_sent, rotation=0)
    ax.set_xticklabels(tgt_sent, rotation=90)
    ax.set_title(f"Layer {i+1}") if layer % rows == 0 else None

plt.subplots_adjust(wspace=0.4, hspace=0.4)

plt.savefig(f"{output_path}/attn_self_heatmaps_decoder_layers.png", dpi=150)
plt.show()

torch.Size([1, 1, 128]) torch.Size([1, 1, 1])
torch.Size([1, 1, 128]) torch.Size([1, 1, 1])
torch.Size([1, 1, 128]) torch.Size([1, 1, 1])
torch.Size([1, 1, 128]) torch.Size([1, 1, 1])
torch.Size([1, 2, 128]) torch.Size([1, 2, 2])
torch.Size([1, 2, 128]) torch.Size([1, 2, 2])
torch.Size([1, 2, 128]) torch.Size([1, 2, 2])
torch.Size([1, 2, 128]) torch.Size([1, 2, 2])
torch.Size([1, 3, 128]) torch.Size([1, 3, 3])
torch.Size([1, 3, 128]) torch.Size([1, 3, 3])
torch.Size([1, 3, 128]) torch.Size([1, 3, 3])
torch.Size([1, 3, 128]) torch.Size([1, 3, 3])
torch.Size([1, 4, 128]) torch.Size([1, 4, 4])
torch.Size([1, 4, 128]) torch.Size([1, 4, 4])
torch.Size([1, 4, 128]) torch.Size([1, 4, 4])
torch.Size([1, 4, 128]) torch.Size([1, 4, 4])
torch.Size([1, 5, 128]) torch.Size([1, 5, 5])
torch.Size([1, 5, 128]) torch.Size([1, 5, 5])
torch.Size([1, 5, 128]) torch.Size([1, 5, 5])
torch.Size([1, 5, 128]) torch.Size([1, 5, 5])
torch.Size([1, 6, 128]) torch.Size([1, 6, 6])
torch.Size([1, 6, 128]) torch.Size

<IPython.core.display.Javascript object>

0 0 0
0 1 1
0 2 2
0 3 3
0 4 4
0 5 5
1 0 6
1 1 7
1 2 8
1 3 9
1 4 10
1 5 11
2 0 12
2 1 13
2 2 14
2 3 15
2 4 16
2 5 17
3 0 18
3 1 19
3 2 20
3 3 21
3 4 22
3 5 23


#### Decoder Cross-Attention

In [13]:
import seaborn as sns

attn_output_weights_list = []


def attention_hook(module, input, output):  # input: (query, key, value)
    _, attn_output_weights = output  # output: (attn_output, attn_output_weights)
    print(output[0].shape, output[1].shape)
    attn_output_weights_list.append(attn_output_weights[0].detach().cpu().numpy())


hook_handles = []
for layer in range(hp["NUM_DECODER_LAYERS"]):
    multihead_attn_module = model.transformer.decoder.layers[layer].multihead_attn
    hook_handles.append(multihead_attn_module.register_forward_hook(attention_hook))

# Inference
translator.greedy_decode(src, model, BOS_IDX, EOS_IDX)

for handle in hook_handles:
    handle.remove()

rows = len(idx2word) - 1  # from BOS to EOS-1
fig, axes = plt.subplots(rows, hp["NUM_DECODER_LAYERS"], figsize=(20, 20), sharey=True)
attn_output_weights_list = reorganize_list(
    attn_output_weights_list, hp["NUM_DECODER_LAYERS"]
)
for layer, attn_weights in enumerate(attn_output_weights_list):
    i, j = divmod(layer, rows)
    print(i, j, layer)
    ax = axes[j, i]
    sent = np.arange(hp["MAX_FRAMES"])
    tgt_sent = idx2word[1 : attn_weights.shape[0] + 1]
    sns.heatmap(
        attn_weights,
        ax=ax,
        xticklabels=sent,
        yticklabels=tgt_sent,
        square=True,
        cbar=False,
    )  # vmin=0.0, vmax=1.0)
    ax.set_aspect("auto")
    ax.set_yticklabels(tgt_sent, rotation=0)
    ax.set_xticklabels(sent, rotation=90)
    ax.set_title(f"Layer {i+1}") if layer % rows == 0 else None

plt.subplots_adjust(wspace=0.4, hspace=0.4)

plt.savefig(f"{output_path}/attn_cross_heatmaps_decoder_layers.png", dpi=150)
plt.show()

torch.Size([1, 1, 128]) torch.Size([1, 1, 30])
torch.Size([1, 1, 128]) torch.Size([1, 1, 30])
torch.Size([1, 1, 128]) torch.Size([1, 1, 30])
torch.Size([1, 1, 128]) torch.Size([1, 1, 30])
torch.Size([1, 2, 128]) torch.Size([1, 2, 30])
torch.Size([1, 2, 128]) torch.Size([1, 2, 30])
torch.Size([1, 2, 128]) torch.Size([1, 2, 30])
torch.Size([1, 2, 128]) torch.Size([1, 2, 30])
torch.Size([1, 3, 128]) torch.Size([1, 3, 30])
torch.Size([1, 3, 128]) torch.Size([1, 3, 30])
torch.Size([1, 3, 128]) torch.Size([1, 3, 30])
torch.Size([1, 3, 128]) torch.Size([1, 3, 30])
torch.Size([1, 4, 128]) torch.Size([1, 4, 30])
torch.Size([1, 4, 128]) torch.Size([1, 4, 30])
torch.Size([1, 4, 128]) torch.Size([1, 4, 30])
torch.Size([1, 4, 128]) torch.Size([1, 4, 30])
torch.Size([1, 5, 128]) torch.Size([1, 5, 30])
torch.Size([1, 5, 128]) torch.Size([1, 5, 30])
torch.Size([1, 5, 128]) torch.Size([1, 5, 30])
torch.Size([1, 5, 128]) torch.Size([1, 5, 30])
torch.Size([1, 6, 128]) torch.Size([1, 6, 30])
torch.Size([1

<IPython.core.display.Javascript object>

0 0 0
0 1 1
0 2 2
0 3 3
0 4 4
0 5 5
1 0 6
1 1 7
1 2 8
1 3 9
1 4 10
1 5 11
2 0 12
2 1 13
2 2 14
2 3 15
2 4 16
2 5 17
3 0 18
3 1 19
3 2 20
3 3 21
3 4 22
3 5 23


In [14]:
from interp.InterpTransformer import clear_intermediate_outputs

In [15]:
clear_intermediate_outputs()

In [16]:
translator.greedy_decode(src, model, BOS_IDX, EOS_IDX)

tensor([[ 2,  4,  5,  6,  7, 52,  3]], device='cuda:0')

In [17]:
from interp.InterpTransformer import intermediate_outputs

In [18]:
for k in intermediate_outputs.keys():
    print(f"{k:-^20}")
    for v in intermediate_outputs[k]:
        print(v.shape)

------sa_block------
torch.Size([1, 1, 128])
torch.Size([1, 1, 128])
torch.Size([1, 1, 128])
torch.Size([1, 1, 128])
torch.Size([1, 2, 128])
torch.Size([1, 2, 128])
torch.Size([1, 2, 128])
torch.Size([1, 2, 128])
torch.Size([1, 3, 128])
torch.Size([1, 3, 128])
torch.Size([1, 3, 128])
torch.Size([1, 3, 128])
torch.Size([1, 4, 128])
torch.Size([1, 4, 128])
torch.Size([1, 4, 128])
torch.Size([1, 4, 128])
torch.Size([1, 5, 128])
torch.Size([1, 5, 128])
torch.Size([1, 5, 128])
torch.Size([1, 5, 128])
torch.Size([1, 6, 128])
torch.Size([1, 6, 128])
torch.Size([1, 6, 128])
torch.Size([1, 6, 128])
-----mha_block------
torch.Size([1, 1, 128])
torch.Size([1, 1, 128])
torch.Size([1, 1, 128])
torch.Size([1, 1, 128])
torch.Size([1, 2, 128])
torch.Size([1, 2, 128])
torch.Size([1, 2, 128])
torch.Size([1, 2, 128])
torch.Size([1, 3, 128])
torch.Size([1, 3, 128])
torch.Size([1, 3, 128])
torch.Size([1, 3, 128])
torch.Size([1, 4, 128])
torch.Size([1, 4, 128])
torch.Size([1, 4, 128])
torch.Size([1, 4, 128]

In [19]:
for k,v in intermediate_outputs.items():
    v_cpu = [e[0].cpu().detach().numpy() for e in v]
    rows = len(idx2word) - 1  # from BOS to EOS-1
    fig, axes = plt.subplots(rows, hp["NUM_DECODER_LAYERS"], figsize=(20, 20), sharey=True)
    attn_output_weights_list = reorganize_list(
        v_cpu, hp["NUM_DECODER_LAYERS"]
    )
    for layer, attn_weights in enumerate(attn_output_weights_list):
        i, j = divmod(layer, rows)
        print(i, j, layer)
        ax = axes[j, i]
        sent = np.arange(hp["D_MODEL"]) #embed_dim
        tgt_sent = idx2word[1 : attn_weights.shape[0] + 1]
        sns.heatmap(
            attn_weights,
            ax=ax,
            xticklabels=sent,
            yticklabels=tgt_sent,
            square=True,
            cbar=True,
            annot=True,
            fmt=".2f",
            annot_kws={"size": 8}
        )  # vmin=0.0, vmax=1.0)
        ax.set_aspect("auto")
        ax.set_yticklabels(tgt_sent, rotation=0)
        ax.set_xticklabels(sent, rotation=90)
        ax.set_title(f"Layer {i+1}") if layer % rows == 0 else None

        # Rotate text annotations
        for text in ax.texts:
            text.set_rotation(90)

    plt.subplots_adjust(wspace=0.4, hspace=0.4)

    plt.savefig(f"{output_path}/attn_{k}_heatmaps_decoder_layers.png", dpi=150)
    plt.show()

<IPython.core.display.Javascript object>

0 0 0
0 1 1
0 2 2
0 3 3
0 4 4
0 5 5
1 0 6
1 1 7
1 2 8
1 3 9
1 4 10
1 5 11
2 0 12
2 1 13
2 2 14
2 3 15
2 4 16
2 5 17
3 0 18
3 1 19
3 2 20
3 3 21
3 4 22
3 5 23


<IPython.core.display.Javascript object>

0 0 0
0 1 1
0 2 2
0 3 3
0 4 4
0 5 5
1 0 6
1 1 7
1 2 8
1 3 9
1 4 10
1 5 11
2 0 12
2 1 13
2 2 14
2 3 15
2 4 16
2 5 17
3 0 18
3 1 19
3 2 20
3 3 21
3 4 22
3 5 23


In [20]:
diff_sa_mha_block = []
for t1, t2 in zip(intermediate_outputs['sa_block'], intermediate_outputs['mha_block']):
    diff = t1 - t2

    # normalize to [-1,1]
    diff -= diff.min()
    diff /= diff.max()
    diff *= 2
    diff -= 1

    diff_sa_mha_block.append(diff)

v_cpu = [e[0].cpu().detach().numpy() for e in diff_sa_mha_block]
rows = len(idx2word) - 1  # from BOS to EOS-1
fig, axes = plt.subplots(rows, hp["NUM_DECODER_LAYERS"], figsize=(20, 20), sharey=True)
attn_output_weights_list = reorganize_list(
    v_cpu, hp["NUM_DECODER_LAYERS"]
)
for layer, attn_weights in enumerate(attn_output_weights_list):
    i, j = divmod(layer, rows)
    print(i, j, layer)
    ax = axes[j, i]
    sent = np.arange(hp["D_MODEL"]) #embed_dim
    tgt_sent = idx2word[1 : attn_weights.shape[0] + 1]
    sns.heatmap(
        attn_weights,
        ax=ax,
        xticklabels=sent,
        yticklabels=tgt_sent,
        square=True,
        cbar=True,
        annot=True,
        fmt=".2f",
        annot_kws={"size": 8}
    )  # vmin=0.0, vmax=1.0)
    ax.set_aspect("auto")
    ax.set_yticklabels(tgt_sent, rotation=0)
    ax.set_xticklabels(sent, rotation=90)
    ax.set_title(f"Layer {i+1}") if layer % rows == 0 else None

    # Rotate text annotations
    for text in ax.texts:
        text.set_rotation(90)

plt.subplots_adjust(wspace=0.4, hspace=0.4)

plt.savefig(f"{output_path}/attn_diff_sa_mha_block_heatmaps_decoder_layers.png", dpi=150)
plt.show()

<IPython.core.display.Javascript object>

0 0 0
0 1 1
0 2 2
0 3 3
0 4 4
0 5 5
1 0 6
1 1 7
1 2 8
1 3 9
1 4 10
1 5 11
2 0 12
2 1 13
2 2 14
2 3 15
2 4 16
2 5 17
3 0 18
3 1 19
3 2 20
3 3 21
3 4 22
3 5 23


In [25]:
for i in range(0, len(attn_output_weights_list), rows):
    print(f"Layer {i // rows + 1}:")
    for j in range(i, i + rows):
        w = attn_output_weights_list[j]
        print(f"{w.shape} → {w.mean():>7.3f}")

Layer 1:
(1, 128) →  -0.071
(2, 128) →   0.083
(3, 128) →   0.123
(4, 128) →  -0.015
(5, 128) →   0.113
(6, 128) →  -0.023
Layer 2:
(1, 128) →   0.043
(2, 128) →   0.056
(3, 128) →   0.003
(4, 128) →   0.121
(5, 128) →   0.107
(6, 128) →   0.110
Layer 3:
(1, 128) →  -0.007
(2, 128) →   0.249
(3, 128) →   0.001
(4, 128) →  -0.035
(5, 128) →  -0.013
(6, 128) →  -0.007
Layer 4:
(1, 128) →  -0.032
(2, 128) →   0.169
(3, 128) →   0.234
(4, 128) →   0.135
(5, 128) →   0.155
(6, 128) →   0.178
