In [1]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2

In [2]:
import os, sys

module_path = os.path.abspath(os.path.join("../.."))
sys.path.append(module_path)

## Dataset and hyperparameters loading

In [5]:
from torchvision.transforms.v2 import Compose
from hyperparameters import load_hyperparameters_from_json

from SLTDataset import SLTDataset
from posecraft.Pose import Pose


DATASET = "GSL"
EXPERIMENT_ID = "glorious-grass-31"
SAMPLE_IDX = 0

dataset_path = f"/mnt/disk3Tb/slt-datasets/{DATASET}"
experiment_path = f"results/{DATASET}/{EXPERIMENT_ID}"
hp = load_hyperparameters_from_json(f"{experiment_path}/hp.json")
output_path = f"{experiment_path}/interp/{SAMPLE_IDX}/"
os.makedirs(output_path, exist_ok=True)

landmarks_mask = Pose.get_components_mask(hp["LANDMARKS_USED"])
transforms: Compose = Compose(hp["TRANSFORMS"])

train_dataset = SLTDataset(
    data_dir=dataset_path,
    split="train",
    input_mode=hp["INPUT_MODE"],
    output_mode=hp["OUTPUT_MODE"],
    transforms=transforms,
    max_tokens=hp["MAX_TOKENS"],
)
test_dataset = SLTDataset(
    data_dir=dataset_path,
    split="test",
    input_mode=hp["INPUT_MODE"],
    output_mode=hp["OUTPUT_MODE"],
    transforms=transforms,
    max_tokens=hp["MAX_TOKENS"],
)

Loaded metadata for dataset: The Greek Sign Language (GSL) Dataset
Loaded train annotations at /mnt/disk3Tb/slt-datasets/GSL/annotations.csv


Validating files: 100%|██████████| 8821/8821 [00:00<00:00, 273770.96it/s]


Dataset loaded correctly

Loaded metadata for dataset: The Greek Sign Language (GSL) Dataset
Loaded test annotations at /mnt/disk3Tb/slt-datasets/GSL/annotations.csv


Validating files: 100%|██████████| 881/881 [00:00<00:00, 230588.57it/s]

Dataset loaded correctly






### Display sample

In [6]:
from IPython.display import HTML

# avoid using the last transform as it flattens the keypoints
visual_transforms: Compose = Compose(hp["TRANSFORMS"][:-1])
anim = test_dataset.visualize_pose(SAMPLE_IDX, transforms=visual_transforms)
HTML(anim.to_jshtml())

<IPython.core.display.Javascript object>

In [7]:
anim.save(f"{output_path}/sample.mp4", writer="ffmpeg")

In [8]:
import torch

device = torch.device(
    "mps"
    if torch.backends.mps.is_available()
    else ("cuda" if torch.cuda.is_available() else "cpu")
)

src, tgt = test_dataset[SAMPLE_IDX]
src = src.unsqueeze(0)
tgt = tgt.unsqueeze(0)
src = src.to(device)
tgt = tgt.to(device)

## Model

### Definition

In [15]:
import glob
from LightningKeypointsTransformer import LKeypointsTransformer


checkpoint_path = glob.glob(f"{experiment_path}/best*")[0]
l_model = LKeypointsTransformer.load_from_checkpoint(checkpoint_path)
model = l_model.model
translator = l_model.translator

/home/pdalbianco/anaconda3/envs/slt_datasets/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.


In [11]:
BOS_IDX = train_dataset.tokenizer.cls_token_id
EOS_IDX = train_dataset.tokenizer.sep_token_id

### Interpretability

In [16]:
idx2word = translator.translate(src, model, "greedy", train_dataset.tokenizer)
idx2word = ["BOS " + s + " EOS" for s in idx2word]
idx2word = idx2word[0].split()
idx2word

['BOS', 'ΕΣΥ', 'ΕΧΩ', 'EOS']

#### Encoder Self-Attention

In [17]:
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

