In [None]:
!pip install -q /kaggle/input/pretrainedmodels-0-7-4/pretrainedmodels-0.7.4-py3-none-any.whl
!pip install -q /kaggle/input/timm-0-4-9/timm-0.4.9-py3-none-any.whl
!pip install -q /kaggle/input/python-gdcm-3-0-9-0-cp37/python_gdcm-3.0.9.0-cp37-cp37m-manylinux2014_x86_64.whl

In [None]:
import pretrainedmodels
import timm
import torch
from torch import nn
import numpy as np
import pandas as pd
import pydicom
import gdcm
from pydicom.pixel_data_handlers.util import apply_voi_lut
import cv2
import os
import copy
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool
from collections import OrderedDict, defaultdict
from tqdm import tqdm
from multiprocessing import Pool

In [None]:
class FrozenBatchNorm2dWithEpsilon(torchvision.ops.misc.FrozenBatchNorm2d):
    """This class aims to make the epsilon consistent with the default value of torch.nn.BatchNorm2d"""
    
    def __init__(self, *args, **kwargs):
        if 'eps' not in kwargs:
            kwargs['eps'] = 1e-5
        super().__init__(*args, **kwargs)

In [None]:
args_dict = {'cls': {}, 'det': {}}
args_dict['batch_size'] = 16

args_dict['cls']['input_size'] = 512
args_dict['cls']['ckpt_paths'] = [
    '../input/covid-cls-model-v6-25-5-fold/covid-cls-model-v6-25-fold-0.pth',
    '../input/covid-cls-model-v6-25-5-fold/covid-cls-model-v6-25-fold-1.pth',
    '../input/covid-cls-model-v6-25-5-fold/covid-cls-model-v6-25-fold-2.pth',
    '../input/covid-cls-model-v6-25-5-fold/covid-cls-model-v6-25-fold-3.pth',
    '../input/covid-cls-model-v6-25-5-fold/covid-cls-model-v6-25-fold-4.pth',
]
args_dict['cls']['cascade_clses'] = [0]
args_dict['cls']['flip_TTA'] = True

#args_dict['det']['input_size'] = 800
args_dict['det']['norm_layer'] = FrozenBatchNorm2dWithEpsilon # misc_nn_ops.FrozenBatchNorm2d
args_dict['det']['ckpt_paths'] = [
    '../input/covid-det-model-v5-17-5-fold/covid_det_v5_17_cv_0.pth',
    '../input/covid-det-model-v5-17-5-fold/covid_det_v5_17_cv_1.pth',
    '../input/covid-det-model-v5-17-5-fold/covid_det_v5_17_cv_2.pth',
    '../input/covid-det-model-v5-17-5-fold/covid_det_v5_17_cv_3.pth',
    '../input/covid-det-model-v5-17-5-fold/covid_det_v5_17_cv_4.pth',
]
args_dict['det']['box_score_thresh']= 0.0
args_dict['det']['flip_TTA'] = True

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
test_root = '/kaggle/input/siim-covid19-detection/test/'
cls_names = ["negative", "typical", "indeterminate", "atypical"]

In [None]:
def read_xray(path, voi_lut=True, fix_monochrome=True):
    # Original from: https://www.kaggle.com/raddar/convert-dicom-to-np-array-the-correct-way
    dicom = pydicom.read_file(path)
    
    # VOI LUT (if available by DICOM device) is used to transform raw DICOM data to 
    # "human-friendly" view
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array
               
    max_value = 2 ** dicom.BitsStored - 1
    # depending on this value, X-ray may look inverted - fix that:
    if fix_monochrome and dicom.PhotometricInterpretation == "MONOCHROME1":
        data = max_value - data
    
    if max_value != 255:
        data = data.astype(np.float) / max_value
        data = (data * 255).astype(np.uint8)
        
    return data


def dicom2png_kaggler(dicom_path, png_path=None, output_width=None, output_height=None):
    np_image = read_xray(dicom_path)
    if output_width is not None and output_height is not None:
        np_image = cv2.resize(np_image, (output_width, output_height), interpolation=cv2.INTER_LINEAR)
    else:
        assert output_width is None and output_height is None
    if png_path is not None:
        cv2.imwrite(png_path, np_image)
    return np_image

In [None]:
source_root_folder = test_root
target_root_folder = 'png_1024/'
os.makedirs(target_root_folder, exist_ok=True)
source_paths = []
target_paths = []
for source_folder, folders, files in os.walk(source_root_folder):
    for file in files:
        if file.endswith('.dcm'):
            source_paths.append(os.path.join(source_folder, file))
            target_folder = source_folder.replace(source_root_folder, target_root_folder)
            target_file = file.replace('dcm', 'png')
            target_paths.append(os.path.join(target_folder, target_file))
