In [1]:
try:
	from google.colab import drive

	IN_COLAB = True
	print("Running on Google Colab")
	drive.mount('/content/drive')
except:
	IN_COLAB = False
	print("Not running on Google Colab")

Running on Google Colab
Mounted at /content/drive


## Dataset download

In [2]:
if IN_COLAB:
	!pip install git+https://github.com/sign-language-processing/datasets.git -q

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.4/85.4 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.1/5.1 MB[0m [31m57.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.5/60.5 MB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for sign-language-datasets (setup.py) ... [?25l[?25hdone


In [3]:
import tensorflow_datasets as tfds
# import sign_language_datasets.datasets
from sign_language_datasets.utils.torch_dataset import TFDSTorchDataset
from sign_language_datasets.datasets.config import SignDatasetConfig

In [4]:
DATA_DIR = "." if not IN_COLAB else "/content/drive/MyDrive/Académico/Doctorado/SLT Datasets/RWTH"

In [5]:
config = SignDatasetConfig(name="rwth_phoenix2014_t_poses", version="3.0.0", include_video=False, include_pose="holistic")
rwth_phoenix2014_t = tfds.load(name='rwth_phoenix2014_t', builder_kwargs=dict(config=config), data_dir=DATA_DIR)



In [6]:
train_dataset = TFDSTorchDataset(rwth_phoenix2014_t["train"])
validation_dataset = TFDSTorchDataset(rwth_phoenix2014_t["validation"])
test_dataset = TFDSTorchDataset(rwth_phoenix2014_t["test"])

In [7]:
# import itertools


# for datum in itertools.islice(train_dataset, 0, 5):
# 	print((datum.keys()))
# 	print(f"Pose shape: {datum['pose']['data'].shape}")
# 	print(f"Text: {datum['text'].decode('utf-8')}")
# 	print()

## Dataset analysis

In [8]:
# src_lenghts = []
# texts = []

# for datum in rwth_phoenix2014_t["train"]:
# 	src_lenghts.append(datum['pose']['data'].shape[0])
# 	texts.append(datum['text'].numpy().decode('utf-8'))

### Frames analysis for padding and truncation

In [9]:
# import pandas as pd


# src_lengths_df = pd.Series(src_lenghts)
# src_lengths_df.describe(percentiles=[.75, .9, .95, .99])

In [10]:
# src_lengths_df.hist()

### Text tokenization and analysis for padding and truncation

In [11]:
from transformers import AutoTokenizer


TEXT_MODEL = "google-bert/bert-base-german-cased"
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/433 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/255k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/485k [00:00<?, ?B/s]

In [12]:
BOS_IDX = tokenizer.cls_token_id if tokenizer.cls_token_id is not None else -1
EOS_IDX = tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1
PAD_IDX = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1

print(f"BOS_IDX: {BOS_IDX}, EOS_IDX: {EOS_IDX}, PAD_IDX: {PAD_IDX}")

BOS_IDX: 3, EOS_IDX: 4, PAD_IDX: 0


In [13]:
# tokenized_sequences = tokenizer(texts, padding=True)

In [14]:
# tokens_length = [len(tokens) for tokens in tokenized_sequences['input_ids']]
# print(max(tokens_length))

In [15]:
# print(texts[0])
# print(tokenized_sequences[0].ids)

In [16]:
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
import torch


USE_CLASS_WEIGHTS = False

if USE_CLASS_WEIGHTS:
	flattened_tgts = [item for sublist in tokenized_sequences["input_ids"] for item in sublist if item != PAD_IDX]
	token_ids = sorted(list(set(flattened_tgts)))
	class_weights = compute_class_weight("balanced", classes=np.array(token_ids), y=flattened_tgts)
	class_weights_complete = torch.ones(tokenizer.vocab_size)
	class_weights_complete[token_ids] = torch.from_numpy(class_weights).float()

In [17]:
if USE_CLASS_WEIGHTS:
	print(tokenizer.convert_ids_to_tokens([i for i in range(10)]))
	print(class_weights_complete[:10].tolist())

## Preprocessing and dataloader generation

In [18]:
import torch


MAX_FRAMES = 259
MAX_TOKENS = 80
BATCH_SIZE = 64

KEYPOINTS_USED = ["pose", "lhand", "rhand"]

holistic_landmarks = ["pose" for i in range(33)] + ["face" for i in range(468)] + ["lhand" for i in range(21)] + ["rhand" for i in range(21)]
keypoints_mask = torch.tensor([True if kp in KEYPOINTS_USED else False for kp in holistic_landmarks])

In [19]:
import torch
from torch import Tensor
import torch.utils.data as utils


def flatten_keypoints(datum: Tensor):
	'''
		Reshape the pose of datum only keeping the first dimension S (sequence lenght) and flattening the number of keypoints K and their dimensions D.
		Args:
			datum: Tensor of shape (S, D, K)
		Returns:
			Tensor of shape (frames, D * K)
	'''
	return datum.reshape(datum.shape[0], -1)

def filter_keypoints(datum: Tensor, mask: Tensor):
	'''
		Keep only the keypoints whose position is in landmarks.
		Args:
			datum: Tensor of shape (S, K, D)
			landmarks: list of landmarks to keep
		Returns:
			Tensor of shape (S, K_new, D)
	'''
	# delete dummy dimension and transpose to (K, S, D) for filtering
	datum = datum.squeeze(1).permute(1, 0, 2)
	datum = datum[mask]
	return datum.permute(1, 0, 2)

def pad_truncate_src(datum: Tensor, max_len: int):
	'''Pad the pose to max_len or truncate it'''
	if datum.size(0) < max_len:
		return torch.cat([datum, torch.zeros(max_len - datum.size(0), datum.size(1))])
	else:
		return datum[:max_len]

def collate_fn(batch):
	src = [
		pad_truncate_src(
			flatten_keypoints(
				filter_keypoints(item['pose']['data'], keypoints_mask)
			), MAX_FRAMES)
	for item in batch]
	src = torch.stack(src)
	tgt = [str(item['text'].decode('utf-8')) for item in batch]
	tgt = tokenizer(tgt, padding='max_length', max_length=MAX_TOKENS, return_tensors='pt').input_ids
	return src, tgt

train_loader = utils.DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)
validation_loader = utils.DataLoader(validation_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)
test_loader = utils.DataLoader(test_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)

In [20]:
# for src, tgt in train_loader:
#   print(src.shape)
#   print(tgt.shape)
#   break

## Model

### Model definition

In [21]:
if IN_COLAB:
	!pip install lightning -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m841.5/841.5 kB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m801.6/801.6 kB[0m [31m22.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m35.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m31.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.1/14.1 MB[0m [31m43.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m731.7/731.7 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m410.6/410.6 MB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━

In [22]:
import torch
from torch import Tensor


def generate_square_subsequent_mask(size: int, device: torch.device):
    '''
        Generates triangular (size, size) mask for the transformer model.
    '''
    mask = (torch.triu(torch.ones((size, size))) == 1).transpose(0, 1).to(device)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_target_mask(tgt: Tensor, pad_idx: int, device: torch.device):
    '''
        Create target mask and padding mask for the transformer model.
        Args:
            tgt: (N, T) where N is the batch size and T is the target sequence length
            pad_idx: padding index
            device: torch device
        Returns:
            tgt_mask: (T, T), so to evaluate the i-th token, we can only look at the first i tokens, for all i's
            tgt_padding_mask: (N, T), for masking pad tokens
    '''
    tgt_seq_len = tgt.shape[1]
    tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device)
    tgt_padding_mask = (tgt == pad_idx)
    return tgt_mask, tgt_padding_mask

In [23]:
from torch import Tensor, nn
from torch.nn.functional import relu


class Conv1DEmbedder(nn.Module):

	def __init__(self, in_channels: int, out_channels: int):
		super(Conv1DEmbedder, self).__init__()
		self.conv1d_1 = nn.Conv1d(in_channels, 512, 1)
		self.conv1d_2 = nn.Conv1d(512, 256, 1)
		self.conv1d_3 = nn.Conv1d(256, 128, 1)
		self.conv1d_4 = nn.Conv1d(128, out_channels, 1)

	def forward(self, x: Tensor) -> Tensor:
		'''
			Args:
				x: (N, S, E) where N is the batch size, S is the sequence length and E is the embedding size
			Returns:
				(N, S, E) where E is the embedding size
		'''
		x = x.permute(0, 2, 1)
		x = relu(self.conv1d_1(x))
		x = relu(self.conv1d_2(x))
		x = relu(self.conv1d_3(x))
		x = relu(self.conv1d_4(x))
		return x.permute(0, 2, 1)

In [24]:
import math
import torch
from torch import nn, Tensor


class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        '''
        Apply positional encoding to the input tensor.
        Args:
            x: (N, S, E)
        Returns:
            Tensor of shape (N, S, E)
        '''
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [25]:
import math
from torch import nn, Tensor


class TokenEmbedding(nn.Module):
    '''Code taken from https://pytorch.org/tutorials/beginner/translation_transformer.html'''

    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        '''
            Applies token embedding to the target tensor.
            Args:
                tokens: (N, T)
            Returns:
                Tensor of shape (N, T, E)
        '''
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

In [26]:
from torch import Tensor, nn
from torch import nn, Tensor
from transformers import AutoModel
from torch.nn.functional import softmax


class KeypointsTransformer(nn.Module):
    '''
        Transformer model for sign language translation. It uses a 1D convolutional layer to embed the keypoints and a transformer to translate the sequence.
        S refers to the source sequence length, T to the target sequence length, N to the batch size, and E is the features number.
    '''

    def __init__(self,
                src_max_len: int,
                tgt_max_len: int,
                in_features: int,
                tgt_vocab_size: int,
                d_model: int = 64,
                num_encoder_layers: int = 6,
                dropout: float = 0.1,
                use_bert_embeddings = False,
                ):
        '''
            Args:
                src_max_len: max length of the source sequence
                tgt_max_len: max length of the target sequence
                in_features: number of features of the input (amount of keypoints * amount of coordinates)
                tgt_vocab_size: size of the target vocabulary
                d_model: number of dimensions of the encoding vectors (default=64). Must be even so the positional encoding works.
                kernel_size: the size of the 1D convolution window (default=5)
                keys_initial_emb_size: the size of the keys embedding (default=128)
        '''
        super(KeypointsTransformer, self).__init__()

        self.src_keyp_emb = Conv1DEmbedder(in_channels=in_features, out_channels=d_model)
        self.src_pe = PositionalEncoding(d_model=d_model, max_len=src_max_len)
        self.use_bert_embeddings = use_bert_embeddings
        if self.use_bert_embeddings:
            self.tgt_tok_emb = AutoModel.from_pretrained(TEXT_MODEL)
            self.tgt_tok_emb.requires_grad_(False)
            self.tgt_tok_conv_emb = Conv1DEmbedder(in_channels=768, out_channels=d_model)
        else:
            self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, d_model)
        self.tgt_pe = PositionalEncoding(d_model=d_model, max_len=tgt_max_len)
        self.transformer = nn.Transformer(d_model=d_model, num_encoder_layers=num_encoder_layers, dropout=dropout, batch_first=True)
        self.generator = nn.Linear(d_model, tgt_vocab_size)

    def embed_tgt(self, tgt: Tensor):
        if self.use_bert_embeddings:
            tgt_emb = self.tgt_tok_emb(tgt, attention_mask=(tgt == PAD_IDX)).last_hidden_state
            tgt_emb = self.tgt_tok_conv_emb(tgt_emb)
        else:
            tgt_emb = self.tgt_tok_emb(tgt)
        return tgt_emb


    def forward(self,
                src: Tensor,
                tgt: Tensor,
                tgt_mask: Tensor,
                tgt_padding_mask: Tensor
    ):
        '''
            Forward pass of the model.
            Args:
                src: (N, S, E)
                tgt: (N, T, E)
                tgt_mask: (T, T)
                tgt_padding_mask: (N, T)
            Returns:
                Tensor of shape (N, T, tgt_vocab_size)
        '''
        src_emb = self.src_keyp_emb(src)
        src_emb = self.src_pe(src_emb)
        tgt_emb = self.embed_tgt(tgt)
        tgt_emb = self.tgt_pe(tgt_emb)
        # src_mask and src_key_padding_mask are set to none as we use the whole input at every timestep
        outs = self.transformer(
            src = src_emb,
            tgt = tgt_emb,
            src_mask = None,
            tgt_mask = tgt_mask,
            src_key_padding_mask = None,
            tgt_key_padding_mask = tgt_padding_mask)
        # return softmax(self.generator(outs), dim=0)
        return self.generator(outs)

    def encode(self, src: Tensor):
        src_emb = self.src_pe(self.src_keyp_emb(src))
        return self.transformer.encoder(src_emb, None)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        tgt = tgt.to(torch.int64)
        tgt_emb = self.embed_tgt(tgt)
        tgt_emb = self.tgt_pe(tgt_emb)
        return self.transformer.decoder(tgt_emb, memory, tgt_mask)

In [27]:
D_MODEL = 128
NUM_ENCODER_LAYERS = 2
DROPOUT = 0.1
USE_BERT_EMBEDDINGS = False

num_keypoints = keypoints_mask.sum().item()
IN_FEATURES = num_keypoints*3

model = KeypointsTransformer(
    src_max_len=MAX_FRAMES,
    tgt_max_len=MAX_TOKENS,
    in_features=IN_FEATURES,
    tgt_vocab_size=tokenizer.vocab_size,
    d_model=D_MODEL,
	  num_encoder_layers=NUM_ENCODER_LAYERS,
    dropout=DROPOUT,
    use_bert_embeddings=USE_BERT_EMBEDDINGS,
)

In [28]:
if IN_COLAB:
	!pip install modelsummary -q

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for modelsummary (setup.py) ... [?25l[?25hdone


In [29]:
from modelsummary import summary


DEVICE = torch.device("mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu"))
BATCH_SIZE_TEST = 1


src = torch.randn(BATCH_SIZE_TEST, MAX_FRAMES, IN_FEATURES).to(DEVICE)
tgt = torch.randint(0, tokenizer.vocab_size, (BATCH_SIZE_TEST, MAX_TOKENS)).to(DEVICE)
tgt_mask = torch.zeros(MAX_TOKENS, MAX_TOKENS).to(DEVICE)
tgt_padding_mask = torch.randint(0, 2, (BATCH_SIZE_TEST, MAX_TOKENS)).bool().to(DEVICE)
print(src.shape, tgt.shape, tgt_mask.shape, tgt_padding_mask.shape)

model = model.to(DEVICE)
summary(model, src, tgt, tgt_mask, tgt_padding_mask)

torch.Size([1, 259, 225]) torch.Size([1, 80]) torch.Size([80, 80]) torch.Size([1, 80])
-----------------------------------------------------------------------
             Layer (type)                Input Shape         Param #
         Conv1DEmbedder-1             [-1, 259, 225]               0
                 Conv1d-2             [-1, 225, 259]         115,712
                 Conv1d-3             [-1, 512, 259]         131,328
                 Conv1d-4             [-1, 256, 259]          32,896
                 Conv1d-5             [-1, 128, 259]          16,512
     PositionalEncoding-6             [-1, 259, 128]               0
                Dropout-7             [-1, 259, 128]               0
         TokenEmbedding-8                   [-1, 80]               0
              Embedding-9                   [-1, 80]       3,840,000
    PositionalEncoding-10              [-1, 80, 128]               0
               Dropout-11              [-1, 80, 128]               0
           Tr



## Model training

In [30]:
if IN_COLAB:
	!pip install wandb -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m266.1/266.1 kB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [31]:
from typing import Literal


def generate_linear_mask(batch_size: int, tgt_len: int, start_index: int, device: torch.device) -> Tensor:
    mask = torch.zeros(batch_size, tgt_len).to(device)
    mask.fill_(float('-inf'))
    mask[:, :start_index] = 0
    return mask

class Translator:
    # TODO: implement batch_greedy_decode and batch_beam_decode

    def __init__(self, model: KeypointsTransformer, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def translate(self, src, method: Literal["greedy", "beam"], k: int = 5) -> str:
        with torch.no_grad():
            if method == "greedy":
                out = self.greedy_decode(src)
            elif method == "beam":
                out = self.beam_decode(src, k)
            else:
                raise ValueError("Invalid method. Choose between 'greedy' and 'beam'.")
        return tokenizer.decode([int(x) for x in out.tolist()], skip_special_tokens=True)

    def greedy_decode(self, src: Tensor) -> Tensor:
        memory = self.model.encode(src)
        ys = torch.ones(1, 1).fill_(BOS_IDX).to(DEVICE)
        for i in range(MAX_TOKENS-1):
            tgt_mask = generate_square_subsequent_mask(ys.size(1), DEVICE)
            out = self.model.decode(ys, memory, tgt_mask)
            prob = self.model.generator(out[:, -1])
            _, next_word = torch.max(prob, dim=1)
            next_word = next_word.item()
            ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
            if next_word == EOS_IDX:
                break
        return ys.squeeze()

    def beam_decode(self, src: Tensor, k: int) -> Tensor:
        # We use first dimension corresponding to the batch to predict over the k posible beams
        memory = self.model.encode(src).repeat(k, 1, 1)
        ys = torch.ones(k, 1).fill_(BOS_IDX).to(DEVICE)
        probs = torch.ones(k, 1).to(DEVICE)
        for i in range(MAX_TOKENS-1):
            tgt_mask = torch.zeros(ys.size(1), ys.size(1)).to(DEVICE)
            out = self.model.decode(ys, memory, tgt_mask)
            prob = self.model.generator(out[:, -1])

            next_words_probs, next_words = torch.topk(prob, k=k)

            next_words_joint_probs = (next_words_probs * probs).view(-1)
            next_words_probs = next_words_probs.view(-1)
            next_words = next_words.view(-1)

            sorted_indices = torch.argsort(next_words_joint_probs, descending=True)

            next_words_probs = torch.index_select(next_words_probs, 0, sorted_indices)[:k]
            next_words = torch.index_select(next_words, 0, sorted_indices)[:k]

            probs = next_words_probs.clone()
            ys = torch.cat([ys, next_words.unsqueeze(1)], dim=1)
            if (next_words == EOS_IDX).all():
                break
        return ys[0].squeeze()

In [32]:
import torch
from torch import Tensor
from torch.optim import Adam
from torch.nn.functional import cross_entropy
from torchmetrics import Accuracy
import lightning as L
from typing import Literal
import pandas as pd
from torchmetrics.functional.text import bleu_score
import wandb


class LKeypointsTransformer(L.LightningModule):

    def __init__(self, model: KeypointsTransformer, num_classes: int):
        super().__init__()
        self.model = model
        self.loss_fn = cross_entropy
        self.accuracy = Accuracy(task="multiclass", num_classes=num_classes, ignore_index=PAD_IDX)
        self.translator = Translator(model, tokenizer)
        self.save_hyperparameters(ignore=['model'])

        self.ys_step = []
        self.beam_translations_step = []
        self.greedy_translations_step = []

    def forward(self, src: Tensor, tgt: Tensor, tgt_mask: Tensor, tgt_padding_mask: Tensor):
        return self.model(src, tgt, tgt_mask, tgt_padding_mask)

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=1e-3)
        return optimizer

    def run_on_batch(self, batch):
        src, tgt = batch
        # tgt_input and tgt_ouptut are displaced by one position, so tgt_input[i] is the input to the model and tgt_output[i] is the expected output
        tgt_input = tgt[:, :-1]
        tgt_mask, tgt_padding_mask = create_target_mask(tgt_input, PAD_IDX, DEVICE)
        logits = self.model(src, tgt_input, tgt_mask, tgt_padding_mask)
        tgt_output = tgt[:, 1:]
        loss = self.loss_fn(
            logits.reshape(-1, logits.shape[-1]),
            tgt_output.reshape(-1),
            ignore_index=PAD_IDX,
            weight=class_weights_complete.to(DEVICE) if USE_CLASS_WEIGHTS else None,
        )
        accuracy = self.accuracy(logits.reshape(-1, logits.shape[-1]), tgt_output.reshape(-1))
        return loss, accuracy

    def training_step(self, batch, batch_idx):
        loss, accuracy = self.run_on_batch(batch)
        self.log("train_loss", loss, on_epoch=True)
        self.log("train_accuracy", accuracy, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.run_on_batch(batch)
        self.log("val_loss", loss, on_epoch=True)
        self.log("val_accuracy", accuracy, on_epoch=True, batch_size=len(batch))
        return loss

    def test_step(self, batch, batch_idx):
        loss, accuracy = self.run_on_batch(batch)
        self.log("test_loss", loss)
        self.log("test_accuracy", accuracy)
        ys, preds_greedy, preds_beam = self.get_translations(batch, batch_idx)
        self.ys_step.extend(ys)
        self.greedy_translations_step.extend(preds_greedy)
        self.beam_translations_step.extend(preds_beam)
        return loss, accuracy

    def on_test_epoch_end(self):
        translation_results = [(y, trans_greedy, trans_beam) +
         tuple(bleu_score(trans_greedy, [y], n_gram=n).item() for n in range(1, 5)) +
         tuple(bleu_score(trans_beam, [y], n_gram=n).item() for n in range(1, 5))
            for y, trans_greedy, trans_beam in zip(self.ys_step, self.greedy_translations_step, self.beam_translations_step)]
        self.ys_step = []
        self.greedy_translations_step = []
        self.beam_translations_step = []
        translation_results_df = pd.DataFrame(translation_results, columns=["y", "trans_greedy", "trans_beam", "bleu_1_greedy", "bleu_2_greedy", "bleu_3_greedy", "bleu_4_greedy", "bleu_1_beam", "bleu_2_beam", "bleu_3_beam", "bleu_4_beam"])
        self.logger.log_table(key="translation-results", columns=list(translation_results_df.columns), data=translation_results)
        self.log("bleu_1_greedy", translation_results_df["bleu_1_greedy"].mean())
        self.log("bleu_2_greedy", translation_results_df["bleu_2_greedy"].mean())
        self.log("bleu_3_greedy", translation_results_df["bleu_3_greedy"].mean())
        self.log("bleu_4_greedy", translation_results_df["bleu_4_greedy"].mean())
        self.log("bleu_1_beam", translation_results_df["bleu_1_beam"].mean())
        self.log("bleu_2_beam", translation_results_df["bleu_2_beam"].mean())
        self.log("bleu_3_beam", translation_results_df["bleu_3_beam"].mean())
        self.log("bleu_4_beam", translation_results_df["bleu_4_beam"].mean())
        translation_results_df.to_csv(f"translation-results-{wandb_logger.experiment.name}.csv", index=False)

    def get_translations(self, batch, batch_idx):
        src, tgt = batch
        preds_greedy = []
        preds_beam = []
        ys = []
        for i in range(len(src)):
            # print(f"Batch {batch_idx}, sample {i}")
            # adds extra dimension representing the batch
            src_0 = src[i].unsqueeze(0)
            preds_greedy.append(self.translator.translate(src_0, method="greedy"))
            preds_beam.append(self.translator.translate(src_0, method="beam"))
            ys.append(tokenizer.decode([int(x) for x in tgt[i].tolist()], skip_special_tokens=True, clean_up_tokenization_spaces=True))
        return ys, preds_greedy, preds_beam


l_model = LKeypointsTransformer(model, tokenizer.vocab_size)

In [33]:
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger


PRECISION = 32

wandb_logger = WandbLogger(project="rwth", log_model="all")
wandb_logger.experiment.config.update({
	# System hyperparameters
	"DEVICE": DEVICE,
	"PRECISION": PRECISION,
	# Data hyperparameters
	"BATCH_SIZE": BATCH_SIZE,
	"MAX_FRAMES": MAX_FRAMES,
	"MAX_TOKENS": MAX_TOKENS,
	"TEXT_MODEL": TEXT_MODEL,
	"KEYPOINTS_USED": str(KEYPOINTS_USED),
	# Model hyperparameters
	"D_MODEL": D_MODEL,
	"DROPOUT": DROPOUT,
	"USE_BERT_EMBEDDINGS": USE_BERT_EMBEDDINGS,
	"NUM_ENCODER_LAYERS": NUM_ENCODER_LAYERS,
	# Training hyperparameters
	"USE_CLASS_WEIGHTS": USE_CLASS_WEIGHTS,
})

checkpoint_callback = ModelCheckpoint(
	monitor='val_loss',
	dirpath='checkpoints/',
	filename=f'rwth-{wandb_logger.experiment.name}-best-{{epoch:02d}}-{{step:02d}}-{{val_loss:.2f}}',
	mode='min',
	save_last=True
)
checkpoint_callback.CHECKPOINT_NAME_LAST = f"rwth-{wandb_logger.experiment.name}-last"

trainer = L.Trainer(
    logger=wandb_logger,
    default_root_dir="./checkpoint",
		precision=PRECISION,
    callbacks=[
		EarlyStopping(monitor="val_loss", mode="min", patience=5),
		checkpoint_callback,],
)

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(
    model=l_model,
    train_dataloaders=train_loader,
    val_dataloaders=validation_loader,
)

/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name     | Type                 | Params
--------------------------------------------------
0 | model    | KeypointsTransformer | 13.1 M
1 | accuracy | MulticlassAccuracy   | 0     
--------------------------------------------------
13.1 M    Trainable params
0         Non-trainable params
13.1 M    Total params
52.596    Total estimated model params size (MB)
INFO:lightning.pytorch.callbacks.model_summary:
  | Name     | Type                 | Params
--------------------------------------------------
0 | model    | KeypointsTransformer | 13.1 M
1 | accuracy | MulticlassAccuracy   | 0     
--------------------------------------------------
13.1 M    Trainable pa

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/lightning/pytorch/utilities/data.py:121: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
import os
import glob


CHKP = glob.glob(f"checkpoints/rwth-{wandb_logger.experiment.name}-best*")[0]
l_model = LKeypointsTransformer.load_from_checkpoint(CHKP, model=model, num_classes=tokenizer.vocab_size)

trainer.test(
    model=l_model,
	  dataloaders=test_loader,
    ckpt_path=CHKP,
)

In [None]:
translation_results_df = pd.read_csv(f"translation-results-{wandb_logger.experiment.name}.csv")
translation_results_df.head()