In [1]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2

In [2]:
import os, sys

module_path = os.path.abspath(os.path.join("../.."))
sys.path.append(module_path)

In [3]:
pose_index = 0

In [4]:
OUTPUT_PATH = f"results/interp/{pose_index}/"
os.makedirs(OUTPUT_PATH, exist_ok=True)

In [5]:
import torch

device = torch.device(
    "mps"
    if torch.backends.mps.is_available()
    else ("cuda" if torch.cuda.is_available() else "cpu")
)

## Dataset and hyperparameters loading

In [6]:
from torchvision.transforms.v2 import Compose
from hyperparameters import HP_DICT

from SLTDataset import SLTDataset
from posecraft.Pose import Pose


DATASET = "GSL"
dataset_path = f"/mnt/disk3Tb/slt-datasets/{DATASET}"
hp = HP_DICT[DATASET]

landmarks_mask = Pose.get_components_mask(hp["LANDMARKS_USED"])
transforms: Compose = Compose(hp["TRANSFORMS"])

train_dataset = SLTDataset(
    data_dir=dataset_path,
    split="train",
    input_mode=hp["INPUT_MODE"],
    output_mode=hp["OUTPUT_MODE"],
    transforms=transforms,
    max_tokens=hp["MAX_TOKENS"],
)
test_dataset = SLTDataset(
    data_dir=dataset_path,
    split="test",
    input_mode=hp["INPUT_MODE"],
    output_mode=hp["OUTPUT_MODE"],
    transforms=transforms,
    max_tokens=hp["MAX_TOKENS"],
)

Loaded metadata for dataset: The Greek Sign Language (GSL) Dataset
Loaded train annotations at /mnt/disk3Tb/slt-datasets/GSL/annotations.csv


Validating files: 100%|██████████| 8821/8821 [00:00<00:00, 271295.73it/s]


Dataset loaded correctly

Loaded metadata for dataset: The Greek Sign Language (GSL) Dataset
Loaded test annotations at /mnt/disk3Tb/slt-datasets/GSL/annotations.csv


Validating files: 100%|██████████| 881/881 [00:00<00:00, 248082.03it/s]

Dataset loaded correctly






### Display sample

In [7]:
from IPython.display import HTML

from posecraft.transforms import (
    CenterToKeypoint,
    FillMissing,
    FilterLandmarks,
    ReplaceNansWithZeros,
    InterpolateFrames,
)

visual_transforms = Compose(
    [
        FilterLandmarks(landmarks_mask),
        # CenterToKeypoint(center_keypoint=0),
        # NormalizeDistances(indices=(11, 12), distance_factor=0.2),
        FillMissing(),
        # InterpolateFrames(30),
        ReplaceNansWithZeros(),
    ]
)

anim = test_dataset.visualize_pose(pose_index, transforms=visual_transforms)
HTML(anim.to_jshtml())

<IPython.core.display.Javascript object>

In [8]:
anim.save(f"{OUTPUT_PATH}/sample.mp4", writer="ffmpeg")

In [9]:
src, tgt = test_dataset[pose_index]
src = src.unsqueeze(0)
tgt = tgt.unsqueeze(0)

## Model

### Definition

In [10]:
from KeypointsTransformer import KeypointsTransformer


num_keypoints = landmarks_mask.sum().item()
in_features = int(num_keypoints * (3 if hp["USE_3D"] else 2))

model = KeypointsTransformer(
    src_len=hp["MAX_FRAMES"],
    tgt_len=hp["MAX_TOKENS"],
    in_features=in_features,
    tgt_vocab_size=train_dataset.tokenizer.vocab_size,
    d_model=hp["D_MODEL"],
    num_encoder_layers=hp["NUM_ENCODER_LAYERS"],
    num_decoder_layers=hp["NUM_DECODER_LAYERS"],
    dropout=hp["DROPOUT"],
    interp=True,
)



In [11]:
from Translator import Translator

translator = Translator(device, hp["MAX_TOKENS"])

In [12]:
best_model_path = (
    "checkpoint/GSL-frosty-haze-24-best-epoch=201-step=27876-val_loss=0.40.ckpt"
)

checkpoint = torch.load(best_model_path, map_location=device)["state_dict"]
adjusted_checkpoint = {}

for key, value in checkpoint.items():
    if key.startswith("model."):
        adjusted_key = key[len("model.") :]  # Elimina el prefijo 'model.'
        adjusted_checkpoint[adjusted_key] = value
    else:
        adjusted_checkpoint[key] = value

model.load_state_dict(adjusted_checkpoint)

<All keys matched successfully>

In [13]:
model = model.to(device)
model = model.eval()
src = src.to(device)
tgt = tgt.to(device)

In [14]:
BOS_IDX = train_dataset.tokenizer.cls_token_id
EOS_IDX = train_dataset.tokenizer.sep_token_id

### Interpretability

In [15]:
idx2word = translator.translate(src, model, "greedy", train_dataset.tokenizer)
idx2word = ["BOS " + s + " EOS" for s in idx2word]
idx2word = idx2word[0].split()
idx2word