print(f'Found {len(source_paths)} dcms.')
def transfer(info):
    i, (source_path, target_path) = info
    os.makedirs(os.path.split(target_path)[0], exist_ok=True)
    #dicom2png_kaggler(source_path, png_path=target_path, output_width=1024, output_height=1024)
    #dicom2png_kaggler(source_path, png_path=target_path) #, output_width=1024, output_height=1024)
    cv2_img = dicom2png_kaggler(source_path)
    org_WH = (cv2_img.shape[-1], cv2_img.shape[-2])
    cv2_img = cv2.resize(cv2_img, (1024, 1024), interpolation=cv2.INTER_LINEAR)
    cv2.imwrite(target_path, cv2_img)
    accno = os.path.splitext(os.path.split(source_path)[1])[0]
    return accno, org_WH
with Pool(8) as pool:
    accno2WH = OrderedDict(tqdm(pool.imap(transfer, enumerate(zip(source_paths, target_paths)))))

In [None]:
class Dataset():
    def __init__(self, transforms=None):
        self.transforms = transforms

    def __len__(self):
        return len(target_paths)

    def __getitem__(self, idx):
        path = target_paths[idx]
        cv2_img = cv2.imread(path)
        #if len(cv2_img.shape) == 2:
        #    cv2_img = np.tile(cv2_img[..., None], 3)
        cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_RGB2BGR)
        return self.transforms(cv2_img)
        #return self.transforms(cv2_img), accno2WH[os.path.splitext(os.path.split(path)[1])[0]]

In [None]:
#class SplitTransformPath(torch.nn.Module):
#    def __init__(self, n_outs):
#        super().__init__()
#        self.n_outs = n_outs
#        
#    def forward(self, inputs):
#        return (inputs, ) + tuple(copy.deepcopy(inputs) for _ in range(self.n_outs - 1))
#    
#    
#class SplitTransformWrapper(torch.nn.Module):
#    def __init__(self, transform, target_copy_idxes):
#        super().__init__()
#        self.transform = transform
#        self.target_copy_idxes = target_copy_idxes
#        
#    def forward(self, inputs):
#        return tuple(self.transform(inp) if i in self.target_copy_idxes else inp for i, inp in enumerate(inputs))

In [None]:
def freeze_layers(model, stage_name, trainable_stages):
    if trainable_stages >= len(stage_name):
        return
    freeze_to = stage_name[len(stage_name) - trainable_stages - 1]
    
    if any([any([name.startswith(x) for x in freeze_to]) for name, param in model.named_parameters()]) and freeze_to != '':
        freeze_flag = False
        for name, param in reversed(list(model.named_parameters())):
            if not freeze_flag and any([name.startswith(x) for x in freeze_to]):
                freeze_flag = True
            if freeze_flag:
                param.requires_grad_(False)
        return
    
    trainable = stage_name[len(stage_name) - trainable_stages:]
    freeze_flag = True
    for name, param in model.named_parameters():
        if freeze_flag and any([name.startswith(y) for x in trainable for y in x]):
            freeze_flag = False
        if freeze_flag:
            param.requires_grad_(False)
    return

def change_layer(model, source_module, target_module, params_map={}, state_dict_map={}):
    def check_children(model):
        for name, module in model.named_children():
            if isinstance(module, source_module):
                state_dict_to_load = {(k if k not in state_dict_map else state_dict_map[k]): v for k, v in module.state_dict().items()}
                setattr(model, name, target_module(**{new_param: getattr(module, org_param) for org_param, new_param in params_map.items()}))
                getattr(model, name).load_state_dict(state_dict_to_load, strict=False)
            check_children(getattr(model, name))
    check_children(model)
    
class TIMMBackboneBodyWrapper(torch.nn.Module):
    def __init__(self, timm_model):
        super().__init__()
        self.timm_model = timm_model
    
    def forward(self, *args, **kwargs):
        out_list = self.timm_model(*args, **kwargs)
        out_dict = OrderedDict([(str(i), feature) for i, feature in enumerate(out_list)])
        return out_dict
    
class NormalBackboneWithFPN(torch.nn.Module):
    """Backbone with FPN builder adapted from torchvision.models.detection.backbone_utils.BackboneWithFPN
    """
    def __init__(self, backbone_body, in_channels_list, out_channels=256):
        super().__init__()
        self.body = backbone_body
        self.fpn = FeaturePyramidNetwork(
            in_channels_list=in_channels_list,
            out_channels=out_channels,
            extra_blocks=LastLevelMaxPool(),
        )
        self.out_channels = out_channels

    def forward(self, x):
        x = self.body(x)
        x.device = next(iter(x.values())).device
        x = self.fpn(x)
        x.device = next(iter(x.values())).device
        return x
    
