# I. Import libraries

In [None]:
import os, glob
import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm
from tqdm import tqdm
import pandas as pd
import numpy as np
import re

# II. Configuration

In [None]:
class config:
    # Data
    input_dir = '.'
    if os.path.exists('/kaggle/input'):
        input_dir = '../input/ranzcr-clip-catheter-line-classification'
    img_col = 'StudyInstanceUID'
    label_cols = [
        'ETT - Abnormal', 'ETT - Borderline', 'ETT - Normal', 
        'NGT - Abnormal', 'NGT - Borderline', 'NGT - Incompletely Imaged', 'NGT - Normal', 
        'CVC - Abnormal', 'CVC - Borderline', 'CVC - Normal', 
        'Swan Ganz Catheter Present',
    ]
    batch_size = 32
    image_size = 512
    num_workers = 2
    pin_memory = True
    seed = 42

    # Model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    in_chans = 3
    num_classes = len(label_cols)
    drop_path_rate = 0.1
    pretrained = False                       # True: load pretrained model, False: train from scratch
    checkpoint_path = ''                    # Path to model's pretrained weights
    checkpoint_dirs = {'convnext_tiny': [
                            'student_checkpoint/fold=0-best-post.pth',
                            'student_checkpoint/fold=1-best-post.pth',
                            'student_checkpoint/fold=2-best-post.pth',
                            'student_checkpoint/fold=3-best-post.pth',
                            'student_checkpoint/fold=4-best-post.pth',
                        ]}
    # checkpoint_dirs = {'convnext_tiny': [
    #                         'student_checkpoint/fold=0-best-full.pth',
    #                         'student_checkpoint/fold=1-best-full.pth',
    #                         'student_checkpoint/fold=2-best-full.pth',
    #                         'student_checkpoint/fold=3-best-full.pth',
    #                         'student_checkpoint/fold=4-best-full.pth',
    #                     ]}
    debug = True

# III. Data

In [None]:
class RANZCRDataset(Dataset):
    def __init__(self, image_dir, df, img_col, label_cols, df_annot=None, 
                 color_map=None, transform=None, prev_transform=None, return_img='image'):
        super(RANZCRDataset, self).__init__()
        self.image_dir = image_dir
        self.df = df
        self.df_annot = df_annot
        self.color_map = color_map
        self.img_col = img_col
        self.label_cols = label_cols
        self.transform = transform
        self.prev_transform = prev_transform
        self.return_img = return_img 

    def __len__(self):
        return len(self.df)

    def draw_annotation(self, image, df_annot, img_col, img_id, color_map):
        df = df_annot.query(f"{img_col} == '{img_id}'")
        for index, annot_row in df.iterrows():
            data = eval(annot_row['data'])
            label = annot_row["label"].split(' - ')
            cv2.polylines(image, np.int32([data]), isClosed=False, 
                          color=color_map[label[0]], thickness=15, lineType=16)
            if len(label) > 1 and label[1] != 'Incompletely Imaged':
                # x_center, y_center = image.shape[1]/2, image.shape[0]/2
                # x, y = min([data[0], data[-1]], key=lambda x: (x[0]-x_center)**2 + (x[1]-y_center)**2)
                # cv2.circle(image, (x, y), radius=15, 
                #             color=color_map[label[1]], thickness=25)
                cv2.circle(image, tuple(data[0]), radius=15, 
                            color=color_map[label[1]], thickness=25)
                cv2.circle(image, tuple(data[-1]), radius=15, 
                            color=color_map[label[1]], thickness=25)
        return image

    def __getitem__(self, index):
        if self.df is not None:
            row = self.df.iloc[index]
            image = cv2.imread(f'{self.image_dir}/{row[self.img_col]}.jpg')[:, :, ::-1].astype(np.uint8)
            label = row[self.label_cols].values.astype('float')
        else:
            image = cv2.imread(f'{self.image_dir}/{os.listdir(self.image_dir)[index]}')[:, :, ::-1].astype(np.uint8)
            label = np.zeros(len(self.label_cols))

        if self.prev_transform:
            image = self.prev_transform(image=image)['image']

        # Draw annotation on image if available
        annot_image = None
        if self.df_annot is not None and self.color_map:
            if self.return_img == 'both':
                annot_image = image.copy()
                annot_image = self.draw_annotation(annot_image, self.df_annot, self.img_col, 
                                                   row[self.img_col], self.color_map)
            elif self.return_img == 'annot_image':
                image = self.draw_annotation(image, self.df_annot, self.img_col, 
                                             row[self.img_col], self.color_map)    

        if self.transform:
            if annot_image is not None:
                transformed = self.transform(image=image, annot_image=annot_image)
                image, annot_image = transformed['image'], transformed['annot_image']
            else:
                image = self.transform(image=image)['image']
        return (image, label) if annot_image is None else (image, annot_image, label)

