In [None]:
!pip install /kaggle/input/rsna-2022-whl/{pydicom-2.3.0-py3-none-any.whl,pylibjpeg-1.4.0-py3-none-any.whl,python_gdcm-3.0.15-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl}
!pip install -q /kaggle/input/nvidia-dali-nightly-cuda110-1230dev/nvidia_dali_nightly_cuda110-1.23.0.dev20230203-7187866-py3-none-manylinux2014_x86_64.whl
# !pip install /kaggle/input/nvidia-dali-wheel/nvidia_dali_nightly_cuda110-1.22.0.dev20221213-6757685-py3-none-manylinux2014_x86_64.whl
!pip install /kaggle/input/nvidia-dali-wheel/dicomsdl-0.109.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl

In [None]:
%%writefile /opt/conda/lib/python3.7/site-packages/nvidia/dali/plugin/pytorch.py

# Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nvidia.dali.backend import TensorGPU, TensorListGPU
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
from nvidia.dali import types
from nvidia.dali.plugin.base_iterator import _DaliBaseIterator
from nvidia.dali.plugin.base_iterator import LastBatchPolicy
import torch
import torch.utils.dlpack as torch_dlpack
import ctypes
import numpy as np

to_torch_type = {
    types.DALIDataType.FLOAT:   torch.float32,
    types.DALIDataType.FLOAT64: torch.float64,
    types.DALIDataType.FLOAT16: torch.float16,
    types.DALIDataType.UINT8:   torch.uint8,
    types.DALIDataType.INT8:    torch.int8,
    types.DALIDataType.UINT16:  torch.int16,
    types.DALIDataType.INT16:   torch.int16,
    types.DALIDataType.INT32:   torch.int32,
    types.DALIDataType.INT64:   torch.int64
}


def feed_ndarray(dali_tensor, arr, cuda_stream=None):
    """
    Copy contents of DALI tensor to PyTorch's Tensor.

    Parameters
    ----------
    `dali_tensor` : nvidia.dali.backend.TensorCPU or nvidia.dali.backend.TensorGPU
                    Tensor from which to copy
    `arr` : torch.Tensor
            Destination of the copy
    `cuda_stream` : torch.cuda.Stream, cudaStream_t or any value that can be cast to cudaStream_t.
                    CUDA stream to be used for the copy
                    (if not provided, an internal user stream will be selected)
                    In most cases, using pytorch's current stream is expected (for example,
                    if we are copying to a tensor allocated with torch.zeros(...))
    """
    dali_type = to_torch_type[dali_tensor.dtype]

    assert dali_type == arr.dtype, ("The element type of DALI Tensor/TensorList"
                                    " doesn't match the element type of the target PyTorch Tensor: "
                                    "{} vs {}".format(dali_type, arr.dtype))
    assert dali_tensor.shape() == list(arr.size()), \
        ("Shapes do not match: DALI tensor has size {0}, but PyTorch Tensor has size {1}".
            format(dali_tensor.shape(), list(arr.size())))
    cuda_stream = types._raw_cuda_stream(cuda_stream)

    # turn raw int to a c void pointer
    c_type_pointer = ctypes.c_void_p(arr.data_ptr())
    if isinstance(dali_tensor, (TensorGPU, TensorListGPU)):
        stream = None if cuda_stream is None else ctypes.c_void_p(cuda_stream)
        dali_tensor.copy_to_external(c_type_pointer, stream, non_blocking=True)
    else:
        dali_tensor.copy_to_external(c_type_pointer)
    return arr


class DALIGenericIterator(_DaliBaseIterator):
    """
    General DALI iterator for PyTorch. It can return any number of
    outputs from the DALI pipeline in the form of PyTorch's Tensors.

    Parameters
    ----------
    pipelines : list of nvidia.dali.Pipeline
                List of pipelines to use
    output_map : list of str
                List of strings which maps consecutive outputs
                of DALI pipelines to user specified name.
                Outputs will be returned from iterator as dictionary
                of those names.
                Each name should be distinct
    size : int, default = -1
                Number of samples in the shard for the wrapped pipeline (if there is more than
                one it is a sum)
                Providing -1 means that the iterator will work until StopIteration is raised
                from the inside of iter_setup(). The options `last_batch_policy` and
                `last_batch_padded` don't work in such case. It works with only one pipeline inside
                the iterator.
                Mutually exclusive with `reader_name` argument
    reader_name : str, default = None
                Name of the reader which will be queried to the shard size, number of shards and
                all other properties necessary to count properly the number of relevant and padded
                samples that iterator needs to deal with. It automatically sets `last_batch_policy`
                to PARTIAL when the FILL is used, and `last_batch_padded` accordingly to match
                the reader's configuration
    auto_reset : string or bool, optional, default = False
                Whether the iterator resets itself for the next epoch or it requires reset() to be
                called explicitly.

                It can be one of the following values:

                * ``"no"``, ``False`` or ``None`` - at the end of epoch StopIteration is raised
                  and reset() needs to be called
                * ``"yes"`` or ``"True"``- at the end of epoch StopIteration is raised but reset()
                  is called internally automatically

    dynamic_shape : any, optional,
                Parameter used only for backward compatibility.
    fill_last_batch : bool, optional, default = None
                **Deprecated** Please use ``last_batch_policy`` instead

                Whether to fill the last batch with data up to 'self.batch_size'.
                The iterator would return the first integer multiple
                of self._num_gpus * self.batch_size entries which exceeds 'size'.
                Setting this flag to False will cause the iterator to return
                exactly 'size' entries.
    last_batch_policy: optional, default = LastBatchPolicy.FILL
                What to do with the last batch when there are not enough samples in the epoch
                to fully fill it. See :meth:`nvidia.dali.plugin.base_iterator.LastBatchPolicy`
    last_batch_padded : bool, optional, default = False
                Whether the last batch provided by DALI is padded with the last sample
                or it just wraps up. In the conjunction with ``last_batch_policy`` it tells
                if the iterator returning last batch with data only partially filled with
                data from the current epoch is dropping padding samples or samples from
                the next epoch. If set to ``False`` next
                epoch will end sooner as data from it was consumed but dropped. If set to
                True next epoch would be the same length as the first one. For this to happen,
                the option `pad_last_batch` in the reader needs to be set to True as well.
                It is overwritten when `reader_name` argument is provided
    prepare_first_batch : bool, optional, default = True
                Whether DALI should buffer the first batch right after the creation of the iterator,
                so one batch is already prepared when the iterator is prompted for the data

    Example
    -------
    With the data set ``[1,2,3,4,5,6,7]`` and the batch size 2:

    last_batch_policy = LastBatchPolicy.PARTIAL, last_batch_padded = True  -> last batch = ``[7]``,
    next iteration will return ``[1, 2]``

    last_batch_policy = LastBatchPolicy.PARTIAL, last_batch_padded = False -> last batch = ``[7]``,
    next iteration will return ``[2, 3]``

    last_batch_policy = LastBatchPolicy.FILL, last_batch_padded = True   -> last batch = ``[7, 7]``,
    next iteration will return ``[1, 2]``

    last_batch_policy = LastBatchPolicy.FILL, last_batch_padded = False  -> last batch = ``[7, 1]``,
    next iteration will return ``[2, 3]``

    last_batch_policy = LastBatchPolicy.DROP, last_batch_padded = True   -> last batch = ``[5, 6]``,
    next iteration will return ``[1, 2]``

    last_batch_policy = LastBatchPolicy.DROP, last_batch_padded = False  -> last batch = ``[5, 6]``,
    next iteration will return ``[2, 3]``
    """

    def __init__(self,
                 pipelines,
                 output_map,
                 size=-1,
                 reader_name=None,
                 auto_reset=False,
                 fill_last_batch=None,
                 dynamic_shape=False,
                 last_batch_padded=False,
                 last_batch_policy=LastBatchPolicy.FILL,
                 prepare_first_batch=True):

        # check the assert first as _DaliBaseIterator would run the prefetch
        assert len(set(output_map)) == len(output_map), "output_map names should be distinct"
        self._output_categories = set(output_map)
        self.output_map = output_map

        _DaliBaseIterator.__init__(self,
                                   pipelines,
                                   size,
                                   reader_name,
                                   auto_reset,
                                   fill_last_batch,
                                   last_batch_padded,
                                   last_batch_policy,
                                   prepare_first_batch=prepare_first_batch)

        self._first_batch = None
        if self._prepare_first_batch:
            try:
                self._first_batch = DALIGenericIterator.__next__(self)
                # call to `next` sets _ever_consumed to True but if we are just calling it from
                # here we should set if to False again
                self._ever_consumed = False
            except StopIteration:
                assert False, "It seems that there is no data in the pipeline. This may happen " \
                       "if `last_batch_policy` is set to PARTIAL and the requested batch size is " \
                       "greater than the shard size."

    def __next__(self):
        self._ever_consumed = True
        if self._first_batch is not None:
            batch = self._first_batch
            self._first_batch = None
            return batch

        # Gather outputs
        outputs = self._get_outputs()

        data_batches = [None for i in range(self._num_gpus)]
        for i in range(self._num_gpus):
            dev_id = self._pipes[i].device_id
            # initialize dict for all output categories
            category_outputs = dict()
            # segregate outputs into categories
            for j, out in enumerate(outputs[i]):
                category_outputs[self.output_map[j]] = out

            # Change DALI TensorLists into Tensors
            category_tensors = dict()
            category_shapes = dict()
            for category, out in category_outputs.items():
                category_tensors[category] = out.as_tensor()
                category_shapes[category] = category_tensors[category].shape()

            category_torch_type = dict()
            category_device = dict()
            torch_gpu_device = None
            torch_cpu_device = torch.device('cpu')
            # check category and device
            for category in self._output_categories:
                category_torch_type[category] = to_torch_type[category_tensors[category].dtype]
                if type(category_tensors[category]) is TensorGPU:
                    if not torch_gpu_device:
                        torch_gpu_device = torch.device('cuda', dev_id)
                    category_device[category] = torch_gpu_device
                else:
                    category_device[category] = torch_cpu_device

            pyt_tensors = dict()
            for category in self._output_categories:
                pyt_tensors[category] = torch.empty(category_shapes[category],
                                                    dtype=category_torch_type[category],
                                                    device=category_device[category])

            data_batches[i] = pyt_tensors

            # Copy data from DALI Tensors to torch tensors
            for category, tensor in category_tensors.items():
                if isinstance(tensor, (TensorGPU, TensorListGPU)):
                    # Using same cuda_stream used by torch.zeros to set the memory
                    stream = torch.cuda.current_stream(device=pyt_tensors[category].device)
                    feed_ndarray(tensor, pyt_tensors[category], cuda_stream=stream)
                else:
                    feed_ndarray(tensor, pyt_tensors[category])

        self._schedule_runs()

        self._advance_and_check_drop_last()

        if self._reader_name:
            if_drop, left = self._remove_padded()
            if np.any(if_drop):
                output = []
                for batch, to_copy in zip(data_batches, left):
                    batch = batch.copy()
                    for category in self._output_categories:
                        batch[category] = batch[category][0:to_copy]
                    output.append(batch)
                return output

        else:
            if self._last_batch_policy == LastBatchPolicy.PARTIAL and (
                                          self._counter > self._size) and self._size > 0:
                # First calculate how much data is required to return exactly self._size entries.
                diff = self._num_gpus * self.batch_size - (self._counter - self._size)
                # Figure out how many GPUs to grab from.
                numGPUs_tograb = int(np.ceil(diff / self.batch_size))
                # Figure out how many results to grab from the last GPU
                # (as a fractional GPU batch may be required to bring us
                # right up to self._size).
                mod_diff = diff % self.batch_size
                data_fromlastGPU = mod_diff if mod_diff else self.batch_size

                # Grab the relevant data.
                # 1) Grab everything from the relevant GPUs.
                # 2) Grab the right data from the last GPU.
                # 3) Append data together correctly and return.
                output = data_batches[0:numGPUs_tograb]
                output[-1] = output[-1].copy()
                for category in self._output_categories:
                    output[-1][category] = output[-1][category][0:data_fromlastGPU]
                return output

        return data_batches