attn_output_weights_list = []


def attention_hook(module, input, output):  # input: (query, key, value)
    _, attn_output_weights = output  # output: (attn_output, attn_output_weights)
    attn_output_weights_list.append(attn_output_weights[0].detach().cpu().numpy())


hook_handles = []
for layer in range(hp["NUM_ENCODER_LAYERS"]):
    self_attn_module = model.transformer.encoder.layers[layer].self_attn
    hook_handles.append(self_attn_module.register_forward_hook(attention_hook))

# Inference
translator.greedy_decode(src, model, BOS_IDX, EOS_IDX)

for handle in hook_handles:
    handle.remove()

fig, axes = plt.subplots(1, hp["NUM_ENCODER_LAYERS"], figsize=(10, 5), sharey=True)
for layer, attn_weights in enumerate(attn_output_weights_list):
    ax = (
        axes[layer] if hp["NUM_ENCODER_LAYERS"] > 1 else axes
    )  # Handle case with only one layer
    sent = np.arange(hp["MAX_FRAMES"])
    sns.heatmap(
        attn_weights,
        ax=ax,
        xticklabels=sent,
        yticklabels=sent,
        square=True,
        cbar=False,
    )  # vmin=0.0, vmax=1.0)
    ax.set_title(f"Layer {layer+1}")

plt.savefig(f"{output_path}/attn_self_heatmaps_encoder_layers.png")
plt.show()

<IPython.core.display.Javascript object>

#### Decoder Self-Attention

In [18]:
def reorganize_list(input_list, N):
    grouped_list = []
    for i in range(N):
        grouped_list.extend(input_list[i::N])
    return grouped_list

In [19]:
import seaborn as sns

attn_output_weights_list = []


def attention_hook(module, input, output):  # input: (query, key, value)
    _, attn_output_weights = output  # output: (attn_output, attn_output_weights)
    print(output[0].shape, output[1].shape)
    attn_output_weights_list.append(attn_output_weights[0].detach().cpu().numpy())


hook_handles = []
for layer in range(hp["NUM_DECODER_LAYERS"]):
    self_attn_module = model.transformer.decoder.layers[layer].self_attn
    hook_handles.append(self_attn_module.register_forward_hook(attention_hook))

# Inference
translator.greedy_decode(src, model, BOS_IDX, EOS_IDX)

for handle in hook_handles:
    handle.remove()

rows = len(idx2word) - 1  # from BOS to EOS-1
fig, axes = plt.subplots(rows, hp["NUM_DECODER_LAYERS"], figsize=(20, 20), sharey=True)
attn_output_weights_list = reorganize_list(
    attn_output_weights_list, hp["NUM_DECODER_LAYERS"]
)
for layer, attn_weights in enumerate(attn_output_weights_list):
    i, j = divmod(layer, rows)
    print(i, j, layer)
    ax = axes[j, i]
    tgt_sent = idx2word[1 : attn_weights.shape[0] + 1]
    sns.heatmap(
        attn_weights,
        ax=ax,
        xticklabels=tgt_sent,
        yticklabels=tgt_sent,
        square=True,
        cbar=False,
    )  # vmin=0.0, vmax=1.0)
    ax.set_yticklabels(tgt_sent, rotation=0)
    ax.set_xticklabels(tgt_sent, rotation=90)
    ax.set_title(f"Layer {i+1}") if layer % rows == 0 else None

plt.subplots_adjust(wspace=0.4, hspace=0.4)

plt.savefig(f"{output_path}/attn_self_heatmaps_decoder_layers.png", dpi=150)
plt.show()