In [None]:
def build_transform(image_size=None, adjust_color=True, is_train=True, include_top=True, additional_targets=None):
    transform = []
    if image_size:
        transform.append(A.Resize(image_size, image_size))
    image_size = image_size or 40
    if adjust_color:
        transform.extend([
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
            A.HueSaturationValue(hue_shift_limit=40, sat_shift_limit=40, val_shift_limit=0, p=0.5),
        ])
    if is_train:
        transform.extend([
            A.HorizontalFlip(p=0.5),
            # A.OneOf([
            #   A.ImageCompression(),
            #   A.Downscale(scale_min=0.1, scale_max=0.15),
            # ], p=0.2),
            # A.PiecewiseAffine(p=0.2),
            # A.Sharpen(p=0.2),
            A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.3, rotate_limit=30, border_mode=0, p=0.5),
            A.CoarseDropout(max_height=int(image_size*0.2), max_width=int(image_size*0.2), 
                            min_holes=1, max_holes=4, p=0.5),
        ])
    if include_top:
        transform.extend([
            A.Normalize(
                mean=[0.5, 0.5, 0.5],
                std=[0.5, 0.5, 0.5],
            ),
            ToTensorV2(),
        ])
    return A.Compose(transform, additional_targets=additional_targets)

In [None]:
# Load data
test_df = pd.read_csv(f'{config.input_dir}/sample_submission.csv')

if config.debug:
    test_df = test_df.iloc[:100]

test_trainsform = build_transform(config.image_size, adjust_color=False, is_train=False, include_top=True)
test_dataset = RANZCRDataset(image_dir=f"{config.input_dir}/test", df=test_df,
                             img_col=config.img_col, label_cols=config.label_cols,
                             transform=test_trainsform)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                         num_workers=config.num_workers, pin_memory=config.pin_memory,
                         shuffle=False)

# IV. Model

In [None]:
class RANZCRClassifier(nn.Module):
    def __init__(self, model_name, pretrained=False, checkpoint_path='', 
                 in_chans=3, num_classes=1000, drop_path_rate=0.0, return_features=True):
        super(RANZCRClassifier, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained,
                                       checkpoint_path=checkpoint_path,
                                       drop_path_rate=drop_path_rate)
        n_features = self.model.get_classifier().in_features
        self.model.reset_classifier(num_classes=0, global_pool='')
        self.pooling = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(n_features, num_classes)
        self.return_features = return_features

    def forward(self, x):
        bs = x.size(0)
        features = self.model(x)
        pooled_features = self.pooling(features).view(bs, -1)
        output = self.fc(pooled_features)
        return features, output if self.return_features else output

In [None]:
def load_checkpoint(checkpoint_path=None, fold=None, checkpoint_dir=None, postfix=''):
    checkpoint = None
    if checkpoint_path:
        # Load checkpoint given by the path
        if checkpoint_path.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(checkpoint_path, 
                                                            map_location='cpu', 
                                                            check_hash=True)
        else:
            checkpoint = torch.load(checkpoint_path, map_location='cpu')
        print(f"Loaded checkpoint from {checkpoint_path}")
    elif checkpoint_dir and fold is not None:
        # Load checkpoint from the latest one
        checkpoint_files = glob.glob(f"{checkpoint_dir}/fold=*-epoch=*{postfix}.pth")
        checkpoint_files = {f: int(re.search('epoch=(\d+)', f).group(1)) for f in checkpoint_files 
                            if int(re.search('fold=(\d+)', f).group(1)) == fold}
        if len(checkpoint_files) > 0:
            checkpoint_file = max(checkpoint_files, key=checkpoint_files.get)
            checkpoint = torch.load(checkpoint_file, map_location='cpu')
    return checkpoint

In [None]:
# Load model
models = []
for model_name, checkpoint_dirs in config.checkpoint_dirs.items():
    for checkpoint_dir in checkpoint_dirs:
        # Initialize model
        model = RANZCRClassifier(model_name, pretrained=config.pretrained,
                                 checkpoint_path=config.checkpoint_path, 
                                 in_chans=config.in_chans, num_classes=config.num_classes,
                                 drop_path_rate=config.drop_path_rate)
        model = model.to(config.device)

        # Load weights
        checkpoint = load_checkpoint(checkpoint_dir)
        if 'auc' in checkpoint:
            print(f"AUC: {checkpoint['auc']}")
        model.load_state_dict(checkpoint['model'])
        models.append(model)

# V. Inference

In [None]:
def predict(model, loader, config):
    model.eval()
    preds = []
    tepoch = tqdm(loader)
    with torch.no_grad():
        for batch_idx, (data, targets) in enumerate(tepoch):
            data = data.to(config.device)
            _, outputs = model(data)
            preds.append(outputs)
    return torch.cat(preds).sigmoid().cpu().numpy()

In [None]:
# Predict
preds = []
for model in models:
    preds.append(predict(model, test_loader, config))
preds = np.mean(preds, axis=0)
preds = (preds > 0.5).astype('int')
test_df[config.label_cols] = preds
test_df.to_csv('submission.csv', index=False)
test_df