## hardware

In [None]:
!cat /proc/cpuinfo

In [None]:
!nvidia-smi

## kaggle

In [None]:
import os

if __name__ == '__main__':

    if not os.path.isfile('/content/drive/MyDrive/Kaggle/CZII - CryoET Object Identification/data/czii-cryo-et-object-identification.zip'):
        !mkdir -p ~/.kaggle
        !cp /content/drive/MyDrive/Kaggle/kaggle.json ~/.kaggle/
        !chmod 600 ~/.kaggle/kaggle.json

        !kaggle competitions download -c czii-cryo-et-object-identification -p '/content/drive/MyDrive/Kaggle/CZII - CryoET Object Identification/data/'

    if not os.path.isdir('/content/train/'):
        !unzip '/content/drive/MyDrive/Kaggle/CZII - CryoET Object Identification/data/czii-cryo-et-object-identification.zip'

## library

In [None]:
#!pip install zarr copick timm segmentation_models_pytorch connected-components-3d monai
################################################################################
!pip uninstall torch -y
!pip install mmengine mmcv zarr copick timm segmentation_models_pytorch connected-components-3d monai torch==2.4.0
################################################################################
!pip install git+https://github.com/copick/copick-utils.git

In [None]:
import numpy as np
import pandas as pd

import os
import random

import glob

import json

from tqdm import tqdm

import matplotlib.pyplot as plt

from sklearn.model_selection import KFold

import torch
import torch.nn as nn
import torch.nn.functional as F

from copy import deepcopy

from transformers.optimization import get_cosine_schedule_with_warmup

import albumentations as A

import zarr

import copick

import timm

import segmentation_models_pytorch as smp

import cc3d

## config

In [None]:
class CustomConfig:
    seed = 42
    device = 'cuda'
    root = '/content/drive/MyDrive/Kaggle/CZII - CryoET Object Identification/'

    radius_scale = 0.5

    n_fold = 7

    filters = [
        'denoised',
        'wbp',
        'ctfdeconvolved',
        'isonetcorrected'
        ]

    particle_types = [
        'apo-ferritin',
        'beta-amylase',
        'beta-galactosidase',
        'ribosome',
        'thyroglobulin',
        'virus-like-particle',
        ]

    particle2class = dict(zip(particle_types, np.arange(1, len(particle_types) + 1)))

    tomogram_path = '/content/train/static/ExperimentRuns/'
    segmentation_path = '/content/overlay/ExperimentRuns/'

    tomogram_name = 'VoxelSpacing10.000'
    segmentation_name = 'Segmentations/10.000_copickUtils_0_paintedPicks-multilabel'

    volume_size = [192, 640, 640]
    patch_size = [64, 256, 256]
    stride_size = [32, 128, 128]

    size_scale = [
        184 / volume_size[0],
        630 / volume_size[1],
        630 / volume_size[2],
    ]

    offset = []
    for i in range(0, volume_size[0] - patch_size[0] + 1, stride_size[0]):
        for j in range(0, volume_size[1] - patch_size[1] + 1, stride_size[1]):
            for k in range(0, volume_size[2] - patch_size[2] + 1, stride_size[2]):
                offset.append([i, j, k])

    clip_range = (1, 99)

    model_version = 1
    n_channel = 1
    n_class = 7

    model_name = 'resnet18d.ra2_in1k'

    if model_name in [
        'resnet18d.ra2_in1k',
        'resnet34d.ra2_in1k',
        'tf_efficientnet_b1.in1k',
    ]:
        drop_path_rate = 0.1
        n_block = 5
        decoder_channels = [256, 128, 64, 32, 32]


    elif model_name == 'r50ir':
        n_block = 4
        encoder_channels = [256, 512, 1024, 2048]
        decoder_channels = [256, 128, 64, 32]
        pretrained_path = 'https://download.openmmlab.com/mmaction/recognition/csn/ircsn_from_scratch_r50_ig65m_20210617-ce545a37.pth'

    elif model_name == 'r152ir':
        n_block = 4
        encoder_channels = [256, 512, 1024, 2048]
        decoder_channels = [256, 128, 64, 32]
        pretrained_path = 'https://download.openmmlab.com/mmaction/recognition/csn/ircsn_from_scratch_r152_ig65m_20200807-771c4135.pth'

    else:
        raise NotImplementedError()

    n_epoch = 30
    batch_size = 4
    test_batch_size = 4
    iters_to_accumulate = 1

    n_worker = os.cpu_count()

    lr = 1e-3
    wd = 1e-2
    warmup_ratio = 0.1
    test_freq = 2

    voxel_space = 10.0
    thresholds = [0.10, 0.25, 0.50, 0.75, 0.90]

    tta = []

    mix_prob = 0.5
    mix_alpha = 1.0

if __name__ == '__main__':
    args = CustomConfig()

## seed

In [None]:
def seed_function(args):
    random.seed(args.seed)
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True

if __name__ == '__main__':
    seed_function(args)

## official

In [None]:
# ref.: https://github.com/czimaginginstitute/2024_czii_mlchallenge_notebooks/blob/main/3d_unet_monai/train.ipynb
# ref.: https://github.com/czimaginginstitute/2024_czii_mlchallenge_notebooks/blob/main/DeepFindET/train.ipynb

In [None]:
# Creating a copick project

import os
import shutil

'''
config_blob = """{
    "name": "czii_cryoet_mlchallenge_2024",
    "description": "2024 CZII CryoET ML Challenge training data.",
    "version": "1.0.0",

    "pickable_objects": [
        {
            "name": "apo-ferritin",
            "is_particle": true,
            "pdb_id": "4V1W",
            "label": 1,
            "color": [  0, 117, 220, 128],
            "radius": 60,
            "map_threshold": 0.0418
        },
        {
            "name": "beta-amylase",
            "is_particle": true,
            "pdb_id": "1FA2",
            "label": 2,
            "color": [153,  63,   0, 128],
            "radius": 65,
            "map_threshold": 0.035
        },
        {
            "name": "beta-galactosidase",
            "is_particle": true,
            "pdb_id": "6X1Q",
            "label": 3,
            "color": [ 76,   0,  92, 128],
            "radius": 90,
            "map_threshold": 0.0578
        },
        {
            "name": "ribosome",
            "is_particle": true,
            "pdb_id": "6EK0",
            "label": 4,
            "color": [  0,  92,  49, 128],
            "radius": 150,
            "map_threshold": 0.0374
        },
        {
            "name": "thyroglobulin",
            "is_particle": true,
            "pdb_id": "6SCJ",
            "label": 5,
            "color": [ 43, 206,  72, 128],
            "radius": 130,
            "map_threshold": 0.0278
        },
        {
            "name": "virus-like-particle",
            "is_particle": true,
            "label": 6,
            "color": [255, 204, 153, 128],
            "radius": 135,
            "map_threshold": 0.201
        },
        {
            "name": "membrane",
            "is_particle": false,
            "label": 8,
            "color": [100, 100, 100, 128]
        },
        {
            "name": "background",
            "is_particle": false,
            "label": 9,
            "color": [10, 150, 200, 128]
        }
    ],

    "overlay_root": "/kaggle/working/overlay",

    "overlay_fs_args": {
        "auto_mkdir": true
    },

    "static_root": "/kaggle/input/czii-cryo-et-object-identification/train/static"
}"""
'''
################################################################################
config_blob = """{
    "name": "czii_cryoet_mlchallenge_2024",
    "description": "2024 CZII CryoET ML Challenge training data.",
    "version": "1.0.0",

    "pickable_objects": [
        {
            "name": "apo-ferritin",
            "is_particle": true,
            "pdb_id": "4V1W",
            "label": 1,
            "color": [  0, 117, 220, 128],
            "radius": 60,
            "map_threshold": 0.0418
        },
        {
            "name": "beta-amylase",
            "is_particle": true,
            "pdb_id": "1FA2",
            "label": 2,
            "color": [153,  63,   0, 128],
            "radius": 65,
            "map_threshold": 0.035
        },
        {
            "name": "beta-galactosidase",
            "is_particle": true,
            "pdb_id": "6X1Q",
            "label": 3,
            "color": [ 76,   0,  92, 128],
            "radius": 90,
            "map_threshold": 0.0578
        },
        {
            "name": "ribosome",
            "is_particle": true,
            "pdb_id": "6EK0",
            "label": 4,
            "color": [  0,  92,  49, 128],
            "radius": 150,
            "map_threshold": 0.0374
        },
        {
            "name": "thyroglobulin",
            "is_particle": true,
            "pdb_id": "6SCJ",
            "label": 5,
            "color": [ 43, 206,  72, 128],
            "radius": 130,
            "map_threshold": 0.0278
        },
        {
            "name": "virus-like-particle",
            "is_particle": true,
            "label": 6,
            "color": [255, 204, 153, 128],
            "radius": 135,
            "map_threshold": 0.201
        },
        {
            "name": "membrane",
            "is_particle": false,
            "label": 8,
            "color": [100, 100, 100, 128]
        },
        {
            "name": "background",
            "is_particle": false,
            "label": 9,
            "color": [10, 150, 200, 128]
        }
    ],

    "overlay_root": "/content/overlay",

    "overlay_fs_args": {
        "auto_mkdir": true
    },

    "static_root": "/content/train/static"
}"""
################################################################################

