In [1]:
import os
import sys
sys.path.append('../input/pytorchimagemodels')

import os
import typing as tp
import math
import time
import random
import shutil
from pathlib import Path
from contextlib import contextmanager
from collections import defaultdict, Counter

import scipy as sp
import numpy as np
import pandas as pd

from tqdm.auto import tqdm
from functools import partial

import cv2
from PIL import Image

from matplotlib import pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
import torchvision.models as models
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader, Dataset
import albumentations
from albumentations import *
from albumentations.pytorch import ToTensorV2


import timm

from torch.cuda.amp import autocast, GradScaler

import warnings 
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from collections import OrderedDict

In [2]:
IMAGE_SIZE = 640
BATCH_SIZE = 32
TEST_PATH = '../input/ranzcr-clip-catheter-line-classification/test'
DEBUG = False

In [3]:
test = pd.read_csv('../input/ranzcr-clip-catheter-line-classification/sample_submission.csv')
if DEBUG:
    test = test.sample(frac=0.01)

In [4]:
class TestDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.file_names = df['StudyInstanceUID'].values
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = f'{TEST_PATH}/{file_name}.jpg'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image
    
def get_transforms():
        return Compose([
            Resize(IMAGE_SIZE, IMAGE_SIZE),
            Normalize(
            ),
            ToTensorV2(),
        ])
    
def get_activation(activ_name: str = "relu"):
    """"""
    act_dict = {
        "relu": nn.ReLU(inplace=True),
        "tanh": nn.Tanh(),
        "sigmoid": nn.Sigmoid(),
        "identity": nn.Identity()}
    if activ_name in act_dict:
        return act_dict[activ_name]
    else:
        raise NotImplementedError


class Conv2dBNActiv(nn.Module):
    """Conv2d -> (BN ->) -> Activation"""

    def __init__(
            self, in_channels: int, out_channels: int,
            kernel_size: int, stride: int = 1, padding: int = 0,
            bias: bool = False, use_bn: bool = True, activ: str = "relu"
    ):
        """"""
        super(Conv2dBNActiv, self).__init__()
        layers = []
        layers.append(nn.Conv2d(
            in_channels, out_channels,
            kernel_size, stride, padding, bias=bias))
        if use_bn:
            layers.append(nn.BatchNorm2d(out_channels))

        layers.append(get_activation(activ))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        """Forward"""
        return self.layers(x)


class SSEBlock(nn.Module):
    """channel `S`queeze and `s`patial `E`xcitation Block."""

    def __init__(self, in_channels: int):
        """Initialize."""
        super(SSEBlock, self).__init__()
        self.channel_squeeze = nn.Conv2d(
            in_channels=in_channels, out_channels=1,
            kernel_size=1, stride=1, padding=0, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        """Forward."""
        # # x: (bs, ch, h, w) => h: (bs, 1, h, w)
        h = self.sigmoid(self.channel_squeeze(x))
        # # x, h => return: (bs, ch, h, w)
        return x * h


class SpatialAttentionBlock(nn.Module):
    """Spatial Attention for (C, H, W) feature maps"""

    def __init__(
            self, in_channels: int,
            out_channels_list: tp.List[int],
    ):
        """Initialize"""
        super(SpatialAttentionBlock, self).__init__()
        self.n_layers = len(out_channels_list)
        channels_list = [in_channels] + out_channels_list
        assert self.n_layers > 0
        assert channels_list[-1] == 1

        for i in range(self.n_layers - 1):
            in_chs, out_chs = channels_list[i: i + 2]
            layer = Conv2dBNActiv(in_chs, out_chs, 3, 1, 1, activ="relu")
            setattr(self, f"conv{i + 1}", layer)

        in_chs, out_chs = channels_list[-2:]
        layer = Conv2dBNActiv(in_chs, out_chs, 3, 1, 1, activ="sigmoid")
        setattr(self, f"conv{self.n_layers}", layer)

    def forward(self, x):
        """Forward"""
        h = x
        for i in range(self.n_layers):
            h = getattr(self, f"conv{i + 1}")(h)

        h = h * x
        return h
    
class MultiHeadModel(nn.Module):

    def __init__(
            self, base_name: str = 'resnext50_32x4d',
            out_dims_head: tp.List[int] = [3, 4, 3, 1], pretrained=False):
        """"""
        self.base_name = base_name
        self.n_heads = len(out_dims_head)
        super(MultiHeadModel, self).__init__()

        # # load base model
        base_model = timm.create_model(base_name, pretrained=pretrained)
        in_features = base_model.num_features

        # # remove global pooling and head classifier
        base_model.reset_classifier(0, '')

        # # Shared CNN Bacbone
        self.backbone = base_model

        # # Multi Heads.
        for i, out_dim in enumerate(out_dims_head):
            layer_name = f"head_{i}"
            layer = nn.Sequential(
                SpatialAttentionBlock(in_features, [64, 32, 16, 1]),
                nn.AdaptiveAvgPool2d(output_size=1),
                nn.Flatten(start_dim=1),
                nn.Linear(in_features, in_features),
                nn.ReLU(inplace=True),
                nn.Dropout(0.5),
                nn.Linear(in_features, out_dim))
            setattr(self, layer_name, layer)

    def forward(self, x):
        """"""
        h = self.backbone(x)
        hs = [
            getattr(self, f"head_{i}")(h) for i in range(self.n_heads)]
        y = torch.cat(hs, axis=1)
        return y

In [5]:
def inference(models, test_loader, device):
    tk0 = tqdm(enumerate(test_loader), total=len(test_loader))
    probs = []
    for i, (images) in tk0:
        images = images.to(device)
        avg_preds = []
        for model in models:
            model.cuda()
            model.eval()
            with torch.no_grad():
                y_preds1 = model(images)
                y_preds2 = model(images.flip(-1))
            y_preds = (y_preds1.sigmoid().to('cpu').numpy() + y_preds2.sigmoid().to('cpu').numpy()) / 2
            avg_preds.append(y_preds)  #---暂时取消---0.05
        avg_preds = np.mean(avg_preds, axis=0)
        probs.append(avg_preds)
    probs = np.concatenate(probs)
    return probs

In [6]:
def load_multi_head_model(model_name, checkpoint):
    state_dict = torch.load(checkpoint)  # 模型可以保存为pth文件，也可以为pt文件。

    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] # remove `module.`，表面从第7个key值字符取到最后一个字符，正好去掉了module.
        new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。
        
    model = MultiHeadModel(model_name, [3, 4, 3, 1], False)
    model.load_state_dict(new_state_dict)
    
    return model