def timm_fpn_backbone(model_name, pretrained=False, norm_layer=misc_nn_ops.FrozenBatchNorm2d, trainable_layers=3):
    try:
        backbone = timm.create_model(model_name, pretrained=pretrained, features_only=True, norm_layer=norm_layer, out_indices=(1, 2, 3, 4))
    except:
        backbone = timm.create_model(model_name, pretrained=pretrained, features_only=True, out_indices=(1, 2, 3, 4))
        if norm_layer != torch.nn.BatchNorm2d:
            change_layer(backbone, torch.nn.BatchNorm2d, norm_layer, params_map={'num_features':'num_features', 'eps':'eps'})
    stage_name = [[info['module'].replace('.', '_'), info['module']] for info in backbone.feature_info.info]
    freeze_layers(backbone, stage_name, trainable_layers)
    
    return NormalBackboneWithFPN(TIMMBackboneBodyWrapper(backbone), backbone.feature_info.channels())

In [None]:
class MSCAM(nn.Module):
    """Module needed in AttentionalFeatureFusionFPN"""
    def __init__(self, num_channels, r):
        super().__init__()
        bottleneck = num_channels // r
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.w1 = nn.Conv2d(num_channels, bottleneck, 1)
        self.w2 = nn.Conv2d(bottleneck, num_channels, 1)
        self.pwc1 = nn.Conv2d(num_channels, bottleneck, 1)
        self.pwc2 = nn.Conv2d(bottleneck, num_channels, 1)
        self.gn_w1 = nn.GroupNorm(num_groups=32, num_channels=bottleneck)
        self.gn_w2 = nn.GroupNorm(num_groups=32, num_channels=num_channels)
        self.gn_pwc1 = nn.GroupNorm(num_groups=32, num_channels=bottleneck)
        self.gn_pwc2 = nn.GroupNorm(num_groups=32, num_channels=num_channels)

    def forward(self, x):
        x1 = self.pool(x)
        x1 = self.w1(x1)
        x1 = self.gn_w1(x1).relu()
        x1 = self.w2(x1)
        x1 = self.gn_w2(x1)
        
        x2 = self.pwc1(x)
        x2 = self.gn_pwc1(x2).relu()
        x2 = self.pwc2(x2)
        x2 = self.gn_pwc2(x2)
        
        return (x1 + x2).sigmoid()


class AttentionalFeatureFusionFPN(nn.Module):
    """ Deprecated, please use ModulerFPN with iAFF=True. Re-implementation of the paper: "Attentional Feature Fusion" """
    
    def __init__(self, in_channels_list, out_channels, extra_blocks=None):
        super().__init__()
        self.inner_blocks = nn.ModuleList()
        self.layer_blocks = nn.ModuleList()
        for in_channels in in_channels_list:
            if in_channels == 0:
                continue
            inner_block_module = nn.Conv2d(in_channels, out_channels, 1)
            layer_block_module = nn.Conv2d(out_channels, out_channels, 3, padding=1)
            self.inner_blocks.append(inner_block_module)
            self.layer_blocks.append(layer_block_module)

        # initialize parameters now to avoid modifying the initialization of top_blocks
        for m in self.children():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, a=1)
                nn.init.constant_(m.bias, 0)

        if extra_blocks is not None:
            assert isinstance(extra_blocks, torchvision.ops.feature_pyramid_network.ExtraFPNBlock)
        self.extra_blocks = extra_blocks
        
        self.MSCAM1s = nn.ModuleList([MSCAM(out_channels, 4) for _ in range(len(in_channels_list))])
        self.MSCAM2s = nn.ModuleList([MSCAM(out_channels, 4) for _ in range(len(in_channels_list))])
    
    @staticmethod
    def iAFF(MSCAM1, MSCAM2, x, y):
        assert x.shape == y.shape, f"input shape is not the same: {x.shape}, {y.shape}"
        att_weight = MSCAM2(MSCAM1(x + y))
        return att_weight * x + (1 - att_weight) * y
    
    def forward(self, x):
        """
        Computes the FPN for a set of feature maps.

        Arguments:
            x (OrderedDict[Tensor]): feature maps for each feature level.

        Returns:
            results (OrderedDict[Tensor]): feature maps after FPN layers.
                They are ordered from highest resolution first.
        """
        # unpack OrderedDict into two lists for easier handling
        names = list(x.keys())
        x = list(x.values())

        last_inner = self.inner_blocks[-1](x[-1])
        results = []
        results.append(self.layer_blocks[-1](last_inner))
        for i, (feature, inner_block, layer_block) in enumerate(zip(
            x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1],
        )):
            if not inner_block:
                continue
            inner_lateral = inner_block(feature)
            feat_shape = inner_lateral.shape[-2:]
            inner_top_down = nn.functional.interpolate(last_inner, size=feat_shape, mode="nearest")
            last_inner = self.iAFF(self.MSCAM1s[i], self.MSCAM2s[i], inner_lateral, inner_top_down)
            results.insert(0, layer_block(last_inner))

        if self.extra_blocks is not None:
            results, names = self.extra_blocks(results, x, names)

        # make it back an OrderedDict
        out = OrderedDict([(k, v) for k, v in zip(names, results)])

        return out