#copick_config_path = "/kaggle/working/copick.config"
#output_overlay = "/kaggle/working/overlay"
################################################################################
copick_config_path = "/content/copick.config"
output_overlay = "/content/overlay"
################################################################################

with open(copick_config_path, "w") as f:
    f.write(config_blob)

# Now, setup new overlay directory

# Define source and destination directories
#source_dir = '/kaggle/input/czii-cryo-et-object-identification/train/overlay'
#destination_dir = '/kaggle/working/overlay'
################################################################################
source_dir = '/content/train/overlay'
destination_dir = '/content/overlay'
################################################################################

# Walk through the source directory
for root, dirs, files in os.walk(source_dir):
    # Create corresponding subdirectories in the destination
    relative_path = os.path.relpath(root, source_dir)
    target_dir = os.path.join(destination_dir, relative_path)
    os.makedirs(target_dir, exist_ok=True)

    # Copy and rename each file
    for file in files:
        if file.startswith("curation_0_"):
            new_filename = file
        else:
            new_filename = f"curation_0_{file}"


        # Define full paths for the source and destination files
        source_file = os.path.join(root, file)
        destination_file = os.path.join(target_dir, new_filename)

        # Copy the file with the new name
        shutil.copy2(source_file, destination_file)
        print(f"Copied {source_file} to {destination_file}")

In [None]:
root = copick.from_file(copick_config_path)

copick_user_name = "copickUtils"
copick_segmentation_name = "paintedPicks"
voxel_size = 10
tomo_type = "denoised"

In [None]:
from copick_utils.segmentation import segmentation_from_picks
import copick_utils.writers.write as write
from collections import defaultdict

# Just do this once
generate_masks = True

if generate_masks:
    target_objects = defaultdict(dict)
    for object in root.pickable_objects:
        if object.is_particle:
            target_objects[object.name]['label'] = object.label
            target_objects[object.name]['radius'] = object.radius


    for run in tqdm(root.runs):
        tomo = run.get_voxel_spacing(10)
        tomo = tomo.get_tomogram(tomo_type).numpy()
        target = np.zeros(tomo.shape, dtype=np.uint8)
        for pickable_object in root.pickable_objects:
            pick = run.get_picks(object_name=pickable_object.name, user_id="curation")
            if len(pick):
                target = segmentation_from_picks.from_picks(pick[0],
                                                            target,
                                                            #target_objects[pickable_object.name]['radius'] * 0.8,
                                                            ####################
                                                            target_objects[pickable_object.name]['radius'] * args.radius_scale,
                                                            ####################
                                                            target_objects[pickable_object.name]['label']
                                                            )
        write.segmentation(run, target, copick_user_name, name=copick_segmentation_name)

In [None]:
data_dicts = []
for run in tqdm(root.runs):
    tomogram = run.get_voxel_spacing(voxel_size).get_tomogram(tomo_type).numpy()
    segmentation = run.get_segmentations(name=copick_segmentation_name, user_id=copick_user_name, voxel_size=voxel_size, is_multilabel=True)[0].numpy()
    data_dicts.append({"image": tomogram, "label": segmentation})

    ############################################################################
    break
    ############################################################################

print(np.unique(data_dicts[0]['label']))

In [None]:
import matplotlib.pyplot as plt

# Plot the images
plt.figure(figsize=(15, 5))

plt.subplot(1, 2, 1)
plt.title('Tomogram')
plt.imshow(data_dicts[0]['image'][100],cmap='gray')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.title('Painted Segmentation from Picks')
plt.imshow(data_dicts[0]['label'][100], cmap='viridis')
plt.axis('off')

plt.tight_layout()
plt.show()

## preprocess

In [None]:
def resize_function(args, x, mode):
    x = torch.tensor(x, dtype = torch.float)
    x = x.unsqueeze(0).unsqueeze(0)
    x = F.interpolate(x, size = args.volume_size, mode = mode)
    x = x.squeeze(0).squeeze(0)
    x = x.numpy()
    return x

def preprocess_function(args):
    experiments = sorted(glob.glob('/content/train/static/ExperimentRuns/*'))
    experiments = np.array([_.split('/')[-1] for _ in experiments], dtype = str)

    cache = {}
    for experiment in experiments:
        cache[experiment] = {}
        for filter in args.filters:
            x = zarr.open(args.tomogram_path + f'{experiment}/{args.tomogram_name}/{filter}.zarr')[0][:]
            x = resize_function(args, x, mode = 'trilinear')

            lower, upper = np.percentile(x, args.clip_range)
            x = np.clip(x, lower, upper)

            cache[experiment][filter] = x

        y = zarr.open(args.segmentation_path + f'{experiment}/{args.segmentation_name}.zarr')[0][:]
        y = resize_function(args, y, mode = 'nearest')
        cache[experiment]['segmentation'] = y

    _experiment = []
    _offset = []
    for experiment in experiments:
        _experiment.extend([experiment] * len(args.offset))
        _offset.extend(args.offset)

    train = pd.DataFrame()
    train['experiment'] = _experiment
    train['offset'] = _offset

    folds = []

    kf = KFold(n_splits = args.n_fold, shuffle = True, random_state = args.seed)
    for train_indices, test_indices in kf.split(experiments):

        train_experiments = experiments[train_indices]
        test_experiments = experiments[test_indices]

        train_df = train[train.experiment.isin(train_experiments)].reset_index(drop = True)
        test_df = train[train.experiment.isin(test_experiments)].reset_index(drop = True)

        folds.append([train_df, test_df])

    return train, folds, cache

if __name__ == '__main__':
    train, folds, cache = preprocess_function(args)

## dataset

In [None]:
class CustomTransform(nn.Module):
    def __init__(self, args, is_training):
        super(CustomTransform, self).__init__()

        if is_training:
            self.transform = A.Compose([
                A.HorizontalFlip(p = 0.5),
                A.VerticalFlip(p = 0.5),
                A.Transpose(p = 0.5),

                A.ShiftScaleRotate(p = 0.8),

                A.RandomBrightnessContrast(p = 0.8),

                #A.Resize(height = args.input_size, width = args.input_size, p = 1.0),
            ])

        else:
            self.transform = A.Compose([
                #A.Resize(height = args.input_size, width = args.input_size, p = 1.0),
            ])

    def forward(self, x, y):
        x = x.transpose(1, 2, 0)
        y = y.transpose(1, 2, 0)

        transformed = self.transform(image = x, mask = y)
        x, y = transformed['image'], transformed['mask']

        x = x.transpose(2, 0, 1)
        y = y.transpose(2, 0, 1)

        return x, y