test_dataset = TestDataset(test, transform=get_transforms())
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                         num_workers=4 , pin_memory=True)

In [7]:
target_cols = test.iloc[:, 1:12].columns.tolist()

resnet200d_public = test.copy()
regnety_080 = test.copy()
resnest101e = test.copy()
dm_nfnet_f0 = test.copy()
seresnet152d = test.copy()
resnet200d = test.copy()
tf_b5 = test.copy()

res = test.copy()

In [8]:
models = []
models.append(load_multi_head_model('resnet200d', '../input/checkpoint-test/resnet200d_b32_lr0.0004_fold0_mh_best_AUC.pth'))
models.append(load_multi_head_model('resnet200d', '../input/checkpoint-test/resnet200d_b32_lr0.0004_fold1_mh_best_AUC.pth'))
models.append(load_multi_head_model('resnet200d', '../input/razerck/resnet200d_b32_lr0.0004_cx_mh/resnet200d_b32_lr0.0004_fold2_mh_best_AUC.pth'))
models.append(load_multi_head_model('resnet200d', '../input/razerck/resnet200d_b32_lr0.0004_cx_mh/resnet200d_b32_lr0.0004_fold3_mh_best_AUC.pth'))
models.append(load_multi_head_model('resnet200d', '../input/razerck/resnet200d_b32_lr0.0004_cx_mh/resnet200d_b32_lr0.0004_fold4_mh_best_AUC.pth'))
predictions = inference(models, test_loader, device)
resnet200d[target_cols] = predictions

  0%|          | 0/112 [00:00<?, ?it/s]

In [9]:
class ResNet200D(nn.Module):
    def __init__(self, model_name='resnet200d'):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=False)
        n_features = self.model.fc.in_features
        self.model.global_pool = nn.Identity()
        self.model.fc = nn.Identity()
        self.pooling = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(n_features, 11)

    def forward(self, x):
        bs = x.size(0)
        features = self.model(x)
        pooled_features = self.pooling(features).view(bs, -1)
        output = self.fc(pooled_features)
        return output

def LoadResNet200D(path):
    model = ResNet200D()
    model.load_state_dict(torch.load(path, map_location='cuda:0'))
    return model

models = []
models.append(LoadResNet200D('../input/resnet200d-baseline-benchmark-public/resnet200d_fold0_cv953.pth'))
models.append(LoadResNet200D('../input/resnet200d-baseline-benchmark-public/resnet200d_fold1_cv955.pth'))
models.append(LoadResNet200D('../input/resnet200d-baseline-benchmark-public/resnet200d_fold2_cv955.pth'))
models.append(LoadResNet200D('../input/resnet200d-baseline-benchmark-public/resnet200d_fold3_cv957.pth'))
models.append(LoadResNet200D('../input/resnet200d-baseline-benchmark-public/resnet200d_fold4_cv954.pth'))

predictions = inference(models, test_loader, device)
resnet200d_public[target_cols] = predictions

  0%|          | 0/112 [00:00<?, ?it/s]

