# Train Embeddings Using a Transformer

> Module for training on a dataset of embeddings

In [None]:
from clip_video_classifier.cli import cli
from functools import lru_cache
from torch_snippets import *
from torch.nn.utils.rnn import pad_sequence

In [None]:
class FeaturesDataset(Dataset):
    labels = ["ArmFlapping", "HeadBanging", "Spinning", "others"]
    label2id = {l: ix for ix, l in enumerate(labels)}
    id2label = {ix: l for l, ix in label2id.items()}

    def __init__(
        self,
        features_dir: str,
        annotations: str,
        average_features: bool = False,
        frames_dir: str = None,
        binary_mode: bool = False,
    ):
        self.average_features = average_features
        self.features_dir = P(features_dir)
        if isinstance(annotations, (str, P)):
            self.annotations = pd.read_csv(annotations)
        else:
            self.annotations = annotations
        available_features = [
            int(stem(f).split(".")[0]) for f in self.features_dir.ls()
        ]
        available_annotations = self.annotations.index.tolist()
        self.annotations = self.annotations.loc[
            list(common(available_annotations, available_features))
        ]
        self.frames_dir = frames_dir
        self.binary_mode = binary_mode
        Info(f"created dataset of {len(self)} items")

    def __len__(self):
        return len(self.annotations)

    def preprocess(self, features):
        return features

    # @lru_cache
    def __getitem__(self, index):
        row = self.annotations.iloc[index]
        features = loaddill(
            self.features_dir / f"{row.name}.frames.features.tensor"
        ).cpu()
        features = self.preprocess(features)
        label = row["label"]
        if self.binary_mode:
            label = int(label != "others")
        else:
            label = self.label2id[label]
        if 0:
            frames = loaddill(self.frames_dir / f"{row.name}.frames.tensor")
        return {
            "features": features.cpu().detach(),
            "targets": label,
        }

In [None]:
from sklearn.model_selection import train_test_split
import collections

root = P("/mnt/347832F37832B388/ml-datasets/ssbd/")

features_dim = 512
features_dir = root / "ssbd-frames-features/10fps/vgg19"
model_output_dir = "./transformers_model_trained_vgg19"

annotations_path = root / "annotations.csv"
annotations = pd.read_csv(annotations_path)
BINARY_MODE = False

ds = FeaturesDataset(features_dir, annotations)
ds[0]

In [None]:
annotations = annotations.query('label != "others"')
trn_items, val_items = train_test_split(
    annotations.video.unique(), test_size=0.15, random_state=11
)

# trn_df = annotations.loc[
#     annotations.query("video in @trn_items").groupby("video")["start"].idxmin()
# ]
trn_df = annotations.query("video in @trn_items")
val_df = annotations.query("video in @val_items")

ds = FeaturesDataset(features_dir, annotations, binary_mode=BINARY_MODE)
trn_ds = FeaturesDataset(features_dir, trn_df, binary_mode=BINARY_MODE)
val_ds = FeaturesDataset(features_dir, val_df, binary_mode=BINARY_MODE)
print(
    "train",
    collections.Counter([i["targets"] for i in track2(trn_ds)]),
    "validation",
    collections.Counter([i["targets"] for i in track2(val_ds)]),
)

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, hidden_size: int, dropout=0.1, max_len=512):
        super(PositionalEncoding, self).__init__()

        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, hidden_size)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.pow(
            1e4, -torch.arange(0, hidden_size, 2).float() / hidden_size
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:, : x.size(1)]
        return self.dropout(x)