class DALIClassificationIterator(DALIGenericIterator):
    """
    DALI iterator for classification tasks for PyTorch. It returns 2 outputs
    (data and label) in the form of PyTorch's Tensor.

    Calling

    .. code-block:: python

       DALIClassificationIterator(pipelines, reader_name)

    is equivalent to calling

    .. code-block:: python

       DALIGenericIterator(pipelines, ["data", "label"], reader_name)

    Parameters
    ----------
    pipelines : list of nvidia.dali.Pipeline
                List of pipelines to use
    size : int, default = -1
                Number of samples in the shard for the wrapped pipeline (if there is more than
                one it is a sum)
                Providing -1 means that the iterator will work until StopIteration is raised
                from the inside of iter_setup(). The options `last_batch_policy` and
                `last_batch_padded` don't work in such case. It works with only one pipeline inside
                the iterator.
                Mutually exclusive with `reader_name` argument
    reader_name : str, default = None
                Name of the reader which will be queried to the shard size, number of shards and
                all other properties necessary to count properly the number of relevant and padded
                samples that iterator needs to deal with. It automatically sets `last_batch_policy`
                to PARTIAL when the FILL is used, and `last_batch_padded` accordingly to match
                the reader's configuration
    auto_reset : string or bool, optional, default = False
                Whether the iterator resets itself for the next epoch or it requires reset() to be
                called explicitly.

                It can be one of the following values:

                * ``"no"``, ``False`` or ``None`` - at the end of epoch StopIteration is raised
                  and reset() needs to be called
                * ``"yes"`` or ``"True"``- at the end of epoch StopIteration is raised but reset()
                  is called internally automatically

    dynamic_shape : any, optional,
                Parameter used only for backward compatibility.
    fill_last_batch : bool, optional, default = None
                **Deprecated** Please use ``last_batch_policy`` instead

                Whether to fill the last batch with data up to 'self.batch_size'.
                The iterator would return the first integer multiple
                of self._num_gpus * self.batch_size entries which exceeds 'size'.
                Setting this flag to False will cause the iterator to return
                exactly 'size' entries.
    last_batch_policy: optional, default = LastBatchPolicy.FILL
                What to do with the last batch when there are not enough samples in the epoch
                to fully fill it. See :meth:`nvidia.dali.plugin.base_iterator.LastBatchPolicy`
    last_batch_padded : bool, optional, default = False
                Whether the last batch provided by DALI is padded with the last sample
                or it just wraps up. In the conjunction with ``last_batch_policy`` it tells
                if the iterator returning last batch with data only partially filled with
                data from the current epoch is dropping padding samples or samples from
                the next epoch. If set to ``False`` next
                epoch will end sooner as data from it was consumed but dropped. If set to
                True next epoch would be the same length as the first one. For this to happen,
                the option `pad_last_batch` in the reader needs to be set to True as well.
                It is overwritten when `reader_name` argument is provided
    prepare_first_batch : bool, optional, default = True
                Whether DALI should buffer the first batch right after the creation of the iterator,
                so one batch is already prepared when the iterator is prompted for the data

    Example
    -------
    With the data set ``[1,2,3,4,5,6,7]`` and the batch size 2:

    last_batch_policy = LastBatchPolicy.PARTIAL, last_batch_padded = True  -> last batch = ``[7]``,
    next iteration will return ``[1, 2]``

    last_batch_policy = LastBatchPolicy.PARTIAL, last_batch_padded = False -> last batch = ``[7]``,
    next iteration will return ``[2, 3]``

    last_batch_policy = LastBatchPolicy.FILL, last_batch_padded = True   -> last batch = ``[7, 7]``,
    next iteration will return ``[1, 2]``

    last_batch_policy = LastBatchPolicy.FILL, last_batch_padded = False  -> last batch = ``[7, 1]``,
    next iteration will return ``[2, 3]``

    last_batch_policy = LastBatchPolicy.DROP, last_batch_padded = True   -> last batch = ``[5, 6]``,
    next iteration will return ``[1, 2]``

    last_batch_policy = LastBatchPolicy.DROP, last_batch_padded = False  -> last batch = ``[5, 6]``,
    next iteration will return ``[2, 3]``
    """

    def __init__(self,
                 pipelines,
                 size=-1,
                 reader_name=None,
                 auto_reset=False,
                 fill_last_batch=None,
                 dynamic_shape=False,
                 last_batch_padded=False,
                 last_batch_policy=LastBatchPolicy.FILL,
                 prepare_first_batch=True):
        super(DALIClassificationIterator, self).__init__(pipelines, ["data", "label"],
                                                         size,
                                                         reader_name=reader_name,
                                                         auto_reset=auto_reset,
                                                         fill_last_batch=fill_last_batch,
                                                         dynamic_shape=dynamic_shape,
                                                         last_batch_padded=last_batch_padded,
                                                         last_batch_policy=last_batch_policy,
                                                         prepare_first_batch=prepare_first_batch)


class TorchPythonFunction(ops.PythonFunctionBase):
    schema_name = "TorchPythonFunction"
    ops.register_cpu_op('TorchPythonFunction')
    ops.register_gpu_op('TorchPythonFunction')

    def _torch_stream_wrapper(self, function, *ins):
        with torch.cuda.stream(self.stream):
            out = function(*ins)
        self.stream.synchronize()
        return out

    def torch_wrapper(self, batch_processing, function, device, *args):
        func = function if device == 'cpu' else \
               lambda *ins: self._torch_stream_wrapper(function, *ins)
        if batch_processing:
            return ops.PythonFunction.function_wrapper_batch(func,
                                                             self.num_outputs,
                                                             torch.utils.dlpack.from_dlpack,
                                                             torch.utils.dlpack.to_dlpack,
                                                             *args)
        else:
            return ops.PythonFunction.function_wrapper_per_sample(func,
                                                                  self.num_outputs,
                                                                  torch_dlpack.from_dlpack,
                                                                  torch_dlpack.to_dlpack,
                                                                  *args)

    def __call__(self, *inputs, **kwargs):
        pipeline = Pipeline.current()
        if pipeline is None:
            Pipeline._raise_no_current_pipeline("TorchPythonFunction")
        if self.stream is None:
            self.stream = torch.cuda.Stream(device=pipeline.device_id)
        return super(TorchPythonFunction, self).__call__(*inputs, **kwargs)

    def __init__(self, function, num_outputs=1, device='cpu', batch_processing=False, **kwargs):
        self.stream = None
        super(TorchPythonFunction, self).__init__(impl_name="DLTensorPythonFunctionImpl",
                                                  function=lambda *ins:
                                                  self.torch_wrapper(batch_processing,
                                                                     function, device,
                                                                     *ins),
                                                  num_outputs=num_outputs, device=device,
                                                  batch_processing=batch_processing, **kwargs)


ops._wrap_op(TorchPythonFunction, "fn", __name__)

In [None]:
import os
import sys
import cv2
import glob
import gdcm
import json
import shutil
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path
import gc

from tqdm.notebook import tqdm
from joblib import Parallel, delayed

In [None]:
IMG_DIR = Path("/kaggle/input/rsna-breast-cancer-detection/test_images/")
test_images = list(IMG_DIR.glob("*/*.dcm"))

DEBUG = len(test_images) == 4

if DEBUG:
    DEBUG_FOLD = 0
    fold_df = pd.read_csv('/kaggle/input/rsna-misc/train_with_fold.csv')
#     SAMPLE_ID = sorted(fold_df.query('fold==@DEBUG_FOLD').patient_id.unique().tolist())[:1000]
    SAMPLE_ID = {
        42624, 48001, 48514, 2179, 31107, 23554, 13185, 53255, 
        29192, 59530, 64908, 32527, 13845, 59552, 54816, 49954, 
        55330, 59307, 21934, 63536, 23729, 61490, 61874, 16955, 
        46014, 38727, 64456, 50375, 9162, 55755, 25550, 15696, 
        50002, 58195, 10198, 13016, 25050, 31581, 26333, 29664, 
        8289, 3305, 6637, 48493, 58610, 42231, 12282, 9083, 32252, 39677}
    IMG_DIR = Path("/kaggle/input/rsna-breast-cancer-detection/train_images/")
    test_images = []
    for pid in SAMPLE_ID:
        test_images.extend(list(IMG_DIR.glob(f"{pid}/*.dcm")))
    
print("Number of images :", len(test_images))

In [None]:
EXPORT_DIR = Path("/tmp/output/")
EXPORT_DIR.mkdir(exist_ok=True)
SIZE = 2048

if len(test_images) > 100:
    N_CHUNKS = 4
else:
    N_CHUNKS = 1

CHUNKS = [(len(test_images) / N_CHUNKS * k, len(test_images) / N_CHUNKS * (k + 1)) for k in range(N_CHUNKS)]
CHUNKS = np.array(CHUNKS).astype(int)
    
J2K_FOLDER = Path("/tmp/j2k/")

In [None]:
import dicomsdl
import torch
import torch.nn.functional as F
import nvidia.dali.fn as fn
import nvidia.dali.types as types
from nvidia.dali import pipeline_def
from nvidia.dali.types import DALIDataType
from pydicom.filebase import DicomBytesIO
from nvidia.dali.plugin.pytorch import feed_ndarray, to_torch_type


def convert_dicom_to_j2k(file, save_folder):
    patient, image = file.parents[0].stem, file.stem
    dcmfile = pydicom.dcmread(str(file))
    
    if dcmfile.file_meta.TransferSyntaxUID == '1.2.840.10008.1.2.4.90':
        with open(file, 'rb') as fp:
            raw = DicomBytesIO(fp.read())
            ds = pydicom.dcmread(raw)
        offset = ds.PixelData.find(b"\x00\x00\x00\x0C")  #<---- the jpeg2000 header info we're looking for
        hackedbitstream = bytearray()
        hackedbitstream.extend(ds.PixelData[offset:])
        with open(save_folder/f"{patient}_{image}.jp2", "wb") as binary_file:
            binary_file.write(hackedbitstream)
            
    if dcmfile.file_meta.TransferSyntaxUID == '1.2.840.10008.1.2.4.70':
        with open(file, 'rb') as fp:
            raw = DicomBytesIO(fp.read())
            ds = pydicom.dcmread(raw)
        offset = ds.PixelData.find(b"\xff\xd8\xff\xe0")  #<---- the jpeg lossless header info we're looking for
        hackedbitstream = bytearray()
        hackedbitstream.extend(ds.PixelData[offset:])
        with open(save_folder/f"{patient}_{image}.jp2", "wb") as binary_file:
            binary_file.write(hackedbitstream)

            
@pipeline_def
def j2k_decode_pipeline(j2kfiles):
    jpegs, _ = fn.readers.file(files=j2kfiles)
    images = fn.experimental.decoders.image(jpegs, device='mixed', output_type=types.ANY_DATA, dtype=DALIDataType.UINT16)
    return images


def load_image_pydicom(dataset, voi_lut=False, size=1024):
    img = dataset.pixel_array
    if voi_lut:
        img = apply_voi_lut(img, dataset)
    if dataset.PhotometricInterpretation == "MONOCHROME1":
        img = np.max(img) - img
    img = (img - img.min()) / (img.max() - img.min())
    img = cv2.resize(img, (size, size))
    return img


def load_image_dicomsdl(img_path, voi_lut=False, size=1024):
    dataset = dicomsdl.open(img_path)
    img = dataset.pixelData()
    
    if voi_lut:
        # Load only the variables we need
        center = dataset["WindowCenter"]
        width = dataset["WindowWidth"]
        bits_stored = dataset["BitsStored"]
        voi_lut_function = dataset["VOILUTFunction"]

        # For sigmoid it's a list, otherwise a single value
        if isinstance(center, list):
            center = center[0]
        if isinstance(width, list):
            width = width[0]

        # Set y_min, max & range
        y_min = 0
        y_max = float(2**bits_stored - 1)
        y_range = y_max

        # Function with default LINEAR (so for Nan, it will use linear)
        if voi_lut_function == "SIGMOID":
            img = y_range / (1 + np.exp(-4 * (img - center) / width)) + y_min
        else:
            # Checks width for < 1 (in our case not necessary, always >= 750)
            center -= 0.5
            width -= 1

            below = img <= (center - width / 2)
            above = img > (center + width / 2)
            between = np.logical_and(~below, ~above)

            img[below] = y_min
            img[above] = y_max
            if between.any():
                img[between] = (
                    ((img[between] - center) / width + 0.5) * y_range + y_min
                )
    
    if dataset["PhotometricInterpretation"] == "MONOCHROME1":
        img = np.max(img) - img
    
    img = (img - img.min()) / (img.max() - img.min())
    img = cv2.resize(img, (size, size))
    
    return img