class AttentionGuidedFPN(nn.Module):
    """ Deprecated, please use ModulerFPN with ACFPN=True. Re-implementation of the CEM module in the paper:
    "Attention-guided Context Feature Pyramid Network for Object Detection".
    Theis implementation is based on this repository:
    https://github.com/Caojunxu/AC-FPN
    """
    def __init__(self, in_channels_list, out_channels, extra_blocks=None):
        super().__init__()
        self.inner_blocks = nn.ModuleList()
        self.layer_blocks = nn.ModuleList()
        for in_channels in in_channels_list:
            if in_channels == 0:
                continue
            inner_block_module = nn.Conv2d(in_channels, out_channels, 1)
            layer_block_module = nn.Conv2d(out_channels, out_channels, 3, padding=1)
            self.inner_blocks.append(inner_block_module)
            self.layer_blocks.append(layer_block_module)

        # initialize parameters now to avoid modifying the initialization of top_blocks
        for m in self.children():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, a=1)
                nn.init.constant_(m.bias, 0)

        if extra_blocks is not None:
            assert isinstance(extra_blocks, torchvision.ops.feature_pyramid_network.ExtraFPNBlock)
        self.extra_blocks = extra_blocks
        self.num_dilations = [3, 6, 12, 18, 24]
        aspp_blocks = []
        dropout0 = 0.1
        d_feature0 = 512
        d_feature1 = 256
        dim_in = in_channels_list[-1]
        for i, dilation in enumerate(self.num_dilations):
            aspp_blocks.append(self.dense_aspp_block(
                input_num=dim_in + d_feature1 * i,
                num1=d_feature0,
                num2=d_feature1,
                dilation_rate=dilation,
                drop_out=dropout0
            ))
        self.aspp_blocks = torch.nn.ModuleList(aspp_blocks)
        self.CEM_final_conv = torch.nn.Conv2d(len(self.num_dilations) * d_feature1, out_channels, 1)
        self.CEM_final_gn = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels)
    
    @staticmethod
    def dense_aspp_block(input_num, num1, num2, dilation_rate, drop_out):
        return torch.nn.Sequential(
            torch.nn.Conv2d(input_num, num1, kernel_size=1),
            torch.nn.GroupNorm(num_groups=32, num_channels=num1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(num1, num2, kernel_size=3, padding=dilation_rate, dilation=dilation_rate),
            torch.nn.ReLU(),
            torch.nn.Dropout(drop_out)
        )

    @staticmethod
    def iAFF(MSCAM1, MSCAM2, x, y):
        assert x.shape == y.shape, f"input shape is not the same: {x.shape}, {y.shape}"
        att_weight = MSCAM2(MSCAM1(x + y))
        return att_weight * x + (1 - att_weight) * y

    def dense_aspp_forward(self, _input):
        conv_outs = []
        
        conv_out = self.aspp_blocks[0](_input)
        if 0 != len(self.num_dilations) - 1:
            x = torch.cat((conv_out, _input), dim=1)
            conv_outs.append(conv_out)
            
        for i, dilation in enumerate(self.num_dilations[1:], 1):
            conv_out = self.aspp_blocks[i](x)
            if i != len(self.num_dilations) - 1:
                x = torch.cat((conv_out, x), dim=1)
            conv_outs.append(conv_out)
        x = torch.cat(conv_outs, dim=1)
        x = self.CEM_final_conv(x)
        x = self.CEM_final_gn(x)
        return x
    
    def forward(self, x):
        """
        Computes the FPN for a set of feature maps.

        Arguments:
            x (OrderedDict[Tensor]): feature maps for each feature level.

        Returns:
            results (OrderedDict[Tensor]): feature maps after FPN layers.
                They are ordered from highest resolution first.
        """
        # unpack OrderedDict into two lists for easier handling
        names = list(x.keys())
        x = list(x.values())

        last_inner = self.inner_blocks[-1](x[-1]) + self.dense_aspp_forward(x[-1])
        results = []
        results.append(self.layer_blocks[-1](last_inner))
        for feature, inner_block, layer_block in zip(
            x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1]
        ):
            if not inner_block:
                continue
            inner_lateral = inner_block(feature)
            feat_shape = inner_lateral.shape[-2:]
            inner_top_down = torch.nn.functional.interpolate(last_inner, size=feat_shape, mode="nearest")
            last_inner = inner_lateral + inner_top_down
            results.insert(0, layer_block(last_inner))

        if self.extra_blocks is not None:
            results, names = self.extra_blocks(results, x, names)

        # make it back an OrderedDict
        out = OrderedDict([(k, v) for k, v in zip(names, results)])

        return out