In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, args, df, cache, is_training):
        self.args = args

        self.df = df

        self.cache = cache

        self.is_training = is_training

        self.transform = CustomTransform(args, is_training = is_training)

    def __len__(self):
        return len(self.df)

    def get_inputs(self, row):
        if self.is_training:
            filter = random.choice(self.args.filters)
        else:
            filter = 'denoised'

        inputs = self.cache[row.experiment][filter]
        inputs = inputs[
            row.offset[0]:row.offset[0] + self.args.patch_size[0],
            row.offset[1]:row.offset[1] + self.args.patch_size[1],
            row.offset[2]:row.offset[2] + self.args.patch_size[2],
        ]

        inputs = inputs - np.min(inputs)
        inputs = inputs / np.max(inputs)
        return inputs

    def get_targets(self, row):
        targets = self.cache[row.experiment]['segmentation']
        targets = targets[
            row.offset[0]:row.offset[0] + self.args.patch_size[0],
            row.offset[1]:row.offset[1] + self.args.patch_size[1],
            row.offset[2]:row.offset[2] + self.args.patch_size[2],
        ]
        return targets

    def __getitem__(self, index):
        row = self.df.loc[index]

        inputs = self.get_inputs(row)
        targets = self.get_targets(row)

        inputs, targets = self.transform(inputs, targets)

        inputs = torch.tensor(inputs, dtype = torch.float)
        targets = torch.tensor(targets, dtype = torch.long)
        offsets = torch.tensor(row.offset, dtype = torch.long)
        return inputs, targets, offsets

if __name__ == '__main__':
    train_df, test_df = folds[0]
    dataset = CustomDataset(args, train_df, cache, is_training = True)

    index = random.randint(0, len(dataset) - 1)
    inputs, targets, offsets = dataset[index]

    print('inputs : ', inputs.shape)
    print('targets : ', targets.shape)
    print('offsets : ', offsets)

    _, axes = plt.subplots(2, args.patch_size[0], figsize = (args.patch_size[0], 2))
    for i in range(args.patch_size[0]):
        axes[0, i].imshow(inputs[i], cmap = 'gray')
        axes[1, i].imshow(targets[i], cmap = 'viridis')
    plt.show()

## model

In [None]:
%run '/content/drive/MyDrive/Kaggle/CZII - CryoET Object Identification/unet.py'
%run '/content/drive/MyDrive/Kaggle/CZII - CryoET Object Identification/3d.py'

In [None]:
class CustomModel(nn.Module):
    def __init__(self, args):
        super(CustomModel, self).__init__()
        self.args = args

        self.encoder = timm.create_model(
            model_name = args.model_name,
            pretrained = True,
            features_only = True,
            in_chans = args.n_channel,
            drop_path_rate = args.drop_path_rate,
        )

        encoder_channels = [args.n_channel] + [self.encoder.feature_info[i]['num_chs'] for i in range(len(self.encoder.feature_info))]
        decoder_channels = args.decoder_channels

        self.decoder = UnetDecoder(
            encoder_channels = encoder_channels,
            decoder_channels = decoder_channels,
            n_blocks = args.n_block,
            use_batchnorm = True,
            center = False,
            attention_type = None,
        )

        self.head = SegmentationHead(
            in_channels = decoder_channels[-1],
            out_channels = args.n_class,
            activation = None,
            kernel_size = 3,
        )

        convert_3d(self.encoder)
        convert_3d(self.decoder)
        convert_3d(self.head)

    def forward(self, x):
        x = x.unsqueeze(1)

        _x = self.encoder(x)
        x = self.decoder(*[x] + _x)
        x = self.head(x)
        return x

if __name__ == "__main__":
    loader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, num_workers = args.n_worker)
    sample = next(iter(loader))
    sample = [x.to(args.device) for x in sample]

    model = CustomModel(args)
    model = model.to(args.device)

    with torch.no_grad():
        with torch.amp.autocast(args.device):
            outputs = model(sample[0])
            print(outputs.shape)

## model-v2

In [None]:
%run '/content/drive/MyDrive/Kaggle/CZII - CryoET Object Identification/unet.py'
%run '/content/drive/MyDrive/Kaggle/CZII - CryoET Object Identification/3d.py'

In [None]:
class CustomModelV2(nn.Module):
    def __init__(self, args):
        super(CustomModelV2, self).__init__()
        self.args = args

        self.encoder = timm.create_model(
            model_name = args.model_name,
            pretrained = True,
            features_only = True,
            in_chans = args.n_channel,
            drop_path_rate = args.drop_path_rate,
        )

        encoder_channels = [args.n_channel] + [self.encoder.feature_info[i]['num_chs'] for i in range(len(self.encoder.feature_info))]
        decoder_channels = args.decoder_channels

        self.decoder = UnetDecoder(
            encoder_channels = encoder_channels,
            decoder_channels = decoder_channels,
            n_blocks = args.n_block,
            use_batchnorm = True,
            center = False,
            attention_type = None,
        )

        self.head = SegmentationHead(
            in_channels = decoder_channels[-1],
            out_channels = args.n_class,
            activation = None,
            kernel_size = 3,
        )

        convert_3d(self.decoder)
        convert_3d(self.head)

    def get_inputs(self, x):
        x = F.pad(
            input = x,
            pad = (0, 0, 0, 0, (self.args.n_channel - 1)//2, (self.args.n_channel - 1)//2),
            mode = 'replicate',
        )

        x = [x[:, i:i + self.args.n_channel] for i in range(self.args.patch_size[0])]
        x = torch.stack(x, dim = 1)
        return x

    def forward(self, x):
        x = self.get_inputs(x)
        x = x.reshape(-1, *x.shape[2:])

        _x = self.encoder(x)

        x = x.reshape(-1, self.args.patch_size[0], *x.shape[1:])
        x = x.permute(0, 2, 1, 3, 4)

        _x = [_.reshape(-1, self.args.patch_size[0], *_.shape[1:]) for _ in _x]
        _x = [_.permute(0, 2, 1, 3, 4) for _ in _x]

        x = self.decoder(*[x] + _x)
        x = self.head(x)
        return x

if __name__ == "__main__":
    loader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, num_workers = args.n_worker)
    sample = next(iter(loader))
    sample = [x.to(args.device) for x in sample]

    model = CustomModelV2(args)
    model = model.to(args.device)

    with torch.no_grad():
        with torch.amp.autocast(args.device):
            outputs = model(sample[0])
            print(outputs.shape)

## model-v3

In [None]:
%run '/content/drive/MyDrive/Kaggle/CZII - CryoET Object Identification/unet.py'
%run '/content/drive/MyDrive/Kaggle/CZII - CryoET Object Identification/3d.py'
%run '/content/drive/MyDrive/Kaggle/CZII - CryoET Object Identification/csn.py'

In [None]:
class CustomModelV3(nn.Module):
    def __init__(self, args):
        super(CustomModelV3, self).__init__()
        self.args = args

        self.encoder = ResNet3dCSN(
            pretrained2d = False,
            pretrained = None,
            depth = int(args.model_name[1:-2]),
            bottleneck_mode = args.model_name[-2:],
            norm_eval = False,
            zero_init_residual = False,
            in_channels = args.n_channel,
            out_indices = np.arange(args.n_block),
        )
        if args.pretrained_path != None:
            self.encoder.init_weights(pretrained = args.pretrained_path)

        encoder_channels = [args.n_channel] + args.encoder_channels
        decoder_channels = args.decoder_channels

        self.decoder = UnetDecoder(
            encoder_channels = encoder_channels,
            decoder_channels = decoder_channels,
            n_blocks = args.n_block,
            use_batchnorm = True,
            center = False,
            attention_type = None,
        )

        self.head = SegmentationHead(
            in_channels = decoder_channels[-1],
            out_channels = args.n_class,
            activation = None,
            kernel_size = 3,
        )

        convert_3d(self.decoder, pretrained = False)
        convert_3d(self.head, pretrained = False)

    def forward(self, x):
        x = x.unsqueeze(1)

        _x = self.encoder(x)
        _x = list(_x)
        x = self.decoder(*[x] + _x)
        x = self.head(x)
        return x

