In [1]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2

In [2]:
import os, sys

module_path = os.path.abspath(os.path.join("../.."))
sys.path.append(module_path)

## Dataset and hyperparameters loading

In [3]:
from torchvision.transforms.v2 import Compose
from hyperparameters import load_hyperparameters_from_json

from SLTDataset import SLTDataset
from posecraft.Pose import Pose


DATASET = "GSL"
EXPERIMENT_ID = "frosty-haze-24"
SAMPLE_IDX = 509

dataset_path = f"/mnt/disk3Tb/slt-datasets/{DATASET}"
experiment_path = f"results/{DATASET}/{EXPERIMENT_ID}"
hp = load_hyperparameters_from_json(f"{experiment_path}/hp.json")
output_path = f"{experiment_path}/interp/{SAMPLE_IDX}/"
os.makedirs(output_path, exist_ok=True)

landmarks_mask = Pose.get_components_mask(hp["LANDMARKS_USED"])
transforms: Compose = Compose(hp["TRANSFORMS"])

train_dataset = SLTDataset(
    data_dir=dataset_path,
    split="train",
    input_mode=hp["INPUT_MODE"],
    output_mode=hp["OUTPUT_MODE"],
    transforms=transforms,
    max_tokens=hp["MAX_TOKENS"],
)
test_dataset = SLTDataset(
    data_dir=dataset_path,
    split="test",
    input_mode=hp["INPUT_MODE"],
    output_mode=hp["OUTPUT_MODE"],
    transforms=transforms,
    max_tokens=hp["MAX_TOKENS"],
)

Loaded metadata for dataset: The Greek Sign Language (GSL) Dataset
Loaded train annotations at /mnt/disk3Tb/slt-datasets/GSL/annotations.csv


Validating files: 100%|██████████| 8821/8821 [00:00<00:00, 221875.46it/s]


Dataset loaded correctly

Loaded metadata for dataset: The Greek Sign Language (GSL) Dataset
Loaded test annotations at /mnt/disk3Tb/slt-datasets/GSL/annotations.csv


Validating files: 100%|██████████| 881/881 [00:00<00:00, 190169.41it/s]

Dataset loaded correctly






In [4]:
for i in range(len(test_dataset)):
    _, text = test_dataset.get_item_raw(i)
    if text == "ΜΩΡΟ ΓΕΝΝΩ ΝΩΡΙΣ":
        print(i)

509
534
559


### Display sample

In [5]:
from IPython.display import HTML

# avoid using the last transform as it flattens the keypoints
visual_transforms: Compose = Compose(hp["TRANSFORMS"][:-1])
anim = test_dataset.visualize_pose(SAMPLE_IDX, transforms=visual_transforms)
HTML(anim.to_jshtml())

<IPython.core.display.Javascript object>

In [6]:
anim.save(f"{output_path}/sample.mp4", writer="ffmpeg")

In [7]:
import torch

device = torch.device(
    "mps"
    if torch.backends.mps.is_available()
    else ("cuda" if torch.cuda.is_available() else "cpu")
)

src, tgt = test_dataset[SAMPLE_IDX]
src = src.unsqueeze(0)
tgt = tgt.unsqueeze(0)
src = src.to(device)
tgt = tgt.to(device)

## Model

### Definition

In [8]:
import glob
from LightningKeypointsTransformer import LKeypointsTransformer

checkpoint_path = glob.glob(f"{experiment_path}/best*")[0]
try:
    l_model = LKeypointsTransformer.load_from_checkpoint(checkpoint_path)
    model = l_model.model
    translator = l_model.translator
except:
    from helpers import load_from_old_checkpoint

    model, translator = load_from_old_checkpoint(
        checkpoint_path, hp, device, landmarks_mask, train_dataset
    )

  checkpoint = torch.load(checkpoint_path, map_location=device)["state_dict"]


In [9]:
model = model.to(device)
model = model.eval()

In [10]:
BOS_IDX = train_dataset.tokenizer.cls_token_id
EOS_IDX = train_dataset.tokenizer.sep_token_id

### Interpretability

In [11]:
from interp.plot_functions import *

In [12]:
translation = translator.translate(src, model, "greedy", train_dataset.tokenizer)
translation = ["BOS " + s + " EOS" for s in translation]
translation = translation[0].split()
translation

['BOS', 'ΜΩΡΟ', 'ΓΕΝΝΩ', 'ΝΩΡΙΣ', 'EOS']

#### Encoder Self-Attention

In [13]:
attn_output_weights_list = []


def attention_hook(module, input, output):  # input: (query, key, value)
    _, attn_output_weights = output  # output: (attn_output, attn_output_weights)
    attn_output_weights_list.append(attn_output_weights[0].cpu().detach().numpy())


