# Train Embeddings Using a Transformer

> Module for training on a dataset of embeddings

In [None]:
# | default_exp models.transformer_encoder

In [None]:
%reload_ext nb_black
%reload_ext autoreload
%autoreload 2

from nbdev.showdoc import *
import sys

__root = "../../"
sys.path.append(__root)

In [None]:
# | export
from clip_video_classifier.cli import cli
from clip_video_classifier.data.dataset import ClipEmbeddingsDataset
from torch_snippets import *
from torch.nn.utils.rnn import pad_sequence

In [None]:
# | export


class PositionalEncoding(nn.Module):
    def __init__(self, hidden_size: int, dropout=0.1, max_len=512):
        super(PositionalEncoding, self).__init__()

        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, hidden_size)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.pow(
            1e4, -torch.arange(0, hidden_size, 2).float() / hidden_size
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:, : x.size(1)]
        return self.dropout(x)


class TransformerEncoder(nn.Module):
    def __init__(
        self,
        transformer_layers: int,
        emb_size: int,
        max_len: int,
        num_classes: int = 4,
        d_model: int = 512,
        n_head: int = 8,
    ):
        super().__init__()
        self.positional_encoding = PositionalEncoding(
            hidden_size=emb_size, max_len=max_len
        )
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_head)
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer=encoder_layer, num_layers=transformer_layers
        )
        self.linear = nn.Linear(d_model, num_classes)
        self.main_input_name = "embeddings"
        self.loss_fn = nn.CrossEntropyLoss()
        self.dropout = 0.25

    def forward(self, embeddings, attention_mask, targets=None):
        embeddings = self.positional_encoding(embeddings)
        # embeddings = F.dropout(embeddings, self.dropout)
        transformer_output = self.transformer_encoder(
            embeddings.swapaxes(1, 0), src_key_padding_mask=attention_mask.bool()
        ).swapaxes(1, 0)
        # transformer_output = F.dropout(transformer_output, self.dropout)
        pooled_output = transformer_output.mean(dim=1)
        logits = self.linear(pooled_output)
        if targets is not None:
            loss = self.loss_fn(logits, targets)
        else:
            loss = -1
        return {"loss": loss, "logits": logits}


def collate_fn(batch):
    seq_len = 128
    embeddings = [item["embeddings"] for item in batch]
    Debug(f"{embeddings=}")
    starts = [randint(len(i) - seq_len) if len(i) > seq_len else 0 for i in embeddings]
    Debug(f"{starts=}")
    embeddings = [
        e if len(e) < seq_len else e[starts[ix] : starts[ix] + seq_len]
        for ix, e in enumerate(embeddings)
    ]
    Debug(f"{embeddings=}")
    batched_embeddings = pad_sequence(embeddings, batch_first=True).float()
    Debug(f"{batched_embeddings=}")
    batch_size, seq_len, _ = batched_embeddings.shape
    attention_mask = torch.zeros((batch_size, seq_len), dtype=torch.float)
    Debug(f"{attention_mask=}")
    for i, seq in enumerate(embeddings):
        attention_mask[i, : len(embeddings)] = 1
    Debug(f"{attention_mask=}")
    if "targets" in batch[0]:
        Debug(f"has targets")
        labels = [item["targets"] for item in batch]
        Debug("{labels=}")
        return {
            "embeddings": batched_embeddings,
            "attention_mask": attention_mask,
            "targets": torch.Tensor(labels).long(),
        }
    else:
        Debug(f"not has targets")
        return {
            "embeddings": batched_embeddings,
            "attention_mask": attention_mask,
        }

In [None]:
# |eval: false

from sklearn.model_selection import train_test_split
import collections

root = P("/mnt/347832F37832B388/ml-datasets/ssbd/")
annotations_path = root / "annotations.csv"
embeddings_folder = root / "ssbd-embeddings/10fps"
frames_folder = root / "ssbd-frames/10fps"
annotations = pd.read_csv(annotations_path)
BINARY_MODE = False
annotations = annotations.query('label != "others"')
trn_items, val_items = train_test_split(
    annotations.video.unique(), test_size=0.35, random_state=11
)

trn_df = annotations.loc[
    annotations.query("video in @trn_items").groupby("video")["start"].idxmin()
]
# trn_df = annotations.query("video in @trn_items")
val_df = annotations.query("video in @val_items")

ds = ClipEmbeddingsDataset(embeddings_folder, annotations, binary_mode=BINARY_MODE)
trn_ds = ClipEmbeddingsDataset(embeddings_folder, trn_df, binary_mode=BINARY_MODE)
val_ds = ClipEmbeddingsDataset(embeddings_folder, val_df, binary_mode=BINARY_MODE)
print(
    "train",
    collections.Counter([i["label"] for i in track2(trn_ds)]),
    "validation",
    collections.Counter([i["label"] for i in track2(val_ds)]),
)

```python
dl = DataLoader(trn_ds, batch_size=3, shuffle=True, collate_fn=collate_fn)
b = next(iter(dl))
model = TransformerEncoder(4, 512, 128)
model(**b)
```

In [None]:
# |eval: false

# import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from transformers import Trainer, TrainingArguments
from torch_snippets.charts import CM

model = TransformerEncoder(4, 512, 128).cuda()


def compute_metrics(input):
    predictions = input.predictions
    targets = input.label_ids
    pred = predictions.argmax(1)
    id2label = trn_ds.id2label if not BINARY_MODE else ["others", "action"]
    pred = np.array([id2label[p] for p in pred])
    targets = np.array([id2label[t] for t in targets])
    show(CM(pred=pred, truth=targets))
    return {"accuracy": (targets == pred).mean()}


training_args = TrainingArguments(
    output_dir="./transformers_model_trained",
    evaluation_strategy="steps",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    max_steps=2000,
    logging_steps=200,
    save_steps=200,
    save_total_limit=2,
    label_names=["targets"],
    include_inputs_for_metrics=True,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
)
training_args.learning_rate_scheduler = "cosine_with_restarts"
training_args.learning_rate_scheduler_kwargs = {
    "num_warmup_steps": 0,  # Adjust this as needed
    "num_cycles": 30,  # Adjust this as needed
    # "learning_rate": 1e-4,
}

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=trn_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics,
)

# trainer.train()
# trainer.predict(val_ds)

In [None]:
# |eval: false
BINARY_MODE = False

model = TransformerEncoder(4, 512, 128).cuda()
load_torch_model_weights_to(model, "saved-models/a/pytorch_model.bin")
model.eval()

training_args = TrainingArguments(
    output_dir="./linear_model_trained",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    label_names=["targets"],
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
)

predictions = trainer.evaluate(val_ds)
predictions = trainer.evaluate(ds)

In [None]:
import nbdev

nbdev.nbdev_export()