In [None]:
import sys
import pandas as pd

In [None]:
sample_sub = pd.read_csv('/kaggle/input/rsna-str-pulmonary-embolism-detection/sample_submission.csv')

In [None]:
commit_flag = False

In [None]:
# for commit
if len(sample_sub)==152703:
    sample_sub.to_csv('submission.csv', index=False)
    commit_flag = True
    sys.exit(1)

In [None]:
!cp ../input/gdcm-conda-install/gdcm.tar . 
!tar -xvzf gdcm.tar
!conda install --offline ./gdcm/gdcm-2.8.9-py37h71b2a6d_0.tar.bz2
!pip install /kaggle/input/software/timm-0.2.1-py3-none-any.whl > /dev/null
!pip install /kaggle/input/software/omegaconf-2.0.2-py3-none-any.whl > /dev/null
!pip install /kaggle/input/software/kornia-0.4.0-py2.py3-none-any.whl > /dev/null

In [None]:
import gc
import math
import pickle
import numpy as np
import pandas as pd
import random
import pydicom
import cv2
import albumentations as A
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import kornia
from omegaconf import OmegaConf
from torch.nn import Parameter
from torch.cuda.amp import autocast
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from joblib import Parallel, delayed

In [None]:
seed = 7
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
label_cols = [
      "pe_present_on_image",
      "negative_exam_for_pe",
      "indeterminate",
      "chronic_pe",
      "acute_and_chronic_pe",
      "central_pe",
      "leftsided_pe",
      "rightsided_pe",
      "rv_lv_ratio_gte_1",
      "rv_lv_ratio_lt_1",
    ]
pred_cols = [f'{col}_pred' for col in label_cols]


In [None]:
def softmax(x, axis):
    u = np.sum(np.exp(x), axis=axis, keepdims=True)
    return np.exp(x)/u

def postprocess(x, s=2.0):
    logit = np.log(x/(1 - x))
    logit = logit + s
    sigmoid = 1 / (1 + np.exp(-logit))
    return sigmoid

def consistency_check(df):
    df['positive_images_in_exam'] = df['StudyInstanceUID'].map(df.groupby(['StudyInstanceUID'])['pe_present_on_image'].max())
    df_pos = df.loc[df.positive_images_in_exam >  0.5]
    df_neg = df.loc[df.positive_images_in_exam <= 0.5]
    rule1a = df_pos.loc[((df_pos['rv_lv_ratio_lt_1']  >  0.5)  & 
                         (df_pos['rv_lv_ratio_gte_1'] >  0.5)) | 
                        ((df_pos['rv_lv_ratio_lt_1']  <= 0.5)  & 
                         (df_pos['rv_lv_ratio_gte_1'] <= 0.5))].reset_index(drop = True)
    rule1a['broken_rule'] = '1a'
    rule1b = df_pos.loc[(df_pos['central_pe']    <= 0.5) & 
                        (df_pos['rightsided_pe'] <= 0.5) & 
                        (df_pos['leftsided_pe']  <= 0.5)].reset_index(drop = True)
    rule1b['broken_rule'] = '1b'

    rule1c = df_pos.loc[(df_pos['acute_and_chronic_pe'] > 0.5) & 
                        (df_pos['chronic_pe']           > 0.5)].reset_index(drop = True)
    rule1c['broken_rule'] = '1c'

    rule1d = df_pos.loc[(df_pos['indeterminate']        > 0.5) | 
                        (df_pos['negative_exam_for_pe'] > 0.5)].reset_index(drop = True)
    rule1d['broken_rule'] = '1d'
    rule2a = df_neg.loc[((df_neg['indeterminate']        >  0.5)  & 
                         (df_neg['negative_exam_for_pe'] >  0.5)) | 
                        ((df_neg['indeterminate']        <= 0.5)  & 
                         (df_neg['negative_exam_for_pe'] <= 0.5))].reset_index(drop = True)
    rule2a['broken_rule'] = '2a'

    rule2b = df_neg.loc[(df_neg['rv_lv_ratio_lt_1']     > 0.5) | 
                        (df_neg['rv_lv_ratio_gte_1']    > 0.5) |
                        (df_neg['central_pe']           > 0.5) | 
                        (df_neg['rightsided_pe']        > 0.5) | 
                        (df_neg['leftsided_pe']         > 0.5) |
                        (df_neg['acute_and_chronic_pe'] > 0.5) | 
                        (df_neg['chronic_pe']           > 0.5)].reset_index(drop = True)
    rule2b['broken_rule'] = '2b'
    errors = pd.concat([rule1a, rule1b, rule1c, rule1d, rule2a, rule2b], axis = 0)
    return errors['broken_rule'].value_counts()