torch.Size([1, 1, 64]) torch.Size([1, 1, 1])
torch.Size([1, 1, 64]) torch.Size([1, 1, 1])
torch.Size([1, 1, 64]) torch.Size([1, 1, 1])
torch.Size([1, 1, 64]) torch.Size([1, 1, 1])
torch.Size([1, 2, 64]) torch.Size([1, 2, 2])
torch.Size([1, 2, 64]) torch.Size([1, 2, 2])
torch.Size([1, 2, 64]) torch.Size([1, 2, 2])
torch.Size([1, 2, 64]) torch.Size([1, 2, 2])
torch.Size([1, 3, 64]) torch.Size([1, 3, 3])
torch.Size([1, 3, 64]) torch.Size([1, 3, 3])
torch.Size([1, 3, 64]) torch.Size([1, 3, 3])
torch.Size([1, 3, 64]) torch.Size([1, 3, 3])


<IPython.core.display.Javascript object>

0 0 0
0 1 1
0 2 2
1 0 3
1 1 4
1 2 5
2 0 6
2 1 7
2 2 8
3 0 9
3 1 10
3 2 11


#### Decoder Cross-Attention

In [20]:
import seaborn as sns

attn_output_weights_list = []


def attention_hook(module, input, output):  # input: (query, key, value)
    _, attn_output_weights = output  # output: (attn_output, attn_output_weights)
    print(output[0].shape, output[1].shape)
    attn_output_weights_list.append(attn_output_weights[0].detach().cpu().numpy())


hook_handles = []
for layer in range(hp["NUM_DECODER_LAYERS"]):
    multihead_attn_module = model.transformer.decoder.layers[layer].multihead_attn
    hook_handles.append(multihead_attn_module.register_forward_hook(attention_hook))

# Inference
translator.greedy_decode(src, model, BOS_IDX, EOS_IDX)

for handle in hook_handles:
    handle.remove()

rows = len(idx2word) - 1  # from BOS to EOS-1
fig, axes = plt.subplots(rows, hp["NUM_DECODER_LAYERS"], figsize=(20, 20), sharey=True)
attn_output_weights_list = reorganize_list(
    attn_output_weights_list, hp["NUM_DECODER_LAYERS"]
)
for layer, attn_weights in enumerate(attn_output_weights_list):
    i, j = divmod(layer, rows)
    print(i, j, layer)
    ax = axes[j, i]
    sent = np.arange(hp["MAX_FRAMES"])
    tgt_sent = idx2word[1 : attn_weights.shape[0] + 1]
    sns.heatmap(
        attn_weights,
        ax=ax,
        xticklabels=sent,
        yticklabels=tgt_sent,
        square=True,
        cbar=False,
    )  # vmin=0.0, vmax=1.0)
    ax.set_aspect("auto")
    ax.set_yticklabels(tgt_sent, rotation=0)
    ax.set_xticklabels(sent, rotation=90)
    ax.set_title(f"Layer {i+1}") if layer % rows == 0 else None

plt.subplots_adjust(wspace=0.4, hspace=0.4)

plt.savefig(f"{output_path}/attn_cross_heatmaps_decoder_layers.png", dpi=150)
plt.show()

torch.Size([1, 1, 64]) torch.Size([1, 1, 30])
torch.Size([1, 1, 64]) torch.Size([1, 1, 30])
torch.Size([1, 1, 64]) torch.Size([1, 1, 30])
torch.Size([1, 1, 64]) torch.Size([1, 1, 30])
torch.Size([1, 2, 64]) torch.Size([1, 2, 30])
torch.Size([1, 2, 64]) torch.Size([1, 2, 30])
torch.Size([1, 2, 64]) torch.Size([1, 2, 30])
torch.Size([1, 2, 64]) torch.Size([1, 2, 30])
torch.Size([1, 3, 64]) torch.Size([1, 3, 30])
torch.Size([1, 3, 64]) torch.Size([1, 3, 30])
torch.Size([1, 3, 64]) torch.Size([1, 3, 30])
torch.Size([1, 3, 64]) torch.Size([1, 3, 30])


<IPython.core.display.Javascript object>

0 0 0
0 1 1
0 2 2
1 0 3
1 1 4
1 2 5
2 0 6
2 1 7
2 2 8
3 0 9
3 1 10
3 2 11
