In [1]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [2]:
def seed_everything(seed: int):
    import random, os, numpy as np
    import torch
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [3]:
seed_everything(1337) # leet :)

In [4]:
import copy
import csv
from enum import Enum
import io
import json
import os
import typing as t

# import cv2
from IPython.display import clear_output
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm
from sklearn.metrics import (
    accuracy_score,
    roc_auc_score,
    roc_curve,
)
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.utils.data as td
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms as transforms

In [5]:
import timm
import pytorch_lightning as pl
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from torchmetrics.functional import auroc

In [6]:
from utils import (
    show_photos, 
    create_dataloader,
    train_epoch,
    test_epoch,
    plot_history,
    print_model_params_required_grad,
    PUBLIC_DATA_FOLDER_PATH,
    PUBLIC_DATA_DESCRIPTION_PATH,
)

In [7]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

# Description

Yandex GO is one of the top three ride-hailing services in the world. Our app facilitates over 4 billion trips per year across 32 countries. We are committed to the quality of our services, ensuring thorough checks of both drivers and their vehicles before they go online, based on dozens of criteria. Part of the vehicle inspection process is carried out remotely using photos of the vehicle, which allows us to either block or grant the driver access to orders. This tool ensures that cars do not go online if they are damaged or dirty.

Computer vision algorithms play a significant role in this remote quality control process. Machine learning models act as a filter that processes vehicle inspection requests, automatically approving a portion of requests that, according to the models, contain no violations, and sending suspicious cases for additional manual review.

### How does the photo inspection process work?
As part of vehicle photo inspections, drivers periodically receive a task to take photos of their car, so it can be checked for damage, compliance with service standards, branding presence, etc. Before these checks, we also need to ensure that drivers took the photos honestly and sent what we expected. The driver is required to take 4 photos (front, rear, left side, right side). The photos are taken through the Yandex PRO app, which has an interface that guides them to capture the 4 photos in the correct order and from the required angles.

In the standard process, the photos are first reviewed by ML pipeline. If ML pipeline doesn't find anything suspicious in the photos, the inspection is automatically approved. If the pipeline flags at least one photo, the inspection is sent to an assessor for a final decision. Thus, the object for decision-making is the inspection itself, i.e., all 4 photos together.

In this task, the license plate numbers have been blacked out.

In [None]:
pass_id = '000f43a6549ad26d'
photos = []
for side in ['front', 'back', 'left', 'right']:
    with open(f'{PUBLIC_DATA_FOLDER_PATH}/{pass_id}_{side}', 'rb') as file:
        photos.append(file.read())
show_photos(photos)

### Data description: 
- **filename** —  name of the photo file, consisting of `pass_id` and `plan_side`.
- **pass_id** — ID of the inspection. Each inspection contains 4 photos.
- **plan_side** — the side of the vehicle that should be in the photo. Possible values: front, back, left, right.
- **fact_side** — the side of the vehicle as determined by assessors. Possible values: front, back, left, right, unknown.
- **fraud_verdict** — the assessor's verdict on what is depicted in the photo. Possible values:
   - ALL_GOOD —  the photo clearly shows one side of the vehicle, which is fully visible and in focus.
   - LACK_OF_PHOTOS — the photo does not contain a vehicle at all.
   - BLURRY_PHOTO — the photo is blurry.
   - SCREEN_PHOTO — not a real vehicle photo, but a photo of a screen.
   - DARK_PHOTO — the photo is too dark.
   - INCOMPLETE_CAPTURE — the vehicle is not fully visible in the photo.
- **fraud_probability** — the proportion of assessors who assigned the given fraud_verdict. If no verdict achieved a majority, a random one is chosen.
- **damage_verdict** — the assessor's verdict on the vehicle's condition. Possible values:
   - NO_DEFECT —  no visible damage.
   - DEFECT — the is some damage.
   - BAD_PHOTO — can't say anything about the damage, because of photo's quality.
- **damage_probability** — the proportion of assessors who assigned the given damage_verdict. If no verdict achieved a majority, a random one is chosen.

In [9]:
description = pd.read_csv(PUBLIC_DATA_DESCRIPTION_PATH, index_col='filename').sort_index()
description.head()