def satisfy_label_consistency(df, delta=1):
    rule_breaks = consistency_check(df).index
    print(rule_breaks)
    if len(rule_breaks) > 0:
        df["positive_exam_for_pe"] = 1 - df["negative_exam_for_pe"]
        df.loc[
            df.query("positive_exam_for_pe <= pe_present_on_image").index,
            "pe_present_on_image",
        ] = df.loc[
            df.query("positive_exam_for_pe <= pe_present_on_image").index,
            "positive_exam_for_pe",
        ]
        rule_breaks = consistency_check(df).index
        df["positive_images_in_exam"] = df["StudyInstanceUID"].map(
            df.groupby(["StudyInstanceUID"])["pe_present_on_image"].max()
        )
        df_pos = df.query("positive_images_in_exam > 0.5")
        df_neg = df.query("positive_images_in_exam <= 0.5")
        if "1a" in rule_breaks:
            rv_filter = "rv_lv_ratio_gte_1 > 0.5 & rv_lv_ratio_lt_1 > 0.5"
            while len(df_pos.query(rv_filter)) > 0:
                df_pos.loc[df_pos.query(rv_filter).index, "rv_min"] = df_pos.query(
                    rv_filter
                )[label_cols[8:]].min(1)
                for rv_col in label_cols[8:]:
                    df_pos.loc[
                        df_pos.query(rv_filter + f" & {rv_col} == rv_min").index, rv_col
                    ] = postprocess(
                        df_pos.query(rv_filter + f" & {rv_col} == rv_min")[
                            rv_col
                        ].values,
                        s=-0.1,
                    )
            rv_filter = "rv_lv_ratio_gte_1 <= 0.5 & rv_lv_ratio_lt_1 <= 0.5"
            while len(df_pos.query(rv_filter)) > 0:
                df_pos.loc[df_pos.query(rv_filter).index, "rv_max"] = df_pos.query(
                    rv_filter
                )[label_cols[8:]].max(1)
                for rv_col in label_cols[8:]:
                    df_pos.loc[
                        df_pos.query(rv_filter + f" & {rv_col} == rv_max").index, rv_col
                    ] = postprocess(
                        df_pos.query(rv_filter + f" & {rv_col} == rv_max")[
                            rv_col
                        ].values,
                        s=0.1,
                    )
            df.loc[df_pos.index, label_cols[8:]] = df_pos[label_cols[8:]]
        if "1b" in rule_breaks:
            pe_filter = " & ".join([f"{col} <= 0.5" for col in label_cols[5:8]])
            while "1b" in consistency_check(df).index:
                for col in label_cols[5:8]:
                    df_pos.loc[df_pos.query(pe_filter).index, col] = postprocess(
                        df_pos.loc[df_pos.query(pe_filter).index, col], s=0.1
                    )
                df.loc[df_pos.index, label_cols[5:8]] = df_pos[label_cols[5:8]].values
        if "1c" in rule_breaks:
            chronic_filter = "chronic_pe > 0.5 & acute_and_chronic_pe > 0.5"
            df_pos.loc[df_pos.query(chronic_filter).index, label_cols[3:5]] = softmax(
                df_pos.query(chronic_filter)[label_cols[3:5]].values, axis=1
            )
            df.loc[df_pos.index, label_cols[3:5]] = df_pos[label_cols[3:5]]
        if "1d" in rule_breaks:
            neg_filter = "negative_exam_for_pe > 0.5 | indeterminate > 0.5"
            while "1d" in consistency_check(df).index:
                for col in label_cols[1:3]:
                    df_pos.loc[df_pos.query(neg_filter).index, col] = postprocess(
                        df_pos.loc[df_pos.query(neg_filter).index, col], s=-0.1
                    )
                df.loc[df_pos.index, label_cols[1:3]] = df_pos[label_cols[1:3]].values
        if "2a" in rule_breaks:
            neg_filter = "negative_exam_for_pe > 0.5 & indeterminate > 0.5"
            while len(df_neg.query(neg_filter)) > 0:
                df_neg.loc[df_neg.query(neg_filter).index, "neg_min"] = df_neg.query(
                    neg_filter
                )[label_cols[1:3]].min(1)
                for neg_col in label_cols[1:3]:
                    df_neg.loc[
                        df_neg.query(neg_filter + f" & {neg_col} == neg_min").index,
                        neg_col,
                    ] = postprocess(
                        df_neg.query(neg_filter + f" & {neg_col} == neg_min")[
                            neg_col
                        ].values,
                        s=-0.1,
                    )
            neg_filter = "negative_exam_for_pe <= 0.5 & indeterminate <= 0.5"
            while len(df_neg.query(neg_filter)) > 0:
                df_neg.loc[df_neg.query(neg_filter).index, "neg_max"] = df_neg.query(
                    neg_filter
                )[label_cols[1:3]].max(1)
                for neg_col in label_cols[1:3]:
                    df_neg.loc[
                        df_neg.query(neg_filter + f" & {neg_col} == neg_max").index,
                        neg_col,
                    ] = postprocess(
                        df_neg.query(neg_filter + f" & {neg_col} == neg_max")[
                            neg_col
                        ].values,
                        s=0.1,
                    )
            df.loc[df_neg.index, label_cols[1:3]] = df_neg[label_cols[1:3]]
        if "2b" in rule_breaks:
            while "2b" in consistency_check(df).index:
                for col in label_cols[3:]:
                    df_neg.loc[df_neg.query(f"{col} > 0.5").index, col] = postprocess(
                        df_neg.loc[df_neg.query(f"{col} > 0.5").index, col], s=-0.1
                    )
                df.loc[df_neg.index, label_cols[3:]] = df_neg[label_cols[3:]].values
    return df

In [None]:
def load_scans(dcm_dir_path):
    f = [pydicom.dcmread(file) for file in dcm_dir_path.glob('*.dcm')]
    return f

def get_data(dcm_dir_path):
    scans = load_scans(dcm_dir_path)
    scans.sort(key = lambda x: float(x.ImagePositionPatient[2]))
    M = float(scans[0].RescaleSlope)
    B = float(scans[0].RescaleIntercept)
    meta_columns =  [
      "SOPInstanceUID",
      "KVP",
      "XRayTubeCurrent",
      "Exposure",
      "SliceThickness",
      "ImagePositionPatient_x",
      "ImagePositionPatient_y",
      "ImagePositionPatient_z",
    ]
    meta_data = []
    image_data = []
    for s in scans:
        meta_data.append([s[col].value for col in meta_columns[: -3]])
        meta_data[-1].extend([s['ImagePositionPatient'].value[i] for i in range(3)])
        pixel_array = s.pixel_array * M + B
        image_data.append(pixel_array)
    return np.array(image_data), pd.DataFrame(np.array(meta_data), columns = meta_columns)

def window(img, WL=50, WW=350):
    upper, lower = WL+WW//2, WL-WW//2
    X = np.clip(img.copy(), lower, upper)
    X = X - np.min(X)
    X = X / np.max(X)
    X = (X*255.0).astype('uint8')
    return X

def convert_image(img, WL_list, WW_list):
    return np.stack([window(img, WL, WW) for WL, WW in zip(WL_list, WW_list)], axis=2)

In [None]:
MODE = 'private'

In [None]:
data_root = Path('/kaggle/input/rsna-str-pulmonary-embolism-detection')

In [None]:
test_df = pd.read_csv(data_root / 'test.csv')
test_df = test_df[['StudyInstanceUID', 'SeriesInstanceUID']].drop_duplicates().reset_index(drop=True)
test_df['dcm_dir_path'] = test_df.apply(lambda x: data_root / 'test' / x["StudyInstanceUID"] / x["SeriesInstanceUID"], axis=1)

In [None]:
public_test_df = pd.read_csv('/kaggle/input/rsna2020-public-dataset/test.csv')
public_test_df = public_test_df[['StudyInstanceUID', 'SeriesInstanceUID']].drop_duplicates().reset_index(drop=True)
public_test_df['dcm_dir_path'] = public_test_df.apply(lambda x: data_root / 'test' / x["StudyInstanceUID"] / x["SeriesInstanceUID"], axis=1)
public_sample_sub = pd.read_csv('/kaggle/input/rsna2020-public-dataset/sample_submission.csv')

In [None]:
if MODE == 'public':
    test_df_ = public_test_df
elif MODE == 'private':
    ids = public_test_df['StudyInstanceUID'].unique()
    test_df_ = test_df.query('StudyInstanceUID not in @ids').reset_index(drop=True)
else:
    test_df_ = test_df