class ModulerFPN(nn.Module):
    """
    Combination of ACFPN and iAFF
    """
    def __init__(self, in_channels_list, out_channels, extra_blocks=None):
        super().__init__()
        self.ACFPN = self.iAFF = False
        self.in_channels_list = in_channels_list
        self.out_channels = out_channels
        self.inner_blocks = nn.ModuleList()
        self.layer_blocks = nn.ModuleList()
        for in_channels in self.in_channels_list:
            if in_channels == 0:
                continue
            inner_block_module = nn.Conv2d(in_channels, out_channels, 1)
            layer_block_module = nn.Conv2d(out_channels, out_channels, 3, padding=1)
            self.inner_blocks.append(inner_block_module)
            self.layer_blocks.append(layer_block_module)

        # initialize parameters now to avoid modifying the initialization of top_blocks
        for m in self.children():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, a=1)
                nn.init.constant_(m.bias, 0)

        if extra_blocks is not None:
            assert isinstance(extra_blocks, torchvision.ops.feature_pyramid_network.ExtraFPNBlock)
        self.extra_blocks = extra_blocks
    
    def setup_ACFPN(self):
        self.ACFPN = True
        self.num_dilations = [3, 6, 12, 18, 24]
        aspp_blocks = []
        dropout0 = 0.1
        d_feature0 = 512
        d_feature1 = 256
        dim_in = self.in_channels_list[-1]
        for i, dilation in enumerate(self.num_dilations):
            aspp_blocks.append(self.dense_aspp_block(
                input_num=dim_in + d_feature1 * i,
                num1=d_feature0,
                num2=d_feature1,
                dilation_rate=dilation,
                drop_out=dropout0
            ))
        self.aspp_blocks = torch.nn.ModuleList(aspp_blocks)
        self.CEM_final_conv = torch.nn.Conv2d(len(self.num_dilations) * d_feature1, self.out_channels, 1)
        self.CEM_final_gn = torch.nn.GroupNorm(num_groups=32, num_channels=self.out_channels)
        
    def setup_iAFF(self):
        self.iAFF = True
        self.MSCAM1s = nn.ModuleList([MSCAM(self.out_channels, 4) for _ in range(len(self.in_channels_list))])
        self.MSCAM2s = nn.ModuleList([MSCAM(self.out_channels, 4) for _ in range(len(self.in_channels_list))])

    @staticmethod
    def apply_iAFF(MSCAM1, MSCAM2, x, y):
        assert x.shape == y.shape, f"input shape is not the same: {x.shape}, {y.shape}"
        att_weight = MSCAM2(MSCAM1(x + y))
        return att_weight * x + (1 - att_weight) * y
    
    @staticmethod
    def dense_aspp_block(input_num, num1, num2, dilation_rate, drop_out):
        return torch.nn.Sequential(
            torch.nn.Conv2d(input_num, num1, kernel_size=1),
            torch.nn.GroupNorm(num_groups=32, num_channels=num1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(num1, num2, kernel_size=3, padding=dilation_rate, dilation=dilation_rate),
            torch.nn.ReLU(),
            torch.nn.Dropout(drop_out)
        )

    def dense_aspp_forward(self, _input):
        conv_outs = []
        
        conv_out = self.aspp_blocks[0](_input)
        if 0 != len(self.num_dilations) - 1:
            x = torch.cat((conv_out, _input), dim=1)
            conv_outs.append(conv_out)
            
        for i, dilation in enumerate(self.num_dilations[1:], 1):
            conv_out = self.aspp_blocks[i](x)
            if i != len(self.num_dilations) - 1:
                x = torch.cat((conv_out, x), dim=1)
            conv_outs.append(conv_out)
        x = torch.cat(conv_outs, dim=1)
        x = self.CEM_final_conv(x)
        x = self.CEM_final_gn(x)
        return x
    
    def forward(self, x):
        """
        Computes the FPN for a set of feature maps.

        Arguments:
            x (OrderedDict[Tensor]): feature maps for each feature level.

        Returns:
            results (OrderedDict[Tensor]): feature maps after FPN layers.
                They are ordered from highest resolution first.
        """
        # unpack OrderedDict into two lists for easier handling
        names = list(x.keys())
        x = list(x.values())

        last_inner = self.inner_blocks[-1](x[-1])
        
        if self.ACFPN:
            last_inner += self.dense_aspp_forward(x[-1])
            
        results = []
        results.append(self.layer_blocks[-1](last_inner))
        for i, (feature, inner_block, layer_block) in enumerate(zip(
            x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1]
        )):
            if not inner_block:
                continue
            inner_lateral = inner_block(feature)
            feat_shape = inner_lateral.shape[-2:]
            inner_top_down = torch.nn.functional.interpolate(last_inner, size=feat_shape, mode="nearest")
            
            if not self.iAFF:
                last_inner = inner_lateral + inner_top_down
            else:
                last_inner = self.apply_iAFF(self.MSCAM1s[i], self.MSCAM2s[i], inner_lateral, inner_top_down)
                
            results.insert(0, layer_block(last_inner))

        if self.extra_blocks is not None:
            results, names = self.extra_blocks(results, x, names)

        # make it back an OrderedDict
        out = OrderedDict([(k, v) for k, v in zip(names, results)])

        return out