hook_handles = []
for layer in range(hp["NUM_ENCODER_LAYERS"]):
    self_attn_module = model.transformer.encoder.layers[layer].self_attn
    hook_handles.append(self_attn_module.register_forward_hook(attention_hook))

# Inference
translator.greedy_decode(src, model, BOS_IDX, EOS_IDX)

for handle in hook_handles:
    handle.remove()

In [14]:
plot_encoder_layers(attn_output_weights_list, hp, output_path)

<IPython.core.display.Javascript object>

#### Decoder Self-Attention

In [15]:
attn_output_weights_list = []


def attention_hook(module, input, output):  # input: (query, key, value)
    _, attn_output_weights = output  # output: (attn_output, attn_output_weights)
    # print(output[0].shape, output[1].shape)
    attn_output_weights_list.append(attn_output_weights[0].cpu().detach().numpy())


hook_handles = []
for layer in range(hp["NUM_DECODER_LAYERS"]):
    self_attn_module = model.transformer.decoder.layers[layer].self_attn
    hook_handles.append(self_attn_module.register_forward_hook(attention_hook))

# Inference
translator.greedy_decode(src, model, BOS_IDX, EOS_IDX)

for handle in hook_handles:
    handle.remove()

In [16]:
plot_decoder_layers(attn_output_weights_list, hp, output_path, translation, "self")

<IPython.core.display.Javascript object>

#### Decoder Cross-Attention

In [17]:
attn_output_weights_list = []


def attention_hook(module, input, output):  # input: (query, key, value)
    _, attn_output_weights = output  # output: (attn_output, attn_output_weights)
    # print(output[0].shape, output[1].shape)
    attn_output_weights_list.append(attn_output_weights[0].cpu().detach().numpy())


hook_handles = []
for layer in range(hp["NUM_DECODER_LAYERS"]):
    multihead_attn_module = model.transformer.decoder.layers[layer].multihead_attn
    hook_handles.append(multihead_attn_module.register_forward_hook(attention_hook))

# Inference
translator.greedy_decode(src, model, BOS_IDX, EOS_IDX)

for handle in hook_handles:
    handle.remove()

In [18]:
plot_decoder_layers(attn_output_weights_list, hp, output_path, translation, "cross")

<IPython.core.display.Javascript object>

#### Decoder `sa_block` & `mha_block`

In [19]:
from interp.InterpTransformer import clear_intermediate_outputs

In [20]:
clear_intermediate_outputs()

In [21]:
translator.greedy_decode(src, model, BOS_IDX, EOS_IDX)

tensor([[  2, 274, 189, 276,   3]], device='cuda:0')

In [22]:
from interp.InterpTransformer import intermediate_outputs

In [23]:
intermediate_outputs = {
    key: [tensor[0].cpu().detach().numpy() for tensor in value]
    for key, value in intermediate_outputs.items()
}

In [24]:
plot_intermediate_outputs(intermediate_outputs, hp, output_path, translation)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [25]:
diff_sa_mha_block = []
for t1, t2 in zip(intermediate_outputs["sa_block"], intermediate_outputs["mha_block"]):
    diff = t1 - t2

    # normalize to [-1,1]
    # diff -= diff.min()
    # diff /= diff.max()
    # diff *= 2
    # diff -= 1

    diff_sa_mha_block.append(diff)
diff_sa_mha_block = {"diff_sa_block_mha_block": diff_sa_mha_block}

In [26]:
plot_intermediate_outputs(diff_sa_mha_block, hp, output_path, translation)

<IPython.core.display.Javascript object>

In [27]:
tgt_length = len(translation) - 1  # from BOS to EOS-1
attn_output_weights_list = reorganize_list(
    diff_sa_mha_block["diff_sa_block_mha_block"], hp["NUM_DECODER_LAYERS"]
)
for i in range(0, len(attn_output_weights_list), tgt_length):
    print(f"Layer {i // tgt_length + 1}:")
    for j in range(i, i + tgt_length):
        w = attn_output_weights_list[j]
        print(f"{w.shape} → {w.mean():>7.3f}")

Layer 1:
(1, 16) →  -0.002
(2, 16) →   0.008
(3, 16) →  -0.011
(4, 16) →  -0.012
Layer 2:
(1, 16) →  -0.028
(2, 16) →  -0.046
(3, 16) →  -0.034
(4, 16) →  -0.032
Layer 3:
(1, 16) →  -0.021
(2, 16) →   0.006
(3, 16) →  -0.009
(4, 16) →  -0.019
Layer 4:
(1, 16) →   0.052
(2, 16) →   0.040
(3, 16) →   0.025
(4, 16) →   0.025
