In [1]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2

In [2]:
import os, sys

module_path = os.path.abspath(os.path.join("../.."))
sys.path.append(module_path)

## Dataset and hyperparameters loading

In [3]:
from torchvision.transforms.v2 import Compose
from hyperparameters import load_hyperparameters_from_json

from SLTDataset import SLTDataset
from posecraft.Pose import Pose


DATASET = "GSL"
EXPERIMENT_ID = "frosty-haze-24"
SAMPLE_IDX = 509

dataset_path = f"/mnt/disk3Tb/slt-datasets/{DATASET}"
experiment_path = f"results/{DATASET}/{EXPERIMENT_ID}"
hp = load_hyperparameters_from_json(f"{experiment_path}/hp.json")
output_path = f"{experiment_path}/interp/{SAMPLE_IDX}/"
os.makedirs(output_path, exist_ok=True)
transparent_plot = False
decoder_attn_weights_layer = 0

landmarks_mask = Pose.get_components_mask(hp["LANDMARKS_USED"])
transforms: Compose = Compose(hp["TRANSFORMS"])

train_dataset = SLTDataset(
    data_dir=dataset_path,
    split="train",
    input_mode=hp["INPUT_MODE"],
    output_mode=hp["OUTPUT_MODE"],
    transforms=transforms,
    max_tokens=hp["MAX_TOKENS"],
)
test_dataset = SLTDataset(
    data_dir=dataset_path,
    split="test",
    input_mode=hp["INPUT_MODE"],
    output_mode=hp["OUTPUT_MODE"],
    transforms=transforms,
    max_tokens=hp["MAX_TOKENS"],
)

Loaded metadata for dataset: The Greek Sign Language (GSL) Dataset
Loaded train annotations at /mnt/disk3Tb/slt-datasets/GSL/annotations.csv


Validating files: 100%|██████████| 8821/8821 [00:00<00:00, 226404.89it/s]


Dataset loaded correctly

Loaded metadata for dataset: The Greek Sign Language (GSL) Dataset
Loaded test annotations at /mnt/disk3Tb/slt-datasets/GSL/annotations.csv


Validating files: 100%|██████████| 881/881 [00:00<00:00, 197730.19it/s]

Dataset loaded correctly






### Display sample

In [4]:
from IPython.display import HTML

# avoid using the last transform as it flattens the keypoints
visual_transforms: Compose = Compose(hp["TRANSFORMS"][:-1])
anim = test_dataset.visualize_pose(SAMPLE_IDX, transforms=visual_transforms)
HTML(anim.to_jshtml())

<IPython.core.display.Javascript object>

In [5]:
anim.save(f"{output_path}/sample.mp4", writer="ffmpeg")

In [6]:
import torch

device = torch.device(
    "mps"
    if torch.backends.mps.is_available()
    else ("cuda" if torch.cuda.is_available() else "cpu")
)

src, tgt = test_dataset[SAMPLE_IDX]
src = src.unsqueeze(0)
tgt = tgt.unsqueeze(0)
src = src.to(device)
tgt = tgt.to(device)

## Model

### Definition

In [7]:
import glob
from LightningKeypointsTransformer import LKeypointsTransformer

checkpoint_path = glob.glob(f"{experiment_path}/best*")[0]
try:
    l_model = LKeypointsTransformer.load_from_checkpoint(checkpoint_path)
    model = l_model.model
    translator = l_model.translator
except:
    from helpers import load_from_old_checkpoint

    model, translator = load_from_old_checkpoint(
        checkpoint_path, hp, device, landmarks_mask, train_dataset
    )

  checkpoint = torch.load(checkpoint_path, map_location=device)["state_dict"]


In [8]:
model = model.to(device)
model = model.eval()

In [9]:
BOS_IDX = train_dataset.tokenizer.cls_token_id
EOS_IDX = train_dataset.tokenizer.sep_token_id

### Interpretability

In [10]:
from interp.plot_functions import *

In [11]:
translation = translator.translate(src, model, "greedy", train_dataset.tokenizer)
translation = ["BOS " + s + " EOS" for s in translation]
translation = translation[0].split()
translation

['BOS', 'ΜΩΡΟ', 'ΓΕΝΝΩ', 'ΝΩΡΙΣ', 'EOS']

#### Encoder Self-Attention

In [12]:
attn_output_weights_list = []


def attention_hook(module, input, output):  # input: (query, key, value)
    _, attn_output_weights = output  # output: (attn_output, attn_output_weights)
    attn_output_weights_list.append(attn_output_weights[0].cpu().detach().numpy())


hook_handles = []
for layer in range(hp["NUM_ENCODER_LAYERS"]):
    self_attn_module = model.transformer.encoder.layers[layer].self_attn
    hook_handles.append(self_attn_module.register_forward_hook(attention_hook))

# Inference
translator.greedy_decode(src, model, BOS_IDX, EOS_IDX)

for handle in hook_handles:
    handle.remove()

In [13]:
plot_encoder_layers(attn_output_weights_list, hp, output_path, transparent_plot)