['BOS', 'ΓΕΙΑ', 'ΕΓΩ(1)', 'ΜΠΟΡΩ', 'ΒΟΗΘΕΙΑ', 'ΠΩΣ;', 'EOS']

#### Encoder Self-Attention

In [16]:
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

attn_output_weights_list = []


def attention_hook(module, input, output):  # input: (query, key, value)
    _, attn_output_weights = output  # output: (attn_output, attn_output_weights)
    attn_output_weights_list.append(attn_output_weights[0].detach().cpu().numpy())


hook_handles = []
for layer in range(hp["NUM_ENCODER_LAYERS"]):
    self_attn_module = model.transformer.encoder.layers[layer].self_attn
    hook_handles.append(self_attn_module.register_forward_hook(attention_hook))

# Inference
translator.greedy_decode(src, model, BOS_IDX, EOS_IDX)

for handle in hook_handles:
    handle.remove()

fig, axes = plt.subplots(1, hp["NUM_ENCODER_LAYERS"], figsize=(10, 5), sharey=True)
for layer, attn_weights in enumerate(attn_output_weights_list):
    ax = (
        axes[layer] if hp["NUM_ENCODER_LAYERS"] > 1 else axes
    )  # Handle case with only one layer
    sent = np.arange(hp["MAX_FRAMES"])
    sns.heatmap(
        attn_weights,
        ax=ax,
        xticklabels=sent,
        yticklabels=sent,
        square=True,
        cbar=False,
    )  # vmin=0.0, vmax=1.0)
    ax.set_title(f"Layer {layer+1}")

plt.savefig(f"{OUTPUT_PATH}/attn_self_heatmaps_encoder_layers.png")
plt.show()

<IPython.core.display.Javascript object>

#### Decoder Self-Attention

In [17]:
def reorganize_list(input_list, N):
    grouped_list = []
    for i in range(N):
        grouped_list.extend(input_list[i::N])
    return grouped_list

In [18]:
import seaborn as sns

attn_output_weights_list = []


def attention_hook(module, input, output):  # input: (query, key, value)
    _, attn_output_weights = output  # output: (attn_output, attn_output_weights)
    print(output[0].shape, output[1].shape)
    attn_output_weights_list.append(attn_output_weights[0].detach().cpu().numpy())


hook_handles = []
for layer in range(hp["NUM_DECODER_LAYERS"]):
    self_attn_module = model.transformer.decoder.layers[layer].self_attn
    hook_handles.append(self_attn_module.register_forward_hook(attention_hook))

# Inference
translator.greedy_decode(src, model, BOS_IDX, EOS_IDX)

for handle in hook_handles:
    handle.remove()

rows = len(idx2word) - 1  # from BOS to EOS-1
fig, axes = plt.subplots(rows, hp["NUM_DECODER_LAYERS"], figsize=(20, 20), sharey=True)
attn_output_weights_list = reorganize_list(
    attn_output_weights_list, hp["NUM_DECODER_LAYERS"]
)
for layer, attn_weights in enumerate(attn_output_weights_list):
    i, j = divmod(layer, rows)
    print(i, j, layer)
    ax = axes[j, i]
    tgt_sent = idx2word[1 : attn_weights.shape[0] + 1]
    sns.heatmap(
        attn_weights,
        ax=ax,
        xticklabels=tgt_sent,
        yticklabels=tgt_sent,
        square=True,
        cbar=False,
    )  # vmin=0.0, vmax=1.0)
    ax.set_yticklabels(tgt_sent, rotation=0)
    ax.set_xticklabels(tgt_sent, rotation=90)
    ax.set_title(f"Layer {i+1}") if layer % rows == 0 else None

plt.subplots_adjust(wspace=0.4, hspace=0.4)

plt.savefig(f"{OUTPUT_PATH}/attn_self_heatmaps_decoder_layers.png", dpi=150)
plt.show()

torch.Size([1, 1, 16]) torch.Size([1, 1, 1])
torch.Size([1, 1, 16]) torch.Size([1, 1, 1])
torch.Size([1, 1, 16]) torch.Size([1, 1, 1])
torch.Size([1, 1, 16]) torch.Size([1, 1, 1])
torch.Size([1, 2, 16]) torch.Size([1, 2, 2])
torch.Size([1, 2, 16]) torch.Size([1, 2, 2])
torch.Size([1, 2, 16]) torch.Size([1, 2, 2])
torch.Size([1, 2, 16]) torch.Size([1, 2, 2])
torch.Size([1, 3, 16]) torch.Size([1, 3, 3])
torch.Size([1, 3, 16]) torch.Size([1, 3, 3])
torch.Size([1, 3, 16]) torch.Size([1, 3, 3])
torch.Size([1, 3, 16]) torch.Size([1, 3, 3])
torch.Size([1, 4, 16]) torch.Size([1, 4, 4])
torch.Size([1, 4, 16]) torch.Size([1, 4, 4])
torch.Size([1, 4, 16]) torch.Size([1, 4, 4])
torch.Size([1, 4, 16]) torch.Size([1, 4, 4])
torch.Size([1, 5, 16]) torch.Size([1, 5, 5])
torch.Size([1, 5, 16]) torch.Size([1, 5, 5])
torch.Size([1, 5, 16]) torch.Size([1, 5, 5])
torch.Size([1, 5, 16]) torch.Size([1, 5, 5])
torch.Size([1, 6, 16]) torch.Size([1, 6, 6])
torch.Size([1, 6, 16]) torch.Size([1, 6, 6])
torch.Size