In [10]:
models = []
models.append(load_multi_head_model('seresnet152d', '../input/razerck/seresnet152d_b32_lr0.0004_cx_mh/seresnet152d_b32_lr0.0004_fold0_mh_best_AUC.pth'))
models.append(load_multi_head_model('seresnet152d', '../input/razerck/seresnet152d_b32_lr0.0004_cx_mh/seresnet152d_b32_lr0.0004_fold1_mh_best_AUC.pth'))
models.append(load_multi_head_model('seresnet152d', '../input/razerck/seresnet152d_b32_lr0.0004_cx_mh/seresnet152d_b32_lr0.0004_fold2_mh_best_AUC.pth'))
models.append(load_multi_head_model('seresnet152d', '../input/razerck/seresnet152d_b32_lr0.0004_cx_mh/seresnet152d_b32_lr0.0004_fold3_mh_best_AUC.pth'))
models.append(load_multi_head_model('seresnet152d', '../input/razerck/seresnet152d_b32_lr0.0004_cx_mh/seresnet152d_b32_lr0.0004_fold4_mh_best_AUC.pth'))

predictions = inference(models, test_loader, device)
seresnet152d[target_cols] = predictions

  0%|          | 0/112 [00:00<?, ?it/s]

In [11]:
models = []
models.append(load_multi_head_model('dm_nfnet_f0', '../input/razerck/dm_nfnet_f0_b32_lr0.0003/dm_nfnet_f0_b32_lr0.0003_fold0_mh_best_AUC.pth'))
models.append(load_multi_head_model('dm_nfnet_f0', '../input/razerck/dm_nfnet_f0_b32_lr0.0003/dm_nfnet_f0_b32_lr0.0003_fold1_mh_best_AUC.pth'))
models.append(load_multi_head_model('dm_nfnet_f0', '../input/razerck/dm_nfnet_f0_b32_lr0.0003/dm_nfnet_f0_b32_lr0.0003_fold2_mh_best_AUC.pth'))
models.append(load_multi_head_model('dm_nfnet_f0', '../input/razerck/dm_nfnet_f0_b32_lr0.0003/dm_nfnet_f0_b32_lr0.0003_fold3_mh_best_AUC.pth'))
models.append(load_multi_head_model('dm_nfnet_f0', '../input/razerck/dm_nfnet_f0_b32_lr0.0003/dm_nfnet_f0_b32_lr0.0003_fold4_mh_best_AUC.pth'))

predictions = inference(models, test_loader, device)
dm_nfnet_f0[target_cols] = predictions

  0%|          | 0/112 [00:00<?, ?it/s]

In [12]:
models = []
models.append(load_multi_head_model('tf_efficientnet_b5_ns', '../input/razerck/tf_efficientnet_b5_ns_b32_lr0.0004_cx_mh/tf_efficientnet_b5_ns_b32_lr0.0004_fold0_mh_best_AUC.pth'))
models.append(load_multi_head_model('tf_efficientnet_b5_ns', '../input/razerck/tf_efficientnet_b5_ns_b32_lr0.0004_cx_mh/tf_efficientnet_b5_ns_b32_lr0.0004_fold1_mh_best_AUC.pth'))
models.append(load_multi_head_model('tf_efficientnet_b5_ns', '../input/razerck/tf_efficientnet_b5_ns_b32_lr0.0004_cx_mh/tf_efficientnet_b5_ns_b32_lr0.0004_fold2_mh_best_AUC.pth'))
models.append(load_multi_head_model('tf_efficientnet_b5_ns', '../input/razerck/tf_efficientnet_b5_ns_b32_lr0.0004_cx_mh/tf_efficientnet_b5_ns_b32_lr0.0004_fold3_mh_best_AUC.pth'))
models.append(load_multi_head_model('tf_efficientnet_b5_ns', '../input/razerck/tf_efficientnet_b5_ns_b32_lr0.0004_cx_mh/tf_efficientnet_b5_ns_b32_lr0.0004_fold4_mh_best_AUC.pth'))

predictions = inference(models, test_loader, device)
tf_b5[target_cols] = predictions

  0%|          | 0/112 [00:00<?, ?it/s]

In [13]:
# models = []
# models.append(load_multi_head_model('regnety_080', '../input/razerck/regnety_080_b32_lr0.0003/regnety_080_b32_lr0.0003_fold0_mh_best_AUC.pth'))
# models.append(load_multi_head_model('regnety_080', '../input/razerck/regnety_080_b32_lr0.0003/regnety_080_b32_lr0.0003_fold1_mh_best_AUC.pth'))
# models.append(load_multi_head_model('regnety_080', '../input/razerck/regnety_080_b32_lr0.0003/regnety_080_b32_lr0.0003_fold2_mh_best_AUC.pth'))
# models.append(load_multi_head_model('regnety_080', '../input/razerck/regnety_080_b32_lr0.0003/regnety_080_b32_lr0.0003_fold3_mh_best_AUC.pth'))
# models.append(load_multi_head_model('regnety_080', '../input/razerck/regnety_080_b32_lr0.0003/regnety_080_b32_lr0.0003_fold4_mh_best_AUC.pth'))