if __name__ == "__main__":
    loader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, num_workers = args.n_worker)
    sample = next(iter(loader))
    sample = [x.to(args.device) for x in sample]

    _args = deepcopy(args)

    _args.model_name = 'r50ir'
    _args.n_block = 4
    _args.encoder_channels = [256, 512, 1024, 2048]
    _args.decoder_channels = [256, 128, 64, 32]
    _args.pretrained_path = 'https://download.openmmlab.com/mmaction/recognition/csn/ircsn_from_scratch_r50_ig65m_20210617-ce545a37.pth'

    model = CustomModelV3(_args)
    model = model.to(_args.device)

    with torch.no_grad():
        with torch.amp.autocast(_args.device):
            outputs = model(sample[0])
            print(outputs.shape)

## loss

In [None]:
class CustomLoss(nn.Module):
    def __init__(self, args, is_training):
        super(CustomLoss, self).__init__()
        self.args = args

        self.is_training = is_training

        self.dice_loss = smp.losses.DiceLoss(mode = 'multiclass')

        self.focal_loss = smp.losses.FocalLoss(mode = 'multiclass')

    def forward(self, inputs, targets):
        if self.is_training:
            loss1 = self.dice_loss(inputs, targets)
            loss2 = self.focal_loss(inputs, targets)

            loss = (loss1 + loss2) / 2

        else:
            loss1 = self.dice_loss(inputs, targets)
            loss2 = self.focal_loss(inputs, targets)

            loss = (loss1 + loss2) / 2

        return loss

if __name__ == '__main__':
    loader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, num_workers = args.n_worker)
    sample = next(iter(loader))
    sample = [x.to(args.device) for x in sample]

    if args.model_version == 1:
        model = CustomModel(args)
    elif args.model_version == 2:
        model = CustomModelV2(args)
    elif args.model_version == 3:
        model = CustomModelV3(args)
    else:
        raise NotImplementedError()

    model = model.to(args.device)

    loss_function = CustomLoss(args, is_training = True)

    with torch.no_grad():
        outputs = model(sample[0])
        loss = loss_function(outputs, sample[1])
        print(loss)

## utils

In [None]:
"""
Derived from:
https://github.com/cellcanvas/album-catalog/blob/main/solutions/copick/compare-picks/solution.py
"""

import numpy as np
import pandas as pd

from scipy.spatial import KDTree


class ParticipantVisibleError(Exception):
    pass


def compute_metrics(reference_points, reference_radius, candidate_points):
    num_reference_particles = len(reference_points)
    num_candidate_particles = len(candidate_points)

    if len(reference_points) == 0:
        return 0, num_candidate_particles, 0

    if len(candidate_points) == 0:
        return 0, 0, num_reference_particles

    ref_tree = KDTree(reference_points)
    candidate_tree = KDTree(candidate_points)
    raw_matches = candidate_tree.query_ball_tree(ref_tree, r=reference_radius)
    matches_within_threshold = []
    for match in raw_matches:
        matches_within_threshold.extend(match)
    # Prevent submitting multiple matches per particle.
    # This won't be be strictly correct in the (extremely rare) case where true particles
    # are very close to each other.
    matches_within_threshold = set(matches_within_threshold)
    tp = int(len(matches_within_threshold))
    fp = int(num_candidate_particles - tp)
    fn = int(num_reference_particles - tp)
    return tp, fp, fn


#def score(
#        solution: pd.DataFrame,
#        submission: pd.DataFrame,
#        row_id_column_name: str,
#        distance_multiplier: float,
#        beta: int) -> float:
################################################################################
def score_function(
        solution: pd.DataFrame,
        submission: pd.DataFrame,
        row_id_column_name: str = 'row_id',
        distance_multiplier: float = 0.5,
        beta: int = 4) -> float:
################################################################################
    '''
    F_beta
      - a true positive occurs when
         - (a) the predicted location is within a threshold of the particle radius, and
         - (b) the correct `particle_type` is specified
      - raw results (TP, FP, FN) are aggregated across all experiments for each particle type
      - f_beta is calculated for each particle type
      - individual f_beta scores are weighted by particle type for final score
    '''

    particle_radius = {
        'apo-ferritin': 60,
        'beta-amylase': 65,
        'beta-galactosidase': 90,
        'ribosome': 150,
        'thyroglobulin': 130,
        'virus-like-particle': 135,
    }

    weights = {
        'apo-ferritin': 1,
        'beta-amylase': 0,
        'beta-galactosidase': 2,
        'ribosome': 1,
        'thyroglobulin': 2,
        'virus-like-particle': 1,
    }

    particle_radius = {k: v * distance_multiplier for k, v in particle_radius.items()}

    # Filter submission to only contain experiments found in the solution split
    split_experiments = set(solution['experiment'].unique())
    submission = submission.loc[submission['experiment'].isin(split_experiments)]

    # Only allow known particle types
    if not set(submission['particle_type'].unique()).issubset(set(weights.keys())):
        raise ParticipantVisibleError('Unrecognized `particle_type`.')

    assert solution.duplicated(subset=['experiment', 'x', 'y', 'z']).sum() == 0
    assert particle_radius.keys() == weights.keys()

    results = {}
    for particle_type in solution['particle_type'].unique():
        results[particle_type] = {
            'total_tp': 0,
            'total_fp': 0,
            'total_fn': 0,
        }

    for experiment in split_experiments:
        for particle_type in solution['particle_type'].unique():
            reference_radius = particle_radius[particle_type]
            select = (solution['experiment'] == experiment) & (solution['particle_type'] == particle_type)
            reference_points = solution.loc[select, ['x', 'y', 'z']].values

            select = (submission['experiment'] == experiment) & (submission['particle_type'] == particle_type)
            candidate_points = submission.loc[select, ['x', 'y', 'z']].values

            if len(reference_points) == 0:
                reference_points = np.array([])
                reference_radius = 1

            if len(candidate_points) == 0:
                candidate_points = np.array([])

            tp, fp, fn = compute_metrics(reference_points, reference_radius, candidate_points)

            results[particle_type]['total_tp'] += tp
            results[particle_type]['total_fp'] += fp
            results[particle_type]['total_fn'] += fn

    aggregate_fbeta = 0.0
    ############################################################################
    fbetas = {}
    ############################################################################
    for particle_type, totals in results.items():
        tp = totals['total_tp']
        fp = totals['total_fp']
        fn = totals['total_fn']

        precision = tp / (tp + fp) if tp + fp > 0 else 0
        recall = tp / (tp + fn) if tp + fn > 0 else 0
        fbeta = (1 + beta**2) * (precision * recall) / (beta**2 * precision + recall) if (precision + recall) > 0 else 0.0
        ########################################################################
        fbetas[particle_type] = f'fbeta:{fbeta:.3f}-recall:{recall:.3f}-precision:{precision:.3f}-tp:{tp}-fp:{fp}-fn:{fn}'
        ########################################################################
        aggregate_fbeta += fbeta * weights.get(particle_type, 1.0)

    if weights:
        aggregate_fbeta = aggregate_fbeta / sum(weights.values())
    else:
        aggregate_fbeta = aggregate_fbeta / len(results)
    #return aggregate_fbeta
    ############################################################################
    fbetas['f-beta'] = aggregate_fbeta
    fbetas['n-true'] = len(solution)
    fbetas['n-pred'] = len(submission)
    return fbetas
    ############################################################################

In [None]:
def get_solution(args, df):
    experiments = list(set(df['experiment']))

    _experiment = []
    _particle_type = []
    _x = []
    _y = []
    _z = []
    for experiment in experiments:
        for particle_type in args.particle_types:
            with open(f'/content/train/overlay/ExperimentRuns/{experiment}/Picks/{particle_type}.json', "r") as f:
                pick = json.load(f)

                x = [_['location']['x'] for _ in pick['points']]
                y = [_['location']['y'] for _ in pick['points']]
                z = [_['location']['z'] for _ in pick['points']]

                _experiment.extend([experiment] * len(pick['points']))
                _particle_type.extend([particle_type] * len(pick['points']))
                _x.extend(x)
                _y.extend(y)
                _z.extend(z)

    solution = pd.DataFrame()
    solution['experiment'] = _experiment
    solution['particle_type'] = _particle_type
    solution['x'] = _x
    solution['y'] = _y
    solution['z'] = _z
    return solution