In [None]:
class ImageDataset(Dataset):
    def __init__(self, image):
        self.image = image
        
    def __getitem__(self, index):
        image = self.image[index].astype(np.float32) / 255
        image = kornia.image_to_tensor(image)
        out = {'image': image}
        return out

    def __len__(self):
        return len(self.image)

In [None]:
def gem(x, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p).squeeze(-1)


class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = Parameter(torch.ones(1) * p)
        self.eps = eps

    def forward(self, x):
        return gem(x, p=self.p, eps=self.eps)
    
class MeanMaxPooling(nn.Module):
    def __init__(self):
        super(MeanMaxPooling, self).__init__()
        self.pool_mean = nn.AdaptiveAvgPool2d((1, 1))
        self.pool_max = nn.AdaptiveMaxPool2d((1, 1))

    def forward(self, x):
        mean_x = self.pool_mean(x)
        max_x = self.pool_max(x)
        out = torch.cat((mean_x, max_x), dim=1)
        return out

class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

In [None]:
class ImageModel(nn.Module):
    def __init__(self, cfg):
        super(ImageModel, self).__init__()
        self.arch = timm.create_model(cfg.model.name, pretrained=False)
        out_channel = self.arch.num_features
        self.pool = MeanMaxPooling()
        if cfg.model.pool == 'MeanMax':
            self.pool = MeanMaxPooling()
            out_channel = out_channel * 2
        elif cfg.model.pool == 'GeM':
            self.pool = GeM()
        else:
            self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.l0 = nn.Linear(out_channel, 256)
        self.bn0 = nn.BatchNorm1d(256)
        self.act = Swish()
        self.l1 = nn.Linear(256, 10)

    def forward(self, batch):
        image = batch['image']
        out = self.arch.forward_features(image)
        out = self.pool(out)
        out = out.view(out.shape[0], -1)
        out = self.act(self.bn0(self.l0(out)))
        out = self.l1(out)
        return out

    def get_feature(self, batch):
        with torch.no_grad():
            image = batch['image']
            out = self.arch.forward_features(image)
            out = self.pool(out)
            out = out.view(out.shape[0], -1)
            return out


