In [1]:
import os

import av
import cv2
import numpy as np
import pandas as pd
from pathlib import Path

from tqdm.notebook import tqdm

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

import albumentations as A

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from transformers import AutoProcessor, AutoModel

In [2]:
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

In [3]:
def apply_video_augmentations(video, transform):
    targets={'image': video[0]}
    for i in range(1, video.shape[0]):
        targets[f'image{i}'] = video[i]
    transformed = transform(**targets)
    transformed = np.concatenate(
        [np.expand_dims(transformed['image'], axis=0)] 
        + [np.expand_dims(transformed[f'image{i}'], axis=0) for i in range(1, video.shape[0])]
    )
    return transformed

In [4]:
def read_video_pyav(container, indices):
    '''
    Decode the video with PyAV decoder.
    Args:
        container (`av.container.input.InputContainer`): PyAV container.
        indices (`List[int]`): List of frame indices to decode.
    Returns:
        result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
    '''
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])


def sample_frame_indices(clip_len, seg_len):
    start_idx, end_idx = 0, seg_len
    indices = np.linspace(start_idx, end_idx, num=clip_len)
    indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
    return indices

In [5]:
batch_size = 16
root_dir = '../data/sibur_data/'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

# Dataset preparation

In [6]:
id2label = {0: "bridge_down", 1: "bridge_up", 2: "no_action", 3: "train_in_out"}
label2id = {l:i for i, l in id2label.items()}
labels = list(id2label.values())

In [7]:
video_paths = list(Path(root_dir).rglob("*.mp4"))
targets = [vp.parent.name for vp in video_paths]
train = pd.DataFrame({
    "video_path": [v.as_posix() for v in video_paths],
    "label": targets,
})

In [8]:
train.label.value_counts()

bridge_down     306
bridge_up        75
train_in_out     66
no_action        49
Name: label, dtype: int64

In [9]:
train['label_id'] = train.label.map(label2id)

In [10]:
X_train, X_val, _, _ = train_test_split(train, train['label'], test_size=0.2)

In [11]:
transform = A.Compose([
    A.ShiftScaleRotate(
        shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5
    ),
    A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
    A.RandomBrightnessContrast(p=0.5),
], additional_targets={
    f'image{i}': 'image'
    for i in range(1, 8)
})

In [12]:
class ActionDataset(Dataset):

    def __init__(self, meta, transform=None):
        self.meta = meta
        self.transform = transform

    def __len__(self):
        return len(self.meta)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()


        file_path = self.meta['video_path'].iloc[idx]
        container = av.open(file_path)
        indices = sample_frame_indices(clip_len=8, seg_len=container.streams.video[0].frames)
            
        video = read_video_pyav(container, indices)
        while video.shape[0] < 8:
            video = np.vstack([video, video[-1:]])

        if self.transform:
            video = apply_video_augmentations(video, self.transform)
            

        inputs = processor(
            text=[''],
            videos=list(video),
            return_tensors="pt",
            padding=True,
        )
        for i in inputs:
            inputs[i] = inputs[i][0]

        return inputs["pixel_values"], self.meta['label_id'].iloc[idx]

In [13]:
train_dataset = ActionDataset(meta=X_train, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

val_dataset = ActionDataset(meta=X_val)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# Load model

In [20]:
processor = AutoProcessor.from_pretrained("microsoft/xclip-base-patch32")
xclip = AutoModel.from_pretrained("microsoft/xclip-base-patch32")
display(xclip)
# model.to(device)
# classifier = nn.Linear(512, len(labels))
# classifier.to(device)

class Net(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        xclip = AutoModel.from_pretrained("microsoft/xclip-base-patch32")
        self.xclip_vision_model = xclip.vision_model
        self.xclip_projector = xclip.visual_projection
        self.classifier = torch.nn.Linear(512, len(labels))
    
    def forward(self, x):
        x = self.xclip_vision_model(x)
        x = self.xclip_projector(x)
        x = self.classifier(x)
        return x

model = Net()

XCLIPModel(
  (text_model): XCLIPTextTransformer(
    (embeddings): XCLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): XCLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x XCLIPEncoderLayer(
          (self_attn): XCLIPAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): XCLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps

In [15]:
# (8, 240, 320, 3) --> (8, 3, 224, 224)

# Full XClip training

In [17]:
for param in model.parameters():
    param.requires_grad = True
# for param in model.text_model.parameters():
#     param.requires_grad = False

In [18]:
epochs = 7
model_lr = 1e-5
classifier_lr = 1e-3

param_groups = [
    {
        "params": model.xclip_vision_model.parameters(),
        "lr": model_lr,
    },
    {
        "params": model.xclip_projector.parameters(),
        "lr": model_lr,
    },
    {
        "params": model.classifier.parameters(),
        "lr": classifier_lr,
    },
]

optimizer = optim.AdamW(param_groups)
criterion = nn.CrossEntropyLoss()

In [19]:
for epoch in range(epochs):

    model.train() 

    train_loss = []
    for i, (batch, targets) in enumerate(tqdm(train_dataloader, desc=f"Epoch: {epoch}")):
        optimizer.zero_grad()

        batch = batch.to(device)
        targets = targets.to(device)

        logits = model(batch)

        loss = criterion(logits, targets) 
        loss.backward()
        optimizer.step()

        train_loss.append(loss.item())


    model.eval()

    val_loss = []
    val_targets = []
    val_preds = []
    for i, (batch, targets) in enumerate(tqdm(val_dataloader, desc=f"Epoch: {epoch}")):
        with torch.no_grad():

            batch = batch.to(device)
            targets = targets.to(device)

            outputs = model(batch)

            loss = criterion(logits, targets) 

            val_loss.append(loss.item())
            val_targets.extend(targets.cpu().numpy())
            val_preds.extend(logits.argmax(axis=1).cpu().numpy())           

    print('Training loss:', np.mean(train_loss))
    print('Val loss:', np.mean(val_loss))
    print('F1:', f1_score(val_targets, val_preds, average='macro'))

Epoch: 0:   0%|          | 0/25 [00:00<?, ?it/s]

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [16, 8, 3, 224, 224]

In [18]:
model.save_pretrained("xclip_classifier")
torch.save(classifier.state_dict(), "xclip_classifier/classifier.pth")