<IPython.core.display.Javascript object>

#### Decoder Self-Attention

In [14]:
attn_output_weights_list = []


def attention_hook(module, input, output):  # input: (query, key, value)
    _, attn_output_weights = output  # output: (attn_output, attn_output_weights)
    # print(output[0].shape, output[1].shape)
    attn_output_weights_list.append(attn_output_weights[0].cpu().detach().numpy())


hook_handles = []
for layer in range(hp["NUM_DECODER_LAYERS"]):
    self_attn_module = model.transformer.decoder.layers[layer].self_attn
    hook_handles.append(self_attn_module.register_forward_hook(attention_hook))

# Inference
translator.greedy_decode(src, model, BOS_IDX, EOS_IDX)

for handle in hook_handles:
    handle.remove()

In [15]:
plot_decoder_layers(
    attn_output_weights_list, hp, output_path, translation, "self", transparent_plot
)

<IPython.core.display.Javascript object>

#### Decoder Cross-Attention

In [16]:
attn_output_weights_list = []


def attention_hook(module, input, output):  # input: (query, key, value)
    _, attn_output_weights = output  # output: (attn_output, attn_output_weights)
    # print(output[0].shape, output[1].shape)
    attn_output_weights_list.append(attn_output_weights[0].cpu().detach().numpy())


hook_handles = []
for layer in range(hp["NUM_DECODER_LAYERS"]):
    multihead_attn_module = model.transformer.decoder.layers[layer].multihead_attn
    hook_handles.append(multihead_attn_module.register_forward_hook(attention_hook))

# Inference
translator.greedy_decode(src, model, BOS_IDX, EOS_IDX)

for handle in hook_handles:
    handle.remove()

In [17]:
plot_decoder_layers(
    attn_output_weights_list, hp, output_path, translation, "cross", transparent_plot
)

<IPython.core.display.Javascript object>

In [18]:
norm_min_max_lambda = lambda t: (t - t.min()) / (t.max() - t.min())

In [19]:
for mode in ["heatmap", "lineplot"]:
    attn_weights = plot_decoder_attn_weights(
        attn_output_weights_list,
        hp,
        output_path,
        translation,
        decoder_attn_weights_layer,
        norm_min_max_lambda,
        mode,
        transparent_plot,
        {"square": True, "cbar": False},
    )

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [20]:
anim = plot_decoder_attn_weights_bars(
    src.cpu(), attn_weights, hp, output_path, translation, decoder_attn_weights_layer
)

<IPython.core.display.Javascript object>

In [21]:
HTML(anim.to_jshtml())

#### Decoder `sa_block` & `mha_block`

In [22]:
from interp.InterpTransformer import clear_intermediate_outputs

In [23]:
clear_intermediate_outputs()

In [24]:
translator.greedy_decode(src, model, BOS_IDX, EOS_IDX)

tensor([[  2, 274, 189, 276,   3]], device='cuda:0')

In [25]:
from interp.InterpTransformer import intermediate_outputs

In [26]:
intermediate_outputs = {
    key: [tensor[0].cpu().detach().numpy() for tensor in value]
    for key, value in intermediate_outputs.items()
}

In [27]:
plot_intermediate_outputs(
    intermediate_outputs, hp, output_path, translation, transparent_plot
)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [28]:
diff_sa_mha_block = []
for t1, t2 in zip(intermediate_outputs["sa_block"], intermediate_outputs["mha_block"]):
    diff = t1 - t2

    # normalize to [-1,1]
    diff -= diff.min()
    diff /= diff.max()
    diff *= 2
    diff -= 1

    diff_sa_mha_block.append(diff)
diff_sa_mha_block = {"diff_sa_block_mha_block": diff_sa_mha_block}

In [29]:
plot_intermediate_outputs(
    diff_sa_mha_block, hp, output_path, translation, transparent_plot
)

<IPython.core.display.Javascript object>

In [30]:
tgt_length = len(translation) - 1  # from BOS to EOS-1
attn_output_weights_list = reorganize_list(
    diff_sa_mha_block["diff_sa_block_mha_block"], hp["NUM_DECODER_LAYERS"]
)
for i in range(0, len(attn_output_weights_list), tgt_length):
    print(f"Layer {i // tgt_length + 1}:")
    for j in range(i, i + tgt_length):
        w = attn_output_weights_list[j]
        print(f"{w.shape} → {w.mean():>7.3f}")

Layer 1:
(1, 16) →   0.145
(2, 16) →   0.053
(3, 16) →  -0.268
(4, 16) →  -0.229
Layer 2:
(1, 16) →   0.232
(2, 16) →   0.181
(3, 16) →   0.095
(4, 16) →   0.022
Layer 3:
(1, 16) →   0.236
(2, 16) →   0.142
(3, 16) →   0.057
(4, 16) →  -0.054
Layer 4:
(1, 16) →   0.188
(2, 16) →   0.216
(3, 16) →  -0.153
(4, 16) →  -0.044