class TransformerEncoderForTimmFeatures(nn.Module):
    def __init__(
        self,
        transformer_layers: int,
        emb_size: int,
        max_len: int,
        features_dim: int,
        num_classes: int = 4,
        d_model: int = 512,
        n_head: int = 8,
    ):
        super().__init__()
        self.positional_encoding = PositionalEncoding(
            hidden_size=emb_size, max_len=max_len
        )
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_head)
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer=encoder_layer, num_layers=transformer_layers
        )
        self.lin = nn.Sequential(
            nn.Linear(features_dim, emb_size),
            nn.ReLU(inplace=True),
        )
        self.linear = nn.Linear(d_model, num_classes)
        self.main_input_name = "features"
        self.loss_fn = nn.CrossEntropyLoss()
        self.dropout = 0.25

    def forward(self, features, attention_mask, targets=None):
        Debug(f"{features=}")
        features = self.lin(features)
        embeddings = self.positional_encoding(features)
        # embeddings = F.dropout(embeddings, self.dropout)
        transformer_output = self.transformer_encoder(
            embeddings.swapaxes(1, 0), src_key_padding_mask=attention_mask.bool()
        ).swapaxes(1, 0)
        # transformer_output = F.dropout(transformer_output, self.dropout)
        pooled_output = transformer_output.mean(dim=1)
        logits = self.linear(pooled_output)
        if targets is not None:
            loss = self.loss_fn(logits, targets)
        else:
            loss = -1
        return {"loss": loss, "logits": logits}


def collate_fn(batch):
    seq_len = 128
    embeddings = [item["features"] for item in batch]
    Debug(f"{embeddings=}")
    starts = [randint(len(i) - seq_len) if len(i) > seq_len else 0 for i in embeddings]
    Debug(f"{starts=}")
    embeddings = [
        e if len(e) < seq_len else e[starts[ix] : starts[ix] + seq_len]
        for ix, e in enumerate(embeddings)
    ]
    Debug(f"{embeddings=}")
    batched_embeddings = pad_sequence(embeddings, batch_first=True).float()
    Debug(f"{batched_embeddings=}")
    batch_size, seq_len, *_ = batched_embeddings.shape
    attention_mask = torch.zeros((batch_size, seq_len), dtype=torch.float)
    Debug(f"{attention_mask=}")
    for i, seq in enumerate(embeddings):
        attention_mask[i, : len(embeddings)] = 1
    Debug(f"{attention_mask=}")
    if "targets" in batch[0]:
        Debug(f"has targets")
        labels = [item["targets"] for item in batch]
        Debug(f"{labels=}")
        return {
            "features": batched_embeddings[..., ::8],
            "attention_mask": attention_mask,
            "targets": torch.Tensor(labels).long(),
        }
    else:
        Debug(f"not has targets")
        return {
            "embeddings": batched_embeddings[..., ::8],
            "attention_mask": attention_mask,
        }

```python
dl = DataLoader(trn_ds, batch_size=3, shuffle=True, collate_fn=collate_fn)
model = TransformerEncoderForTimmFeatures(4, features_dim, 128, d_model=features_dim)
with debug_mode():
    b = next(iter(dl))
    o = model(**b)
print(o)
```

In [None]:
# |eval: false

# import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from transformers import Trainer, TrainingArguments
from torch_snippets.charts import CM

model = TransformerEncoderForTimmFeatures(
    4, features_dim, 128, features_dim, d_model=features_dim
)
reset_logger()


def compute_metrics(input):
    predictions = input.predictions
    targets = input.label_ids
    pred = predictions.argmax(1)
    id2label = trn_ds.id2label if not BINARY_MODE else ["others", "action"]
    pred = np.array([id2label[p] for p in pred])
    targets = np.array([id2label[t] for t in targets])
    show(CM(pred=pred, truth=targets))
    return {"accuracy": (targets == pred).mean()}


training_args = TrainingArguments(
    output_dir=model_output_dir,
    evaluation_strategy="steps",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    max_steps=512,
    logging_steps=64,
    save_steps=64,
    save_total_limit=2,
    label_names=["targets"],
    include_inputs_for_metrics=True,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    eval_accumulation_steps=1,
)
training_args.learning_rate_scheduler = "cosine_with_restarts"
training_args.learning_rate_scheduler_kwargs = {
    "num_warmup_steps": 0,  # Adjust this as needed
    "num_cycles": 16,  # Adjust this as needed
    # "learning_rate": 1e-4,
}

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=trn_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics,
)

trainer.train()

In [None]:
_ = trainer.predict(val_ds)

In [None]:
_ = trainer.predict(trn_ds)

In [None]:
backup_this_notebook("07.01_transformer_timm_vgg19.ipynb")