# Train Embeddings Using a linear Probe

> Module for training on a dataset of embeddings

In [None]:
%reload_ext nb_black
%reload_ext autoreload
%autoreload 2

from nbdev.showdoc import *
import sys

__root = "../../"
sys.path.append(__root)

In [None]:
import torch
import torch.multiprocessing as mp

mp.set_start_method("spawn")

In [None]:
from clip_video_classifier.cli import cli
from clip_video_classifier.data.dataset import ClipEmbeddingsDataset
from torch_snippets import *

In [None]:
class LinearModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(512, 64), nn.ReLU(inplace=True), nn.Linear(64, 4)
        )
        self.loss_fn = nn.CrossEntropyLoss()
        self.main_input_name = "input"

    def forward(self, input, targets):
        logits = self.model(input)
        loss = self.loss_fn(logits, targets)
        return {"loss": loss, "logits": logits}

In [None]:
from sklearn.model_selection import train_test_split
import collections

root = P("/mnt/347832F37832B388/ml-datasets/ssbd/")
annotations_path = root / "annotations.csv"
embeddings_folder = root / "ssbd-embeddings/10fps"
frames_folder = root / "ssbd-frames/10fps"
annotations = pd.read_csv(annotations_path)
annotations = annotations.query('label != "others"')
trn_items, val_items = train_test_split(
    annotations.video.unique(), test_size=0.15, random_state=11
)
trn_df, val_df = (
    annotations.loc[
        annotations.query("video in @trn_items").groupby("video")["start"].idxmin()
    ],
    annotations.query("video in @val_items"),
)

trn_ds = ClipEmbeddingsDataset(embeddings_folder, trn_df, frames_dir=frames_folder)
val_ds = ClipEmbeddingsDataset(embeddings_folder, val_df, frames_dir=frames_folder)
print(
    "train",
    collections.Counter([i["label"] for i in track2(trn_ds)]),
    "validation",
    collections.Counter([i["label"] for i in track2(val_ds)]),
)


def collate_fn(batch):
    output = {}
    output["input"] = torch.stack([i["input"] for i in batch]).to(torch.float)
    output["targets"] = torch.tensor([i["targets"] for i in batch])
    return output

```python
dl = DataLoader(trn_ds, shuffle=True, batch_size=3, collate_fn=collate_fn)
model = LinearModel()
i = next(iter(dl))
model(**i)
```

In [None]:
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from transformers import Trainer, TrainingArguments
from torch_snippets.charts import CM

model = LinearModel().cuda()


def compute_metrics(input):
    predictions = input.predictions
    targets = input.label_ids
    pred = predictions.argmax(1)
    pred = np.array([trn_ds.id2label[p] for p in pred])
    targets = np.array([trn_ds.id2label[t] for t in targets])
    show(CM(pred=pred, truth=targets))
    return {"accuracy": (targets == pred).mean()}


training_args = TrainingArguments(
    output_dir="./linear_model_trained",
    evaluation_strategy="steps",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    num_train_epochs=1500,
    logging_steps=1500,
    save_steps=200,
    save_total_limit=2,
    label_names=["targets"],
    include_inputs_for_metrics=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=trn_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics,
)

trainer.train()

In [None]:
trainer.predict(trn_ds)

In [None]:
trainer.predict(val_ds)

In [None]:
import nbdev

nbdev.nbdev_export()