In [None]:
def get_cls_model(model_name):
    assert model_name in timm.list_models(), 'Other model is not implemented now.'
    model = timm.create_model(
        model_name,
        features_only=True,
    )
    change_layer(model, torch.nn.BatchNorm2d, torchvision.ops.misc.FrozenBatchNorm2d, params_map={'num_features':'num_features', 'eps':'eps'})
    return model

class FuseClsFeatsBackboneBody(torch.nn.Module):
    def __init__(self,
                 det_backbone_body,
                 cls_model,
                 cls_size,
                 cls_norm_mean,
                 cls_norm_std,
                 cls_feat_name=[1, 2, 3, 4],
                 det_feat_ch=[],
                 cls_feat_ch=[],
                ):
        super().__init__()
        self.det_backbone_body = det_backbone_body
        self.cls_model = cls_model
        self.cls_feat_name = cls_feat_name
        self.att_convs = torch.nn.ModuleList([torch.nn.Conv2d(in_ch, out_ch, 1) for in_ch, out_ch in zip(cls_feat_ch, det_feat_ch)])
        self.cls_normalize = torchvision.transforms.Normalize(cls_norm_mean, cls_norm_std)
        self.cls_resize = torchvision.transforms.Resize((cls_size, cls_size))
    
    def forward(self, image):
        det_feats = self.det_backbone_body(image)
        det_feats_keys = det_feats.keys()
        det_feats = list(det_feats.values())
        cls_input_image = torch.stack([self.cls_normalize(self.cls_resize(img)) for img in self.org_image])
        with torch.no_grad():
            cls_feats = self.cls_model(cls_input_image)
        cls_feats = [cls_feats[k] for k in self.cls_feat_name]
        det_feats = OrderedDict([(det_feats_key, torch.nn.functional.interpolate(conv(cls_feat).sigmoid(), det_feat.shape[2:]) * det_feat) for cls_feat, conv, det_feat, det_feats_key in zip(cls_feats, self.att_convs, det_feats, det_feats_keys)])
        return det_feats

def add_backup_org_image(model):
    model.org_forward = model.forward
    def forward_with_backup(self, image, target=None):
        self.backbone.body.org_image = copy.deepcopy(image)
        return self.org_forward(image, target)
    model.forward = forward_with_backup.__get__(model, model.__class__)

In [None]:
def GeneralEnsemble(dets, iou_thresh = 0.5, weights=None):
    assert(type(iou_thresh) == float)
    
    ndets = len(dets)
    
    if weights is None:
        w = 1/float(ndets)
        weights = [w]*ndets
    else:
        assert(len(weights) == ndets)
        
        s = sum(weights)
        for i in range(0, len(weights)):
            weights[i] /= s

    out = list()
    used = list()
    
    for idet in range(0,ndets):
        det = dets[idet]
        for box in det:
            if box in used:
                continue
                
            used.append(box)
            # Search the other detectors for overlapping box of same class
            found = []
            for iodet in range(0, ndets):
                odet = dets[iodet]
                
                if odet == det:
                    continue
                
                bestbox = None
                bestiou = iou_thresh
                for obox in odet:
                    if not obox in used:
                        # Not already used
                        if box[4] == obox[4]:
                            # Same class
                            iou = computeIOU(box, obox)
                            if iou > bestiou:
                                bestiou = iou
                                bestbox = obox
                                
                if not bestbox is None:
                    w = weights[iodet]
                    found.append((bestbox,w))
                    used.append(bestbox)
                            
            # Now we've gone through all other detectors
            if len(found) == 0:
                new_box = list(box)
                new_box[5] /= ndets
                out.append(new_box)
            else:
                allboxes = [(box, weights[idet])]
                allboxes.extend(found)
                
                xc = 0.0
                yc = 0.0
                bw = 0.0
                bh = 0.0
                conf = 0.0
                
                wsum = 0.0
                for bb in allboxes:
                    w = bb[1]
                    wsum += w

                    b = bb[0]
                    xc += w*b[0]
                    yc += w*b[1]
                    bw += w*b[2]
                    bh += w*b[3]
                    conf += w*b[5]
                
                xc /= wsum
                yc /= wsum
                bw /= wsum
                bh /= wsum    

                new_box = [xc, yc, bw, bh, box[4], conf]
                out.append(new_box)
    return out
    
def getCoords(box):
    x1 = float(box[0]) - float(box[2])/2
    x2 = float(box[0]) + float(box[2])/2
    y1 = float(box[1]) - float(box[3])/2
    y2 = float(box[1]) + float(box[3])/2
    return x1, x2, y1, y2
    