<IPython.core.display.Javascript object>

0 0 0
0 1 1
0 2 2
0 3 3
0 4 4
0 5 5
1 0 6
1 1 7
1 2 8
1 3 9
1 4 10
1 5 11
2 0 12
2 1 13
2 2 14
2 3 15
2 4 16
2 5 17
3 0 18
3 1 19
3 2 20
3 3 21
3 4 22
3 5 23


#### Decoder Cross-Attention

In [20]:
import seaborn as sns

attn_output_weights_list = []


def attention_hook(module, input, output):  # input: (query, key, value)
    _, attn_output_weights = output  # output: (attn_output, attn_output_weights)
    print(output[0].shape, output[1].shape)
    attn_output_weights_list.append(attn_output_weights[0].detach().cpu().numpy())


hook_handles = []
for layer in range(hp["NUM_DECODER_LAYERS"]):
    multihead_attn_module = model.transformer.decoder.layers[layer].multihead_attn
    hook_handles.append(multihead_attn_module.register_forward_hook(attention_hook))

# Inference
translator.greedy_decode(src, model, BOS_IDX, EOS_IDX)

for handle in hook_handles:
    handle.remove()

rows = len(idx2word) - 1  # from BOS to EOS-1
fig, axes = plt.subplots(rows, hp["NUM_DECODER_LAYERS"], figsize=(20, 20), sharey=True)
attn_output_weights_list = reorganize_list(
    attn_output_weights_list, hp["NUM_DECODER_LAYERS"]
)
for layer, attn_weights in enumerate(attn_output_weights_list):
    i, j = divmod(layer, rows)
    print(i, j, layer)
    ax = axes[j, i]
    sent = np.arange(hp["MAX_FRAMES"])
    tgt_sent = idx2word[1 : attn_weights.shape[0] + 1]
    sns.heatmap(
        attn_weights,
        ax=ax,
        xticklabels=sent,
        yticklabels=tgt_sent,
        square=True,
        cbar=False,
    )  # vmin=0.0, vmax=1.0)
    ax.set_aspect("auto")
    ax.set_yticklabels(tgt_sent, rotation=0)
    ax.set_xticklabels(sent, rotation=90)
    ax.set_title(f"Layer {i+1}") if layer % rows == 0 else None

plt.subplots_adjust(wspace=0.4, hspace=0.4)

plt.savefig(f"{OUTPUT_PATH}/attn_cross_heatmaps_decoder_layers.png", dpi=150)
plt.show()

torch.Size([1, 1, 16]) torch.Size([1, 1, 30])
torch.Size([1, 1, 16]) torch.Size([1, 1, 30])
torch.Size([1, 1, 16]) torch.Size([1, 1, 30])
torch.Size([1, 1, 16]) torch.Size([1, 1, 30])
torch.Size([1, 2, 16]) torch.Size([1, 2, 30])
torch.Size([1, 2, 16]) torch.Size([1, 2, 30])
torch.Size([1, 2, 16]) torch.Size([1, 2, 30])
torch.Size([1, 2, 16]) torch.Size([1, 2, 30])
torch.Size([1, 3, 16]) torch.Size([1, 3, 30])
torch.Size([1, 3, 16]) torch.Size([1, 3, 30])
torch.Size([1, 3, 16]) torch.Size([1, 3, 30])
torch.Size([1, 3, 16]) torch.Size([1, 3, 30])
torch.Size([1, 4, 16]) torch.Size([1, 4, 30])
torch.Size([1, 4, 16]) torch.Size([1, 4, 30])
torch.Size([1, 4, 16]) torch.Size([1, 4, 30])
torch.Size([1, 4, 16]) torch.Size([1, 4, 30])
torch.Size([1, 5, 16]) torch.Size([1, 5, 30])
torch.Size([1, 5, 16]) torch.Size([1, 5, 30])
torch.Size([1, 5, 16]) torch.Size([1, 5, 30])
torch.Size([1, 5, 16]) torch.Size([1, 5, 30])
torch.Size([1, 6, 16]) torch.Size([1, 6, 30])
torch.Size([1, 6, 16]) torch.Size(

<IPython.core.display.Javascript object>

0 0 0
0 1 1
0 2 2
0 3 3
0 4 4
0 5 5
1 0 6
1 1 7
1 2 8
1 3 9
1 4 10
1 5 11
2 0 12
2 1 13
2 2 14
2 3 15
2 4 16
2 5 17
3 0 18
3 1 19
3 2 20
3 3 21
3 4 22
3 5 23