In [None]:
class ResNet(nn.Module):
    def __init__(self, in_features, out_features, kernel_size, dropout=0.0):
        super(ResNet, self).__init__()
        assert kernel_size % 2 == 1
        self.conv1 = nn.Sequential(
            nn.Conv1d(
                in_features,
                out_features,
                kernel_size,
                stride=1,
                padding=(kernel_size - 1) // 2,
            ),
            nn.BatchNorm1d(out_features),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(dropout),
        )
        self.conv2 = nn.Sequential(
            nn.Conv1d(
                out_features,
                out_features,
                kernel_size,
                stride=1,
                padding=(kernel_size - 1) // 2,
            ),
            nn.BatchNorm1d(out_features),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        x = self.conv1(x)
        out = self.conv2(x) + x
        return out

    
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=201):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 0:
            pe[:, 1::2] = torch.cos(position * div_term)
        else:
            pe[:, 1::2] = torch.cos(position * div_term)[:, : d_model // 2]
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[: x.size(0), :]
        return self.dropout(x)

    
    
class DeconvFeatureModel(nn.Module):
    def __init__(self, model_config):
        super(DeconvFeatureModel, self).__init__()
        self.model_config = model_config
        if model_config.backbone in ["cnn", "cnn_rnn"]:
            self.resnet1 = ResNet(
                model_config.num_feature, 512, kernel_size=3, dropout=model_config.dropout_rate
            )
            self.deconv1 = nn.ConvTranspose1d(
                512, 512, kernel_size=3, stride=2, padding=1
            )
            self.resnet2 = ResNet(
                512, 256, kernel_size=5, dropout=model_config.dropout_rate
            )
            self.deconv2 = nn.ConvTranspose1d(
                256, 256, kernel_size=3, stride=2, padding=1
            )
            self.resnet3 = ResNet(
                256, 128, kernel_size=7, dropout=model_config.dropout_rate
            )
            self.deconv3 = nn.ConvTranspose1d(
                128, 128, kernel_size=3, stride=2, padding=1
            )
            self.resnet4 = ResNet(
                128, 64, kernel_size=9, dropout=model_config.dropout_rate
            )
            self.deconv4 = nn.ConvTranspose1d(
                64, 64, kernel_size=3, stride=2, padding=1
            )
            self.resnet5 = ResNet(
                64, 32, kernel_size=11, dropout=model_config.dropout_rate
            )
            self.deconv5 = nn.ConvTranspose1d(
                32, 32, kernel_size=3, stride=2, padding=1
            )
            out_features = 512 + 256 + 128 + 64 + 32
        elif model_config.backbone == "lstm":
            self.rnn1 = nn.LSTM(
                model_config.num_feature,
                512,
                num_layers=2,
                batch_first=True,
                bidirectional=True,
            )
            self.deconv1 = nn.ConvTranspose1d(
                1024, 512, kernel_size=3, stride=2, padding=1
            )
            self.rnn2 = nn.LSTM(
                1024,
                512,
                num_layers=2,
                batch_first=True,
                bidirectional=True,
            )
            self.deconv2 = nn.ConvTranspose1d(
                1024, 512, kernel_size=3, stride=2, padding=1
            )
            out_features = 1024
        elif model_config.backbone == "gru":
            self.rnn1 = nn.GRU(
                model_config.num_feature,
                512,
                num_layers=2,
                batch_first=True,
                bidirectional=True,
            )
            self.deconv1 = nn.ConvTranspose1d(
                1024, 512, kernel_size=3, stride=2, padding=1
            )
            self.rnn2 = nn.GRU(
                1024,
                512,
                num_layers=2,
                batch_first=True,
                bidirectional=True,
            )
            self.deconv2 = nn.ConvTranspose1d(
                1024, 512, kernel_size=3, stride=2, padding=1
            )
            out_features = 1024
        elif model_config.backbone in ["transformer", "transformer_rnn"]:
            self.linear = nn.Linear(model_config.num_feature, 2048)
            self.scale = math.sqrt(2048)
            self.pe = PositionalEncoding(2048, model_config.dropout_rate)
            encoder_layer1 = nn.TransformerEncoderLayer(
                2048,
                nhead=8,
                dim_feedforward=1024,
                dropout=model_config.dropout_rate,
                activation="gelu",
            )
            self.transformer1 = nn.TransformerEncoder(encoder_layer1, 1)
            self.deconv1 = nn.ConvTranspose1d(
                2048, 1024, kernel_size=3, stride=2, padding=1
            )
            encoder_layer2 = nn.TransformerEncoderLayer(
                1024,
                nhead=8,
                dim_feedforward=1024,
                dropout=model_config.dropout_rate,
                activation="gelu",
            )
            self.transformer2 = nn.TransformerEncoder(encoder_layer2, 1)
            self.deconv2 = nn.ConvTranspose1d(
                2048, 1024, kernel_size=3, stride=2, padding=1
            )
            out_features = 2048
        else:
            raise NotImplementedError()
        if model_config.backbone in ["cnn_rnn", "transformer_rnn"]:
            self.rnn = nn.LSTM(
                out_features,
                512,
                num_layers=2,
                batch_first=True,
                bidirectional=True,
            )
            out_features = 1024
        self.exam_classfifier = nn.Linear(out_features, model_config.num_classes)
        self.image_classfifier = nn.Linear(out_features, 1)

    def pooling(self, feature, seq_lens):
        pool_out = torch.stack(
            [feature[idx, :seq_len].mean(0) for idx, seq_len in enumerate(seq_lens)]
        )
        # pool_out = feature.mean(1)
        return pool_out

    def forward(self, batch):
        """
        Input:
            data (torch.Tensor): shape [bs, seq_len, n_feature]
        """
        feature = batch["feature"].float()
        meta_feature = batch["meta_feature"].float()
        feature = torch.cat([feature, meta_feature], dim=-1)
        if self.model_config.backbone in ["cnn", "cnn_rnn"]:
            feature = feature.permute(0, 2, 1)
            outs = []
            for i in range(1, 6):
                feature = getattr(self, f"resnet{i}")(feature)
                out = getattr(self, f"deconv{i}")(feature)
                outs.append(out)
            feature = torch.cat(outs, dim=1).permute(0, 2, 1)
        elif self.model_config.backbone in ["gru", "lstm"]:
            outs = []
            feature, _ = self.rnn1(feature)
            out = self.deconv1(feature.permute(0, 2, 1)).permute(0, 2, 1)
            outs.append(out)
            feature, _ = self.rnn2(feature)
            out = self.deconv2(feature.permute(0, 2, 1)).permute(0, 2, 1)
            outs.append(out)
            feature = torch.cat(outs, dim=-1)
        elif self.model_config.backbone in ["transformer", "transformer_rnn"]:
            feature = torch.relu(self.linear(feature))
            feature = self.pe(feature.permute(1, 0, 2))
            outs = []
            feature = self.transformer1(feature)
            out = self.deconv1(feature.permute(1, 2, 0)).permute(2, 0, 1)
            outs.append(out)
            feature = self.transformer1(feature)
            out = self.deconv2(feature.permute(1, 2, 0)).permute(2, 0, 1)
            outs.append(out)
            feature = torch.cat(outs, dim=-1).permute(1, 0, 2)
        else:
            pass
        if self.model_config.backbone in ["cnn_rnn", "transformer_rnn"]:
            feature, _ = self.rnn(feature)
        # feature, _ = self.rnn(feature)
        pool_out = self.pooling(feature, batch["seq_len"])
        exam_out = self.exam_classfifier(pool_out)
        image_out = self.image_classfifier(feature)
        return exam_out, image_out

    
class StackingModel(nn.Module):
    def __init__(self, model_config):
        super(StackingModel, self).__init__()
        self.model_config = model_config
        in_features = 120
        if model_config.backbone in ["cnn", "cnn_rnn"]:
            self.resnet1 = ResNet(
                in_features, 256, kernel_size=3, dropout=model_config.dropout_rate
            )
            self.resnet2 = ResNet(
                256, 128, kernel_size=3, dropout=model_config.dropout_rate
            )
            self.resnet3 = ResNet(
                128, 64, kernel_size=3, dropout=model_config.dropout_rate
            )
            self.resnet4 = ResNet(
                64, 32, kernel_size=3, dropout=model_config.dropout_rate
            )
            self.resnet5 = ResNet(
                32, 16, kernel_size=3, dropout=model_config.dropout_rate
            )
            out_features = 256 + 128 + 64 + 32 + 16
        elif model_config.backbone == "lstm":
            self.rnn = nn.LSTM(
                in_features,
                in_features,
                num_layers=2,
                batch_first=True,
                bidirectional=True,
            )
            out_features = in_features * 2
        elif model_config.backbone == "gru":
            self.rnn = nn.GRU(
                in_features,
                in_features,
                num_layers=2,
                batch_first=True,
                bidirectional=True,
            )
            out_features = in_features * 2
        else:
            pass
        self.exam_classfifier = nn.Linear(out_features, model_config.num_classes)
        self.image_classfifier = nn.Linear(out_features, 1)

    def pooling(self, feature, seq_lens):
        pool_out = torch.stack(
            [feature[idx, :seq_len].mean(0) for idx, seq_len in enumerate(seq_lens)]
        )
        # pool_out = feature.mean(1)
        return pool_out

    def forward(self, batch):
        """
        Input:
            data (torch.Tensor): shape [bs, seq_len, n_feature]
        """
        feature = batch["feature"].float()
        # feature = torch.cat([feature, meta_feature], dim=-1)
        if self.model_config.backbone in ["cnn", "cnn_rnn"]:
            feature = feature.permute(0, 2, 1)
            outs = []
            for i in range(1, 6):
                feature = getattr(self, f"resnet{i}")(feature)
                outs.append(feature)
            feature = torch.cat(outs, dim=1).permute(0, 2, 1)
        elif self.model_config.backbone in ["gru", "lstm"]:
            feature, _ = self.rnn(feature)
        else:
            pass
        pool_out = self.pooling(feature, batch["seq_len"])
        exam_out = self.exam_classfifier(pool_out)
        image_out = self.image_classfifier(feature)
        return exam_out, image_out

In [None]:
def load_image_512_model(cfg, data_path):
    fold_num = 5
    models = [ImageModel(cfg) for i in range(fold_num)]
    [m.to(device) for m in models]
    states = [torch.load(data_path + '/image_model/fold_{}.ckpt'.format(i+1), map_location=device)['state_dict'] for i in range(fold_num)]
    states = [{key[6:]: value for key, value in state.items()} for state in states]
    [models[i].load_state_dict(states[i]) for i in range(fold_num)]
    [m.eval() for m in models]
    del states; gc.collect()
    return models

def load_image_384_model(cfg, data_path):
    fold_num = 5
    models = [ImageModel(cfg) for i in range(fold_num)]
    [m.to(device) for m in models]
    states = [torch.load(data_path + '/fold_{}.ckpt'.format(i+1), map_location=device)['state_dict'] for i in range(fold_num)]
    states = [{key[6:]: value for key, value in state.items()} for state in states]
    [models[i].load_state_dict(states[i]) for i in range(fold_num)]
    [m.eval() for m in models]
    del states; gc.collect()
    return models

def load_seq1_512_model(cfg, data_path):
    fold_num = 5
    models = [DeconvFeatureModel(cfg) for i in range(fold_num)]
    [m.to(device) for m in models]
    states = [torch.load(data_path + '/tf_efficientnet_b5_ns_feature_deconv_cnn_scaler_fold{}_state_dict.ckpt'.format(i), map_location=device) for i in range(fold_num)]
    [models[i].load_state_dict(states[i], strict=False) for i in range(fold_num)]
    [m.eval() for m in models]
    del states; gc.collect()
    return models

def load_seq2_512_model(cfg, data_path):
    fold_num = 5
    models = [DeconvFeatureModel(cfg) for i in range(fold_num)]
    [m.to(device) for m in models]
    states = [torch.load(data_path + '/tf_efficientnet_b5_ns_feature_deconv_rnn_scaler_fold{}_state_dict.ckpt'.format(i), map_location=device) for i in range(fold_num)]
    [models[i].load_state_dict(states[i], strict=False) for i in range(fold_num)]
    [m.eval() for m in models]
    del states; gc.collect()
    return models

def load_seq3_512_model(cfg, data_path):
    fold_num = 5
    models = [DeconvFeatureModel(cfg) for i in range(fold_num)]
    [m.to(device) for m in models]
    states = [torch.load(data_path + '/tf_efficientnet_b5_ns_feature_deconv_gru_scaler_fold{}_state_dict.ckpt'.format(i), map_location=device) for i in range(fold_num)]
    [models[i].load_state_dict(states[i], strict=False) for i in range(fold_num)]
    [m.eval() for m in models]
    del states; gc.collect()
    return models


def load_seq4_512_model(cfg, data_path):
    fold_num = 5
    models = [DeconvFeatureModel(cfg) for i in range(fold_num)]
    [m.to(device) for m in models]
    states = [torch.load(data_path + '/tf_efficientnet_b5_ns_feature_deconv_cnn_rnn_scaler_fold{}_state_dict.ckpt'.format(i), map_location=device) for i in range(fold_num)]
    [models[i].load_state_dict(states[i], strict=False) for i in range(fold_num)]
    [m.eval() for m in models]
    del states; gc.collect()
    return models


def load_seq1_384_model(cfg, data_path):
    fold_num = 5
    models = [DeconvFeatureModel(cfg) for i in range(fold_num)]
    [m.to(device) for m in models]
    states = [torch.load(data_path + '/tf_efficientnet_b3_ns_feature_deconv_cnn_scaler_fold{}_state_dict.ckpt'.format(i), map_location=device) for i in range(fold_num)]
    [models[i].load_state_dict(states[i], strict=False) for i in range(fold_num)]
    [m.eval() for m in models]
    del states; gc.collect()
    return models

def load_seq2_384_model(cfg, data_path):
    fold_num = 5
    models = [DeconvFeatureModel(cfg) for i in range(fold_num)]
    [m.to(device) for m in models]
    states = [torch.load(data_path + '/tf_efficientnet_b3_ns_feature_deconv_rnn_scaler_fold{}_state_dict.ckpt'.format(i), map_location=device) for i in range(fold_num)]
    [models[i].load_state_dict(states[i], strict=False) for i in range(fold_num)]
    [m.eval() for m in models]
    del states; gc.collect()
    return models

def load_seq3_384_model(cfg, data_path):
    fold_num = 5
    models = [DeconvFeatureModel(cfg) for i in range(fold_num)]
    [m.to(device) for m in models]
    states = [torch.load(data_path + '/tf_efficientnet_b3_ns_feature_deconv_gru_scaler_fold{}_state_dict.ckpt'.format(i), map_location=device) for i in range(fold_num)]
    [models[i].load_state_dict(states[i], strict=False) for i in range(fold_num)]
    [m.eval() for m in models]
    del states; gc.collect()
    return models


def load_seq4_384_model(cfg, data_path):
    fold_num = 5
    models = [DeconvFeatureModel(cfg) for i in range(fold_num)]
    [m.to(device) for m in models]
    states = [torch.load(data_path + '/tf_efficientnet_b3_ns_feature_deconv_cnn_rnn_scaler_fold{}_state_dict.ckpt'.format(i), map_location=device) for i in range(fold_num)]
    [models[i].load_state_dict(states[i], strict=False) for i in range(fold_num)]
    [m.eval() for m in models]
    del states; gc.collect()
    return models

def load_seq1_concat_model(cfg, data_path):
    fold_num = 5
    models = [DeconvFeatureModel(cfg) for i in range(fold_num)]
    [m.to(device) for m in models]
    states = [torch.load(data_path + '/concat_512_384_cnn_fold{}_state_dict.ckpt'.format(i), map_location=device) for i in range(fold_num)]
    [models[i].load_state_dict(states[i], strict=False) for i in range(fold_num)]
    [m.eval() for m in models]
    del states; gc.collect()
    return models

def load_seq2_concat_model(cfg, data_path):
    fold_num = 5
    models = [DeconvFeatureModel(cfg) for i in range(fold_num)]
    [m.to(device) for m in models]
    states = [torch.load(data_path + '/concat_512_384_rnn_fold{}_state_dict.ckpt'.format(i), map_location=device) for i in range(fold_num)]
    [models[i].load_state_dict(states[i], strict=False) for i in range(fold_num)]
    [m.eval() for m in models]
    del states; gc.collect()
    return models

def load_seq3_concat_model(cfg, data_path):
    fold_num = 5
    models = [DeconvFeatureModel(cfg) for i in range(fold_num)]
    [m.to(device) for m in models]
    states = [torch.load(data_path + '/concat_512_384_gru_fold{}_state_dict.ckpt'.format(i), map_location=device) for i in range(fold_num)]
    [models[i].load_state_dict(states[i], strict=False) for i in range(fold_num)]
    [m.eval() for m in models]
    del states; gc.collect()
    return models

def load_seq4_concat_model(cfg, data_path):
    fold_num = 5
    models = [DeconvFeatureModel(cfg) for i in range(fold_num)]
    [m.to(device) for m in models]
    states = [torch.load(data_path + '/concat_512_384_cnn_rnn_fold{}_state_dict.ckpt'.format(i), map_location=device) for i in range(fold_num)]
    [models[i].load_state_dict(states[i], strict=False) for i in range(fold_num)]
    [m.eval() for m in models]
    del states; gc.collect()
    return models


def load_stacking_cnn_model(cfg, data_path):
    fold_num = 5
    models = [StackingModel(cfg) for i in range(fold_num)]
    [m.to(device) for m in models]
    states = [torch.load(data_path + '/stacking_cnn_final_fold{}_state_dict.ckpt'.format(i), map_location=device) for i in range(fold_num)]
    [models[i].load_state_dict(states[i], strict=False) for i in range(fold_num)]
    [m.eval() for m in models]
    del states; gc.collect()
    return models

def load_stacking_rnn_model(cfg, data_path):
    fold_num = 5
    models = [StackingModel(cfg) for i in range(fold_num)]
    [m.to(device) for m in models]
    states = [torch.load(data_path + '/stacking_lstm_concat_fold{}_state_dict.ckpt'.format(i), map_location=device) for i in range(fold_num)]
    [models[i].load_state_dict(states[i], strict=False) for i in range(fold_num)]
    [m.eval() for m in models]
    del states; gc.collect()
    return models

In [None]:
# ===== Image Model ======

cfg_384_b3 =  '''
model:
  name: tf_efficientnet_b3_ns
  pool: GeM
  figsize: 384
'''
cfg_384_b3 = OmegaConf.create(cfg_384_b3)


cfg_512_b5 =  '''
model:
  name: tf_efficientnet_b5_ns
  pool: GeM
  figsize: 512
'''
cfg_512_b5 = OmegaConf.create(cfg_512_b5)

# ===== 512 Sequence Model ======

cfg_cnn_512 =  '''
model:
  backbone: cnn
  dropout_rate: 0.2
  num_classes: 9
  num_feature: 2055
'''
cfg_cnn_512 = OmegaConf.create(cfg_cnn_512)


cfg_lstm_512 =  '''
model:
  backbone: lstm
  dropout_rate: 0.2
  num_classes: 9
  num_feature: 2055
'''
cfg_lstm_512 = OmegaConf.create(cfg_lstm_512)

cfg_gru_512 =  '''
model:
  backbone: gru
  dropout_rate: 0.2
  num_classes: 9
  num_feature: 2055
'''
cfg_gru_512 = OmegaConf.create(cfg_gru_512)

cfg_cnn_rnn_512 =  '''
model:
  backbone: cnn_rnn
  dropout_rate: 0.2
  num_classes: 9
  num_feature: 2055
'''
cfg_cnn_rnn_512 = OmegaConf.create(cfg_cnn_rnn_512)


cfg_transformer_512 =  '''
model:
  backbone: transformer
  dropout_rate: 0.2
  num_classes: 9
  num_feature: 2055
'''
cfg_transformer_512 = OmegaConf.create(cfg_transformer_512)


# ===== 384 Sequence Model ======

cfg_cnn_384 =  '''
model:
  backbone: cnn
  dropout_rate: 0.2
  num_classes: 9
  num_feature: 1543
'''
cfg_cnn_384 = OmegaConf.create(cfg_cnn_384)


cfg_lstm_384 =  '''
model:
  backbone: lstm
  dropout_rate: 0.2
  num_classes: 9
  num_feature: 1543
'''
cfg_lstm_384 = OmegaConf.create(cfg_lstm_384)


cfg_gru_384 =  '''
model:
  backbone: gru
  dropout_rate: 0.2
  num_classes: 9
  num_feature: 1543
'''
cfg_gru_384 = OmegaConf.create(cfg_gru_384)


cfg_cnn_rnn_384 =  '''
model:
  backbone: cnn_rnn
  dropout_rate: 0.2
  num_classes: 9
  num_feature: 1543
'''
cfg_cnn_rnn_384 = OmegaConf.create(cfg_cnn_rnn_384)


cfg_transformer_384 =  '''
model:
  backbone: transformer
  dropout_rate: 0.2
  num_classes: 9
  num_feature: 1543
'''
cfg_transformer_384 = OmegaConf.create(cfg_transformer_384)


# ===== 512 + 384 Sequence Model ======


cfg_cnn_concat =  '''
model:
  backbone: cnn
  dropout_rate: 0.2
  num_classes: 9
  num_feature: 3591
'''
cfg_cnn_concat = OmegaConf.create(cfg_cnn_concat)


cfg_lstm_concat =  '''
model:
  backbone: lstm
  dropout_rate: 0.2
  num_classes: 9
  num_feature: 3591
'''
cfg_lstm_concat = OmegaConf.create(cfg_lstm_concat)

cfg_gru_concat =  '''
model:
  backbone: gru
  dropout_rate: 0.2
  num_classes: 9
  num_feature: 3591
'''
cfg_gru_concat = OmegaConf.create(cfg_gru_concat)


cfg_cnn_rnn_concat =  '''
model:
  backbone: cnn_rnn
  dropout_rate: 0.2
  num_classes: 9
  num_feature: 3591
'''
cfg_cnn_rnn_concat = OmegaConf.create(cfg_cnn_rnn_concat)


# ===== Stacking Model ======

cfg_stacking_cnn =  '''
model:
  backbone: cnn
  dropout_rate: 0.2
  num_classes: 9
'''
cfg_stacking_cnn = OmegaConf.create(cfg_stacking_cnn)

cfg_stacking_lstm =  '''
model:
  backbone: lstm
  dropout_rate: 0.2
  num_classes: 9
'''
cfg_stacking_lstm = OmegaConf.create(cfg_stacking_lstm)

cfg_stacking_gru =  '''
model:
  backbone: gru
  dropout_rate: 0.2
  num_classes: 9
'''
cfg_stacking_gru = OmegaConf.create(cfg_stacking_gru)

In [None]:
if not commit_flag:
    image_models_512 = load_image_512_model(cfg_512_b5, data_path = '/kaggle/input/tf-efficientnet-b5-ns-512')
    image_models_384 = load_image_384_model(cfg_384_b3, data_path = '/kaggle/input/tf-efficientnet-b3-ns-384')

    cnn_models_512 = load_seq1_512_model(cfg_cnn_512.model, data_path = '/kaggle/input/rsna-shimacos-sequence-models-final')
    lstm_models_512 = load_seq2_512_model(cfg_lstm_512.model, data_path = '/kaggle/input/rsna-shimacos-sequence-models-final')
    gru_models_512 = load_seq3_512_model(cfg_gru_512.model, data_path = '/kaggle/input/rsna-shimacos-sequence-models-final')
    cnn_rnn_models_512 = load_seq4_512_model(cfg_cnn_rnn_512.model, data_path = '/kaggle/input/rsna-shimacos-sequence-models-final')
    # transformer_models_512 = load_seq4_512_model(cfg_transformer_512.model, data_path = '/kaggle/input/rsna-shimacos-sequence-models4')

    cnn_models_384 = load_seq1_384_model(cfg_cnn_384.model, data_path = '/kaggle/input/rsna-shimacos-sequence-models-final')
    lstm_models_384 = load_seq2_384_model(cfg_lstm_384.model, data_path = '/kaggle/input/rsna-shimacos-sequence-models-final')
    gru_models_384 = load_seq3_384_model(cfg_gru_384.model, data_path = '/kaggle/input/rsna-shimacos-sequence-models-final')
    cnn_rnn_models_384 = load_seq4_384_model(cfg_cnn_rnn_384.model, data_path = '/kaggle/input/rsna-shimacos-sequence-models-final')

    cnn_models_concat = load_seq1_concat_model(cfg_cnn_concat.model, data_path = '/kaggle/input/rsna-shimacos-sequence-models-final')
    lstm_models_concat = load_seq2_concat_model(cfg_lstm_concat.model, data_path = '/kaggle/input/rsna-shimacos-sequence-models-final')
    gru_models_concat = load_seq3_concat_model(cfg_gru_concat.model, data_path = '/kaggle/input/rsna-shimacos-sequence-models-final')
    cnn_rnn_models_concat = load_seq4_concat_model(cfg_cnn_rnn_concat.model, data_path = '/kaggle/input/rsna-shimacos-sequence-models-final')

In [None]:
lgb_dicts = []
for label_col in label_cols:
    with open(f'/kaggle/input/stacking-models-final/stacking_{label_col}.pkl', 'rb') as f:
        lgb_dicts.append(pickle.load(f))
with open(f'/kaggle/input/stacking-models-final/stacking_std.pkl', 'rb') as f:
    stacking_std = pickle.load(f)

In [None]:
stacking_cnn_models = load_stacking_cnn_model(cfg_stacking_cnn.model, data_path = '/kaggle/input/stacking-models-final')
stacking_rnn_models = load_stacking_rnn_model(cfg_stacking_gru.model, data_path = '/kaggle/input/stacking-models-final')
# stacking_rnn_models = load_stacking_rnn_model(cfg_stacking_lstm.model, data_path = '/kaggle/input/stacking-models-final')


In [None]:
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])
transform_512_b5 = nn.Sequential(kornia.geometry.transform.Resize((512, 512)), 
                                 kornia.augmentation.Normalize(mean, std)
                                )