def computeIOU(box1, box2):
    x11, x12, y11, y12 = getCoords(box1)
    x21, x22, y21, y22 = getCoords(box2)
    
    x_left   = max(x11, x21)
    y_top    = max(y11, y21)
    x_right  = min(x12, x22)
    y_bottom = min(y12, y22)

    if x_right < x_left or y_bottom < y_top:
        return 0.0    
        
    intersect_area = (x_right - x_left) * (y_bottom - y_top)
    box1_area = (x12 - x11) * (y12 - y11)
    box2_area = (x22 - x21) * (y22 - y21)        
    
    iou = intersect_area / (box1_area + box2_area - intersect_area)
    return iou

def ensemble_outputs(outputs):
    dets = [[[(box[0]+box[2])/2, (box[1]+box[3])/2, box[2]-box[0], box[3]-box[1], label, score] for box, label, score in zip(output['boxes'].tolist(), output['labels'].tolist(), output['scores'].tolist())] for output in outputs]
    ensemble_dets = GeneralEnsemble(dets)
    #print(ensemble_dets)
    ensemble_outs = {'boxes': torch.FloatTensor([[det[0]-det[2]/2, det[1]-det[3]/2, det[0]+det[2]/2, det[1]+det[3]/2] for det in ensemble_dets]).view(-1, 4), 'labels': torch.LongTensor([det[4] for det in ensemble_dets]), 'scores': torch.FloatTensor([det[5] for det in ensemble_dets])}
    return ensemble_outs

In [None]:
#cls_transforms = torch.nn.Sequential(
#    torchvision.transforms.Resize((args_dict['cls']['input_size'], args_dict['cls']['input_size'])),
#    torchvision.transforms.Normalize(args_dict['cls']['norm_mean'], args_dict['cls']['norm_std']),
#)
#transforms = torchvision.transforms.Compose([
#    torchvision.transforms.ToTensor(),
#    SplitTransformPath(2),
#    SplitTransformWrapper(cls_transforms, [0]),
#])

det_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
])
#cls_collate = torch.utils.data._utils.collate.default_collate
det_collate = lambda x: x
#def collates(inputs):
#    return cls_collate([inp[0][0] for inp in inputs]) , det_collate([inp[0][1] for inp in inputs]), [inp[1] for inp in inputs]
cls_dataset = Dataset()
det_dataset = Dataset(det_transforms)
cls_dataloader = torch.utils.data.DataLoader(cls_dataset, batch_size=args_dict['batch_size'], num_workers=4, pin_memory=True)
det_dataloader = torch.utils.data.DataLoader(det_dataset, batch_size=args_dict['batch_size'], num_workers=4, pin_memory=True, collate_fn=det_collate)

In [None]:
def get_cls_outs(model, dataloader, use_sigmoid=False, flipH=False):
    outs = []
    for data in tqdm(dataloader):
        data = data.to(device)
        if flipH:
            data = data.flip(-1)
        out = model(data)
        outs.append(out.detach().cpu())
    outs = torch.cat(outs)
    if use_sigmoid:
        outs = outs.sigmoid()
    else:
        outs = outs.softmax(dim=1)
    return outs

with torch.no_grad():
    cls_outs = []
    for ckpt_path in args_dict['cls']['ckpt_paths']:
        ckpt = torch.load(ckpt_path, map_location='cpu')
        args_dict['cls']['norm_mean'] = ckpt['norm_mean']
        args_dict['cls']['norm_std'] = ckpt['norm_std']
        
        cls_transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Resize((args_dict['cls']['input_size'], args_dict['cls']['input_size'])),
            torchvision.transforms.Normalize(args_dict['cls']['norm_mean'], args_dict['cls']['norm_std']),
        ])
        cls_dataloader.dataset.transforms = cls_transforms

        if ckpt['args_dict']['model_name'] == 'inceptionresnetv2':
            model = pretrainedmodels.__dict__['inceptionresnetv2'](pretrained=False, num_classes=4)
        elif ckpt['args_dict']['model_name'] in timm.list_models():
            model = timm.create_model(ckpt['args_dict']['model_name'], pretrained=False, num_classes=4)

        model.load_state_dict(ckpt['model'])    
        model.eval().to(device)
        outs = get_cls_outs(
            model, cls_dataloader,
            use_sigmoid=ckpt['args_dict'].get('use_sigmoid', False), flipH=False)
        if args_dict['cls']['flip_TTA']:
            outs2 = get_cls_outs(
                model, cls_dataloader,
                use_sigmoid=ckpt['args_dict'].get('use_sigmoid', False), flipH=True)
            outs = (outs + outs2) / 2
        cls_outs.append(outs)
    cls_outs = torch.stack(cls_outs)