def get_submission(args, df, preds, threshold):
    experiments = list(df.groupby('experiment', sort = False))
    #experiments = [[x[0], len(x[1])] for x in experiments]
    ############################################################################
    experiments = [[x[0], 184] for x in experiments]
    ############################################################################

    _experiment = []
    _particle_type = []
    _x = []
    _y = []
    _z = []

    for experiment, n_slice in experiments:
        pred = preds[:n_slice]
        for particle_type in args.particle_types:
            _class = args.particle2class[particle_type]
            _pred = pred[:, _class]

            cc = cc3d.connected_components(_pred > threshold)
            stats = cc3d.statistics(cc)
            zyx = stats['centroids'][1:] * args.voxel_space

            #x = zyx[:, 2].tolist()
            #y = zyx[:, 1].tolist()
            #z = zyx[:, 0].tolist()
            ####################################################################
            x = (zyx[:, 2] * args.size_scale[2]).tolist()
            y = (zyx[:, 1] * args.size_scale[1]).tolist()
            z = (zyx[:, 0] * args.size_scale[0]).tolist()
            ####################################################################

            _experiment.extend([experiment] * zyx.shape[0])
            _particle_type.extend([particle_type] * zyx.shape[0])
            _x.extend(x)
            _y.extend(y)
            _z.extend(z)

        preds = preds[n_slice:]

    submission = pd.DataFrame()
    submission['experiment'] = _experiment
    submission['particle_type'] = _particle_type
    submission['x'] = _x
    submission['y'] = _y
    submission['z'] = _z
    return submission

if __name__ == '__main__':
    experiments = [_[0] for _ in list(train.groupby('experiment', sort = False))]
    for experiment in experiments:
        preds = cache[experiment]['segmentation']
        preds = torch.tensor(preds, dtype = torch.long)
        preds = F.one_hot(preds, num_classes = args.n_class)
        preds = preds.permute(0, 3, 1, 2).numpy()

        df = train[train.experiment == experiment].reset_index(drop = True)

        #args.size_scale = [1.0, 630 / 630, 630 / 630]

        solution = get_solution(args, df)
        submission = get_submission(args, df, preds, args.thresholds[0])

        print(f'{experiment} : {score_function(solution, submission)}')

        n_label = [(cache[experiment]["segmentation"] == i).sum() for i in range(1, args.n_class)]
        print(f'{experiment} : {dict(zip(args.particle_types, n_label))}\n')

        break

'''
radius_scale = 0.6
TS_5_4 : {'apo-ferritin': 'fbeta:1.000-recall:1.000-precision:1.000-tp:46-fp:0-fn:0', 'beta-amylase': 'fbeta:0.900-recall:0.900-precision:0.900-tp:9-fp:1-fn:1', 'beta-galactosidase': 'fbeta:1.000-recall:1.000-precision:1.000-tp:12-fp:0-fn:0', 'ribosome': 'fbeta:1.000-recall:1.000-precision:1.000-tp:31-fp:0-fn:0', 'thyroglobulin': 'fbeta:1.000-recall:1.000-precision:1.000-tp:30-fp:0-fn:0', 'virus-like-particle': 'fbeta:1.000-recall:1.000-precision:1.000-tp:11-fp:0-fn:0', 'f-beta': 1.0, 'n-true': 140, 'n-pred': 140}
TS_5_4 : {'apo-ferritin': 9744, 'beta-amylase': 2175, 'beta-galactosidase': 8477, 'ribosome': 102495, 'thyroglobulin': 64065, 'virus-like-particle': 26991}

radius_scale = 0.5
TS_5_4 : {'apo-ferritin': 'fbeta:1.000-recall:1.000-precision:1.000-tp:46-fp:0-fn:0', 'beta-amylase': 'fbeta:1.000-recall:1.000-precision:1.000-tp:10-fp:0-fn:0', 'beta-galactosidase': 'fbeta:1.000-recall:1.000-precision:1.000-tp:12-fp:0-fn:0', 'ribosome': 'fbeta:1.000-recall:1.000-precision:1.000-tp:31-fp:0-fn:0', 'thyroglobulin': 'fbeta:1.000-recall:1.000-precision:1.000-tp:30-fp:0-fn:0', 'virus-like-particle': 'fbeta:1.000-recall:1.000-precision:1.000-tp:11-fp:0-fn:0', 'f-beta': 1.0, 'n-true': 140, 'n-pred': 140}
TS_5_4 : {'apo-ferritin': 5650, 'beta-amylase': 1330, 'beta-galactosidase': 4921, 'ribosome': 59466, 'thyroglobulin': 37126, 'virus-like-particle': 15754}

radius_scale = 0.4
TS_5_4 : {'apo-ferritin': 'fbeta:1.000-recall:1.000-precision:1.000-tp:46-fp:0-fn:0', 'beta-amylase': 'fbeta:1.000-recall:1.000-precision:1.000-tp:10-fp:0-fn:0', 'beta-galactosidase': 'fbeta:1.000-recall:1.000-precision:1.000-tp:12-fp:0-fn:0', 'ribosome': 'fbeta:1.000-recall:1.000-precision:1.000-tp:31-fp:0-fn:0', 'thyroglobulin': 'fbeta:1.000-recall:1.000-precision:1.000-tp:30-fp:0-fn:0', 'virus-like-particle': 'fbeta:1.000-recall:1.000-precision:1.000-tp:11-fp:0-fn:0', 'f-beta': 1.0, 'n-true': 140, 'n-pred': 140}
TS_5_4 : {'apo-ferritin': 2865, 'beta-amylase': 701, 'beta-galactosidase': 2505, 'ribosome': 30490, 'thyroglobulin': 18941, 'virus-like-particle': 8071}

radius_scale = 0.3
TS_5_4 : {'apo-ferritin': 'fbeta:1.000-recall:1.000-precision:1.000-tp:46-fp:0-fn:0', 'beta-amylase': 'fbeta:1.000-recall:1.000-precision:1.000-tp:10-fp:0-fn:0', 'beta-galactosidase': 'fbeta:1.000-recall:1.000-precision:1.000-tp:12-fp:0-fn:0', 'ribosome': 'fbeta:1.000-recall:1.000-precision:1.000-tp:31-fp:0-fn:0', 'thyroglobulin': 'fbeta:1.000-recall:1.000-precision:1.000-tp:30-fp:0-fn:0', 'virus-like-particle': 'fbeta:1.000-recall:1.000-precision:1.000-tp:11-fp:0-fn:0', 'f-beta': 1.0, 'n-true': 140, 'n-pred': 140}
TS_5_4 : {'apo-ferritin': 1198, 'beta-amylase': 297, 'beta-galactosidase': 1068, 'ribosome': 12821, 'thyroglobulin': 7996, 'virus-like-particle': 3425}

radius_scale = 0.2
TS_5_4 : {'apo-ferritin': 'fbeta:1.000-recall:1.000-precision:1.000-tp:46-fp:0-fn:0', 'beta-amylase': 'fbeta:1.000-recall:1.000-precision:1.000-tp:10-fp:0-fn:0', 'beta-galactosidase': 'fbeta:1.000-recall:1.000-precision:1.000-tp:12-fp:0-fn:0', 'ribosome': 'fbeta:1.000-recall:1.000-precision:1.000-tp:31-fp:0-fn:0', 'thyroglobulin': 'fbeta:1.000-recall:1.000-precision:1.000-tp:30-fp:0-fn:0', 'virus-like-particle': 'fbeta:1.000-recall:1.000-precision:1.000-tp:11-fp:0-fn:0', 'f-beta': 1.0, 'n-true': 140, 'n-pred': 140}
TS_5_4 : {'apo-ferritin': 359, 'beta-amylase': 99, 'beta-galactosidase': 311, 'ribosome': 3738, 'thyroglobulin': 2378, 'virus-like-particle': 1018}

'''
pass