def torch_voi_lut(img, dataset):
    center = dataset["WindowCenter"]
    width = dataset["WindowWidth"]
    bits_stored = dataset["BitsStored"]
    voi_lut_function = dataset["VOILUTFunction"]
    # For sigmoid it's a list, otherwise a single value
    if isinstance(center, list):
        center = center[0]
    if isinstance(width, list):
        width = width[0]
    # Set y_min, max & range
    y_min = 0
    y_max = float(2**bits_stored - 1)
    y_range = y_max

    # Function with default LINEAR (so for Nan, it will use linear)
    if voi_lut_function == "SIGMOID":
        img = y_range / (1 + torch.exp(-4 * (img - center) / width)) + y_min
    else:
        # Checks width for < 1 (in our case not necessary, always >= 750)
        center -= 0.5
        width -= 1

        below = img <= (center - width / 2)
        above = img > (center + width / 2)
        between = torch.logical_and(~below, ~above)

        img[below] = y_min
        img[above] = y_max
        if between.any():
            img[between] = (
                ((img[between] - center) / width + 0.5) * y_range + y_min
            )
    return img

In [None]:
for chunk in tqdm(CHUNKS):
    J2K_FOLDER.mkdir(exist_ok=True)

    _ = Parallel(n_jobs=2)(
        delayed(convert_dicom_to_j2k)(img, save_folder=J2K_FOLDER)
        for img in test_images[chunk[0]: chunk[1]]
    )
    
    j2kfiles = list(J2K_FOLDER.glob("*.jp2"))

    if not len(j2kfiles):
        continue

    pipe = j2k_decode_pipeline(j2kfiles, batch_size=1, num_threads=2, device_id=0, debug=True)
    pipe.build()

    for i, f in enumerate(j2kfiles):
        patient, image = f.stem.split('_')
        dicom = dicomsdl.open(str(IMG_DIR/f"{patient}/{image}.dcm"))
        
        try:
            out = pipe.run()

            # Dali -> Torch
            img = out[0][0]
            img_torch = torch.empty(img.shape(), dtype=torch.int16, device="cuda")
            feed_ndarray(img, img_torch, cuda_stream=torch.cuda.current_stream(device=0))
            img = img_torch.float()

            # Scale, resize, invert on GPU !
            img = torch_voi_lut(img, dicom)

            if dicom.PhotometricInterpretation == "MONOCHROME1":
                img = img.amax() - img

            min_, max_ = img.amin(), img.amax()
            img = (img - min_) / (max_ - min_)

            if SIZE:
                img = F.interpolate(
                    img.view(1, 1, img.size(0), img.size(1)), (SIZE, SIZE), mode="bilinear")[0, 0]

            # Back to CPU + SAVE
            img = (img * 255).cpu().numpy().astype(np.uint8)

            cv2.imwrite(str(EXPORT_DIR /f"{patient}_{image}.png"), img)
        
        except Exception as e:

            print(i, e)
            pipe = j2k_decode_pipeline(j2kfiles[i+1:], batch_size=1, num_threads=2, device_id=0, debug=True)
            pipe.build()
            continue

    shutil.rmtree(J2K_FOLDER)

In [None]:
fns = glob.glob(f'{EXPORT_DIR}/*.png')
n_saved = len(fns)
print(f'Image on disk count : {n_saved}')

In [None]:
gpu_processed_files = [fn.split('/')[-1].replace('_','/').replace('png','dcm') for fn in fns]
to_process_images = [f for f in test_images if '/'.join(str(f).split('/')[-2:]) not in gpu_processed_files]
len(gpu_processed_files), len(to_process_images)

In [None]:
def process(f, size=2048, save_folder=""):
    patient, image = f.parents[0].stem, f.stem

    dicom = pydicom.dcmread(f)

    try:
        img = load_image_dicomsdl(f, voi_lut=True, size=size)
    except:
        img = load_image_pydicom(dicom, voi_lut=True, size=size)

    cv2.imwrite(str(save_folder/f"{patient}_{image}.png"), (img * 255).astype(np.uint8))

In [None]:
_ = Parallel(n_jobs=2)(
    delayed(process)(img, size=SIZE, save_folder=EXPORT_DIR)
    for img in tqdm(to_process_images)
)