Unnamed: 0_level_0,pass_id,plan_side,fact_side,fraud_verdict,fraud_probability,damage_verdict,damage_probability
filename,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
00015b960a1c013e_back,00015b960a1c013e,back,back,DARK_PHOTO,1.0,BAD_PHOTO,0.8
00015b960a1c013e_front,00015b960a1c013e,front,front,DARK_PHOTO,0.666667,BAD_PHOTO,1.0
00015b960a1c013e_left,00015b960a1c013e,left,unknown,DARK_PHOTO,0.666667,BAD_PHOTO,0.6
00015b960a1c013e_right,00015b960a1c013e,right,unknown,DARK_PHOTO,0.666667,BAD_PHOTO,1.0
0001f673ef360c58_back,0001f673ef360c58,back,back,ALL_GOOD,0.666667,NO_DEFECT,1.0


In addition to fraud that can be identified by looking at an individual photo, there may be cases where each photo individually has a fraud_verdict of 'ALL_GOOD', but the driver took two photos of the same side of the vehicle and failed to capture another side:

In [10]:
description[description.pass_id == '001c07aa1e3edf7e']

Unnamed: 0_level_0,pass_id,plan_side,fact_side,fraud_verdict,fraud_probability,damage_verdict,damage_probability
filename,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
001c07aa1e3edf7e_back,001c07aa1e3edf7e,back,back,ALL_GOOD,1.0,NO_DEFECT,1.0
001c07aa1e3edf7e_front,001c07aa1e3edf7e,front,front,ALL_GOOD,1.0,NO_DEFECT,1.0
001c07aa1e3edf7e_left,001c07aa1e3edf7e,left,front,ALL_GOOD,1.0,NO_DEFECT,1.0
001c07aa1e3edf7e_right,001c07aa1e3edf7e,right,back,ALL_GOOD,1.0,NO_DEFECT,0.8


In [11]:
# !pip install timm

# Objective
To assess the quality of vehicles and photos using machine learning algorithms:  
1. For detecting fraud (incorrect photos, unclear images, or incorrect photo sets).  
2. For detecting vehicle damage.

## Performance Metric and Deliverables
There are 2 targets and 4 sides of a vehicle in each exam. But after all, we need to predict whether the inspection should be sent to a human for review to provide feedback to the driver, or if there are no defects and the inspection can be automatically approved. This means that the metric is calculated not for individual photos for each target, but for the inspection as a whole.

*Evaluation Metric:* ROC AUC (object — inspection)

*Required Deliverables*:
- Model Weights: The trained model's weights for reproducibility and further analysis.
- Executable Script: A script containing all necessary code to run the model, including data reading, preprocessing steps, model architecture, inference code.
   

# Example

Let's try to train a fraud detection model with a simplified target that does not account for cases where two photos in an inspection may capture the same side of the vehicle. For this, we will use a pretrained ResNet18 model and replace its classifier.

In [12]:
class CarSide(Enum):
    FRONT = 0
    BACK = 1
    LEFT = 2
    RIGHT = 3
    UNKNOWN = 4
    
class FraudResolution(Enum):
    ALL_GOOD = 0
    LACK_OF_PHOTOS = 1
    BLURRY_PHOTO = 2
    SCREEN_PHOTO = 3
    DARK_PHOTO = 4
    INCOMPLETE_CAPTURE = 5
    
class DamageResolution(Enum):
    NO_DEFECT = 0
    DEFECT = 1
    BAD_PHOTO = 2

## Mapping

In [13]:
fraud_verdict_mapping = {
    'ALL_GOOD': 0,
    'LACK_OF_PHOTOS': 1,
    'BLURRY_PHOTO': 2,
    'SCREEN_PHOTO': 3,
    'DARK_PHOTO': 4,
    'INCOMPLETE_CAPTURE': 5,
}

description['fraud_label'] = description['fraud_verdict'].map(fraud_verdict_mapping)

In [14]:
damage_verdict_mapping = {
    'NO_DEFECT': 0,
    'DEFECT': 1,
    'BAD_PHOTO': 2
}

description['damage_label'] = description['damage_verdict'].map(damage_verdict_mapping)

In [15]:
fraud_verdict_mapping = {
    'ALL_GOOD': 0,
    'LACK_OF_PHOTOS': 1,
    'BLURRY_PHOTO': 2,
    'SCREEN_PHOTO': 3,
    'DARK_PHOTO': 4,
    'INCOMPLETE_CAPTURE': 5,
}

description['fraud_label'] = description['fraud_verdict'].map(fraud_verdict_mapping)


