In [3]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2

In [2]:
import os, sys

module_path = os.path.abspath(os.path.join("../.."))
sys.path.append(module_path)

## Dataset and hyperparameters loading

In [1]:
from torchvision.transforms.v2 import Compose
from hyperparameters import load_hyperparameters_from_json

from SLTDataset import SLTDataset
from posecraft.Pose import Pose


DATASET = "GSL"
dataset_path = f"/mnt/disk3Tb/slt-datasets/{DATASET}"
hp = load_hyperparameters_from_json(f"config/{DATASET}.json")

landmarks_mask = Pose.get_components_mask(hp["LANDMARKS_USED"])
transforms: Compose = Compose(hp["TRANSFORMS"])

train_dataset = SLTDataset(
    data_dir=dataset_path,
    split="train",
    input_mode=hp["INPUT_MODE"],
    output_mode=hp["OUTPUT_MODE"],
    transforms=transforms,
    max_tokens=hp["MAX_TOKENS"],
)
val_dataset = SLTDataset(
    data_dir=dataset_path,
    split="val",
    input_mode=hp["INPUT_MODE"],
    output_mode=hp["OUTPUT_MODE"],
    transforms=transforms,
    max_tokens=hp["MAX_TOKENS"],
)
test_dataset = SLTDataset(
    data_dir=dataset_path,
    split="test",
    input_mode=hp["INPUT_MODE"],
    output_mode=hp["OUTPUT_MODE"],
    transforms=transforms,
    max_tokens=hp["MAX_TOKENS"],
)

Loaded metadata for dataset: The Greek Sign Language (GSL) Dataset
Loaded train annotations at /mnt/disk3Tb/slt-datasets/GSL/annotations.csv


Validating files: 100%|██████████| 8821/8821 [00:00<00:00, 275462.77it/s]


Dataset loaded correctly

Loaded metadata for dataset: The Greek Sign Language (GSL) Dataset
Loaded val annotations at /mnt/disk3Tb/slt-datasets/GSL/annotations.csv


Validating files: 100%|██████████| 588/588 [00:00<00:00, 229333.34it/s]


Dataset loaded correctly

Loaded metadata for dataset: The Greek Sign Language (GSL) Dataset
Loaded test annotations at /mnt/disk3Tb/slt-datasets/GSL/annotations.csv


Validating files: 100%|██████████| 881/881 [00:00<00:00, 251287.44it/s]

Dataset loaded correctly






### Display sample

In [4]:
from IPython.display import HTML
from random import randint


idx = randint(0, len(train_dataset))
# avoid using the last transform as it flattens the keypoints
visual_transforms: Compose = Compose(hp["TRANSFORMS"][:-1])
anim = train_dataset.visualize_pose(idx, transforms=visual_transforms)
HTML(anim.to_jshtml())

<IPython.core.display.Javascript object>

### Text tokenization

In [4]:
import torch
import numpy as np
from sklearn.utils.class_weight import compute_class_weight


class_weights_complete = None

if hp["USE_CLASS_WEIGHTS"]:
    texts = train_dataset.annotations[hp["OUTPUT_MODE"]].tolist()
    tokenized_sequences = train_dataset.tokenizer(
        texts, padding="max_length", max_length=25
    )
    flattened_tgts: list[list[int]] = [
        item for sublist in tokenized_sequences for item in sublist
    ]  # type: ignore
    token_ids = sorted(list(set(flattened_tgts)))
    class_weights = compute_class_weight(
        "balanced", classes=np.array(token_ids), y=flattened_tgts
    )
    class_weights_complete = torch.ones(train_dataset.tokenizer.vocab_size)
    class_weights_complete[token_ids] = torch.from_numpy(class_weights).float()

### Dataloader generation

In [5]:
import torch
from torch.utils.data import DataLoader


NUM_WORKERS = 4

train_loader = DataLoader(
    train_dataset,
    batch_size=hp["BATCH_SIZE"],
    num_workers=NUM_WORKERS,
    shuffle=True,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=hp["BATCH_SIZE"],
    num_workers=NUM_WORKERS,
    shuffle=True,
)
test_loader = DataLoader(
    test_dataset,
    batch_size=hp["BATCH_SIZE"],
    num_workers=NUM_WORKERS,
    shuffle=True,
)

