In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

from einops import rearrange
from decord import VideoReader
from sklearn.metrics import f1_score
from torch.utils.data import Dataset, DataLoader
from segmentation_models_pytorch.losses import FocalLoss
from transformers import AutoModel, AutoImageProcessor, AutoConfig
from skmultilearn.model_selection import iterative_train_test_split
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorchvideo.transforms.transforms_factory import create_video_transform

from crash_modules.crash_dataset import VideoDataset

In [None]:
config = {
    "seed":2023,
    "model_name":"facebook/timesformer-base-finetuned-k600",
    "batch_size":3,
    "learning_rate":1e-5,
    "data_dir":'',
    "checkpoint_dir":'./checkpoint_crashego16',
    "submission_dir":'./submission',
    "n_classes": 3,
    "label_dict":{
         1:0,
         2:0,
         3:0,
         4:0,
         5:0,
         6:0,
         7:1,
         8:1,
         9:1,
        10:1,
        11:1,
        12:1,

        13:2,
        14:2,
        15:2,
        16:2,
        17:2,
        18:2,

    },
    "label_reverse_dict":{
        0:0,
        1:1,
        2:2
    }
}

In [None]:
pl.seed_everything(config['seed'])

In [None]:
train_df = pd.read_csv(f"{config['data_dir']}/train_ver2.csv")
train_df = train_df[['sample_id', 'video_path', 'label']]

In [None]:
train_df['sample_id'] = train_df['sample_id'].apply(lambda x: int(x.split('_')[1]))
train_df['video_path'] = train_df['video_path'].apply(lambda x: config['data_dir'] + x[1:])

In [None]:
train_df['label_split'] = train_df['label'].apply(config['label_dict'].get)
train_df['label'] = train_df['label_split']

In [None]:
model_config = AutoConfig.from_pretrained('facebook/timesformer-base-finetuned-k600')
image_processor_config = AutoImageProcessor.from_pretrained('facebook/timesformer-base-finetuned-k600')

In [None]:
train_transform = create_video_transform(
    mode='train',
    num_samples=16,
    video_mean = tuple(image_processor_config.image_mean),
    video_std = tuple(image_processor_config.image_std),
    crop_size = tuple(image_processor_config.crop_size.values())
)

val_transform = create_video_transform(
    mode='val',
    num_samples=16,
    video_mean = tuple(image_processor_config.image_mean),
    video_std = tuple(image_processor_config.image_std),
    crop_size = tuple(image_processor_config.crop_size.values())
)

In [None]:
class PLVideoModel(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.learning_rate = config['learning_rate']
        self.model = AutoModel.from_pretrained('facebook/timesformer-base-finetuned-k600')
        self.classifier = nn.LazyLinear(3)
        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.model(x).last_hidden_state.mean(dim=1)
        x_out = self.classifier(x)
        return x_out

    def training_step(self, batch, batch_idx, optimizer_idx):
        video, label, label_split = batch['video'], batch['label'], batch['label_split']
        y_hats = self.forward(batch["video"])
        loss = self.loss(y_hats, batch["label"])
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        video, label, label_split = batch['video'], batch['label'], batch['label_split']
        y_hats = self.forward(batch["video"])

        with torch.no_grad():
            loss = self.loss(y_hats, batch["label"])

        self.log("valid_loss", loss)

        step_output = [y_hats, label]
        return step_output
    
    
    def predict_step(self, batch, batch_idx):
        video, _, _ = batch['video'], batch['label'], batch['label_split']
        y_hats = self.forward(batch["video"])
        step_output = y_hats
        return step_output

    def validation_epoch_end(self, step_outputs):
        preds = []
        labels = []

        for step_output in step_outputs:
            pred, label = step_output
            preds += pred.argmax(1).detach().cpu().tolist()
            labels += label.tolist()            

        score = f1_score(labels, preds, average='macro')
        self.log("val_score", score)
        return score
    
    def post_preproc(self, step_outputs):
        preds = []
        for step_output in step_outputs:
            pred = step_output[0]
            preds += pred.argmax(1).detach().cpu().tolist()            
        
        return preds
            
    def configure_optimizers(self):
        opt1 = torch.optim.Adam(self.parameters(), lr=self.learning_rate)        
        opt2 = torch.optim.SGD(self.parameters(), lr=self.learning_rate)
        optimizers = [opt1, opt2]

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt1, mode='max', factor=0.5, patience=2,threshold_mode='abs',min_lr=1e-8, verbose=True)        
        lr_schedulers = {"scheduler": scheduler, "monitor": "valid_loss"}
        
        return optimizers, lr_schedulers

In [None]:
from sklearn.model_selection import StratifiedKFold
kf = StratifiedKFold(n_splits = 5, shuffle=True, random_state=2023)
k = 1
for t_idx, v_idx in kf.split(X = train_df['sample_id'], y = train_df['label'] ):
    train_df_for_dataset = train_df.loc[t_idx].reset_index(drop=True).values
    val_df_for_dataset = train_df.loc[v_idx].reset_index(drop=True).values

    train_dataset = VideoDataset(train_df_for_dataset, transform=train_transform, mode='train')
    val_dataset = VideoDataset(val_df_for_dataset, transform=val_transform, mode = 'valid')

    train_dataloader = DataLoader(train_dataset, batch_size= config['batch_size'],  num_workers= 8, pin_memory= True)
    val_dataloader = DataLoader(val_dataset, batch_size = config['batch_size']*2, num_workers= 8, pin_memory= True)

    checkpoint_callback = ModelCheckpoint(
        monitor='val_score',
        dirpath=config['checkpoint_dir'],
        filename=f'{config["model_name"]}'+'-{epoch:02d}-{train_loss:.4f}-{valid_loss:.4f}-{val_score:.4f}',
        mode='max'
    )

    early_stop_callback = EarlyStopping(
        monitor="val_score",
        patience=3,
        verbose=False,
        mode="max")

    pl_video_model = PLVideoModel(config)

    trainer = pl.Trainer(
        max_epochs=100,
        accelerator='auto', 
        precision=16,
        callbacks=[early_stop_callback, checkpoint_callback],                                
    )
    trainer.fit(pl_video_model, train_dataloader, val_dataloader)
    k+=1