In [16]:
fact_side_mapping = {
    'front': CarSide.FRONT.value,
    'back': CarSide.BACK.value,
    'left': CarSide.LEFT.value,
    'right': CarSide.RIGHT.value,
    'unknown': CarSide.UNKNOWN.value
}

description['fact_side_label'] = description['fact_side'].map(fact_side_mapping)

In [17]:
description['final_label'] = 1 - ((description['fraud_label'] == 0) & (description['damage_label'] == 0)).astype(int)

In [18]:
confidence_threshold = 0.0
description = description[
    (description['fraud_probability'] >= confidence_threshold) &
    (description['damage_probability'] >= confidence_threshold)
]

In [19]:
description.shape

(181805, 11)

## HPARAMS

In [22]:
# MODEL_ID = "vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k"
MODEL_ID = "efficientformer_l1.snap_dist_in1k"

data_config = timm.data.resolve_data_config(timm.get_pretrained_cfg(MODEL_ID).__dict__)
TRANSFORMS = timm.data.create_transform(**data_config, is_training=False)

In [23]:
def get_transforms(model_id, is_training):
    data_config = timm.data.resolve_data_config({}, model=timm.create_model(model_id, pretrained=True, num_classes=0))
    transforms_ = timm.data.create_transform(**data_config, is_training=is_training)
    return transforms_

In [24]:
TRANSFORMS

Compose(
    Resize(size=235, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(224, 224))
    MaybeToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)

In [25]:
TEST_SIZE = 0.1
BATCH_SIZE = 64
NUM_WORKERS = 6

## Dataset

In [26]:
import imgaug.augmenters as iaa
import random
import imageio

class MultiTargetDatasetWithAugs(Dataset):
    def __init__(self, img_dir, description, transform=None):
        self.img_dir = img_dir
        self.description = description.reset_index()
        self.transform = transform

        self.flip_augmenter = iaa.Fliplr(1.0)
        self.noise_augmenters = iaa.OneOf([
            iaa.AdditiveGaussianNoise(scale=(1, 15)),  
            iaa.AdditiveLaplaceNoise(scale=(5, 15)), 
            iaa.AdditivePoissonNoise(lam=(1, 15)),
            iaa.SaltAndPepper(p=(0.003, 0.005)), 
            iaa.Salt(p=(0.003, 0.007)),    
            iaa.Pepper(p=(0.003, 0.01)),      
            iaa.MultiplyElementwise((0.95, 1.05)),      
            iaa.Dropout(p=(0.001, 0.005)),    
            iaa.JpegCompression(compression=(5, 10))
        ])
        self.sharpen_augmenter = iaa.Sharpen(alpha=(0.1, 0.5), lightness=(0.75, 2.0))

        self.blur_augmenters = iaa.OneOf([
            iaa.GaussianBlur(sigma=(2.0, 4.0)),
            iaa.MotionBlur(k=(8, 20)),
            iaa.MedianBlur(k=(7, 9))
        ])

    def __len__(self):
        return len(self.description)

    def __getitem__(self, idx):
        row = self.description.iloc[idx]
        img_path = os.path.join(self.img_dir, row['filename'])
        image = Image.open(img_path).convert('RGB')

        last_aug_type = "none"
        fraud_label = row['fraud_label']
        damage_label = row['damage_label']
        fact_side_label = row['fact_side_label']

        if random.random() < 0.5:
            if row['fact_side_label'] in [CarSide.FRONT.value, CarSide.BACK.value, CarSide.UNKNOWN.value]:
                image = self.flip_augmenter(image = np.array(image))
                image = Image.fromarray(image) 
            else:
                image = self.flip_augmenter(image = np.array(image))
                image = Image.fromarray(image) 
            
                if fact_side_label == CarSide.LEFT.value:
                    fact_side_label = CarSide.RIGHT.value
                else:
                    fact_side_label = CarSide.LEFT.value
                
            last_aug_type = "flip"

        if random.random() < 0.1:
            image = self.noise_augmenters(image = np.array(image))
            image = Image.fromarray(image) 
            last_aug_type = "noise"

        if random.random() < 0.1 and row['fraud_label'] != 2:
            image = self.sharpen_augmenter(image = np.array(image))
            image = Image.fromarray(image) 
            last_aug_type = "sharpen"

        if random.random() < 0.1 and fraud_label == 0:
            image = self.blur_augmenters(image = np.array(image))            
            image = Image.fromarray(image)            
            fraud_label = fraud_verdict_mapping['BLURRY_PHOTO']    
            damage_label = damage_verdict_mapping['BAD_PHOTO']

            last_aug_type = "blur"

        
        if self.transform:
            image = self.transform(image)
            
        final_label = 1 - int((fraud_label == 0) & (damage_label == 0))

        data = {
            'image': image,
            'fraud_label': fraud_label,
            'damage_label': damage_label,
            'fact_side_label': fact_side_label,
            'final_label': final_label,
            'pass_id': row['pass_id'],
            'plan_side': row['plan_side'],
            'last_aug_type' : last_aug_type
        }
        return data