In [None]:
with torch.no_grad():
    fpn_bak = copy.deepcopy(FeaturePyramidNetwork)
    det_outs = []
    for ckpt_path in args_dict['det']['ckpt_paths']:
        ckpt = torch.load(ckpt_path, map_location='cpu')
        if ckpt['args_dict']['ACFPN'] or ckpt['args_dict']['iAFF']:
            FeaturePyramidNetwork = ModulerFPN 
        else:
            FeaturePyramidNetwork = fpn_bak
        
        args_dict['det']['norm_mean'] = ckpt['args_dict']['normalize_mean']
        args_dict['det']['norm_std'] = ckpt['args_dict']['normalize_std']
        backbone = timm_fpn_backbone(ckpt['args_dict']['backbone_body'], pretrained=False, norm_layer=args_dict['det']['norm_layer'])
        model = FasterRCNN(
            backbone,
            num_classes=2,
            min_size=ckpt['args_dict']['input_size'],
            max_size=ckpt['args_dict']['input_size'],
            image_mean=args_dict['det']['norm_mean'],
            image_std=args_dict['det']['norm_std'],
            box_score_thresh=args_dict['det']['box_score_thresh'],
        )
        if ckpt['args_dict']['ACFPN']:
            model.backbone.fpn.setup_ACFPN()
        if ckpt['args_dict']['iAFF']:
            model.backbone.fpn.setup_iAFF()
        cls_model = get_cls_model(ckpt['cls_model_name'])
        cls_feat_name = [1, 2, 3, 4]
        model.backbone.body = FuseClsFeatsBackboneBody(
            model.backbone.body,
            cls_model,
            ckpt['cls_size'],
            ckpt['cls_norm_mean'],
            ckpt['cls_norm_std'],
            cls_feat_name=cls_feat_name,
            det_feat_ch=[model.backbone.body.timm_model.feature_info.info[name]['num_chs'] for name in [1,2,3,4]],
            cls_feat_ch=[cls_model.feature_info.info[name]['num_chs'] for name in cls_feat_name],
        )
        add_backup_org_image(model)
        model.load_state_dict(ckpt['model'])
        model.eval().to(device)
        outs = []
        for data in tqdm(det_dataloader):
            data = [d.to(device) for d in data]
            out = model(data)
            out = [{k: v.detach().cpu() for k, v in o.items()} for o in out]
            if args_dict['det']['flip_TTA']:
                shapes = [d.shape[-2:] for d in data]
                data = [d.flip(-1) for d in data]
                out2 = model(data)
                out2 = [{k: v.detach().cpu() for k, v in o.items()} for o in out2]
                for i, shape in enumerate(shapes):
                    out2[i]['boxes'][:, [0, 2]] = shape[1] - out2[i]['boxes'][:, [2, 0]]
                out = [ensemble_outputs([o, o2]) for o, o2 in zip(out, out2)]
            outs += out           
        det_outs.append(outs)

In [None]:
from collections import defaultdict
study2idxes = defaultdict(list)
for i, path in enumerate(target_paths):
    study2idxes[path.split('/')[-3]].append(i)

det_confs = torch.tensor([[out['scores'].max().item() if len(out['scores']) > 0 else 0. for out in outs] for outs in zip(*det_outs)], dtype=cls_outs.dtype, device=cls_outs.device).mean(dim=1)
cls_confs = cls_outs.mean(dim=0)
for cls_idx in args_dict['cls']['cascade_clses']:
    if cls_idx == 0:
        cls_outs[:, :, 0] = 1 - det_confs[None] * (1 - cls_outs[:, :, 0])
    else:
        cls_outs[:, :, cls_idx] *= det_confs[None]


ids, pred_strs = [], []
for study, idxes in study2idxes.items():
    ids.append(study + '_study')
    scores = cls_outs[:, idxes].mean(dim=(0, 1))
    pred_strs.append(' '.join([f'{name} {score} 0 0 1 1' for name, score in zip(cls_names, scores)]))

#assert len(det_outs) == 1, 'ensembles not implemented.'
for i, (accno, org_WH) in enumerate(accno2WH.items()):
    ids.append(accno + '_image')
    #det_out = det_outs[0][i]
    det_out = ensemble_outputs([det_out[i] for det_out in det_outs])
    if len(det_out['boxes']) == 0:
        pred_strs.append('none 1 0 0 1 1')
    else:
        xmins = (det_out['boxes'][:, 0] * (org_WH[0] / 1024)).tolist()
        ymins = (det_out['boxes'][:, 1] * (org_WH[1] / 1024)).tolist()
        xmaxes = (det_out['boxes'][:, 2] * (org_WH[0] / 1024)).tolist()
        ymaxes = (det_out['boxes'][:, 3] * (org_WH[1] / 1024)).tolist()
        pred_strs.append(
            f"none {1 - det_out['scores'].max().item()} 0 0 1 1 " + 
            ' '.join([f'opacity {score.item()} {xmin} {ymin} {xmax} {ymax}' for score, xmin, ymin, xmax, ymax in zip(det_out['scores'], xmins, ymins, xmaxes, ymaxes)]))

out_df = pd.DataFrame(OrderedDict([('Id', ids), ('PredictionString', pred_strs)]))

In [None]:
out_df.to_csv('submission.csv', index=False)

In [None]:
os.system(f'rm -rf {target_root_folder}')