In [None]:
import numpy as np
import pandas as pd
import torch, os
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

from torch.utils.data import Dataset, DataLoader
from segmentation_models_pytorch.losses import FocalLoss
from transformers import AutoModel, AutoImageProcessor, AutoConfig
from skmultilearn.model_selection import iterative_train_test_split
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorchvideo.transforms.transforms_factory import create_video_transform

from crash_modules.crash_dataset import VideoDataset
from crash_modules.models.CrashEgo import CrashEgo
from crash_modules.models.Weather import Weather
from crash_modules.models.Timing import Timing

pl.seed_everything(42)

In [None]:
config = {
    "seed":2023,
    "model_name":"facebook/timesformer-base-finetuned-k600",
    "batch_size":2,
    "learning_rate":1e-5,
    "data_dir":'',
    "checkpoint_dir":'./checkpoint',
    "submission_dir":'./submission',
    "n_classes":1}    

crash_ego_config = {
    "seed":2023,
    "model_name":"facebook/timesformer-base-finetuned-k600",
    "batch_size":2,
    "learning_rate":1e-5,
    "data_dir":'',
    "checkpoint_dir":'./checkpoint',
    "submission_dir":'./submission',
    "n_classes":3}    

weather_config = {
    "seed":2023,
    "model_name":"facebook/timesformer-base-finetuned-k600",
    "batch_size":2,
    "learning_rate":1e-5,
    "data_dir":'',
    "checkpoint_dir":'./checkpoint',
    "submission_dir":'./submission',
    "n_classes":3}    

timing_config = {
    "seed":2023,
    "model_name":"facebook/timesformer-base-finetuned-k600",
    "batch_size":2,
    "learning_rate":1e-5,
    "data_dir":'',
    "checkpoint_dir":'./checkpoint',
    "submission_dir":'./submission',
    "n_classes":1}    


In [None]:
test_df = pd.read_csv(f"{config['data_dir']}/test.csv")
test_df['sample_id'] = test_df['sample_id'].apply(lambda x: int(x.split('_')[1]))
test_df['video_path'] = test_df['video_path'].apply(lambda x: config['data_dir'] + x[1:])

In [None]:
model_config = AutoConfig.from_pretrained('facebook/timesformer-base-finetuned-k600')
image_processor_config = AutoImageProcessor.from_pretrained('facebook/timesformer-base-finetuned-k600')

val_transform = create_video_transform(
    mode='val',
    num_samples=8,
    video_mean = tuple(image_processor_config.image_mean),
    video_std = tuple(image_processor_config.image_std),
    crop_size = tuple(image_processor_config.crop_size.values())
)

val_transform_frame16 = create_video_transform(
    mode='val',
    num_samples=16,
    video_mean = tuple(image_processor_config.image_mean),
    video_std = tuple(image_processor_config.image_std),
    crop_size = tuple(image_processor_config.crop_size.values())
)

In [None]:
test_df['label'] = -1
test_df['label_split'] = -1

test_dataset = VideoDataset(test_df.values, transform=val_transform)
test_dataloader = DataLoader(test_dataset, batch_size = config['batch_size']*2, num_workers=8, pin_memory=True)

test_dataset_frame16 = VideoDataset(test_df.values, transform=val_transform_frame16)
test_dataloader_frame16 = DataLoader(test_dataset_frame16, batch_size = config['batch_size']*2, num_workers=8, pin_memory=True)

In [None]:
crash_ego_model = CrashEgo(crash_ego_config)
weather_model = Weather(weather_config)
timing_model = Timing(timing_config)

In [None]:
crash_ego_result = []
for i in os.listdir(f'checkpoint_crashego16/facebook/'): # model ckpt 폴더
    ckpt = f'checkpoint_crashego16/facebook/{i}'
    
    crash_ego_pretrained = crash_ego_model.load_from_checkpoint(
        ckpt,
        config = crash_ego_config
    )

    trainer = pl.Trainer(accelerator='auto')
    pred = trainer.predict(crash_ego_pretrained, test_dataloader_frame16)

    result = []    
    for step_out in pred:
        result += torch.sigmoid(step_out).tolist() # [Batch, N_sample]
    crash_ego_result.append(result)
crash_ego_result = np.array(crash_ego_result)    
crash_ego_result = crash_ego_result.mean(0).argmax(1)

In [None]:
weather_result = []
for i in os.listdir(f'checkpoint_weather/facebook/'):
    ckpt = f'checkpoint_weather/facebook/{i}'
    weather_pretrained = weather_model.load_from_checkpoint(
        ckpt,
        config = weather_config
    )

    trainer = pl.Trainer(accelerator='auto')
    pred = trainer.predict(weather_pretrained, test_dataloader)

    result = []    
    for step_out in pred:
        result += torch.sigmoid(step_out).tolist() # [Batch, N_sample]b
    weather_result.append(result)
weather_result = np.array(weather_result)    

weather_preds = []
for normal, snow, rain in weather_result.mean(0):
    if snow >= 0.5:
        weather_preds.append(1)
    elif rain > normal:
        weather_preds.append(2)
    else:
        weather_preds.append(0)

In [None]:
timing_result = []
for i in os.listdir(f'checkpoint_timing/facebook/'):
    ckpt = f'checkpoint_timing/facebook/{i}'
    timing_pretrained = timing_model.load_from_checkpoint(
        ckpt,
        config = timing_config
    )

    trainer = pl.Trainer(accelerator='auto')
    pred = trainer.predict(timing_pretrained, test_dataloader)

    result = []
    for step_out in pred:        
        result += torch.sigmoid(step_out).tolist()
    timing_result.append(result)
timing_result = np.array(timing_result)
timing_result = np.where(timing_result.mean(0) > 0.5, 1, 0)

In [None]:
label_reverse_dict = {
    (0,0,0):1,
    (0,0,1):2,
    (0,1,0):3,
    (0,1,1):4,
    (0,2,0):5,
    (0,2,1):6,
    (1,0,0):7,
    (1,0,1):8,
    (1,1,0):9,
    (1,1,1):10,
    (1,2,0):11,
    (1,2,1):12,

    (2,0,0):0,
    (2,1,0):0,
    (2,2,0):0,

    (2,0,1):0,
    (2,1,1):0,
    (2,2,1):0
    }

In [None]:
submit = pd.read_csv(f"{config['data_dir']}/sample_submission.csv")
submit['label'] = [label_for_submit(label_reverse_dict[(i,j,k)]) for i,j,k in zip(crash_ego_result, weather_preds, timing_result)]
submit.to_csv('sub.csv', index=None)