transform_512_b5.to(device)
transform_384_b3 = nn.Sequential(kornia.geometry.transform.Resize((384, 384)), 
                                 kornia.augmentation.Normalize(mean, std))
transform_384_b3.to(device)

In [None]:
def convert_slice_image(image_array):
    instance_num_max = len(image_array)
    seq_len = min(len(image_array), 401)
    image = []
    pe_array = np.array(
        [window(i, 100, 700) for i in image_array]
    )
    for instance_num in range(seq_len)[0::2]:
        if instance_num > 0:
            if instance_num < instance_num_max:
                idx_range = range(idx - 1, idx + 2)
            else:
                # instance_num_max の最小値は64
                idx_range = range(idx - 2, idx + 1)
        else:
            idx_range = range(idx, idx + 3)
        image.append(np.transpose(pe_array[idx_range], (1, 2, 0)))
    image = np.array(image)
    return image


def second_level_predict(models, cnn_feature_dict, meta_feature, seq_len):
    preds = [
        seq_model({"feature": cnn_feature_dict[i], 'meta_feature': meta_feature, "seq_len": [seq_len]})
        for i, seq_model in enumerate(models)
    ]
    per_exam_xs, per_image_xs = zip(*preds)
    per_exam_x = np.mean(
        [
            torch.sigmoid(per_exam_x.squeeze()).cpu().numpy()
            for per_exam_x in per_exam_xs
        ],
        axis=0,
    )
    per_image_x = np.mean(
        [
            torch.sigmoid(per_image_x.squeeze()).cpu().numpy()
            for per_image_x in per_image_xs
        ],
        axis=0,
    )[:, None]
    return per_exam_x, per_image_x