In [None]:
!cp ../input/rsna-mammo-2023/dataset/*.py ./
!ln -s ../input/rsna-mammo-2023/dataset/kuma_utils/ 
!ln -s ../input/rsna-mammo-2023/dataset/timm/ 
!ln -s ../input/rsna-mammo-2023/dataset/iterstrat/ 
!ln -s ../input/rsna-mammo-2023/dataset/segmentation_models_pytorch/ 
!ln -s ../input/rsna-mammo-2023/dataset/global_objectives/

from configs import *
from kuma_utils.utils import sigmoid
from timm.layers import convert_sync_batchnorm
from metrics import Pfbeta, PercentilePfbeta

from torch.cuda import amp
from tqdm.auto import tqdm
import pickle

In [None]:
!cp ../input/ishikei-mammo/*.py ./
from ishikei_configs import *

In [None]:
!pip install /kaggle/input/einops-download/einops-0.6.0-py3-none-any.whl > /dev/null

from functools import partial

import torch
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from einops import rearrange
import timm
from timm.models.layers import DropPath, trunc_normal_
from timm.models import register_model
from torch import nn
import warnings
warnings.simplefilter('ignore')

def merge_pre_bn(module, pre_bn_1, pre_bn_2=None):
    """ Merge pre BN to reduce inference runtime.
    """
    weight = module.weight.data
    if module.bias is None:
        zeros = torch.zeros(module.out_channels, device=weight.device).type(weight.type())
        module.bias = nn.Parameter(zeros)
    bias = module.bias.data
    if pre_bn_2 is None:
        assert pre_bn_1.track_running_stats is True, "Unsupport bn_module.track_running_stats is False"
        assert pre_bn_1.affine is True, "Unsupport bn_module.affine is False"

        scale_invstd = pre_bn_1.running_var.add(pre_bn_1.eps).pow(-0.5)
        extra_weight = scale_invstd * pre_bn_1.weight
        extra_bias = pre_bn_1.bias - pre_bn_1.weight * pre_bn_1.running_mean * scale_invstd
    else:
        assert pre_bn_1.track_running_stats is True, "Unsupport bn_module.track_running_stats is False"
        assert pre_bn_1.affine is True, "Unsupport bn_module.affine is False"

        assert pre_bn_2.track_running_stats is True, "Unsupport bn_module.track_running_stats is False"
        assert pre_bn_2.affine is True, "Unsupport bn_module.affine is False"

        scale_invstd_1 = pre_bn_1.running_var.add(pre_bn_1.eps).pow(-0.5)
        scale_invstd_2 = pre_bn_2.running_var.add(pre_bn_2.eps).pow(-0.5)

        extra_weight = scale_invstd_1 * pre_bn_1.weight * scale_invstd_2 * pre_bn_2.weight
        extra_bias = scale_invstd_2 * pre_bn_2.weight *(pre_bn_1.bias - pre_bn_1.weight * pre_bn_1.running_mean * scale_invstd_1 - pre_bn_2.running_mean) + pre_bn_2.bias

    if isinstance(module, nn.Linear):
        extra_bias = weight @ extra_bias
        weight.mul_(extra_weight.view(1, weight.size(1)).expand_as(weight))
    elif isinstance(module, nn.Conv2d):
        assert weight.shape[2] == 1 and weight.shape[3] == 1
        weight = weight.reshape(weight.shape[0], weight.shape[1])
        extra_bias = weight @ extra_bias
        weight.mul_(extra_weight.view(1, weight.size(1)).expand_as(weight))
        weight = weight.reshape(weight.shape[0], weight.shape[1], 1, 1)
    bias.add_(extra_bias)

    module.weight.data = weight
    module.bias.data = bias



NORM_EPS = 1e-5


class ConvBNReLU(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            stride,
            groups=1):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
                              padding=1, groups=groups, bias=False)
        self.norm = nn.BatchNorm2d(out_channels, eps=NORM_EPS)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        return x


def _make_divisible(v, divisor, min_value=None):
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class PatchEmbed(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 stride=1):
        super(PatchEmbed, self).__init__()
        norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS)
        if stride == 2:
            self.avgpool = nn.AvgPool2d((2, 2), stride=2, ceil_mode=True, count_include_pad=False)
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        elif in_channels != out_channels:
            self.avgpool = nn.Identity()
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        else:
            self.avgpool = nn.Identity()
            self.conv = nn.Identity()
            self.norm = nn.Identity()

    def forward(self, x):
        return self.norm(self.conv(self.avgpool(x)))


class MHCA(nn.Module):
    """
    Multi-Head Convolutional Attention
    """
    def __init__(self, out_channels, head_dim):
        super(MHCA, self).__init__()
        norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS)
        self.group_conv3x3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1,
                                       padding=1, groups=out_channels // head_dim, bias=False)
        self.norm = norm_layer(out_channels)
        self.act = nn.ReLU(inplace=True)
        self.projection = nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False)

    def forward(self, x):
        out = self.group_conv3x3(x)
        out = self.norm(out)
        out = self.act(out)
        out = self.projection(out)
        return out


class Mlp(nn.Module):
    def __init__(self, in_features, out_features=None, mlp_ratio=None, drop=0., bias=True):
        super().__init__()
        out_features = out_features or in_features
        hidden_dim = _make_divisible(in_features * mlp_ratio, 32)
        self.conv1 = nn.Conv2d(in_features, hidden_dim, kernel_size=1, bias=bias)
        self.act = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(hidden_dim, out_features, kernel_size=1, bias=bias)
        self.drop = nn.Dropout(drop)

    def merge_bn(self, pre_norm):
        merge_pre_bn(self.conv1, pre_norm)

    def forward(self, x):
        x = self.conv1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.conv2(x)
        x = self.drop(x)
        return x


class NCB(nn.Module):
    """
    Next Convolution Block
    """
    def __init__(self, in_channels, out_channels, stride=1, path_dropout=0,
                 drop=0, head_dim=32, mlp_ratio=3):
        super(NCB, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS)
        assert out_channels % head_dim == 0

        self.patch_embed = PatchEmbed(in_channels, out_channels, stride)
        self.mhca = MHCA(out_channels, head_dim)
        self.attention_path_dropout = DropPath(path_dropout)

        self.norm = norm_layer(out_channels)
        self.mlp = Mlp(out_channels, mlp_ratio=mlp_ratio, drop=drop, bias=True)
        self.mlp_path_dropout = DropPath(path_dropout)
        self.is_bn_merged = False

    def merge_bn(self):
        if not self.is_bn_merged:
            self.mlp.merge_bn(self.norm)
            self.is_bn_merged = True

    def forward(self, x):
        x = self.patch_embed(x)
        x = x + self.attention_path_dropout(self.mhca(x))
        if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
            out = self.norm(x)
        else:
            out = x
        x = x + self.mlp_path_dropout(self.mlp(out))
        return x


class E_MHSA(nn.Module):
    """
    Efficient Multi-Head Self Attention
    """
    def __init__(self, dim, out_dim=None, head_dim=32, qkv_bias=True, qk_scale=None,
                 attn_drop=0, proj_drop=0., sr_ratio=1):
        super().__init__()
        self.dim = dim
        self.out_dim = out_dim if out_dim is not None else dim
        self.num_heads = self.dim // head_dim
        self.scale = qk_scale or head_dim ** -0.5
        self.q = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.k = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.v = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.proj = nn.Linear(self.dim, self.out_dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

        self.sr_ratio = sr_ratio
        self.N_ratio = sr_ratio ** 2
        if sr_ratio > 1:
            self.sr = nn.AvgPool1d(kernel_size=self.N_ratio, stride=self.N_ratio)
            self.norm = nn.BatchNorm1d(dim, eps=NORM_EPS)
        self.is_bn_merged = False

    def merge_bn(self, pre_bn):
        merge_pre_bn(self.q, pre_bn)
        if self.sr_ratio > 1:
            merge_pre_bn(self.k, pre_bn, self.norm)
            merge_pre_bn(self.v, pre_bn, self.norm)
        else:
            merge_pre_bn(self.k, pre_bn)
            merge_pre_bn(self.v, pre_bn)
        self.is_bn_merged = True

    def forward(self, x):
        B, N, C = x.shape
        q = self.q(x)
        q = q.reshape(B, N, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)

        if self.sr_ratio > 1:
            x_ = x.transpose(1, 2)
            x_ = self.sr(x_)
            if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
                x_ = self.norm(x_)
            x_ = x_.transpose(1, 2)
            k = self.k(x_)
            k = k.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 3, 1)
            v = self.v(x_)
            v = v.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
        else:
            k = self.k(x)
            k = k.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 3, 1)
            v = self.v(x)
            v = v.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
        attn = (q @ k) * self.scale

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class NTB(nn.Module):
    """
    Next Transformer Block
    """
    def __init__(
            self, in_channels, out_channels, path_dropout, stride=1, sr_ratio=1,
            mlp_ratio=2, head_dim=32, mix_block_ratio=0.75, attn_drop=0, drop=0,
    ):
        super(NTB, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.mix_block_ratio = mix_block_ratio
        norm_func = partial(nn.BatchNorm2d, eps=NORM_EPS)

        self.mhsa_out_channels = _make_divisible(int(out_channels * mix_block_ratio), 32)
        self.mhca_out_channels = out_channels - self.mhsa_out_channels

        self.patch_embed = PatchEmbed(in_channels, self.mhsa_out_channels, stride)
        self.norm1 = norm_func(self.mhsa_out_channels)
        self.e_mhsa = E_MHSA(self.mhsa_out_channels, head_dim=head_dim, sr_ratio=sr_ratio,
                             attn_drop=attn_drop, proj_drop=drop)
        self.mhsa_path_dropout = DropPath(path_dropout * mix_block_ratio)

        self.projection = PatchEmbed(self.mhsa_out_channels, self.mhca_out_channels, stride=1)
        self.mhca = MHCA(self.mhca_out_channels, head_dim=head_dim)
        self.mhca_path_dropout = DropPath(path_dropout * (1 - mix_block_ratio))

        self.norm2 = norm_func(out_channels)
        self.mlp = Mlp(out_channels, mlp_ratio=mlp_ratio, drop=drop)
        self.mlp_path_dropout = DropPath(path_dropout)

        self.is_bn_merged = False

    def merge_bn(self):
        if not self.is_bn_merged:
            self.e_mhsa.merge_bn(self.norm1)
            self.mlp.merge_bn(self.norm2)
            self.is_bn_merged = True

    def forward(self, x):
        x = self.patch_embed(x)
        B, C, H, W = x.shape
        if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
            out = self.norm1(x)
        else:
            out = x
        out = rearrange(out, "b c h w -> b (h w) c")  # b n c
        out = self.mhsa_path_dropout(self.e_mhsa(out))
        x = x + rearrange(out, "b (h w) c -> b c h w", h=H)

        out = self.projection(x)
        out = out + self.mhca_path_dropout(self.mhca(out))
        x = torch.cat([x, out], dim=1)

        if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
            out = self.norm2(x)
        else:
            out = x
        x = x + self.mlp_path_dropout(self.mlp(out))
        return x


class NextViT(nn.Module):
    def __init__(self, stem_chs, depths, path_dropout, attn_drop=0, drop=0, num_classes=1000,
                 strides=[1, 2, 2, 2], sr_ratios=[8, 4, 2, 1], head_dim=32, mix_block_ratio=0.75,
                 use_checkpoint=False, pretrained_cfg_overlay=None):
        super(NextViT, self).__init__()
        self.use_checkpoint = use_checkpoint

        self.stage_out_channels = [[96] * (depths[0]),
                                   [192] * (depths[1] - 1) + [256],
                                   [384, 384, 384, 384, 512] * (depths[2] // 5),
                                   [768] * (depths[3] - 1) + [1024]]

        # Next Hybrid Strategy
        self.stage_block_types = [[NCB] * depths[0],
                                  [NCB] * (depths[1] - 1) + [NTB],
                                  [NCB, NCB, NCB, NCB, NTB] * (depths[2] // 5),
                                  [NCB] * (depths[3] - 1) + [NTB]]

        self.stem = nn.Sequential(
            ConvBNReLU(3, stem_chs[0], kernel_size=3, stride=2),
            ConvBNReLU(stem_chs[0], stem_chs[1], kernel_size=3, stride=1),
            ConvBNReLU(stem_chs[1], stem_chs[2], kernel_size=3, stride=1),
            ConvBNReLU(stem_chs[2], stem_chs[2], kernel_size=3, stride=2),
        )
        input_channel = stem_chs[-1]
        features = []
        idx = 0
        dpr = [x.item() for x in torch.linspace(0, path_dropout, sum(depths))]  # stochastic depth decay rule
        for stage_id in range(len(depths)):
            numrepeat = depths[stage_id]
            output_channels = self.stage_out_channels[stage_id]
            block_types = self.stage_block_types[stage_id]
            for block_id in range(numrepeat):
                if strides[stage_id] == 2 and block_id == 0:
                    stride = 2
                else:
                    stride = 1
                output_channel = output_channels[block_id]
                block_type = block_types[block_id]
                if block_type is NCB:
                    layer = NCB(input_channel, output_channel, stride=stride, path_dropout=dpr[idx + block_id],
                                drop=drop, head_dim=head_dim)
                    features.append(layer)
                elif block_type is NTB:
                    layer = NTB(input_channel, output_channel, path_dropout=dpr[idx + block_id], stride=stride,
                                sr_ratio=sr_ratios[stage_id], head_dim=head_dim, mix_block_ratio=mix_block_ratio,
                                attn_drop=attn_drop, drop=drop)
                    features.append(layer)
                input_channel = output_channel
            idx += numrepeat
        self.features = nn.Sequential(*features)

        self.norm = nn.BatchNorm2d(output_channel, eps=NORM_EPS)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.proj_head = nn.Sequential(
            nn.Linear(output_channel, num_classes),
        )

        self.stage_out_idx = [sum(depths[:idx + 1]) - 1 for idx in range(len(depths))]
#         print('initialize_weights...')
        self._initialize_weights()

    def merge_bn(self):
        self.eval()
        for idx, module in self.named_modules():
            if isinstance(module, NCB) or isinstance(module, NTB):
                module.merge_bn()

    def _initialize_weights(self):
        for n, m in self.named_modules():
            if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm, nn.BatchNorm1d)):
                nn.init.constant_(m.weight, 1.0)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Conv2d):
                trunc_normal_(m.weight, std=.02)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.stem(x)
        for idx, layer in enumerate(self.features):
            if self.use_checkpoint:
                x = checkpoint.checkpoint(layer, x)
            else:
                x = layer(x)
        x = self.norm(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.proj_head(x)
        return x


@register_model
def nextvit_small(pretrained=False, pretrained_cfg=None, **kwargs):
    model = NextViT(stem_chs=[64, 32, 64], depths=[3, 4, 10, 3], path_dropout=0.1, **kwargs)
    return model


@register_model
def nextvit_base(pretrained=False, pretrained_cfg=None, **kwargs):
    model = NextViT(stem_chs=[64, 32, 64], depths=[3, 4, 20, 3], path_dropout=0.2, **kwargs)
    return model


@register_model
def nextvit_large(pretrained=False, pretrained_cfg=None, **kwargs):
    model = NextViT(stem_chs=[64, 32, 64], depths=[3, 4, 30, 3], path_dropout=0.2, **kwargs)
    return model

class NextVitNet(nn.Module):
    def __init__(self, model_name, pretrained=True, num_classes=1, pretrained_cfg_overlay=None):
        super(NextVitNet, self).__init__()
        # self.register_buffer('mean', torch.FloatTensor([0.5, 0.5, 0.5]).reshape(1, 3, 1, 1))
        # self.register_buffer('std', torch.FloatTensor([0.5, 0.5, 0.5]).reshape(1, 3, 1, 1))
        self.encoder = timm.create_model(model_name)
        self.encoder.proj_head = nn.Linear(1024, num_classes)
        # self.encoder.norm = nn.Identity()
        # self.encoder.avgpool = nn.Identity()
        # self.encoder.proj_head = nn.Identity()
        # self.fc = nn.Linear(1024,1)

    def forward(self, x):
        # x = (x - self.mean) / self.std
        x = self.encoder(x)
        # x = F.adaptive_avg_pool2d(x,1)
        # x = torch.flatten(x,1,3)
        # x = self.fc(x).reshape(-1)
        return x
    
    
from torch.utils.data import DataLoader, Dataset
import time
import albumentations as A
from albumentations.pytorch import ToTensorV2

class RSNADatasetAriyasu(Dataset):
    def __init__(self, paths, cfg):
        self.paths = paths
        self.transforms = cfg.transforms
        self.reduce_0_area_th = cfg.reduce_0_area_th

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        image = cv2.imread(path)[:,:,0]

        reduce_0_area_th = self.reduce_0_area_th
        image = image[:, image.mean(0)>reduce_0_area_th]
        image = image[image.mean(1)>reduce_0_area_th]
        image = image[:, image.mean(0)>reduce_0_area_th]
        image = image[image.mean(1)>reduce_0_area_th]

        if image.sum()==0:
            image = cv2.imread(path)[:,:,0]
        image = np.array([image, image, image]).transpose((1,2,0))

        image = self.transforms(image=image)['image']
        return image

# configs


class ariyasu_nano_config():
    def __init__(self):
        self.model_names = ['convnext_nano.in12k_ft_in1k']*4
        self.batch_size = 8
        self.num_classes = 3
        self.reduce_0_area_th = 18
        self.tta = 1
        self.transforms = A.Compose([
            A.Resize(1536, 768),
            A.Normalize(),
            ToTensorV2(),
        ])
        self.model_paths = [f'/kaggle/input/distill-mse-temp2-mixup-heavy-aug-temp3/last_fold{fold}.ckpt' for fold in range(4)]
        self.file_name = 'distill_convnext_nano_oof'

class ariyasu_large_config():
    def __init__(self):
        self.model_names = ['convnext_large.fb_in22k_ft_in1k_384']*4
        self.batch_size = 2
        self.num_classes = 3
        self.reduce_0_area_th = 18
        self.tta = 1
        self.transforms = A.Compose([
            A.Resize(1536, 768),
            A.Normalize(),
            ToTensorV2(),
        ])
        if DEBUG:
            self.model_paths = [f'/kaggle/input/large-kumadata-1536/last_fold{DEBUG_FOLD}.ckpt' for fold in range(4)]
        else:
            self.model_paths = [f'/kaggle/input/large-kumadata-1536/last_fold{fold}.ckpt' for fold in range(4)]
        self.file_name = 'convnext_large_oof'

class ariyasu_xlarge_config():
    def __init__(self):
        self.model_names = ['convnext_xlarge.fb_in22k_ft_in1k_384']*4
        self.batch_size = 2
        self.num_classes = 3
        self.reduce_0_area_th = 18
        self.tta = 1
        self.transforms = A.Compose([
            A.Resize(1536, 768),
            A.Normalize(),
            ToTensorV2(),
        ])
        if DEBUG:
            self.model_paths = [f'/kaggle/input/xlarge-kumadata-1536/last_fold{DEBUG_FOLD}.ckpt' for fold in range(4)]
        else:
            self.model_paths = [f'/kaggle/input/xlarge-kumadata-1536/last_fold{fold}.ckpt' for fold in range(4)]
        self.file_name = 'convnext_xlarge_oof'

class ariyasu_nextvit_base_config():
    def __init__(self):
        self.model_names = ['nextvit_base']*4
        self.batch_size = 4
        self.num_classes = 3
        self.reduce_0_area_th = 18
        self.tta = 1
        self.transforms = A.Compose([
            A.Resize(1536, 768),
            A.Normalize(),
            ToTensorV2(),
        ])
        if DEBUG:
            self.model_paths = [f'/kaggle/input/nextvit-base-kumadata-1536/last_fold{DEBUG_FOLD}.ckpt' for fold in range(4)]
        else:
            self.model_paths = [f'/kaggle/input/nextvit-base-kumadata-1536/last_fold{fold}.ckpt' for fold in range(4)]
        self.file_name = 'nextvit_base_oof'

In [None]:
import sys
sys.path.append('/kaggle/input/rsna-mammo-charm-modules')
!pip install /kaggle/input/omegaconf/omegaconf-2.0.5-py3-none-any.whl
!ln -s ../input/rsna-mammo-2023/dataset/timm/ 

import yaml
from typing import Dict, Tuple
from pathlib import Path
from collections import defaultdict
from tqdm.auto import tqdm


import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from omegaconf import DictConfig, OmegaConf

from src.nn.backbone import load_backbone
from src.nn.backbones.base import BackboneBase
from src.nn.pool.pool import ChannelWiseGeM, GeM

def get_weights_to_load(model: nn.Module, ckpt: Dict[str, Tensor]) -> Dict[str, Tensor]:
    model_dict = model.state_dict()
    for ckpt_key, ckpt_weight in ckpt.items():
        if ckpt_key not in model_dict:
            pass
        else:
            if ckpt_weight.size() != model_dict[ckpt_key].size():
                pass
            else:
                model_dict[ckpt_key] = ckpt_weight
    return model_dict


def init_model_from_config(cfg: DictConfig, weight_path: str):
    model = nn.Sequential()
    backbone = init_backbone(cfg, pretrained=False)
    forward_features = nn.Sequential()

    forward_features.add_module("backbone", backbone)
    if cfg.pool.type == "adaptive":
        forward_features.add_module("pool", nn.AdaptiveAvgPool2d((1, 1)))
        forward_features.add_module("flatten", nn.Flatten())
    elif cfg.pool.type == "gem":
        forward_features.add_module(
            "pool", GeM(p=cfg.pool.p, p_trainable=cfg.pool.p_trainable)
        )
        forward_features.add_module("flatten", nn.Flatten())
    elif cfg.pool.type == "gem_ch":
        forward_features.add_module(
            "pool",
            ChannelWiseGeM(
                dim=backbone.out_features,
                p=cfg.pool.p,
                requires_grad=cfg.pool.p_trainable,
            ),
        )
        forward_features.add_module("flatten", nn.Flatten())

    if cfg.use_bn:
        forward_features.add_module("normalize", nn.BatchNorm1d(backbone.out_features))
        forward_features.add_module("relu", torch.nn.PReLU())

    model.add_module("forward_features", forward_features)
    if cfg.head.type == "linear":
        out_features = backbone.out_features
        if cfg.use_multi_view:
            out_features *= 2
        if cfg.use_multi_lat:
            out_features *= 2
        # "cancer", "biopsy", "invasive", "age_scaled", "BIRADS_scaled", "machine_id_enc", "site_id"
        head = nn.Linear(out_features, 1, bias=True)
        head_biopsy = nn.Linear(out_features, 1, bias=True)
        head_invasive = nn.Linear(out_features, 1, bias=True)
        head_birads = nn.Linear(out_features, 1, bias=True)
        head_difficult_negative_case = nn.Linear(out_features, 1, bias=True)
        head_age = nn.Linear(out_features, 1, bias=True)
        head_machine_id = nn.Linear(out_features, 11, bias=True)
        head_site_id = nn.Linear(out_features, 1, bias=True)
        if cfg.use_multi_lat:
            # LR model
            head_2 = nn.Linear(out_features, 1, bias=True)
            head_biopsy_2 = nn.Linear(out_features, 1, bias=True)
            head_invasive_2 = nn.Linear(out_features, 1, bias=True)
            head_birads_2 = nn.Linear(out_features, 1, bias=True)
            head_difficult_negative_case_2 = nn.Linear(out_features, 1, bias=True)
    else:
        raise ValueError(f"{cfg.head.type} is not implemented")

    head_all = nn.Sequential()
    head_all.add_module("head", head)
    head_all.add_module("head_biopsy", head_biopsy)
    head_all.add_module("head_invasive", head_invasive)
    head_all.add_module("head_birads", head_birads)
    head_all.add_module("head_difficult_negative_case", head_difficult_negative_case)
    head_all.add_module("head_age", head_age)
    head_all.add_module("head_machine_id", head_machine_id)
    head_all.add_module("head_site_id", head_site_id)
    if cfg.use_multi_lat:
        head_all.add_module("head_2", head_2)
        head_all.add_module("head_biopsy_2", head_biopsy_2)
        head_all.add_module("head_invasive_2", head_invasive_2)
        head_all.add_module("head_birads_2", head_birads_2)
        head_all.add_module("head_difficult_negative_case_2", head_difficult_negative_case_2)
    model.add_module("head", head_all)

    ckpt = torch.load(weight_path, map_location="cpu")
    model_dict = get_weights_to_load(model, ckpt)
    model.load_state_dict(model_dict, strict=True)

    return model


def init_backbone(cfg: DictConfig, pretrained: bool) -> BackboneBase:
    in_chans = cfg.in_chans
    backbone = load_backbone(
        base_model=cfg.base_model,
        pretrained=pretrained,
        in_chans=in_chans,
    )
    if cfg.grad_checkpointing:
        backbone.set_grad_checkpointing()
    if cfg.freeze_backbone:
        for param in backbone.parameters():
            param.requires_grad = False
    return backbone


class Forwarder(nn.Module):
    def __init__(self, model: nn.Module, cfg: DictConfig) -> None:
        super().__init__()
        self.model = model
        self.cfg = cfg

    def forward(self, inputs: Tensor):
        use_multi_view = self.cfg.forwarder.use_multi_view
        use_multi_lat = self.cfg.forwarder.use_multi_lat
        use_single_view = (not use_multi_view) and (not use_multi_lat)

        # inputs: Input tensor.
        # reshape for multiple image model
        bs, ch, h, w = inputs.shape
        if use_multi_view or use_multi_lat:
            inputs = inputs.view(bs * ch, 1, h, w)

        if use_single_view:
            assert ch == 1
        elif use_multi_lat:
            assert ch == 4
        elif use_multi_view:
            assert ch == 2

        # extract features
        embed_features = self.model.forward_features(inputs)
        if use_multi_view or use_multi_lat:
            embed_features = embed_features.view(bs, -1)

        # head
        logits = self.model.head.head(embed_features)
        if use_multi_lat:
            logits_2 = self.model.head.head_2(embed_features)
            return logits, logits_2

        return logits
    
    
class RSNADataset(Dataset):
    
    ROOT_PATH = Path("/tmp/output/")
    
    def __init__(
        self,
        df: pd.DataFrame,
        cfg: DictConfig=None,
    ) -> None:
        root = self.ROOT_PATH
        self.df = df.copy()
        self.roi_th = cfg.dataset.roi_th
        self.roi_buffer = cfg.dataset.roi_buffer
        self.use_multi_view = cfg.dataset.use_multi_view
        self.use_multi_lat = cfg.dataset.use_multi_lat
        self.prediction_id_to_filename_dict = self.get_prediction_id_to_filename_map(self.df)
        self.df["prediction_id"] = self.df["patient_id"].astype(str) + "_" + self.df["laterality"]
        self.df["filename"] = self.df["patient_id"].astype(str) + "_" + self.df["image_id"].astype(str) + ".png"
        self.index = np.arange(len(self.df))
        self.use_single_view = (not self.use_multi_view) and (not self.use_multi_lat)
        if self.use_single_view:
            self.index = np.arange(len(self.df))
        elif self.use_multi_lat:
            self.index = self.df["patient_id"].unique()
        elif self.use_multi_view:
            self.index = self.df["prediction_id"].unique()
        
        transforms = [
            # Targets: image, mask, bboxes, keypoints
            A.Resize(cfg.preprocessing.h_resize_to, cfg.preprocessing.w_resize_to, p=1),
            # Targets: image
            A.Normalize(mean=cfg.preprocessing.mean, std=cfg.preprocessing.std),
            # Targets: image, mask
            ToTensorV2(transpose_mask=True),
        ]
        self.transform = A.Compose(transforms)
    
    def __len__(self) -> int:
        return len(self.index)

    def get_prediction_id_to_filename_map(self, df):
        prediction_id_dict = {}
        for idx, row in df.iterrows():
            patient_id = row["patient_id"]
            laterality = row["laterality"]
            prediction_id = f"{patient_id}_{laterality}"
            image_id = row["image_id"]
            view = row["view"]
            if prediction_id not in prediction_id_dict:
                prediction_id_dict[prediction_id] = defaultdict(list)
            prediction_id_dict[prediction_id][view].append(f"{patient_id}_{image_id}.png")
            
        for prediction_id in prediction_id_dict:
            for view in prediction_id_dict[prediction_id]:
                prediction_id_dict[prediction_id][view] = prediction_id_dict[prediction_id][view][0]
            
        return prediction_id_dict
    
    def get_prediction_ids_and_image_paths(self, index):
        root = self.ROOT_PATH
        if self.use_single_view:
            prediction_ids = [self.df["prediction_id"].values[index]]
            image_paths = [root / self.df["filename"].values[index]]
        elif self.use_multi_lat:
            prediction_id_l = f"{self.index[index]}_L"
            prediction_id_r = f"{self.index[index]}_R"
            prediction_ids = [prediction_id_l, prediction_id_r]
            filename_cc_l = self.prediction_id_to_filename_dict[prediction_id_l]["CC"]
            filename_mlo_l = self.prediction_id_to_filename_dict[prediction_id_l]["MLO"]
            filename_cc_r = self.prediction_id_to_filename_dict[prediction_id_r]["CC"]
            filename_mlo_r = self.prediction_id_to_filename_dict[prediction_id_r]["MLO"]
            image_paths = [root / filename_cc_l, root / filename_mlo_l, root / filename_cc_r, root / filename_mlo_r]
        elif self.use_multi_view:
            prediction_id = self.index[index]
            prediction_ids = [prediction_id]
            filename_cc = self.prediction_id_to_filename_dict[prediction_id]["CC"]
            filename_mlo = self.prediction_id_to_filename_dict[prediction_id]["MLO"]
            image_paths = [root / filename_cc, root / filename_mlo]
        return prediction_ids, image_paths
    
    def read_image(self, image_path):
        image = cv2.imread(str(image_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        return image
    
    def get_roi_crop(self, image, threshold=0.1, buffer=30):
        y_max, x_max = image.shape
        image2 = image > image.mean()
        y_mean = image2.mean(1)
        x_mean = image2.mean(0)
        x_mean[:5] = 0
        x_mean[-5:] = 0
        y_mean[:5] = 0
        y_mean[-5:] = 0
        y_mean = (y_mean - y_mean.min() + 1e-4) / (y_mean.max() - y_mean.min() + 1e-4)
        x_mean = (x_mean - x_mean.min() + 1e-4) / (x_mean.max() - x_mean.min() + 1e-4)
        y_slice = np.where(y_mean > threshold)[0]
        x_slice = np.where(x_mean > threshold)[0]
        if len(x_slice) == 0:
            x_start, x_end = 0, x_max
        else:
            x_start, x_end = max(x_slice.min() - buffer, 0), min(
                x_slice.max() + buffer, x_max
            )
        if len(y_slice) == 0:
            y_start, y_end = 0, y_max
        else:
            y_start, y_end = max(y_slice.min() - buffer, 0), min(
                y_slice.max() + buffer, y_max
            )
        return x_start, y_start, x_end, y_end
    
    def get_bbox(self, image):
        th = self.roi_th
        buffer = self.roi_buffer
        x_min, y_min, x_max, y_max = self.get_roi_crop(
            image, threshold=th, buffer=buffer
        )
        return x_min, y_min, x_max, y_max
    
    def crop_image(self, image):
        x_min, y_min, x_max, y_max = self.get_bbox(image)
        return image[y_min:y_max, x_min:x_max]
    
    def apply_transform(self, image):
        return self.transform(image=image)["image"]

    def __getitem__(self, index: int):
        prediction_ids, image_paths = self.get_prediction_ids_and_image_paths(index)
        images = [self.read_image(image_path) for image_path in image_paths]
        images = [self.crop_image(image) for image in images]
        images = torch.concat([self.apply_transform(image) for image in images])
        return images, prediction_ids

In [None]:
train = pd.read_csv('../input/rsna-breast-cancer-detection/train.csv')
test = pd.read_csv('../input/rsna-breast-cancer-detection/test.csv')
if DEBUG:
    test = train.loc[
        train['patient_id'].isin(SAMPLE_ID), 
        [
            'site_id', 'patient_id', 'image_id', 'laterality', 
            'view', 'age', 'implant', 'machine_id', 'cancer'
        ]]
else:
    test['cancer'] = 0

In [None]:
def get_threshold(prec_oof, predictions, test_sites, mode, threshold_params):
    folds = list(range(4))
    if mode == 'OOF': # concat oofs and search a new threshold (soft vote)
        oof_labels, oof_preds = prec_oof['oof_labels'], prec_oof['oof_preds']
        score, threshold = Pfbeta().optimal_f1(oof_labels, oof_preds)
        print(score, threshold)
        predictions = predictions.mean(0)
        # predictions = (predictions > threshold).astype(float)
        thresholds = [threshold, threshold]

    elif mode == 'PROPORTIONAL':
        predictions = predictions.mean(0)
        oof_labels, oof_preds, oof_folds = prec_oof['oof_labels'], prec_oof['oof_preds'], prec_oof['oof_folds']
        metric = PercentilePfbeta(percentile_range=[97.5, 99], n_trials=30)
        score, percentile, threshold = metric(oof_preds, oof_labels)
        print(score, percentile, threshold)
        threshold = np.percentile(predictions, percentile)
        print(threshold)
        # predictions = (predictions > threshold).astype(float)
        thresholds = [threshold, threshold]

    elif mode == 'PROPORTIONAL_PER_SITE':
        test_sites = np.array(test_sites)
        predictions = predictions.mean(0)
        oof_labels, oof_preds, oof_sites, oof_folds = \
            prec_oof['oof_labels'], prec_oof['oof_preds'],  prec_oof['oof_sites'], prec_oof['oof_folds'] 
        metric = PercentilePfbeta(percentile_range=[97.5, 99], n_trials=30)
        score, percentile, threshold = metric(oof_preds, oof_labels)
        score_s1, percentile_s1, threshold_s1 = metric(oof_preds[oof_sites==1], oof_labels[oof_sites==1])
        score_s2, percentile_s2, threshold_s2 = metric(oof_preds[oof_sites==2], oof_labels[oof_sites==2])
        print(score_s1, percentile_s1, threshold_s1)
        print(score_s2, percentile_s2, threshold_s2)
        threshold_s1 = np.percentile(predictions[test_sites == 1], percentile_s1)
        threshold_s2 = np.percentile(predictions[test_sites == 2], percentile_s2)
        print(threshold_s1, threshold_s2)
        # predictions[test_sites == 1] = (predictions[test_sites == 1] > threshold_s1).astype(float)
        # predictions[test_sites == 2] = (predictions[test_sites == 2] > threshold_s2).astype(float)
        thresholds = [threshold_s1, threshold_s2]

    elif mode == 'PROPORTIONAL_MEAN':
        predictions = predictions.mean(0)
        oof_labels, oof_preds, oof_folds = prec_oof['oof_labels'], prec_oof['oof_preds'], prec_oof['oof_folds']
        metric = PercentilePfbeta(percentile_range=[97.5, 99], n_trials=30)
        percentiles = []
        for fold in folds:
            score_f, percentile_f, threshold_f = metric(oof_preds[oof_folds==fold], oof_labels[oof_folds==fold])
            print(score_f, percentile_f, threshold_f)
            percentiles.append(percentile_f)
        mean_percentile = np.mean(percentiles)
        threshold = np.percentile(predictions, mean_percentile)
        print(mean_percentile, threshold)
        # predictions = (predictions > threshold).astype(float)
        thresholds = [threshold, threshold]

    elif mode == 'PROPORTIONAL_MEAN_PER_SITE':
        test_sites = np.array(test_sites)
        predictions = predictions.mean(0)
        oof_labels, oof_preds, oof_sites, oof_folds = \
            prec_oof['oof_labels'], prec_oof['oof_preds'],  prec_oof['oof_sites'], prec_oof['oof_folds'] 
        metric = PercentilePfbeta(percentile_range=[97.5, 99], n_trials=30)
        thresholds = []
        for site in [1, 2]:
            oof_preds_fold = oof_preds[oof_sites == site]
            oof_labels_fold = oof_labels[oof_sites == site]
            oof_folds_fold = oof_folds[oof_sites == site]
            percentiles = []
            for fold in folds:
                score_f, percentile_f, threshold_f = metric(
                    oof_preds_fold[oof_folds_fold==fold], oof_labels_fold[oof_folds_fold==fold])
                print(score_f, percentile_f, threshold_f)
                percentiles.append(percentile_f)
            mean_percentile = np.mean(percentiles)
            threshold = np.percentile(predictions[test_sites == site], mean_percentile)
            print(site, mean_percentile, threshold)
            # predictions[test_sites == site] = (predictions[test_sites == site] > threshold).astype(float)
            thresholds.append(threshold)

    elif mode == 'PROPORTIONAL_MEAN_PER_SITE_OFFSET':
        '''
        Assumptions: 
        all rejected samples have value smaller than the current inference set.

        for fold:
            percentile_fold = pf1(oof.fold, label.fold)
            offset = 100 * len(infer_df) / len(total_df)
            percentile_adjusted = (percentile_fold - offset) / (100 - offset)
        '''
        test_sites = np.array(test_sites)
        predictions = predictions.mean(0)
        offset = 100 * (1 - len(predictions) / threshold_params['total_size'])
        oof_labels, oof_preds, oof_sites, oof_folds = \
            prec_oof['oof_labels'], prec_oof['oof_preds'],  prec_oof['oof_sites'], prec_oof['oof_folds'] 
        metric = PercentilePfbeta(percentile_range=[97.5, 99], n_trials=30)
        thresholds = []
        for site in [1, 2]:
            oof_preds_fold = oof_preds[oof_sites == site]
            oof_labels_fold = oof_labels[oof_sites == site]
            oof_folds_fold = oof_folds[oof_sites == site]
            percentiles = []
            for fold in folds:
                score_f, percentile_f, threshold_f = metric(
                    oof_preds_fold[oof_folds_fold==fold], oof_labels_fold[oof_folds_fold==fold])
                print(score_f, percentile_f, threshold_f)
                percentiles.append(percentile_f)
            mean_percentile = np.mean(percentiles)
            mean_percentile = 100 * (mean_percentile - offset) / (100 - offset)
            print(f'1 - offset = {len(predictions)}/{threshold_params["total_size"]}')
            threshold = np.percentile(predictions[test_sites == site], mean_percentile)
            print(site, mean_percentile, threshold)
            # predictions[test_sites == site] = (predictions[test_sites == site] > threshold).astype(float)
            thresholds.append(threshold)

    elif mode == 'OOF_PER_SITE':
        test_sites = np.array(test_sites)
        predictions = predictions.mean(0)
        oof_labels, oof_preds, oof_sites = prec_oof['oof_labels'], prec_oof['oof_preds'], prec_oof['oof_sites']
        score_s1, threshold_s1 = Pfbeta().optimal_f1(oof_labels[oof_sites==1], oof_preds[oof_sites==1])
        score_s2, threshold_s2 = Pfbeta().optimal_f1(oof_labels[oof_sites==2], oof_preds[oof_sites==2])
        print(score_s1, threshold_s1)
        print(score_s2, threshold_s2)
        # predictions[test_sites == 1] = (predictions[test_sites == 1] > threshold_s1).astype(float)
        # predictions[test_sites == 2] = (predictions[test_sites == 2] > threshold_s2).astype(float)
        thresholds = [threshold_s1, threshold_s2]

    elif mode == 'OOF_MEAN_PER_SITE':
        test_sites = np.array(test_sites)
        predictions = predictions.mean(0)
        oof_labels, oof_preds, oof_sites, oof_folds = \
            prec_oof['oof_labels'], prec_oof['oof_preds'], prec_oof['oof_sites'], prec_oof['oof_folds']
        metric = PercentilePfbeta(percentile_range=[97.5, 99], n_trials=30)
        thresholds = []
        for site in [1, 2]:
            oof_preds_site = oof_preds[oof_sites == site]
            oof_labels_site = oof_labels[oof_sites == site]
            oof_folds_site = oof_folds[oof_sites == site]
            thresholds_site = []
            for fold in folds:
                score_f, percentile_f, threshold_f = metric(
                    oof_preds_site[oof_folds_site==fold], oof_labels_site[oof_folds_site==fold])
                print(fold, score_f, percentile_f, threshold_f)
                thresholds_site.append(threshold_f)
            mean_threshold = np.mean(thresholds_site)
            print(site, mean_threshold)
            # predictions[test_sites == site] = (predictions[test_sites == site] > mean_threshold).astype(float)
            thresholds.append(mean_threshold)
            
    elif mode == 'OOF_MEAN_PER_SITE_EASY':
        test_sites = np.array(test_sites)
        predictions = predictions.mean(0)
        oof_labels, oof_preds, oof_sites, oof_folds = \
            prec_oof['oof_labels'], prec_oof['oof_preds'], prec_oof['oof_sites'], prec_oof['oof_folds']
        oof_hardflags = prec_oof['oof_hardflags']
        oof_preds = oof_preds[oof_hardflags == 0]
        oof_labels = oof_labels[oof_hardflags == 0]
        oof_folds = oof_folds[oof_hardflags == 0]
        oof_sites = oof_sites[oof_hardflags == 0]
        print(f'{oof_hardflags.sum()} hard samples excluded from oof.')
        metric = PercentilePfbeta(percentile_range=[97.5, 99], n_trials=30)
        thresholds = []
        for site in [1, 2]:
            oof_preds_site = oof_preds[oof_sites == site]
            oof_labels_site = oof_labels[oof_sites == site]
            oof_folds_site = oof_folds[oof_sites == site]
            thresholds_site = []
            for fold in folds:
                score_f, percentile_f, threshold_f = metric(
                    oof_preds_site[oof_folds_site==fold], oof_labels_site[oof_folds_site==fold])
                print(fold, score_f, percentile_f, threshold_f)
                thresholds_site.append(threshold_f)
            mean_threshold = np.mean(thresholds_site)
            print(site, mean_threshold)
            # predictions[test_sites == site] = (predictions[test_sites == site] > mean_threshold).astype(float)
            thresholds.append(mean_threshold)

    return thresholds

In [None]:
class KumaInference:
    
    def get_oof_path(self, cfg):
        return Path('../input/rsna-oof/kuma_oofs_v3/')/f'kuma_{cfg.name}.csv'
    
    def get_dataloader(
        self, 
        cfg, 
        test_df, 
        results_dir=Path(f'../input/rsna-mammo-2023/dataset/results/'),
        batch_size_ratio=1):
        '''
        input parameters:
            cfg: kuma_utils config class
            test: test dataframe to predict
            results_dir: pathlib.Path results directory
            batch_size_ratio: bs = bs // ratio

        output parameters:
            test_loader: pytorch dataloader
            test_indices: list of indices corresponding to test_loader
                e.g. [(57, L), (57, R), (10007, L), ..., ]
        '''
        proj_dir = results_dir / cfg.name
        cfg.dataset_params['sep'] = '_'
        cfg.dataset_params['aux_target_cols'] = []
        if 'bbox_path' in cfg.dataset_params and cfg.dataset_params['bbox_path'] is not None:
            pd.DataFrame([{'name': 'dummy', 'xmin': 0, 'ymin': 0, 'xmax':0, 'ymax':0}]).to_csv(
                'dummy_bbox.csv', index=False)
            cfg.dataset_params['bbox_path'] = 'dummy_bbox.csv'
        else:
            cfg.dataset_params['bbox_path'] = None
        cfg.model_params['pretrained'] = False
        test_data = cfg.dataset(
            df=test_df, 
            image_dir=EXPORT_DIR,
            preprocess=cfg.preprocess['test'],
            transforms=cfg.transforms['test'],
            is_test=True,
            **cfg.dataset_params)
        test_loader = D.DataLoader(
            test_data, batch_size=cfg.batch_size//batch_size_ratio, shuffle=False,
            num_workers=2, pin_memory=False)

        if test_data.__class__.__name__ == 'PatientLevelDatasetLR':
            test_indices = []
            for pid in test_data.pids:
                test_indices.extend([(pid, 'L'), (pid, 'R')])
        else:
            test_indices = test_data.pids

        test_sites = [test_data.df_dict[test_data.pids[i]]['site_id'].values[0] for i in range(len(test_data))]
        return test_loader, test_indices, test_sites
    
    def infer_model(
        self, 
        cfg, 
        test_loader, 
        folds=None, 
        results_dir=Path('../input/rsna-mammo-2023/dataset/results/'),
        fp16=False):
        '''
        input parameters:
            cfg: kuma_utils config class
            test_loader: test dataloader
            folds: list of folds to include
            results_dir: pathlib.Path results directory

        output parameters:
            predictions: np.array [cv, test samples, 1]
        '''
        def _load_data(input_t):
            if isinstance(input_t, torch.Tensor):
                x = input_t
            elif isinstance(input_t, (list, tuple)):
                x = input_t[0]
            return x

        def _predict(model, x):
            with torch.no_grad():
                if fp16:
                    with amp.autocast():
                        y = model(x)
                else:
                    y = model(x)
            if isinstance(y, (list, tuple)):
                y = y[0]
            if len(y.shape) > 1 and y.shape[1] > 1:
                y = y[:, 0]
            return y.float().cpu().numpy()

        proj_dir = results_dir / cfg.name
        predictions = []
        if folds is None:
            folds = list(range(cfg.cv))
        models = []
        for fold in range(cfg.cv):
            if not fold in folds:
                continue
            model = cfg.model(**cfg.model_params)
            if DEBUG:
                checkpoint = torch.load(proj_dir/f'fold{DEBUG_FOLD}.pt', 'cpu')
                print(str(proj_dir/f'fold{DEBUG_FOLD}.pt'))
            else:
                checkpoint = torch.load(proj_dir/f'fold{fold}.pt', 'cpu')
                print(str(proj_dir/f'fold{fold}.pt'))
            model.load_state_dict(checkpoint['model'])
            if cfg.parallel == 'ddp':
                model = convert_sync_batchnorm(model)
            model.cuda()
            model.eval()
            models.append(model)
            del checkpoint; gc.collect()
        predictions = []
        for input_t in tqdm(test_loader):
            x = _load_data(input_t).cuda()
            ys = []
            for m in models:
                ys.append(_predict(m, x))
            ys = np.stack(ys, axis=0) # (cv, bs, 1)
            predictions.append(ys)
        predictions = np.concatenate(predictions, axis=1)
        del models; gc.collect()
        torch.cuda.empty_cache()
        return predictions
    
    def apply_threshold(
        self,
        cfg, 
        raw_predictions, 
        test_sites, 
        mode='OOF',
        threshold_params={}):
        '''
        input params:
            cfg: kuma_utils config class
            predictions: np.array [cv, test samples, 1]
            mode: str 
            folds: list of folds to include
            results_dir: pathlib.Path results directory
            use_extended_oof: False

        output params:
            predictions: np.array [test_samples, ]
            or thresholds: list [threshold for site1, threshold for site2]
        '''
        def load_precalculated_oof(path):
            oof = pd.read_csv(path)[['prediction_id', 'pred', 'site_id', 'fold']].merge(
                train[['prediction_id', 'target', 'difficult_negative_case']], on='prediction_id', how='left')
            return {
                'oof_labels': oof['target'].values,
                'oof_preds': oof['pred'].values,
                'oof_sites': oof['site_id'].values,
                'oof_folds': oof['fold'].values,
                'oof_hardflags': oof['difficult_negative_case'].astype(int).values
            }

        predictions = raw_predictions.copy()
        folds = list(range(4))
        precalculated_path = Path('../input/rsna-oof/kuma_oofs_v3/')/f'kuma_{cfg.name}.csv'
        print(precalculated_path)
        train = pd.read_csv('../input/rsna-misc/train_with_fold.csv').rename({'cancer': 'target'}, axis=1)
        train['prediction_id'] = train['patient_id'].astype(str) + '_' + train['laterality']
        train = train.groupby('prediction_id').agg({'target': 'max', 'fold': 'max', 'difficult_negative_case': 'max'}).reset_index()
        prec_oof = load_precalculated_oof(precalculated_path)
        thresholds = get_threshold(prec_oof, predictions, test_sites, mode, threshold_params)
        return thresholds
    
    def extend_df(self, df, max_record_per_patient=2, view_category=[['MLO', 'LMO', 'LM', 'ML'], ['CC', 'AT']]):
        '''
        Extend the train df to include multiple images
        '''
        def _sample_idx(df, used_ids=[], sample_all=False):
            new_pdf = []
            for iv, view_cat in enumerate(view_category):
                view0 = pdf.loc[pdf['view'].isin(view_cat) & ~pdf['image_id'].isin(used_ids)]
                if len(view0) == 0:
                    new_pdf.append(pdf.loc[pdf['view'].isin(view_cat)])
                elif sample_all:
                    new_pdf.append(view0)
                else:
                    new_pdf.append(view0.sample(min(len(view0), max(1, len(view0)//max_record_per_patient))))
            return pd.concat(new_pdf).reset_index(drop=True)

        new_df = []
        for plr, pdf in df.groupby(['patient_id', 'laterality']):
            if len(pdf) == 2:
                pdf['oversample_id'] = 0
                new_df.append(pdf)
            else:
                used_ids = []
                for i in range(max_record_per_patient):
                    idf = _sample_idx(pdf, used_ids, sample_all=i == max_record_per_patient-1)
                    idf['oversample_id'] = i
                    new_df.append(idf)
                    used_ids.extend(idf['image_id'].values.tolist())
        return pd.concat(new_df).reset_index(drop=True)
    
    def aggregate_prediction(self, predictions, test_indices, test_sites, agg_func='mean'):
        '''
        mean the LOGITS (not PROB)
        '''
        agg_predictions = []
        for pred_id, site, pred in zip(test_indices, test_sites, predictions.transpose(1, 0, 2)):
            record = {
                'prediction_id': f'{pred_id[1]}_{pred_id[2]}',
                'site_id': site,
                'pred0': pred[0][0], 'pred1': pred[1][0], 'pred2': pred[2][0], 'pred3': pred[3][0]
            }
            agg_predictions.append(record)
        agg_predictions = pd.DataFrame(agg_predictions)
        agg_predictions = agg_predictions.groupby('prediction_id').agg({
            'pred0': agg_func, 'pred1': agg_func, 'pred2': agg_func, 'pred3': agg_func, 
            'site_id': 'max'}).reset_index()
        return (
            agg_predictions[['pred0', 'pred1', 'pred2', 'pred3']].values.transpose(1, 0).reshape(4, -1, 1),
            agg_predictions['prediction_id'].values.tolist(),
            agg_predictions['site_id'].values.tolist())
    
    def inference_block(self, cfg, test_df, threshold_mode, batch_size_ratio=1, fp16=False, extend=0, threshold_params={}):
        if extend > 0:
            test_df = self.extend_df(test_df, max(2, 1+extend))
        test_loader, test_indices, test_sites = self.get_dataloader(cfg, test_df=test_df, batch_size_ratio=batch_size_ratio)
        predictions = self.infer_model(cfg, test_loader, fp16=fp16)
        if extend > 0:
            predictions, prediction_ids, test_sites = self.aggregate_prediction(
                predictions, test_indices, test_sites)
        else:
            prediction_ids = [f'{pid}_{lat}' for pid, lat in test_indices]
        predictions = sigmoid(predictions)
        thresholds = self.apply_threshold(
            cfg, predictions, test_sites, mode=threshold_mode, threshold_params=threshold_params)
        results = []
        for pred_id, pred, site in zip(prediction_ids, predictions.mean(0), test_sites):
            results.append({
                'prediction_id': pred_id,
                cfg.name: pred[0],
                'site_id': site
            })
        results = pd.DataFrame(results)
        return results, thresholds

In [None]:
class AriyasuInference:
    
    def get_oof_path(self, cfg):
        return list(Path('../input/rsna-oof/ariyasu/').glob(f'{cfg.file_name}*.csv'))[0]

    def get_models(self, cfg):
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        models = []
        for model_name, model_path in zip(cfg.model_names, cfg.model_paths):
            print(model_path)
            if 'nextvit' in model_name:
                model = NextVitNet(model_name, pretrained=False, num_classes=cfg.num_classes)
            else:
                model = timm.create_model(model_name, pretrained=False, num_classes=cfg.num_classes)
            state_dict = torch.load(model_path, map_location=torch.device('cpu'))['state_dict']
            torch_state_dict = {}
            for k, v in state_dict.items():
                torch_state_dict[k[6:]] = v
            model.load_state_dict(torch_state_dict)
            model.to(device)
            model.eval()
            # if 'nano' in model_name:
            #     model = torch_tensorrt.compile(model, inputs = [torch_tensorrt.Input(min_shape=[1, 3, image_size[1], image_size[0]],opt_shape=[cfg.batch_size, 3, image_size[1], image_size[0]],max_shape=[cfg.batch_size, 3, image_size[1], image_size[0]],dtype=torch.float32)],
            #         enabled_precisions = torch.float32, # Run with FP32
            #         workspace_size = 1 << 32
            #     )
            models.append(model)
        del model
        torch.cuda.empty_cache()
        return models

    def apply_threshold(
        self,
        cfg,
        raw_predictions,
        test_sites,
        mode='OOF',
        threshold_params={}):
        def load_precalculated_oof(path):
            oof = pd.read_csv(path)[['prediction_id', 'pred', 'site_id', 'fold']].merge(
                train[['prediction_id', 'target', 'difficult_negative_case']], on='prediction_id', how='left')
            return {
                'oof_labels': oof['target'].values,
                'oof_preds': oof['pred'].values,
                'oof_sites': oof['site_id'].values,
                'oof_folds': oof['fold'].values,
                'oof_hardflags': oof['difficult_negative_case'].astype(int).values
            }

        predictions = raw_predictions.copy()
        folds = list(range(4))
        precalculated_path = list(Path('../input/rsna-oof/ariyasu/').glob(f'{cfg.file_name}*.csv'))
        print(precalculated_path)
        precalculated_path = precalculated_path[0]
        train = pd.read_csv('../input/rsna-misc/train_with_fold.csv').rename({'cancer': 'target'}, axis=1)
        train['prediction_id'] = train['patient_id'].astype(str) + '_' + train['laterality']
        train = train.groupby('prediction_id').agg({'target': 'max', 'fold': 'max', 'difficult_negative_case': 'max'}).reset_index()
        prec_oof = load_precalculated_oof(precalculated_path)
        thresholds = get_threshold(prec_oof, predictions, test_sites, mode, threshold_params)
        return thresholds  
    
    def inference_block(self, cfg, test, threshold_mode, threshold_params={}, fp16=False):
        assert cfg.tta == 1
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        models = self.get_models(cfg)

        print('len(test):', len(test))
        print('tta:', cfg.tta)
        ds = RSNADatasetAriyasu(test.path.values, cfg)
        loader = DataLoader(ds, batch_size=cfg.batch_size, shuffle=False, drop_last=False, num_workers=2)
        preds = []
        with torch.no_grad():
            for images in tqdm(loader, smoothing=0):
                images = images.to(device)
#                 batch_preds = []
                ys = []
                for model in models:
                    if fp16:
                        with amp.autocast():
                            logits = model(images)
                    else:
                        logits = model(images)

                    ys.append(logits.sigmoid().cpu().detach().numpy())
                ys = np.stack(ys) # (cv, bs, 1)
                preds.append(ys)
        preds = np.concatenate(preds, axis=1)
                
        del models; torch.cuda.empty_cache()

        test[[f'pred_fold_{fold}' for fold in range(4)]] = preds[:,:,0].T
        for fold in range(4):
            df = pd.DataFrame(test.groupby('prediction_id')[f'pred_fold_{fold}'].mean()).reset_index()
            del test[f'pred_fold_{fold}']
            test = test.merge(df, on='prediction_id')
        df = test.drop_duplicates('prediction_id')
        predictions = df[[f'pred_fold_{fold}' for fold in range(4)]].values.T
        predictions = predictions[:,:,np.newaxis]
        _sites = df.site_id.values
        thresholds = self.apply_threshold(
            cfg, predictions, _sites, mode=threshold_mode, threshold_params=threshold_params)

        df[cfg.file_name] = df[[f'pred_fold_{fold}' for fold in range(4)]].mean(1)

        for col in [f'pred_fold_{fold}' for fold in range(4)]:
            del test[col]
        return df[['prediction_id', 'site_id', cfg.file_name]], thresholds

In [None]:
class IshikeiInference:

    def _init_cfg(self, cfg):
        cfg.dataset_params['sep'] = '_'
        cfg.dataset_params['aux_target_cols'] = []
        cfg.dataset_params['bbox_path'] = None
        cfg.model_params['pretrained'] = False
        if 'with_cp' in cfg.model_params:
            cfg.model_params['with_cp'] = False
        return cfg
    
    def get_oof_path(self, cfg):
        return list(Path('../input/rsna-oof/ishikei_002/').glob(f'{cfg.name}*.csv'))[0]

    def get_dataloader(self, cfg, test_df, batch_size=None):
        if batch_size is None:
            batch_size = cfg.batch_size
        
        # local - kaggle notebook diff
        cfg = self._init_cfg(cfg)
        
        test_data = cfg.dataset(
            df=test_df,
            image_dir=EXPORT_DIR,
            preprocess=cfg.preprocess['test'],
            transforms=cfg.transforms['test'],
            is_test=True,
            **cfg.dataset_params)

        test_loader = D.DataLoader(
            test_data, batch_size=batch_size, shuffle=False,
            num_workers=2, pin_memory=False)
    
        test_indices = test_data.pids
        test_sites = [
            test_data.df_dict[test_data.pids[i]]['site_id'].values[0] for i in range(len(test_data))
        ]
    
        return test_loader, test_data, test_indices, test_sites
    
    def infer_model(self, cfg, test_loader, test_data, batch_size=None, fp16=False):
        if batch_size is None:
            batch_size = cfg.batch_size

        proj_dir = Path(f'../input/ishikei-mammo/{cfg.name}')
        folds = list(range(cfg.cv))
        
        # get models.
        models = []
        for fold in folds:
            # init model.
            model = cfg.model(**cfg.model_params)

            if DEBUG:
                print(f"load: {str(proj_dir/f'fold{DEBUG_FOLD}.pt')}")
                checkpoint = torch.load(proj_dir/f'fold{DEBUG_FOLD}.pt', 'cpu')
            else:
                print(f"load: {str(proj_dir/f'fold{fold}.pt')}")
                checkpoint = torch.load(proj_dir/f'fold{fold}.pt', 'cpu')
            model.load_state_dict(checkpoint['model'])
            del checkpoint; gc.collect()
            
            if cfg.parallel == 'ddp':
                model = convert_sync_batchnorm(model)
            model.to('cuda:0')
            model.eval()
            models.append(model)
        
        # inference.
        predictions = []
        with torch.no_grad():
            for idx, data in tqdm(enumerate(test_loader), total=len(test_loader)):
                data = [d.to('cuda:0') for d in data]
                ys = []
                for model in models:
                    if fp16:
                        with amp.autocast():
                            logits = model(*data[:-1])
                    else:
                        logits = model(*data[:-1])
                    ys.append(logits.sigmoid().cpu().detach().numpy())
                ys = np.stack(ys) # (cv, bs, 1)
                predictions.append(ys)
        predictions = np.concatenate(predictions, axis=1)
        
        del models
        gc.collect(); torch.cuda.empty_cache()
        return predictions
    
    def apply_threshold(
        self,
        cfg, 
        raw_predictions, 
        test_sites,
        mode='OOF',
        threshold_params={},):
        def load_precalculated_oof(path):
            oof = pd.read_csv(path)[['prediction_id', 'pred', 'site_id', 'fold']].merge(
                train[['prediction_id', 'target', 'difficult_negative_case']], on='prediction_id', how='left')
            return {
                'oof_labels': oof['target'].values,
                'oof_preds': oof['pred'].values,
                'oof_sites': oof['site_id'].values,
                'oof_folds': oof['fold'].values,
                'oof_hardflags': oof['difficult_negative_case'].astype(int).values
            }

        predictions = raw_predictions.copy()
        folds = list(range(cfg.cv))
        precalculated_path = list(Path('../input/rsna-oof/ishikei_002/').glob(f'{cfg.name}*.csv'))
        print(precalculated_path)
        precalculated_path = precalculated_path[0]
        train = pd.read_csv('../input/rsna-misc/train_with_fold.csv').rename({'cancer': 'target'}, axis=1)
        train['prediction_id'] = train['patient_id'].astype(str) + '_' + train['laterality']
        train = train.groupby('prediction_id').agg({'target': 'max', 'fold': 'max', 'difficult_negative_case': 'max'}).reset_index()
        prec_oof = load_precalculated_oof(precalculated_path)
        thresholds = get_threshold(prec_oof, predictions, test_sites, mode, threshold_params)
        return thresholds
    
    def inference_block(self, cfg, test_df, threshold_mode, batch_size=None, fp16=False, threshold_params={}):
        '''
        input parameters:
            cfg: kuma_utils config class
            test_df: test dataframe to inference
            batch_size: inference batch_size

        output parameters:
            results: pd.DataFrame
        '''
        _loader, _dataset, _indices, _sites = self.get_dataloader(cfg, test_df, batch_size)
        predictions = self.infer_model(cfg, _loader, _dataset, batch_size, fp16=fp16)
        thresholds = self.apply_threshold(
            cfg, predictions, _sites, mode=threshold_mode, threshold_params=threshold_params)
        
        results = []
        for pred_id, site, pred in zip(_indices, _sites, predictions.mean(0)):
            results.append({'prediction_id': f'{pred_id[0]}_{pred_id[1]}',
                            'site_id': site,
                            cfg.name: pred[0]})
        results = pd.DataFrame(results)
        return results, thresholds

In [None]:
class CharmInterface:
    
    def get_oof_path(self, cfg):
        return Path('../input/rsna-oof/charm/')/f'{cfg.name}.csv'
    
    def get_dataloader(self, cfg, test_df, batch_size=2):
        model_name = cfg.name
        model_path = cfg.model_path
        cfg_path = f"{model_path}/fold_0/config.yaml"
        cfg = OmegaConf.create(yaml.safe_load(open(cfg_path)))
        dataset = RSNADataset(test_df, cfg)
        test_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=False)
        return test_dataloader
    
    def infer_model(self, config, test_dataloader, folds=[0,1,2,3]):
        model_name = config.name
        model_path = config.model_path
        cfg_path_list = [f"{model_path}/fold_{i}/config.yaml" for i in folds]
        cfg_list = [OmegaConf.create(yaml.safe_load(open(cfg_path))) for cfg_path in cfg_path_list]
        if DEBUG:
            model_path_list = [f"{model_path}/fold_{DEBUG_FOLD}/model_weights_ema.pth" for i in folds]
        else:
            model_path_list = [f"{model_path}/fold_{i}/model_weights_ema.pth" for i in folds]
        predictor_list = []
        for model_path, cfg in zip(model_path_list, cfg_list):
            print("Loading model from", model_path)
            predictor = Forwarder(init_model_from_config(cfg.model, model_path), cfg)
            predictor.eval()
            predictor.to("cuda")
            predictor_list.append(predictor)
        
        use_multi_lat = cfg_list[0].dataset.use_multi_lat
        prediction_ids_lr_list = [() for _ in range(len(predictor_list))]
        preds_list = [[] for _ in range(len(predictor_list))]
        if use_multi_lat:
            prediction_ids_l_list = [() for _ in range(len(predictor_list))]
            prediction_ids_r_list = [() for _ in range(len(predictor_list))]
            preds_l_list = [[] for _ in range(len(predictor_list))]
            preds_r_list = [[] for _ in range(len(predictor_list))]
        print("inference started")
        for inputs, prediction_ids in tqdm(test_dataloader):
            with torch.cuda.amp.autocast(enabled=True):
                inputs = inputs.to("cuda")
                with torch.no_grad():
                    if use_multi_lat:
                        for i in range(len(predictor_list)):
                            logits_l, logits_r = predictor_list[i](inputs)
                            preds_l_list[i].append(logits_l.sigmoid().cpu().detach().numpy().reshape(-1))
                            preds_r_list[i].append(logits_r.sigmoid().cpu().detach().numpy().reshape(-1))
                            prediction_ids_l_list[i] += prediction_ids[0]
                            prediction_ids_r_list[i] += prediction_ids[1]
                    else:
                        for i in range(len(predictor_list)):
                            logits = predictor_list[i](inputs)
                            preds_list[i].append(logits.sigmoid().cpu().detach().numpy().reshape(-1))
                            prediction_ids_lr_list[i] += prediction_ids[0]
        torch.cuda.empty_cache()
        print("inference finished")
        if use_multi_lat:
            for i in range(len(predictor_list)):
                preds_l_list[i] = np.hstack(preds_l_list[i])
                preds_r_list[i] = np.hstack(preds_r_list[i])
                preds_list[i] = np.hstack([preds_l_list[i], preds_r_list[i]])
                prediction_ids_lr_list[i] = prediction_ids_l_list[i] + prediction_ids_r_list[i]
        else:
            for i in range(len(predictor_list)):
                preds_list[i] = np.hstack(preds_list[i])
        for predictor in predictor_list:
            torch.cuda.empty_cache()
            del predictor; gc.collect()
        torch.cuda.empty_cache()
        pred_df_list = []
        for ifold, (preds, prediction_ids_lr) in enumerate(zip(preds_list, prediction_ids_lr_list)):
            pred_df = pd.DataFrame(data={"prediction_id": prediction_ids_lr, f"pred{ifold}": preds})
            pred_df_list.append(pred_df)

        return pred_df_list
    
    def apply_threshold(
        self,
        cfg, 
        raw_predictions, 
        test_sites, 
        mode='OOF',
        threshold_params={}):
        def load_precalculated_oof(path):
            oof = pd.read_csv(path)[['prediction_id', 'pred', 'site_id', 'fold']].merge(
                train[['prediction_id', 'target', 'difficult_negative_case']], on='prediction_id', how='left')
            return {
                'oof_labels': oof['target'].values,
                'oof_preds': oof['pred'].values,
                'oof_sites': oof['site_id'].values,
                'oof_folds': oof['fold'].values,
                'oof_hardflags': oof['difficult_negative_case'].astype(int).values
            }

        predictions = raw_predictions.copy()
        folds = list(range(4))
        precalculated_path = Path('../input/rsna-oof/charm/')/f'{cfg.name}.csv'
        print(precalculated_path)
        train = pd.read_csv('../input/rsna-misc/train_with_fold.csv').rename({'cancer': 'target'}, axis=1)
        train['prediction_id'] = train['patient_id'].astype(str) + '_' + train['laterality']
        train = train.groupby('prediction_id').agg({'target': 'max', 'fold': 'max', 'difficult_negative_case': 'max'}).reset_index()
        prec_oof = load_precalculated_oof(precalculated_path)
        thresholds = get_threshold(prec_oof, predictions, test_sites, mode, threshold_params)
        return thresholds
    
    def inference_block(self, cfg, test_df, threshold_mode, fp16=False, threshold_params={}, batch_size=2, folds=[0,1,2,3]):
        # Hard-code fix
        test_master = test_df.copy()
        test_master['prediction_id'] = test_master['patient_id'].astype(str) + '_' + test_master['laterality']
        test_master = test_master.groupby('prediction_id').agg({'site_id': 'max'}).reset_index()
        #
        test_dataloader = self.get_dataloader(cfg, test_df, batch_size)
        preds = self.infer_model(cfg, test_dataloader, folds)
        torch.cuda.empty_cache()
        preds = pd.concat([p.set_index('prediction_id') for p in preds], axis=1).reset_index()
        preds = preds.merge(test_master, on='prediction_id', how='left')
        preds_array = preds[['pred0', 'pred1', 'pred2', 'pred3']].T.values
        preds[cfg.name] = preds_array.mean(0)
        thresholds = self.apply_threshold(
            cfg, preds_array, preds['site_id'].values, threshold_mode, threshold_params=threshold_params)
        
        return preds[['prediction_id', 'site_id', cfg.name]], thresholds