In [None]:
'''
radius_scale = 0.8
TS_5_4 : {'apo-ferritin': 0.8303341902313625, 'beta-amylase': 0.8000000000000002, 'beta-galactosidase': 1.0, 'ribosome': 0.9372623574144486, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 0.9667995068065444}
TS_69_2 : {'apo-ferritin': 0.8887015177065767, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 0.9428076256499134, 'virus-like-particle': 1.0, 'f-beta': 0.9677595384294861}
TS_6_4 : {'apo-ferritin': 0.9329268292682926, 'beta-amylase': 0.8888888888888888, 'beta-galactosidase': 1.0, 'ribosome': 0.8160000000000001, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 0.9641324041811846}
TS_6_6 : {'apo-ferritin': 0.8573487031700289, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 0.9153846153846152, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 0.9675333312220921}
TS_73_6 : {'apo-ferritin': 0.90863890615289, 'beta-amylase': 1.0, 'beta-galactosidase': 0.9285714285714285, 'ribosome': 0.9153846153846152, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 0.9544523398114804}
TS_86_3 : {'apo-ferritin': 0.8177613320999075, 'beta-amylase': 1.0, 'beta-galactosidase': 0.958974358974359, 'ribosome': 0.8426724137931033, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 0.9397689234059613}
TS_99_9 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 0.8814814814814815, 'ribosome': 0.91016333938294, 'thyroglobulin': 0.9603365384615384, 'virus-like-particle': 1.0, 'f-beta': 0.9419713398955685}

radius_scale = 0.7
TS_5_4 : {'apo-ferritin': 0.9577464788732394, 'beta-amylase': 0.9, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 0.9939637826961771}
TS_69_2 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
TS_6_4 : {'apo-ferritin': 0.9664974619289342, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 0.9474522292993631, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 0.9877070987468997}
TS_6_6 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
TS_73_6 : {'apo-ferritin': 0.9801611903285802, 'beta-amylase': 1.0, 'beta-galactosidase': 0.9285714285714285, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 0.9767577210673482}
TS_86_3 : {'apo-ferritin': 0.9696412143514258, 'beta-amylase': 1.0, 'beta-galactosidase': 0.958974358974359, 'ribosome': 0.9646680942184155, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 0.9788940037883656}
TS_99_9 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 0.8814814814814815, 'ribosome': 0.970108695652174, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 0.9618673798021623}

radius_scale = 0.6
TS_5_4 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
TS_69_2 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
TS_6_4 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
TS_6_6 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
TS_73_6 : {'apo-ferritin': 1.000619578686493, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0000885112409275}
TS_86_3 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 0.958974358974359, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 0.9882783882783883}
TS_99_9 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 0.875, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 0.9642857142857143}

radius_scale = 0.5
TS_5_4 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
TS_69_2 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
TS_6_4 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
TS_6_6 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
TS_73_6 : {'apo-ferritin': 1.000619578686493, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0000885112409275}
TS_86_3 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 0.958974358974359, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 0.9882783882783883}
TS_99_9 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 0.9583333333333335, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 0.9880952380952381}

radius_scale = 0.4
TS_5_4 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
TS_69_2 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
TS_6_4 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
TS_6_6 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
TS_73_6 : {'apo-ferritin': 1.000619578686493, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0000885112409275}
TS_86_3 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 0.958974358974359, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 0.9882783882783883}
TS_99_9 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}

radius_scale = 0.3
TS_5_4 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
TS_69_2 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
TS_6_4 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
TS_6_6 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
TS_73_6 : {'apo-ferritin': 1.000619578686493, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0000885112409275}
TS_86_3 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 0.958974358974359, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 0.9882783882783883}
TS_99_9 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}

radius_scale = 0.2
TS_5_4 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
TS_69_2 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
TS_6_4 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
TS_6_6 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
TS_73_6 : {'apo-ferritin': 1.000619578686493, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0000885112409275}
TS_86_3 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
TS_99_9 : {'apo-ferritin': 1.0, 'beta-amylase': 1.0, 'beta-galactosidase': 1.0, 'ribosome': 1.0, 'thyroglobulin': 1.0, 'virus-like-particle': 1.0, 'f-beta': 1.0}
'''
pass

## train

In [None]:
def mixup_function(args, x, y):
    _index = torch.randperm(x.shape[0]).to(args.device)
    _lambda = np.random.beta(args.mix_alpha, args.mix_alpha)

    x = _lambda * x + (1 - _lambda) * x[_index, :]

    y1 = y
    y2 = y[_index]
    return x, (y1, y2), _lambda