def stacking_predict(lgb_dict, feat):
    return np.mean([lgb_dict[i].predict(feat) for i in range(5)], axis=0)

def inference_2_stage(image_array, meta_data):
    N_MODEL = 12
    WL_list = [-600, 100, 40]
    WW_list = [1500, 700, 400]
    meta_feature_cols = [
        "KVP",
        "XRayTubeCurrent",
        "Exposure",
        "SliceThickness",
        "ImagePositionPatient_x",
        "ImagePositionPatient_y",
        "ImagePositionPatient_z",
    ]
    meta_mean = np.array(
        [
            114.08353157,
            419.09953533,
            108.65329098,
            1.0090654,
            -172.34524724,
            -141.6034326,
            -45.66431551,
        ]
    )
    meta_std = np.array(
        [
            1.09305001e01,
            1.92887125e02,
            4.97377464e02,
            2.65628402e-01,
            2.62277483e01,
            7.11791545e01,
            4.41005297e02,
        ]
    )
    seq_len = min(len(image_array), 401)
    image = np.array(
        [convert_image(i, WL_list, WW_list) for i in image_array[:401][0::2]]
    )
    dataset = ImageDataset(image)
    dataloader = DataLoader(dataset, batch_size=32, drop_last=False, shuffle=False)
    cnn_feature_dict_512 = {i: [] for i in range(len(cnn_models_512))}
    cnn_feature_dict_384 = {i: [] for i in range(len(cnn_models_384))}
    with torch.no_grad():
        for batch in dataloader:
            batch = {k: batch[k].to(device) for k in batch.keys()}
            with autocast():
                batch_512_image = transform_512_b5(batch["image"])    
                for i in range(len(image_models_512)):
                    cnn_feature_dict_512[i].append(
                        image_models_512[i].get_feature({"image": batch_512_image})
                    )
                batch_384_image = transform_384_b3(batch["image"])
                for i in range(len(image_models_384)):
                    cnn_feature_dict_384[i].append(
                        image_models_384[i].get_feature({"image": batch_384_image})
                    )
        cnn_feature_dict_512 = {
            i: torch.cat(feature, dim=0).unsqueeze(0)
            for i, feature in cnn_feature_dict_512.items()
        }
        cnn_feature_dict_512 = {
            i: F.pad(feature, (0, 0, 0, 201 - len(image)))
            for i, feature in cnn_feature_dict_512.items()
        }
        cnn_feature_dict_384 = {
            i: torch.cat(feature, dim=0).unsqueeze(0)
            for i, feature in cnn_feature_dict_384.items()
        }
        cnn_feature_dict_384 = {
            i: F.pad(feature, (0, 0, 0, 201 - len(image)))
            for i, feature in cnn_feature_dict_384.items()
        }
        cnn_feature_dict_concat = {
            i: torch.cat([cnn_feature_dict_384[i], cnn_feature_dict_512[i]], dim=-1)
            for i in range(5)
        }

        meta_feature = meta_data[meta_feature_cols].values[:401][0::2].astype(np.float32)
        meta_feature = (meta_feature - meta_mean) / meta_std
        meta_feature = F.pad(
            torch.tensor(meta_feature, device=cnn_feature_dict_512[0].device), (0, 0, 0, 201 - len(image))
        ).unsqueeze(0)
        preds_512 = [second_level_predict(models, cnn_feature_dict_512, meta_feature, seq_len) for models in [cnn_models_512, lstm_models_512, gru_models_512, cnn_rnn_models_512]]
        preds_384 = [second_level_predict(models, cnn_feature_dict_384, meta_feature, seq_len) for models in [cnn_models_384, lstm_models_384, gru_models_384, cnn_rnn_models_384]]
        preds_concat = [second_level_predict(models, cnn_feature_dict_concat, meta_feature, seq_len) for models in [cnn_models_concat, lstm_models_concat, gru_models_concat, cnn_rnn_models_concat]]
    exam_preds_512, image_preds_512 = zip(*preds_512)
    exam_preds_384, image_preds_384 = zip(*preds_384)
    exam_preds_concat, image_preds_concat = zip(*preds_concat)
    exam_preds = exam_preds_512 + exam_preds_384 + exam_preds_concat
    image_preds = np.concatenate(image_preds_512 + image_preds_384 + image_preds_concat, axis=1)[:seq_len]
    image_feats = []
    for i in range(N_MODEL):
        image_feats.append(np.concatenate([image_preds[:, [i]], exam_preds[i][None, :].repeat(seq_len, axis=0)], axis=1))
    image_feats = np.concatenate(image_feats, axis=1)
    exam_preds = np.concatenate(exam_preds)
    image_agg = pd.DataFrame(image_preds).agg(['mean', 'std', 'min', 'max'])
    exam_feat = np.concatenate([exam_preds] + [image_agg[i].values for i in range(N_MODEL)])
    
    # Stacking prediction
    lgb_image_pred = stacking_predict(lgb_dicts[0], image_feats)
    lgb_exam_pred = np.squeeze(np.concatenate([stacking_predict(lgb_dict, exam_feat[None, :]) for lgb_dict in lgb_dicts[1:]]))
    
    
    image_feats = stacking_std.transform(image_feats)
    nn_feature = F.pad(torch.tensor(image_feats).cuda().unsqueeze(0), (0, 0, 0, 401 - seq_len))
    cnn_exam_preds, cnn_image_preds = zip(*[model({'feature': nn_feature, 'seq_len': [seq_len]}) for model in stacking_cnn_models])
    cnn_exam_pred = np.mean([torch.sigmoid(pred).detach().cpu().numpy().squeeze() for pred in cnn_exam_preds], axis=0)
    cnn_image_pred = np.mean([torch.sigmoid(pred).detach().cpu().numpy().squeeze() for pred in cnn_image_preds], axis=0)[:seq_len]
    
    rnn_exam_preds, rnn_image_preds = zip(*[model({'feature': nn_feature, 'seq_len': [seq_len]}) for model in stacking_rnn_models])
    rnn_exam_pred = np.mean([torch.sigmoid(pred).detach().cpu().numpy().squeeze() for pred in rnn_exam_preds], axis=0)
    rnn_image_pred = np.mean([torch.sigmoid(pred).detach().cpu().numpy().squeeze() for pred in rnn_image_preds], axis=0)[:seq_len]

    per_exam_x = np.average([lgb_exam_pred, cnn_exam_pred, rnn_exam_pred], axis=0, weights=[2, 1, 2])
    per_image_x = np.average([lgb_image_pred, cnn_image_pred, rnn_image_pred], axis=0, weights=[2, 1, 2])
    
    # Mean Prediction