In [27]:
class MultiTargetDataset(Dataset):
    def __init__(self, img_dir, description, transform=None):
        self.img_dir = img_dir
        self.description = description.reset_index()
        self.transform = transform

    def __len__(self):
        return len(self.description)

    def __getitem__(self, idx):
        row = self.description.iloc[idx]
        img_path = os.path.join(self.img_dir, row['filename'])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        data = {
            'image': image,
            'fraud_label': row['fraud_label'],
            'damage_label': row['damage_label'],
            'fact_side_label': row['fact_side_label'],
            'final_label': row['final_label'],
            'pass_id': row['pass_id'],
            'plan_side': row['plan_side']
        }
        return data


### train/val split

In [28]:
from sklearn.model_selection import train_test_split

# Perform stratified split based on 'target'
train_df, val_df = train_test_split(
    description,
    test_size=TEST_SIZE,
    stratify=description['final_label'],
    random_state=42
)

## Consistency filter

In [None]:
conf_thresh = 0.5

In [29]:
train_df = train_df[
    (train_df['fraud_probability'] >= conf_thresh) &
    (train_df['damage_probability'] >= conf_thresh)
]

## Loaders

In [30]:
# Create datasets
train_dataset = MultiTargetDatasetWithAugs(img_dir=PUBLIC_DATA_FOLDER_PATH, description=train_df, transform=TRANSFORMS)
val_dataset = MultiTargetDataset(img_dir=PUBLIC_DATA_FOLDER_PATH, description=val_df, transform=TRANSFORMS)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

## Model

In [31]:
from torchmetrics.functional import auroc, f1_score