In [None]:
class CustomTrainer:
    def __init__(self, args, model):
        self.model = model

        self.scaler = torch.amp.GradScaler(args.device)

        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)

        self.log_path = f'{args.save_dir}/log.txt'

        self.optimizer = torch.optim.AdamW(model.parameters(), lr = args.lr, weight_decay = args.wd)

        total_steps = args.total_steps
        warmup_steps = int(total_steps * args.warmup_ratio)
        print('total_steps: ', total_steps)
        print('warmup_steps: ', warmup_steps)

        self.scheduler = get_cosine_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps = warmup_steps,
            num_training_steps = total_steps,
            )

        self.train_loss_fn = CustomLoss(args, is_training = True)
        self.test_loss_fn = CustomLoss(args, is_training = False)

    def run(self, args, train_loader, test_loader):
        for epoch in range(args.n_epoch):
            lr = self.optimizer.param_groups[0]['lr']
            train_loss = self.train_function(args, train_loader)

            train_log = f'epoch:{epoch + 1}, lr:{lr}, train_loss:{train_loss:.6f}'
            self.log(args, train_log)

            if ((epoch + 1) % args.test_freq) == 0:
                test_loss, score_dicts = self.test_function(args, test_loader)

                test_score = np.max([_[1]['f-beta'] for _ in score_dicts.items()])

                test_log = f'epoch:{epoch + 1}, lr:{lr}, test_loss:{test_loss:.6f}, test_score:{test_score:.6f}\n'
                self.log(args, test_log)

                score_log = json.dumps(score_dicts, indent = 4)
                self.log(args, score_log)

                save_path = args.save_dir + '/epoch:' + f'{epoch + 1}'.zfill(3) + \
                            f'-train_loss:{round(train_loss, 6)}' + \
                            f'-test_loss:{round(test_loss, 6)}' + \
                            f'-test_score:{round(test_score, 6)}' + '.bin'
                torch.save(self.model.state_dict(), save_path)

    def train_function(self, args, loader):
        self.model.train()

        total_loss = 0.0
        for bi, sample in enumerate(tqdm(loader)):
            sample = [x.to(args.device) for x in sample]

            inputs = sample[0]
            targets = sample[1]

            mix = (np.random.rand() < args.mix_prob)

            if mix:
                inputs, targets, _lambda = mixup_function(args, inputs, targets)

            with torch.amp.autocast(args.device):
                outputs = self.model(inputs)

                if mix:
                    loss = self.train_loss_fn(outputs, targets[0]) * _lambda + self.train_loss_fn(outputs, targets[1]) * (1 - _lambda)
                else:
                    loss = self.train_loss_fn(outputs, targets)

            loss = loss / args.iters_to_accumulate

            self.scaler.scale(loss).backward()
            if (bi + 1) % args.iters_to_accumulate == 0:
                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.optimizer.zero_grad()

                self.scheduler.step()

            total_loss += loss.detach().cpu().tolist() * args.iters_to_accumulate

        return total_loss/len(loader)

    def test_function(self, args, loader):
        self.model.eval()

        total_loss = 0.0

        preds = torch.zeros([args.volume_size[0], args.n_class, args.volume_size[1], args.volume_size[2]], dtype = torch.float16)
        trues = torch.zeros([args.volume_size[0], args.volume_size[1], args.volume_size[2]], dtype = torch.long)
        counts = torch.zeros([args.volume_size[0], args.volume_size[1], args.volume_size[2]], dtype = torch.long)
        for bi, sample in enumerate(tqdm(loader)):
            sample = [x.to(args.device) for x in sample]

            inputs = sample[0]
            targets = sample[1]
            offsets = sample[2]

            with torch.no_grad():
                with torch.amp.autocast(args.device):
                    outputs = self.model(inputs)
                    loss = self.test_loss_fn(outputs, targets)

            total_loss += loss.detach().cpu().tolist()

            outputs = torch.softmax(outputs, dim = 1).cpu()
            targets = targets.cpu()
            offsets = offsets.cpu()

            for i in range(offsets.shape[0]):
                o1, o2, o3 = offsets[i]

                preds[o1:o1 + args.patch_size[0], :, o2:o2 + args.patch_size[1], o3:o3 + args.patch_size[2]] += outputs[i].permute(1, 0, 2, 3)
                trues[o1:o1 + args.patch_size[0], o2:o2 + args.patch_size[1], o3:o3 + args.patch_size[2]] += targets[i]
                counts[o1:o1 + args.patch_size[0], o2:o2 + args.patch_size[1], o3:o3 + args.patch_size[2]] += 1

        preds = preds / counts.unsqueeze(1)
        trues = trues / counts

        assert cache[list(set(loader.dataset.df.experiment))[0]]['segmentation'].tolist() == trues.tolist()

        solution = get_solution(args, loader.dataset.df)

        score_dicts = {}
        for threshold in args.thresholds:
            submission = get_submission(args, loader.dataset.df, preds.numpy(), threshold)

            score_dicts[f'f-beta-{threshold:.2f}'] = score_function(solution, submission)

        #index = random.randint(0, args.volume_size[0] - 1)
        #_, axes = plt.subplots(1, 2, figsize = (5 * 2, 5))
        #axes[0].imshow(preds[index].argmax(0), cmap = 'viridis')
        #axes[1].imshow(trues[index], cmap = 'viridis')
        #plt.show()
        ########################################################################
        n_plot = 16
        assert args.volume_size[0] % n_plot == 0

        _, axes = plt.subplots(2, n_plot, figsize = (3 * n_plot, 3 * 2))
        for i in range(n_plot):
            axes[0, i].imshow(preds[i*(args.volume_size[0]//n_plot)].argmax(0), cmap = 'viridis')
            axes[1, i].imshow(trues[i*(args.volume_size[0]//n_plot)], cmap = 'viridis')
        plt.show()
        ########################################################################

        return total_loss/len(loader), score_dicts

    def log(self, args, message):
        print(message)
        with open(f'{args.save_dir}/log.txt', 'a+') as logger:
            logger.write(f'{message}\n')

## run

In [None]:
if __name__ == '__main__':
    args = CustomConfig()

    try: cache
    except: train, folds, cache = preprocess_function(args)

    for i in [
        0, #TS_5_4
        #2, #TS_86_3
        #4, #TS_73_6
        #6, #TS_99_9
        ]:
        seed_function(args)

        train_df, test_df = folds[i]
        print('test : ', list(set(test_df['experiment'])))

        ########################################################################
        #train_df = pd.concat([train_df, test_df], axis = 0)
        #train_df = train_df.reset_index(drop = True)
        ########################################################################

        train_dataset = CustomDataset(args, train_df, cache, is_training = True)
        test_dataset = CustomDataset(args, test_df, cache, is_training = False)

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size = args.batch_size,
            num_workers = args.n_worker,
            shuffle = True,
            drop_last = True,
            )
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size = args.test_batch_size,
            num_workers = args.n_worker,
            shuffle = False,
            drop_last = False,
            )

        if args.model_version == 1:
            model = CustomModel(args)
        elif args.model_version == 2:
            model = CustomModelV2(args)
        elif args.model_version == 3:
            model = CustomModelV3(args)
        else:
            raise NotImplementedError()

        model = model.to(args.device)

        name = 'weights/experiment[3d]'
        args.save_dir = args.root + name + f'/fold{i + 1}'
        args.total_steps = int((len(train_df) // (args.batch_size * args.iters_to_accumulate)) * args.n_epoch)

        trainer = CustomTrainer(args, model)
        trainer.run(args, train_loader, test_loader)

## test

In [None]:
class CustomWrapper(nn.Module):
    def __init__(self, model):
        super(CustomWrapper, self).__init__()
        self.model = model

    def forward(self, x):
        x = self.model(x)
        x = torch.softmax(x, dim = 1)
        return x

class CustomEnsemble(nn.Module):
    def __init__(self, args):
        super(CustomEnsemble, self).__init__()
        self.models = nn.ModuleList()
        for config in args.model_weights:
            args.model_name = config['model_name']
            args.n_block = config['n_block']
            args.n_channel = config['n_channel']
            args.decoder_channels = [256, 128, 64, 32, 32][:config['n_block']]
            args.patch_size = config['patch_size']
            args.tta = config['tta']

            if config['model_version'] == 1:
                model = CustomModel(deepcopy(args))
            elif config['model_version'] == 2:
                model = CustomModelV2(deepcopy(args))
            elif config['model_version'] == 3:
                args.encoder_channels = config['encoder_channels']
                args.pretrained_path = None
                model = CustomModelV3(deepcopy(args))
            else:
                raise NotImplementedError()

            model = model.to(args.device)

            checkpoint = torch.load(config['path'], weights_only = True)
            model.load_state_dict(checkpoint)
            model.eval()

            model = CustomWrapper(model)
            model = CustomTTA(deepcopy(args), model)

            self.models.append(model)


    def forward(self, x):
        x = [model(x) for model in self.models]
        x = torch.stack(x, dim = 0)
        x = torch.mean(x, dim = 0)
        return x

class CustomTTA(nn.Module):
    def __init__(self, args, model):
        super(CustomTTA, self).__init__()
        self.args = args

        self.model = model

    def forward(self, x):
        _x = 0

        if 'original' in self.args.tta:
            _x += self.model(x)

        if 'horizontal' in self.args.tta:
            _x += self.model(x.flip(3)).flip(4)

        if 'vertical' in self.args.tta:
            _x += self.model(x.flip(2)).flip(3)

        if 'transpose' in self.args.tta:
            _x += self.model(x.permute(0, 1, 3, 2)).permute(0, 1, 2, 4, 3)

        if 'rotate90' in self.args.tta:
            _x += self.model(x.rot90(k = 1, dims = (2, 3))).rot90(k = -1, dims = (3, 4))

        if 'rotate180' in self.args.tta:
            _x += self.model(x.rot90(k = 2, dims = (2, 3))).rot90(k = -2, dims = (3, 4))

        if 'rotate270' in self.args.tta:
            _x += self.model(x.rot90(k = 3, dims = (2, 3))).rot90(k = -3, dims = (3, 4))

        _x = _x / len(self.args.tta)
        return _x

def get_outputs(args, model, loader):
    model.eval()

    preds = torch.zeros([args.volume_size[0], args.n_class, args.volume_size[1], args.volume_size[2]], dtype = torch.float16, device = args.device)
    trues = torch.zeros([args.volume_size[0], args.volume_size[1], args.volume_size[2]], dtype = torch.long, device = args.device)
    counts = torch.zeros([args.volume_size[0], args.volume_size[1], args.volume_size[2]], dtype = torch.long, device = args.device)
    for bi, sample in enumerate(tqdm(loader)):
        sample = [x.to(args.device) for x in sample]

        inputs = sample[0]
        targets = sample[1]
        offsets = sample[2]

        with torch.no_grad():
            with torch.amp.autocast(args.device):
                outputs = model(inputs)

        for i in range(offsets.shape[0]):
            o1, o2, o3 = offsets[i]

            preds[o1:o1 + args.patch_size[0], :, o2:o2 + args.patch_size[1], o3:o3 + args.patch_size[2]] += outputs[i].permute(1, 0, 2, 3)
            trues[o1:o1 + args.patch_size[0], o2:o2 + args.patch_size[1], o3:o3 + args.patch_size[2]] += targets[i]
            counts[o1:o1 + args.patch_size[0], o2:o2 + args.patch_size[1], o3:o3 + args.patch_size[2]] += 1

    preds = preds / counts.unsqueeze(1)
    trues = trues / counts

    preds = preds.cpu()
    trues = trues.cpu()
    return preds, trues

def inference_function(args, models, experiment, threshold):
    n_ensemble, preds = 0, 0
    for config in args.configs:
        args.patch_size = config['patch_size']
        args.offset = config['offset']

        df = pd.DataFrame()
        df['experiment'] = [experiment] * len(args.offset)
        df['offset'] = args.offset

        dataset = CustomDataset(args, df, cache, is_training = False)
        loader = torch.utils.data.DataLoader(
            dataset,
            batch_size = config['batch_size'],
            num_workers = args.n_worker,
            shuffle = False,
            drop_last = False
            )

        _preds, trues = get_outputs(args, models[config['data_version']], loader)

        n_ensemble += len(config['model_weights'])
        preds += len(config['model_weights']) * _preds

    preds = preds / n_ensemble

    solution = get_solution(args, loader.dataset.df)
    submission = get_submission(args, loader.dataset.df, preds.numpy(), threshold)
    return solution, submission, preds

if __name__ == '__main__':
    args = CustomConfig()

    tta = np.array([
        'original',
        'vertical',
        'transpose',
        'rotate90',
        'rotate270',
    ])

    args.configs = [
        {
            'data_version' : 1,
            'patch_size' : [64, 256, 256],
            'offset' : [[i, j, k] for i in [0, 42, 84, 128] for j in [0, 192, 384] for k in [0, 192, 384]],
            'batch_size' : 1 * 3,
            'model_weights' : [
                {
                    'model_name' : 'resnet34d.ra2_in1k',
                    'n_block' : 5,
                    'model_version' : 1,
                    'n_channel' : 1,
                    'path' : args.root + 'weights/main/model[resnet34d]-volume_size[64,256,256]-clip[1]-radius_scale[0.5]-mixup[0.50]/fold1/epoch:030-train_loss:0.247588-test_loss:0.226862-test_score:0.775869.bin',
                    'patch_size' : [64, 256, 256],
                    'tta' : tta[[0, 1]],
                },
                {
                    'model_name' : 'tf_efficientnet_b1.in1k',
                    'n_block' : 5,
                    'model_version' : 2,
                    'n_channel' : 3,
                    'path' : args.root + 'weights/main/model[tf_efficientnet_b1]-volume_size[64,3,256,256]-clip[1]-radius_scale[0.5]-mixup[0.50]/fold1/epoch:026-train_loss:0.247768-test_loss:0.218416-test_score:0.775452.bin',
                    'patch_size' : [64, 256, 256],
                    'tta' : tta[[0, 2]],
                },
                {
                    'model_name' : 'r152ir',
                    'n_block' : 4,
                    'model_version' : 3,
                    'n_channel' : 1,
                    'encoder_channels' : [256, 512, 1024, 2048],
                    'path' : args.root + 'weights/main/model[r152ir]-volume_size[64,256,256]-clip[1]-radius_scale[0.5]-mixup[0.50]/fold1/epoch:030-train_loss:0.239091-test_loss:0.219818-test_score:0.759626.bin',
                    'patch_size' : [64, 256, 256],
                    'tta' : tta[[0, 3]],
                },
            ]
        },
        {
            'data_version' : 2,
            'patch_size' : [32, 352, 352],
            'offset' : [[i, j, k] for i in [0, 20, 40, 60, 80, 100, 120, 140, 160] for j in [0, 288] for k in [0, 288]],
            'batch_size' : 1 * 3,
            'model_weights' : [
                {
                    'model_name' : 'resnet18d.ra2_in1k',
                    'n_block' : 5,
                    'model_version' : 1,
                    'n_channel' : 1,
                    'path' : args.root + 'weights/main/model[resnet18d]-volume_size[32,352,352]-clip[1]-radius_scale[0.5]-mixup[0.50]/fold1/epoch:024-train_loss:0.259471-test_loss:0.215104-test_score:0.772145.bin',
                    'patch_size' : [32, 352, 352],
                    'tta' : tta[[0, 4]],
                },
                {
                    'model_name' : 'resnet18d.ra2_in1k',
                    'n_block' : 5,
                    'model_version' : 2,
                    'n_channel' : 3,
                    'path' : args.root + 'weights/main/model[resnet18d]-volume_size[32,3,352,352]-clip[1]-radius_scale[0.5]-mixup[0.50]/fold1/epoch:028-train_loss:0.244331-test_loss:0.200117-test_score:0.726298.bin',
                    'patch_size' : [32, 352, 352],
                    'tta' : tta[[1, 2]],
                },
            ]
        },
        {
            'data_version' : 3,
            'patch_size' : [32, 224, 224],
            'offset' : [[i, j, k] for i in [0, 23, 46, 69, 92, 115, 138, 160] for j in [0, 208, 416] for k in [0, 208, 416]],
            'batch_size' : 1 * 4,
            'model_weights' : [
                {
                    'model_name' : 'resnet18d.ra2_in1k',
                    'n_block' : 5,
                    'model_version' : 1,
                    'n_channel' : 1,
                    'path' : args.root + 'weights/main/model[resnet18d]-volume_size[32,224,224]-clip[1]-radius_scale[0.5]-mixup[0.50]/fold1/epoch:028-train_loss:0.26338-test_loss:0.188437-test_score:0.759627.bin',
                    'patch_size' : [32, 224, 224],
                    'tta' : tta[[1, 3]],
                },
                {
                    'model_name' : 'resnet18d.ra2_in1k',
                    'n_block' : 5,
                    'model_version' : 2,
                    'n_channel' : 3,
                    'path' : args.root + 'weights/main/model[resnet18d]-volume_size[32,3,224,224]-clip[1]-radius_scale[0.5]-mixup[0.50]/fold1/epoch:030-train_loss:0.255669-test_loss:0.183632-test_score:0.762387.bin',
                    'patch_size' : [32, 224, 224],
                    'tta' : tta[[1, 4]],
                },
                {
                    'model_name' : 'r50ir',
                    'n_block' : 4,
                    'model_version' : 3,
                    'n_channel' : 1,
                    'encoder_channels' : [256, 512, 1024, 2048],
                    'path' : args.root + 'weights/main/model[r50ir]-volume_size[32,224,224]-clip[1]-radius_scale[0.5]-mixup[0.50]/fold1/epoch:030-train_loss:0.249857-test_loss:0.184979-test_score:0.756809.bin',
                    'patch_size' : [32, 224, 224],
                    'tta' : tta[[2, 3]],
                },
            ]
        },
        {
            'data_version' : 4,
            'patch_size' : [32, 128, 128],
            'offset' : [[i, j, k] for i in [0, 27, 54, 81, 108, 135, 160] for j in [0, 102, 204, 306, 408, 512] for k in [0, 102, 204, 306, 408, 512]],
            'batch_size' : 1 * 3,
            'model_weights' : [
                {
                    'model_name' : 'resnet18d.ra2_in1k',
                    'n_block' : 5,
                    'model_version' : 1,
                    'n_channel' : 1,
                    'path' : args.root + 'weights/main/model[resnet18d]-volume_size[32,128,128]-clip[1]-radius_scale[0.5]-mixup[0.50]-n_epoch[40]/fold1/epoch:034-train_loss:0.256083-test_loss:0.192153-test_score:0.786046.bin',
                    'patch_size' : [32, 128, 128],
                    'tta' : tta[[2, 4]],
                },
                {
                    'model_name' : 'resnet18d.ra2_in1k',
                    'n_block' : 5,
                    'model_version' : 2,
                    'n_channel' : 3,
                    'path' : args.root + 'weights/main/model[resnet18d]-volume_size[32,3,128,128]-clip[1]-radius_scale[0.5]-mixup[0.50]-n_epoch[40]/fold1/epoch:032-train_loss:0.257508-test_loss:0.187082-test_score:0.793733.bin',
                    'patch_size' : [32, 128, 128],
                    'tta' : tta[[3, 4]],
                },
            ]
        },
    ]

    args.thresholds = [
        #0.90,
        #0.75,
        #0.50,
        0.25,
        #0.10,
        ]

    seed_function(args)

    experiment = 'TS_5_4'

    models = {}
    for config in args.configs:
        args.model_weights = config['model_weights']

        model = CustomEnsemble(args)

        models[config['data_version']] = model

    for threshold in args.thresholds:
        solution, submission, preds = inference_function(args, models, experiment, threshold)

        score = score_function(solution, submission)
        print(f'threshold : {threshold}, score : {score}')