# predictions = inference(models, test_loader, device)
# regnety_080[target_cols] = predictions

In [14]:
# models = []
# models.append(load_multi_head_model('resnest101e', '../input/razerck/resnest101e_b32_lr0.0003/resnest101e_b32_lr0.0003_fold0_mh_best_AUC.pth'))
# models.append(load_multi_head_model('resnest101e', '../input/razerck/resnest101e_b32_lr0.0003/resnest101e_b32_lr0.0003_fold1_mh_best_AUC.pth'))
# models.append(load_multi_head_model('resnest101e', '../input/razerck/resnest101e_b32_lr0.0003/resnest101e_b32_lr0.0003_fold2_mh_best_AUC.pth'))
# models.append(load_multi_head_model('resnest101e', '../input/razerck/resnest101e_b32_lr0.0003/resnest101e_b32_lr0.0003_fold3_mh_best_AUC.pth'))
# models.append(load_multi_head_model('resnest101e', '../input/razerck/resnest101e_b32_lr0.0003/resnest101e_b32_lr0.0003_fold4_mh_best_AUC.pth'))

# predictions = inference(models, test_loader, device)
# resnest101e[target_cols] = predictions

In [15]:
a = 0.25
w = [0.25, 0.2, 0.2, 0.2, 0.15]
print(sum(w))
res[target_cols] = w[0]*resnet200d[target_cols]**a + w[1]*resnet200d_public[target_cols]**a + w[2]*seresnet152d[target_cols]**a + w[3]*dm_nfnet_f0[target_cols]**a + w[4]*tf_b5[target_cols]**a
res.head()

1.0


Unnamed: 0,StudyInstanceUID,ETT - Abnormal,ETT - Borderline,ETT - Normal,NGT - Abnormal,NGT - Borderline,NGT - Incompletely Imaged,NGT - Normal,CVC - Abnormal,CVC - Borderline,CVC - Normal,Swan Ganz Catheter Present
0,1.2.826.0.1.3680043.8.498.46923145579096002617...,0.390576,0.79456,0.837104,0.197227,0.232169,0.35668,0.993572,0.383404,0.529913,0.960729,0.998912
1,1.2.826.0.1.3680043.8.498.84006870182611080091...,0.06746,0.079404,0.106998,0.091488,0.0983,0.079383,0.107517,0.343738,0.268361,0.994874,0.053898
2,1.2.826.0.1.3680043.8.498.12219033294413119947...,0.082558,0.092475,0.11443,0.112398,0.109725,0.09458,0.116534,0.310796,0.814369,0.843613,0.076752
3,1.2.826.0.1.3680043.8.498.84994474380235968109...,0.318444,0.43172,0.476027,0.375899,0.257676,0.988783,0.344332,0.423249,0.496492,0.964559,0.204014
4,1.2.826.0.1.3680043.8.498.35798987793805669662...,0.087376,0.104922,0.131174,0.139403,0.131821,0.092458,0.163586,0.262122,0.431157,0.991224,0.051398


In [16]:
res[['StudyInstanceUID'] + target_cols].to_csv('submission.csv', index=False)
res.head()

Unnamed: 0,StudyInstanceUID,ETT - Abnormal,ETT - Borderline,ETT - Normal,NGT - Abnormal,NGT - Borderline,NGT - Incompletely Imaged,NGT - Normal,CVC - Abnormal,CVC - Borderline,CVC - Normal,Swan Ganz Catheter Present
0,1.2.826.0.1.3680043.8.498.46923145579096002617...,0.390576,0.79456,0.837104,0.197227,0.232169,0.35668,0.993572,0.383404,0.529913,0.960729,0.998912
1,1.2.826.0.1.3680043.8.498.84006870182611080091...,0.06746,0.079404,0.106998,0.091488,0.0983,0.079383,0.107517,0.343738,0.268361,0.994874,0.053898
2,1.2.826.0.1.3680043.8.498.12219033294413119947...,0.082558,0.092475,0.11443,0.112398,0.109725,0.09458,0.116534,0.310796,0.814369,0.843613,0.076752
3,1.2.826.0.1.3680043.8.498.84994474380235968109...,0.318444,0.43172,0.476027,0.375899,0.257676,0.988783,0.344332,0.423249,0.496492,0.964559,0.204014
4,1.2.826.0.1.3680043.8.498.35798987793805669662...,0.087376,0.104922,0.131174,0.139403,0.131821,0.092458,0.163586,0.262122,0.431157,0.991224,0.051398