class CarInspectionModel(pl.LightningModule):
    def __init__(self, model_id, num_fraud_classes, num_damage_classes, num_fact_side_classes,
                 weight_fraud=1.0, weight_damage=1.0, weight_fact_side=1.0, weight_final=1.0,
                 learning_rate=1e-3, dropout_rate=0.1):
        super().__init__()
        self.save_hyperparameters()
        self.learning_rate = learning_rate

        # Backbone model from timm
        self.backbone = timm.create_model(
            model_id,
            pretrained=True,
            num_classes=0
        )
        num_features = self.backbone.num_features  # Number of features from the backbone

        # Activation and dropout layers
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(dropout_rate)

        # Classification heads
        self.fraud_classifier = nn.Sequential(
            nn.Linear(num_features, num_features),
            self.activation,
            self.dropout,
            nn.Linear(num_features, num_fraud_classes)
        )

        self.damage_classifier = nn.Sequential(
            nn.Linear(num_features, num_features),
            self.activation,
            self.dropout,
            nn.Linear(num_features, num_damage_classes)
        )

        self.fact_side_classifier = nn.Sequential(
            nn.Linear(num_features, num_features),
            self.activation,
            self.dropout,
            nn.Linear(num_features, num_fact_side_classes)
        )

        # Final binary classification head using intermediate features
        self.final_classifier = nn.Sequential(
            nn.Linear(num_features, num_features),
            self.activation,
            self.dropout,
            nn.Linear(num_features, 1)
        )

        # Loss functions
        self.fraud_loss_fn = nn.CrossEntropyLoss()
        self.damage_loss_fn = nn.CrossEntropyLoss()
        self.fact_side_loss_fn = nn.CrossEntropyLoss()
        self.final_loss_fn = nn.BCEWithLogitsLoss()

        # Loss weights
        self.weight_fraud = weight_fraud
        self.weight_damage = weight_damage
        self.weight_fact_side = weight_fact_side
        self.weight_final = weight_final

    def forward(self, x):
        features = self.backbone(x)

        fraud_logits = self.fraud_classifier(features)
        damage_logits = self.damage_classifier(features)
        fact_side_logits = self.fact_side_classifier(features)
        final_logits = self.final_classifier(features)

        return fraud_logits, damage_logits, fact_side_logits, final_logits

    def training_step(self, batch, batch_idx):
        images = batch['image']
        fraud_labels = batch['fraud_label']
        damage_labels = batch['damage_label']
        fact_side_labels = batch['fact_side_label']
        final_labels = batch['final_label']

        fraud_logits, damage_logits, fact_side_logits, final_logits = self(images)
        loss_fraud = self.fraud_loss_fn(fraud_logits, fraud_labels)
        loss_damage = self.damage_loss_fn(damage_logits, damage_labels)
        loss_fact_side = self.fact_side_loss_fn(fact_side_logits, fact_side_labels)
        loss_final = self.final_loss_fn(final_logits.squeeze(), final_labels.float())
        total_loss = (self.weight_fraud * loss_fraud +
                      self.weight_damage * loss_damage +
                      self.weight_fact_side * loss_fact_side +
                      self.weight_final * loss_final)

        self.log('train_loss', total_loss, prog_bar=True)
        self.log('train_loss_fraud', loss_fraud)
        self.log('train_loss_damage', loss_damage)
        self.log('train_loss_fact_side', loss_fact_side)
        self.log('train_loss_final', loss_final)

        fraud_preds = torch.argmax(fraud_logits, dim=1)
        f1_fraud = f1_score(fraud_preds, fraud_labels, num_classes=fraud_logits.size(1), task="multiclass", average='macro')
        self.log('train_f1_fraud', f1_fraud, prog_bar=True)

        damage_preds = torch.argmax(damage_logits, dim=1)
        f1_damage = f1_score(damage_preds, damage_labels, num_classes=damage_logits.size(1), task="multiclass", average='macro')
        self.log('train_f1_damage', f1_damage, prog_bar=True)

        fact_side_preds = torch.argmax(fact_side_logits, dim=1)
        f1_fact_side = f1_score(fact_side_preds, fact_side_labels, num_classes=fact_side_logits.size(1), task="multiclass", average='macro')
        self.log('train_f1_fact_side', f1_fact_side, prog_bar=True)

        final_probs = torch.sigmoid(final_logits.squeeze())
        auc_final = auroc(final_probs, final_labels.int(), task="binary")
        self.log('train_final_auc', auc_final, prog_bar=True)
        return total_loss

    def validation_step(self, batch, batch_idx):
        images = batch['image']
        fraud_labels = batch['fraud_label']
        damage_labels = batch['damage_label']
        fact_side_labels = batch['fact_side_label']
        final_labels = batch['final_label']

        fraud_logits, damage_logits, fact_side_logits, final_logits = self(images)
        loss_fraud = self.fraud_loss_fn(fraud_logits, fraud_labels)
        loss_damage = self.damage_loss_fn(damage_logits, damage_labels)
        loss_fact_side = self.fact_side_loss_fn(fact_side_logits, fact_side_labels)
        loss_final = self.final_loss_fn(final_logits.squeeze(), final_labels.float())
        total_loss = (self.weight_fraud * loss_fraud +
                      self.weight_damage * loss_damage +
                      self.weight_fact_side * loss_fact_side +
                      self.weight_final * loss_final)

        self.log('val_loss', total_loss, prog_bar=True)
        self.log('val_loss_fraud', loss_fraud)
        self.log('val_loss_damage', loss_damage)
        self.log('val_loss_fact_side', loss_fact_side)
        self.log('val_loss_final', loss_final)

        fraud_preds = torch.argmax(fraud_logits, dim=1)
        f1_fraud = f1_score(fraud_preds, fraud_labels, num_classes=fraud_logits.size(1), task="multiclass", average='macro')
        self.log('val_f1_fraud', f1_fraud, prog_bar=True)

        damage_preds = torch.argmax(damage_logits, dim=1)
        f1_damage = f1_score(damage_preds, damage_labels, num_classes=damage_logits.size(1), task="multiclass", average='macro')
        self.log('val_f1_damage', f1_damage, prog_bar=True)

        fact_side_preds = torch.argmax(fact_side_logits, dim=1)
        f1_fact_side = f1_score(fact_side_preds, fact_side_labels, num_classes=fact_side_logits.size(1), task="multiclass", average='macro')
        self.log('val_f1_fact_side', f1_fact_side, prog_bar=True)

        final_probs = torch.sigmoid(final_logits.squeeze())
        auc_final = auroc(final_probs, final_labels.int(), task="binary")
        self.log('val_final_auc', auc_final, prog_bar=True)

        return total_loss

    def configure_optimizers(self):
        param_groups = [
            {
                'params': [p for n, p in self.named_parameters() if not any(nd in n for nd in ['bias', 'LayerNorm.weight'])],
                'weight_decay': 1e-4
            },
            {
                'params': [p for n, p in self.named_parameters() if any(nd in n for nd in ['bias', 'LayerNorm.weight'])],
                'weight_decay': 0.0
            }
        ]
        optimizer = torch.optim.AdamW(param_groups, lr=self.learning_rate)
        
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=50, T_mult=1, eta_min=1e-5,
        )
        return [optimizer], [scheduler]