In [6]:
for src, tgt in train_loader:
    print(f"Source shape (Batch, Frames, Keypoints): {src.shape}")
    print(f"Target shape (Batch, Tokens): {tgt.shape}")
    break

Source shape (Batch, Frames, Keypoints): torch.Size([64, 220, 150])
Target shape (Batch, Tokens): torch.Size([64, 20])


## Model

### Definition

In [6]:
import lightning.pytorch.utilities.model_summary.model_summary as model_summary

from LightningKeypointsTransformer import LKeypointsTransformer


device = torch.device(
    "mps"
    if torch.backends.mps.is_available()
    else ("cuda" if torch.cuda.is_available() else "cpu")
)

l_model = LKeypointsTransformer(
    hp,
    device,
    train_dataset.tokenizer,
    interp=True,
)
model_summary.summarize(l_model, max_depth=10)

  from .autonotebook import tqdm as notebook_tqdm


True


   | Name                                                       | Type                            | Params | In sizes                                                                     | Out sizes                   
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
0  | model                                                      | KeypointsTransformer            | 3.4 M  | [[1, 220, 150], [1, 20], [220, 220], [1, 220], [20, 20], [1, 20]]            | [1, 20, 402]                
1  | model.src_keyp_emb                                         | Conv1DEmbedder                  | 35.8 K | [1, 220, 150]                                                                | [1, 220, 128]               
2  | model.src_keyp_emb.conv1d_1                                | Conv1d                          | 19.3 K | [1, 150, 220]    

### Training

In [9]:
import json

import lightning.pytorch as L
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import WandbLogger


wandb_logger = WandbLogger(project=DATASET)
wandb_logger.experiment.config.update(hp)
results_path = f"results/{DATASET}/{wandb_logger.experiment.name}"
os.makedirs(results_path, exist_ok=True)

with open(f"{results_path}/hp.json", "w") as f:
    json.dump(hp, f, default=str, indent=4)

trainer = L.Trainer(
    logger=wandb_logger,
    callbacks=[
        EarlyStopping(monitor="val_accuracy", mode="max", patience=30),
        ModelCheckpoint(
            monitor="val_loss",
            dirpath=results_path,
            filename=f"best-{{epoch:02d}}-{{step:02d}}-{{val_loss:.2f}}",
            mode="min",
        ),
    ],
    max_epochs=3,
)

/home/pdalbianco/anaconda3/envs/slt_datasets/lib/python3.11/site-packages/lightning/pytorch/loggers/wandb.py:390: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [10]:
trainer.fit(
    model=l_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

/home/pdalbianco/anaconda3/envs/slt_datasets/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:652: Checkpoint directory /home/pdalbianco/Github/slt_models_tryout/src/results/GSL/lunar-dust-33 exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type                 | Params | In sizes                                                          | Out sizes   
-------------------------------------------------------------------------------------------------------------------------------------
0 | model    | KeypointsTransformer | 3.4 M  | [[1, 220, 150], [1, 20], [220, 220], [1, 220], [20, 20], [1, 20]] | [1, 20, 402]
1 | accuracy | MulticlassAccuracy   | 0      | ?                                                                 | ?           
-------------------------------------------------------------------------------------------------------------------------------------
3.4 M     Trainable params
0         Non-trainable params
3.

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/pdalbianco/anaconda3/envs/slt_datasets/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:492: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


Epoch 2: 100%|██████████| 138/138 [00:08<00:00, 15.52it/s, v_num=d3gf]     

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 138/138 [00:08<00:00, 15.51it/s, v_num=d3gf]


In [11]:
import glob


checkpoint = glob.glob(f"{results_path}/best*")[0]

trainer.test(
    model=l_model,
    dataloaders=test_loader,
    ckpt_path=checkpoint,
)

if l_model.translation_results_df is not None:
    l_model.translation_results_df.to_csv(
        f"{results_path}/translations.csv", index=False
    )

Restoring states from the checkpoint path at results/GSL/lunar-dust-33/best-epoch=00-step=138-val_loss=4.64.ckpt


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at results/GSL/lunar-dust-33/best-epoch=00-step=138-val_loss=4.64.ckpt
/home/pdalbianco/anaconda3/envs/slt_datasets/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:492: Your `test_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


Testing DataLoader 0: 100%|██████████| 14/14 [00:20<00:00,  0.68it/s]