#     per_image_x = lgb_image_pred
#     per_exam_x = lgb_exam_pred
    return per_exam_x, per_image_x


In [None]:
def inference_dir(dcm_dir_path):
    image_array, meta_data = get_data(dcm_dir_path)
    per_exam_x, per_image_x = inference_2_stage(image_array, meta_data)
    return per_exam_x, per_image_x, meta_data['SOPInstanceUID']

In [None]:
per_exam_label_col = label_cols[1:]

In [None]:
if not commit_flag:
    submits = []
    for idx, data in tqdm(test_df_.iterrows(), total=len(test_df_)):
    #     for idx, data in tqdm(test_df_.query('StudyInstanceUID=="84a57a6bc1b4"').iterrows(), total=len(test_df_)):
        per_exam_x, per_image_x, SOPInstanceUID = inference_dir(data['dcm_dir_path'])
        if len(SOPInstanceUID) > 401:
            out = np.zeros(len(SOPInstanceUID))
            out[:401] = per_image_x
            per_image_x = out
        else:
            per_image_x = per_image_x[:len(SOPInstanceUID)]
        per_image_x = per_image_x[:]
        StudyInstanceUID, SeriesInstanceUID = data['StudyInstanceUID'], data['SeriesInstanceUID']
        tmp = pd.DataFrame({'SOPInstanceUID': SOPInstanceUID, 'pe_present_on_image': per_image_x})
        for i, label_col in enumerate(per_exam_label_col):
            tmp['StudyInstanceUID'] = StudyInstanceUID
            tmp[label_col] = per_exam_x[i]
        submits.append(tmp)
    submits = pd.concat(submits).reset_index(drop=True)
#     submits = satisfy_label_consistency(submits, delta=3)
    exam_pred = submits[['StudyInstanceUID'] + per_exam_label_col].drop_duplicates()
    submit_dict = {'id': [], 'label': []}
    for label_col in per_exam_label_col:
        submit_dict['id'].extend((exam_pred['StudyInstanceUID'] + f'_{label_col}').values.tolist())
        submit_dict['label'].extend(exam_pred[label_col].values.tolist())
    submit_dict['id'].extend(submits['SOPInstanceUID'].values.tolist())
    submit_dict['label'].extend(submits['pe_present_on_image'].values.tolist())
    submit = pd.DataFrame(submit_dict)

    if MODE in ['public', 'private']:
        submit = pd.merge(submit, sample_sub.rename(columns={'label': 'old_label'}), on='id', how='outer')
        submit['label'] = submit['label'].fillna(submit['old_label'])
        submit = submit[['id', 'label']]
    submit.to_csv('submission.csv', index=False)