## Trainer

In [32]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger

In [33]:
LR = 1e-4

In [34]:
model = CarInspectionModel(
    model_id=MODEL_ID,
    num_fraud_classes=len(fraud_verdict_mapping),
    num_damage_classes=len(damage_verdict_mapping),
    num_fact_side_classes=len(fact_side_mapping),
    weight_fraud=1.0,
    weight_damage=1.0,
    weight_fact_side=1.0,
    weight_final=3.0,
    learning_rate=LR,
)


In [35]:
EXP_NAME = MODEL_ID 

In [None]:
logger = CSVLogger("logs_effformer", name=EXP_NAME, version=f"test")
logger.log_hyperparams({"model_id": {MODEL_ID}})

In [37]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_final_auc',  # Monitor the final ROC AUC
    mode='max',
    save_top_k=3,
)

In [36]:
trainer = Trainer(
    logger=logger,
    max_epochs=5,
    callbacks=[checkpoint_callback],
    log_every_n_steps=100,
    val_check_interval=100
)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, train_loader, val_loader)

In [None]:
model.eval()

In [None]:
model.eval()
script = torch.jit.script(model)

In [110]:
torch.jit.save(script, 'model.pt')

In [111]:
script = torch.jit.load('model.pt')

**NB**: There are only **filename**, **pass_id**, **plan_side** in private data description

In [98]:
from sklearn.metrics import roc_auc_score

from utils import get_predictions

In [99]:
def get_predictions(model, dataloader, device, fact_side_mapping=fact_side_mapping, side_confidence_threshold=0.7):
    model.eval()
    model.to(device)
    
    final_predictions = []
    pass_ids = []
    plan_sides = []
    predicted_fact_sides = []
    fact_side_confidences = []
    
    with torch.inference_mode():
        for batch in tqdm(dataloader):
            images = batch['image'].to(device)
            batch_pass_ids = batch['pass_id']
            batch_plan_sides = batch['plan_side']
            
            _, _, fact_side_logits, final_logits = model(images)
            
            final_probs = torch.sigmoid(final_logits.squeeze())
            
            fact_side_probs = torch.softmax(fact_side_logits, dim=1)
            max_probs, fact_side_preds = torch.max(fact_side_probs, dim=1)
            fact_side_confidences.extend(max_probs.cpu().numpy())
            predicted_fact_sides.extend(fact_side_preds.cpu().numpy())
            
            final_predictions.extend(final_probs.cpu().numpy())
            pass_ids.extend(batch_pass_ids)
            plan_sides.extend(batch_plan_sides)
    
    inverse_fact_side_mapping = {v: k for k, v in fact_side_mapping.items()}
    predicted_fact_sides_names = [inverse_fact_side_mapping[label] for label in predicted_fact_sides]
    
    predictions_df = pd.DataFrame({
        'pass_id': pass_ids,
        'prediction': final_predictions,
        'plan_side': plan_sides,
        'predicted_fact_side': predicted_fact_sides_names,
        'fact_side_confidence': fact_side_confidences
    })
    
    predictions_df['prediction'] = predictions_df.apply(
        lambda row: 1.0 if (row['plan_side'] != row['predicted_fact_side'] and row['fact_side_confidence'] >= side_confidence_threshold) else row['prediction'],
        axis=1
    )
    
    return predictions_df

In [None]:
predictions_df = get_predictions(model, val_loader, device, fact_side_mapping=fact_side_mapping)