# 2.5D Inference

Model weights and cross validation [rsna-2-5dmodelcheckpoints-cross-validation](https://www.kaggle.com/code/samu2505/rsna-2-5dmodelcheckpoints-cross-validation) is here

Training notebook [rsna2024-training2-5dmodel](https://www.kaggle.com/code/samu2505/rsna2024-training2-5dmodel)

In [1]:
import os, gc, sys, copy, pickle
from pathlib import Path
import glob
from tqdm.auto import tqdm
tqdm.pandas()

import math
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from joblib import Parallel, delayed
import multiprocessing as mp

import albumentations as A
import torch
import torch.nn as nn
import torch.optim as optim
import torch.cuda.amp as amp
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

from torch.utils.data import WeightedRandomSampler
from sklearn.utils.class_weight import compute_class_weight
from sklearn import model_selection

from transformers import get_cosine_schedule_with_warmup

import timm

import cv2
cv2.setNumThreads(0)
import PIL
import pydicom
from IPython import display
import warnings
warnings.filterwarnings("ignore")
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [2]:
def seeding(SEED):
    np.random.seed(SEED)
    random.seed(SEED)
    os.environ['PYTHONHASHSEED'] = str(SEED)
    torch.manual_seed(SEED)
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(SEED)
        torch.cuda.manual_seed_all(SEED)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    print('seeding done!!!')

def flush():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

In [3]:
CONFIG = dict(
    project_name = "RSNA-2024-25DModel",
    artifact_name = "rsnaEffNetModel",
    load_kernel = None,
    load_last = True,
    n_folds = 5,
    backbone = "convnext_nano.in12k_ft_in1k", # convnext_pico.d1_in1k, tf_efficientnet_b0.ns_jft_in1k
    img_size = 224, # 224, 384
    n_slice_per_c = 15,
    in_chans = 3,
    axial_chans = 10,
    axial_labels = 10,
    axial_classes = 3 * 10,
    
    sagT1_chans = 10,
    sagT1_labels = 10,
    sagT1_classes = 3 * 10,
    
    sagT2_chans = 10,
    sagT2_labels = 5,
    sagT2_classes = 3 * 5,
    
    n_classes = 3 * 25,

    drop_rate = 0.,
    drop_rate_last = 0.,
    drop_path_rate = 0.,
    p_mixup = 0.5,
    p_rand_order_v1 = 0.2,
    lr = 1e-3,
    wd = 1e-4,

    epochs = 5,
    batch_size = 1,
    warmup = 1,
    num_cycles = 0.375,
    device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu",
    seed = 42,
    log_wandb = False,
    with_clip = False,
    use_lstm = False,
)

if CONFIG['log_wandb']:
    import wandb
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    secret_value_0 = user_secrets.get_secret("WANDB_API_KEY")
    wandb.login(key=secret_value_0)

seeding(CONFIG['seed'])

seeding done!!!


In [4]:
DATA_PATH = Path("/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification")
train_main = pd.read_csv(DATA_PATH/"train.csv")
test_desc = pd.read_csv(DATA_PATH/"test_series_descriptions.csv")
sample_df = pd.read_csv(DATA_PATH/"sample_submission.csv")
study_ids = test_desc['study_id'].unique().tolist()

In [5]:
AXIAL_COLS = {col:i for i, col in enumerate(train_main.columns[1:]) if 'subarticular_stenosis' in col}
SAGT1_COLS = {col:i for i, col in enumerate(train_main.columns[1:]) if 'neural_foraminal_narrowing' in col}
SAGT2_COLS = {col:i for i, col in enumerate(train_main.columns[1:]) if 'spinal_canal_stenosis' in col}

In [6]:
def load_dicom(path):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    if dicom.PhotometricInterpretation == "MONOCHROME1":
        data = data - np.min(data)
        
    if np.max(data) != 0:
        data = data / (np.max(data) + 1e-4)
    data = (data * 255).astype(np.uint8)
    return data

# Cropped dataset

In [7]:
seg_model = timm.create_model('resnet18', pretrained=False, num_classes=10)
# path = "/kaggle/input/k/samu2505/lumbar-coordinate-dataset-code/resnet18_2024.pt"
path = "/kaggle/input/lumbar-coordinate-segmentation/pytorch/default/1/resnet18_2024.pt"
weights = torch.load(path, map_location=torch.device("cpu"))
seg_model.load_state_dict(weights)
seg_model.to(CONFIG['device'])
print("seg model weights loaded succefully ...")
# seg_model.eval()

@torch.no_grad()
def crop_single_image(net, image, cfg):
    # Predict segmentation mask for the single image
    size = image.shape[0]
    tensor_image = torch.from_numpy(image).permute(2, 0, 1).float()
    pred = net(torch.unsqueeze(tensor_image, dim=0).to(cfg['device'])).sigmoid().detach().cpu().numpy() * size
    pred = pred.squeeze()
    
    # Extract x and y coordinates from the prediction
    x_ = pred[0::2]
    y_ = pred[1::2]
    
    # Calculate bounding box coordinates
    xmin = np.min(x_).astype(int)
    xmax = np.max(x_).astype(int)
    ymin = np.min(y_).astype(int)
    ymax = np.max(y_).astype(int)
    
    # Ensure bounding box is within image dimensions
    xmin = max(0, xmin - 30)
    xmax = min(size, xmax + 30)
    ymin = max(0, ymin - 20)
    ymax = min(size, ymax + 20)
    
    # Crop the image using the bounding box coordinates
    cropped_img = image[ymin:ymax, xmin:xmax]
    cropped_img = cv2.resize(cropped_img, (cfg['crop_width'], cfg['crop_height']))
    return cropped_img


class CropDataset(Dataset):
    def __init__(self, data, st_ids, cfg, model, mode='train', transform=None):
        self.data = data
        self.mode = mode
        self.transform = transform
        self.cfg = cfg
        self.st_ids = st_ids
        self.view = "sagittal"
        self.seg_model = model
    
    def __len__(self):
        return len(self.st_ids)
    
    def get_img_paths(self, study_id, series_desc):
        pdf = self.data[self.data['study_id'] == study_id]
        pdf_ = pdf[pdf['series_description'] == series_desc]
        allimgs = []
        for i, row in pdf_.iterrows():
            pimgs = glob.glob(f"{str(DATA_PATH)}/test_images/{study_id}/{row['series_id']}/*.dcm")
            pimgs = sorted(pimgs, key=lambda p: int(os.path.basename(p).split('.')[0]))
            allimgs.extend(pimgs)
        return allimgs
    
    def read_dcm(self, src_path):
        img = load_dicom(src_path)
        return img
    
    def get_images(self, nslides, image_paths):
        H, W = self.cfg['img_size'], self.cfg['img_size']
        IMAGES = np.zeros((nslides, H, W, 3), dtype=np.uint8)
        for i in range(nslides):
            try:
                img = self.read_dcm(image_paths[i])
                img = cv2.resize(img, (H,W)).astype(np.uint8)
                img = img[..., None].repeat(3, -1)
                if self.view == "sagittal":
                    img = crop_single_image(self.seg_model, img, cfg=self.cfg)
                img = self.transform(image=img)['image']
                IMAGES[i, ...] = img
            except:
                pass
            
        return IMAGES
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        study_id = self.st_ids[idx]
        
        H, W = self.cfg['img_size'], self.cfg['img_size']
        n_slides_per_c = self.cfg["n_slice_per_c"]
        sagittal = np.zeros((n_slides_per_c, H, W, 3), dtype=np.uint8)
        coronal = np.zeros((n_slides_per_c, H, W, 3), dtype=np.uint8)
        axial = np.zeros((n_slides_per_c, H, W, 3), dtype=np.uint8)
        
        # Sagittal
        allimgs_sag = self.get_img_paths(study_id, 'Sagittal T2/STIR')
        
        if len(allimgs_sag)==0:
            pass
        
        else:
            sagT2_scans = len(allimgs_sag)
            sagT2_indices = np.quantile(list(range(sagT2_scans)), np.linspace(0., 1., n_slides_per_c)).round().astype(int)
            allimgs_sag = [allimgs_sag[i] for i in sagT2_indices]
            self.view = "sagittal"
            sagittal = self.get_images(nslides=n_slides_per_c, image_paths=allimgs_sag)
                
        # coronal
        allimgs_cor = self.get_img_paths(study_id, 'Sagittal T1')
    
        if len(allimgs_cor)==0:
            pass
        
        else:
            sagT1_scans = len(allimgs_cor)
            sagT1_indices = np.quantile(list(range(sagT1_scans)), np.linspace(0., 1., n_slides_per_c)).round().astype(int)
            allimgs_cor = [allimgs_cor[i] for i in sagT1_indices]
            self.view = "sagittal"
            coronal = self.get_images(nslides=n_slides_per_c, image_paths=allimgs_cor)
                
        # Axial
        allimgs_ax = self.get_img_paths(study_id, 'Axial T2')
        
        if len(allimgs_ax)==0:
            pass
        
        else:
            ax_scans = len(allimgs_ax)
            ax_indices = np.quantile(list(range(ax_scans)), np.linspace(0., 1., n_slides_per_c)).round().astype(int)
            allimgs_ax = [allimgs_ax[i] for i in ax_indices]
            self.view = "axial"
            axial = self.get_images(nslides=n_slides_per_c, image_paths=allimgs_ax)
        
        axial = axial.transpose(0, 3, 1, 2).astype(np.float32) / 255.0 
        coronal = coronal.transpose(0, 3, 1, 2).astype(np.float32) / 255.0 
        sagittal = sagittal.transpose(0, 3, 1, 2).astype(np.float32) / 255.0 
        
        return {"axial": axial, "coronal": coronal, "sagittal": sagittal, "study_id": str(study_id)}

seg model weights loaded succefully ...


# Full Dataset

In [8]:
class Spine25DDataset(Dataset):
    def __init__(self, data, st_ids, transform=None):
        self.data = data
        self.st_ids = st_ids
        self.transform = transform
    
    def __len__(self):
        return len(self.st_ids)
    
    def get_img_paths(self, study_id, series_desc):
        pdf = self.data[self.data['study_id'] == study_id]
        pdf_ = pdf[pdf['series_description'] == series_desc]
        allimgs = []
        for i, row in pdf_.iterrows():
            pimgs = glob.glob(f"{str(DATA_PATH)}/test_images/{study_id}/{row['series_id']}/*.dcm")
            pimgs = sorted(pimgs, key=lambda p: int(os.path.basename(p).split('.')[0]))
            allimgs.extend(pimgs)
        return allimgs
    
    def read_dcm(self, src_path):
        img = load_dicom(src_path)
        return img
    
    def get_images(self, nslides, image_paths):
        H, W = CONFIG['img_size'], CONFIG['img_size']
        IMAGES = np.zeros((nslides, H, W, 3), dtype=np.uint8)
        for i in range(nslides):
            try:
                img = self.read_dcm(image_paths[i])
                img = cv2.resize(img, (H,W)).astype(np.uint8)
#                 img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
                img = img[..., None].repeat(3, -1)
                img = self.transform(image=img)['image']
                IMAGES[i, ...] = img
            except:
                pass
            
        return IMAGES
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        study_id = self.st_ids[idx]
        
        H, W = CONFIG['img_size'], CONFIG['img_size']
        n_slides_per_c = CONFIG["n_slice_per_c"]
        sagittal = np.zeros((n_slides_per_c, H, W, 3), dtype=np.uint8)
        coronal = np.zeros((n_slides_per_c, H, W, 3), dtype=np.uint8)
        axial = np.zeros((n_slides_per_c, H, W, 3), dtype=np.uint8)
        
        # Sagittal
        allimgs_sag = self.get_img_paths(study_id, 'Sagittal T2/STIR')
        if len(allimgs_sag)==0:
            pass
        
        else:
            sagT2_scans = len(allimgs_sag)
            sagT2_indices = np.quantile(list(range(sagT2_scans)), np.linspace(0., 1., n_slides_per_c)).round().astype(int)
            allimgs_sag = [allimgs_sag[i] for i in sagT2_indices]
            sagittal = self.get_images(nslides=n_slides_per_c, image_paths=allimgs_sag)
                
        # coronal
        allimgs_cor = self.get_img_paths(study_id, 'Sagittal T1')
        if len(allimgs_cor)==0:
            pass
        
        else:
            sagT1_scans = len(allimgs_cor)
            sagT1_indices = np.quantile(list(range(sagT1_scans)), np.linspace(0., 1., n_slides_per_c)).round().astype(int)
            allimgs_cor = [allimgs_cor[i] for i in sagT1_indices]
            coronal = self.get_images(nslides=n_slides_per_c, image_paths=allimgs_cor)
                
        # Axial
        allimgs_ax = self.get_img_paths(study_id, 'Axial T2')
        if len(allimgs_ax)==0:
            pass
        
        else:
            ax_scans = len(allimgs_ax)
            ax_indices = np.quantile(list(range(ax_scans)), np.linspace(0., 1., n_slides_per_c)).round().astype(int)
            allimgs_ax = [allimgs_ax[i] for i in ax_indices]
            axial = self.get_images(nslides=n_slides_per_c, image_paths=allimgs_ax)
        
        axial = axial.transpose(0, 3, 1, 2).astype(np.float32) / 255.0 
        coronal = coronal.transpose(0, 3, 1, 2).astype(np.float32) / 255.0 
        sagittal = sagittal.transpose(0, 3, 1, 2).astype(np.float32) / 255.0 
        
        return {"axial": axial, "coronal": coronal, "sagittal": sagittal, "study_id": str(study_id)}

# Dataloaders

In [9]:
def get_transforms(height, width):
    train_tsfm = A.Compose([
        A.Resize(height=height, width=height, interpolation=cv2.INTER_AREA),
        A.Perspective(p=0.5),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Rotate(-25, 25, p=0.5),
        A.ShiftScaleRotate(shift_limit=0.3, scale_limit=0.3, rotate_limit=45, border_mode=4, p=0.7),
        
        A.OneOf([
            A.MotionBlur(blur_limit=3),
            A.MedianBlur(blur_limit=3),
            A.GaussianBlur(blur_limit=3),
            A.GaussNoise(var_limit=(3.0, 9.0)),
        ], p=0.5),
        
        A.OneOf([
            A.OpticalDistortion(distort_limit=1.),
            A.GridDistortion(num_steps=5, distort_limit=1.),
        ], p=0.5),
        
        A.CoarseDropout(max_holes=2, max_height=int(height * 0.25), max_width=int(width * 0.25), p=0.3),
    ])
    
    valid_tsfm = A.Compose([
        A.Resize(height=height, width=width, interpolation=cv2.INTER_AREA),
    ])
    return {"train": train_tsfm, "eval": valid_tsfm}

def get_dataloaders(data, ids, cfg, split="test"):
    img_size = cfg['img_size']
    height, width = img_size, img_size
    tsfm = get_transforms(height=height, width=width)
    if split == 'train':
        tr_tsfm = tsfm['train']
        ds = Spine25DDataset(data=data, st_ids=ids, transform=tr_tsfm)
        dls = DataLoader(ds, 
                         batch_size=cfg['batch_size'], 
                         shuffle=True,
                         num_workers=os.cpu_count(), 
                         drop_last=True, 
                         pin_memory=True)
        
    elif split == 'valid' or split == 'test':
        eval_tsfm = tsfm['eval']
        ds = Spine25DDataset(data=data, st_ids=ids, transform=eval_tsfm)
        dls = DataLoader(ds, 
                         batch_size=cfg['batch_size'], 
                         shuffle=False, 
                         num_workers=os.cpu_count(), 
                         drop_last=False, 
                         pin_memory=True)
    else:
        raise Exception("Split should be 'train' or 'valid' or 'test'!!!")
    return dls


def get_crop_dataloaders(data, ids, cfg, model, split="test"):
    img_size = cfg['img_size']
    height, width = img_size, img_size
    tsfm = get_transforms(height=height, width=width)
    study_ids = data.study_id.unique().tolist()
    if split == 'train':
        tr_tsfm = tsfm['train']
        ds = CropDataset(data=data, st_ids=ids, cfg=cfg, model=model, mode='train', transform=tr_tsfm)
        dls = DataLoader(ds, 
                         batch_size=cfg['batch_size'], 
                         shuffle=True,
                         num_workers=os.cpu_count(), 
                         drop_last=True, 
                         pin_memory=True)
        
    elif split == 'valid' or split == 'test':
        eval_tsfm = tsfm['eval']
        ds = CropDataset(data=data, st_ids=ids, cfg=cfg, model=model, mode='valid', transform=eval_tsfm)
        dls = DataLoader(ds, 
                         batch_size=2*cfg['batch_size'], 
                         shuffle=False, 
                         num_workers=os.cpu_count(),
                         drop_last=False, 
                         pin_memory=True,
                         persistent_workers=True,
                        )
    else:
        raise Exception("Split should be 'train' or 'valid' or 'test'!!!")
    return dls

# Models

In [10]:
class BaseModel(nn.Module):
    def __init__(self, backbone, in_chans=3, pretrained=False, increase_stride=False):
        super(BaseModel, self).__init__()

        self.encoder = timm.create_model(
            backbone,
            in_chans=in_chans,
            num_classes=0,
            features_only=False,
            drop_rate=CONFIG["drop_rate"],
            drop_path_rate=CONFIG["drop_path_rate"],
            pretrained=pretrained
        )
        self.encoder.name = backbone
        self.nb_fts = self.encoder.num_features
        self.gap = nn.AdaptiveAvgPool2d(1)
#         self.gap = GeM(p_trainable=False)
        
        if increase_stride:
            self.increase_stride()
        
    def increase_stride(self):
        """
        Increase the stride of the first layer of the encoder
        """
        if "efficientnet" in self.encoder.name:
            self.encoder.conv_stem.stride = (4, 4)
        elif "nfnet" in self.encoder.name:
            self.encoder.stem.conv1.stride = (4, 4)
        else:
            raise NotImplementedError
            
    def forward(self, x):
        x = self.encoder.forward_features(x)
        x = self.gap(x)[:,:,0,0]
        return x
    
class CroppedClf(nn.Module):
    def __init__(self, backbone, pretrained=False, increase_stride=False):
        super(CroppedClf, self).__init__()
        self.axial_encoder = BaseModel(backbone=backbone, in_chans=3, 
                                       pretrained=pretrained, increase_stride=increase_stride)
        self.coronal_encoder = BaseModel(backbone=backbone, in_chans=3, 
                                         pretrained=pretrained, increase_stride=increase_stride)
        self.sagittal_encoder = BaseModel(backbone=backbone, in_chans=3, 
                                          pretrained=pretrained, increase_stride=increase_stride)
        
        self.in_chans = 3
        self.nb_fts = 3*self.axial_encoder.nb_fts
        self.lstm = nn.LSTM(self.nb_fts, 256, num_layers=2, dropout=CONFIG["drop_rate"], bidirectional=True, batch_first=True)
        self.axial_head = self.get_head(CONFIG['axial_classes'])
        self.coronal_head = self.get_head(CONFIG['sagT1_classes'])
        self.sagittal_head = self.get_head(CONFIG['sagT2_classes'])
    
    
    def get_head(self, n_classes):
        head = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(CONFIG["drop_rate_last"]),
            nn.LeakyReLU(0.1),
            nn.Linear(256, n_classes),
        )
        return head
    
    def extract_features(self, x, view='axial'):
        bs = x.shape[0]
        x = x.view(bs * CONFIG["n_slice_per_c"], self.in_chans, CONFIG["img_size"], CONFIG["img_size"])
        if view == 'axial':
            feat = self.axial_encoder(x)
        elif view == 'coronal':
            feat = self.coronal_encoder(x)
        elif view == 'sagittal':
            feat = self.sagittal_encoder(x)
        else:
            raise NotImplementedError
            
        feat = feat.view(bs, CONFIG["n_slice_per_c"], -1)
        return feat
        
        
    def forward(self, ax, cor, sag):
        bs = ax.shape[0]
        ax_fts = self.extract_features(ax, view='axial')
        cor_fts = self.extract_features(cor, view='coronal')
        sag_fts = self.extract_features(sag, view='sagittal')
        fts = torch.concatenate([ax_fts, cor_fts, sag_fts], dim=-1)
        fts, _ = self.lstm(fts)
        fts = fts.contiguous().view(bs * CONFIG["n_slice_per_c"], -1)
        ax_fts = self.axial_head(fts)
        y_ax = ax_fts.view(bs, CONFIG["n_slice_per_c"], CONFIG["axial_classes"]).contiguous()
        cor_fts = self.coronal_head(fts)
        y_cor = cor_fts.view(bs, CONFIG["n_slice_per_c"], CONFIG["sagT1_classes"]).contiguous()
        sag_fts = self.sagittal_head(fts)
        y_sag = sag_fts.view(bs, CONFIG["n_slice_per_c"], CONFIG["sagT2_classes"]).contiguous()
        return y_ax, y_cor, y_sag
    
    
class Clf(nn.Module):
    def __init__(self, backbone, pretrained=False, increase_stride=False):
        super(Clf, self).__init__()
        self.axial_encoder = BaseModel(backbone=backbone, in_chans=3, 
                                       pretrained=pretrained, increase_stride=increase_stride)
        self.coronal_encoder = BaseModel(backbone=backbone, in_chans=3, 
                                         pretrained=pretrained, increase_stride=increase_stride)
        self.sagittal_encoder = BaseModel(backbone=backbone, in_chans=3, 
                                          pretrained=pretrained, increase_stride=increase_stride)
        
        self.in_chans = 3
        self.nb_fts = 3*self.axial_encoder.nb_fts
        self.lstm = nn.LSTM(self.nb_fts, 256, num_layers=2, dropout=CONFIG["drop_rate"], bidirectional=True, batch_first=True)
        self.axial_head = self.get_head(CONFIG['axial_classes'])
        self.coronal_head = self.get_head(CONFIG['sagT1_classes'])
        self.sagittal_head = self.get_head(CONFIG['sagT2_classes'])
    
    
    def get_head(self, n_classes):
        head = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(CONFIG["drop_rate_last"]),
            nn.LeakyReLU(0.1),
            nn.Linear(256, n_classes),
        )
        return head
    
    def extract_features(self, x, view='axial'):
        bs = x.shape[0]
        x = x.view(bs * CONFIG["n_slice_per_c"], self.in_chans, CONFIG["img_size"], CONFIG["img_size"])
        if view == 'axial':
            feat = self.axial_encoder(x)
        elif view == 'coronal':
            feat = self.coronal_encoder(x)
        elif view == 'sagittal':
            feat = self.sagittal_encoder(x)
        else:
            raise NotImplementedError
            
        feat = feat.view(bs, CONFIG["n_slice_per_c"], -1)
        return feat
        
        
    def forward(self, ax, cor, sag):
        bs = ax.shape[0]
        ax_fts = self.extract_features(ax, view='axial')
        cor_fts = self.extract_features(cor, view='coronal')
        sag_fts = self.extract_features(sag, view='sagittal')
        fts = torch.concatenate([ax_fts, cor_fts, sag_fts], dim=-1)
        fts, _ = self.lstm(fts)
        fts = fts.contiguous().view(bs * CONFIG["n_slice_per_c"], -1)
        ax_fts = self.axial_head(fts)
        y_ax = ax_fts.view(bs, CONFIG["n_slice_per_c"], CONFIG["axial_classes"]).contiguous()
        cor_fts = self.coronal_head(fts)
        y_cor = cor_fts.view(bs, CONFIG["n_slice_per_c"], CONFIG["sagT1_classes"]).contiguous()
        sag_fts = self.sagittal_head(fts)
        y_sag = sag_fts.view(bs, CONFIG["n_slice_per_c"], CONFIG["sagT2_classes"]).contiguous()
        return y_ax, y_cor, y_sag

In [11]:
CONDITIONS = [
    'spinal_canal_stenosis', 
    'left_neural_foraminal_narrowing', 
    'right_neural_foraminal_narrowing',
    'left_subarticular_stenosis',
    'right_subarticular_stenosis'
]

LEVELS = [
    'l1_l2',
    'l2_l3',
    'l3_l4',
    'l4_l5',
    'l5_s1',
]

dls = get_dataloaders(test_desc, study_ids, CONFIG, split='test')
cropped_dls = get_crop_dataloaders(test_desc, study_ids, cfg=CONFIG, model=seg_model, split='test')

In [12]:
def inference(model, dataloader):
    model.to(CONFIG["device"])
    model.eval()
    y_preds = []
    row_names = []

    axial_indices = list(AXIAL_COLS.values())
    coronal_indices = list(SAGT1_COLS.values())
    sagittal_indices = list(SAGT2_COLS.values())
    
    pbar = tqdm(dls, leave=True)
    
    with torch.no_grad():
        for idx, batch in enumerate(pbar):
            axial = batch['axial'].to(CONFIG["device"], non_blocking=True)
            coronal = batch['coronal'].to(CONFIG["device"], non_blocking=True)
            sagittal = batch['sagittal'].to(CONFIG["device"], non_blocking=True)
            si = batch['study_id']
            pred_per_study = np.ones((25, 3)) * (1/3)
            for cond in CONDITIONS:
                for level in LEVELS:
                    row_names.append(si[0] + '_' + cond + '_' + level)
                
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                y_axial, y_coronal, y_sagittal = model(axial, coronal, sagittal)
                
                y_axial_mean = y_axial.squeeze().mean(dim=0) 
                y_coronal_mean = y_coronal.squeeze().mean(dim=0) 
                y_sagittal_mean = y_sagittal.squeeze().mean(dim=0)
                
                y_axial_max = y_axial.squeeze().amax(dim=0) 
                y_coronal_max = y_coronal.squeeze().amax(dim=0) 
                y_sagittal_max = y_sagittal.squeeze().amax(dim=0)
                
                a = 0.5; b = 0.5
                y_axial = a * y_axial_max + b * y_axial_mean
                y_coronal = a * y_coronal_max + b * y_coronal_mean
                y_sagittal = a * y_sagittal_max + b * y_sagittal_mean
                
                # axial
                for col in range(CONFIG['axial_labels']):
                    pred = y_axial[col*3:col*3+3]
                    y_pred = pred.float().softmax(dim=-1).cpu().numpy()
                    pred_per_study[axial_indices[col]] = y_pred
                    
                # coronal
                for col in range(CONFIG['sagT1_labels']):
                    pred = y_coronal[col*3:col*3+3]
                    y_pred = pred.float().softmax(dim=-1).cpu().numpy()
                    pred_per_study[coronal_indices[col]] = y_pred
                    
                # sagittal
                for col in range(CONFIG['sagT2_labels']):
                    pred = y_sagittal[col*3:col*3+3]
                    y_pred = pred.float().softmax(dim=-1).cpu().numpy()
                    pred_per_study[sagittal_indices[col]] = y_pred
            y_preds.append(pred_per_study)
            
    y_preds = np.concatenate(y_preds, axis=0)
    return y_preds, row_names

# Load model weights

In [13]:
CKPTS = {
    "cropped_convnext_pico": {
        "backbone": "convnext_pico.d1_in1k",
        "path": sorted([
            f"/kaggle/input/rsna-2-5dmodelcheckpoints-cross-validation/cropped_convnext_pico_{i}.pth" for i in [2, 3]
        ]),
    },
    "full_convnext_nano": {
        "backbone": "convnext_nano.in12k_ft_in1k",
        "path": sorted([
            f"/kaggle/input/rsna-2-5dmodelcheckpoints-cross-validation/full_convnext_nano_{i}.pth" for i in [0, 3, 4]
        ])
    },
    "full_convnext_pico": {
        "backbone": "convnext_pico.d1_in1k",
        "path": sorted(glob.glob("/kaggle/input/rsna-2-5dmodelcheckpoints-cross-validation/full_convnext_pico_*.pth")),
    },
    "full_efficientnet": {
        "backbone": "tf_efficientnet_b0.ns_jft_in1k",
        "path": sorted(glob.glob("/kaggle/input/rsna-2-5dmodelcheckpoints-cross-validation/full_tf_efficientnetb0_*.pth")),
    },
}

CKPTS

{'cropped_convnext_pico': {'backbone': 'convnext_pico.d1_in1k',
  'path': ['/kaggle/input/rsna-2-5dmodelcheckpoints-cross-validation/cropped_convnext_pico_2.pth',
   '/kaggle/input/rsna-2-5dmodelcheckpoints-cross-validation/cropped_convnext_pico_3.pth']},
 'full_convnext_nano': {'backbone': 'convnext_nano.in12k_ft_in1k',
  'path': ['/kaggle/input/rsna-2-5dmodelcheckpoints-cross-validation/full_convnext_nano_0.pth',
   '/kaggle/input/rsna-2-5dmodelcheckpoints-cross-validation/full_convnext_nano_3.pth',
   '/kaggle/input/rsna-2-5dmodelcheckpoints-cross-validation/full_convnext_nano_4.pth']},
 'full_convnext_pico': {'backbone': 'convnext_pico.d1_in1k',
  'path': ['/kaggle/input/rsna-2-5dmodelcheckpoints-cross-validation/full_convnext_pico_0.pth']},
 'full_efficientnet': {'backbone': 'tf_efficientnet_b0.ns_jft_in1k',
  'path': ['/kaggle/input/rsna-2-5dmodelcheckpoints-cross-validation/full_tf_efficientnetb0_2.pth',
   '/kaggle/input/rsna-2-5dmodelcheckpoints-cross-validation/full_tf_effici

In [14]:
CONVNEXT_MODELS = []
EFFICIENT_MODELS = []
CROPPED_MODELS = []
for model_name in CKPTS.keys():
    backbone = CKPTS[model_name]['backbone']
    weights_path = CKPTS[model_name]['path']
    for path in weights_path:
        if "cropped" in model_name:
            model = CroppedClf(backbone=backbone)
        else:
            model = Clf(backbone=backbone)
        weights = torch.load(path, map_location=torch.device("cpu"))
        model.load_state_dict(weights)
        print(f"\n{model_name} weights {path} loaded successfully ...")
        if 'full_convnext' in model_name:
            CONVNEXT_MODELS.append(model)
        elif "full_eff" in model_name:
            EFFICIENT_MODELS.append(model)
        elif "cropped_convnext" in model_name:
            CROPPED_MODELS.append(model)
        else:
            print("Not implemented")
            continue
            
gc.collect()


cropped_convnext_pico weights /kaggle/input/rsna-2-5dmodelcheckpoints-cross-validation/cropped_convnext_pico_2.pth loaded successfully ...

cropped_convnext_pico weights /kaggle/input/rsna-2-5dmodelcheckpoints-cross-validation/cropped_convnext_pico_3.pth loaded successfully ...

full_convnext_nano weights /kaggle/input/rsna-2-5dmodelcheckpoints-cross-validation/full_convnext_nano_0.pth loaded successfully ...

full_convnext_nano weights /kaggle/input/rsna-2-5dmodelcheckpoints-cross-validation/full_convnext_nano_3.pth loaded successfully ...

full_convnext_nano weights /kaggle/input/rsna-2-5dmodelcheckpoints-cross-validation/full_convnext_nano_4.pth loaded successfully ...

full_convnext_pico weights /kaggle/input/rsna-2-5dmodelcheckpoints-cross-validation/full_convnext_pico_0.pth loaded successfully ...

full_efficientnet weights /kaggle/input/rsna-2-5dmodelcheckpoints-cross-validation/full_tf_efficientnetb0_2.pth loaded successfully ...

full_efficientnet weights /kaggle/input/rsna-2

45

In [15]:
convnext_outputs = [inference(net, dls) for net in CONVNEXT_MODELS]
gc.collect()

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Exception in thread QueueFeederThread:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/multiprocessing/queues.py", line 239, in _feed
    reader_close()
  File "/opt/conda/lib/python3.10/multiprocessing/connection.py", line 177, in close
    self._close()
  File "/opt/conda/lib/python3.10/multiprocessing/connection.py", line 361, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/opt/conda/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.10/multiprocessing/queues.py", line 271, in _feed
    queue_sem.release()
ValueError: semaphore or lock released too many times
Exception in thread QueueFeederThread:
Traceback (most recent call last):
  File "/opt/conda/lib/pytho

734

In [16]:
eff_outputs = [inference(net, dls) for net in EFFICIENT_MODELS]
gc.collect()

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Exception in thread QueueFeederThread:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/multiprocessing/queues.py", line 239, in _feed
    reader_close()
  File "/opt/conda/lib/python3.10/multiprocessing/connection.py", line 177, in close
    self._close()
  File "/opt/conda/lib/python3.10/multiprocessing/connection.py", line 361, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/opt/conda/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.10/multiprocessing/queues.py", line 271, in _feed
    queue_sem.release()
ValueError: semaphore or lock released too many times
Exception in thread Exception in thread QueueFeederThread:
Traceback (most recent call last):
  File "

546

    self._target(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.10/multiprocessing/queues.py", line 271, in _feed
    queue_sem.release()
ValueError: semaphore or lock released too many times


In [17]:
crop_convnext_outputs = [inference(net, cropped_dls) for net in CROPPED_MODELS]
gc.collect()

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Exception in thread QueueFeederThread:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/multiprocessing/queues.py", line 239, in _feed
Exception in thread QueueFeederThread:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/multiprocessing/queues.py", line 239, in _feed
Exception in thread QueueFeederThread:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/multiprocessing/queues.py", line 240, in _feed
    reader_close()
  File "/opt/conda/lib/python3.10/multiprocessing/connection.py", line 177, in close
    self._close()
  File "/opt/conda/lib/python3.10/multiprocessing/connection.py", line 361, in _close
    reader_close()
  File "/opt/conda/lib/python3.10/multiprocessing/connection.py", line 177, in close
    self._close()
  File "/opt/conda/lib/python3.10/multiprocessing/connection.py", line 361, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor

During handling of the above exception, another except

376

In [18]:
# pred_convnext = convnext_outputs[0][0]
# pred_efficient = eff_outputs[0][0]
# pred_crop_convnext = crop_convnext_outputs[0][0]
row_ids = eff_outputs[0][1]

convnext_preds = np.array([convnext_outputs[i][0] for i in range(len(convnext_outputs))]).mean(axis=0)
eff_preds = np.array([eff_outputs[i][0] for i in range(len(eff_outputs))]).mean(axis=0)
crop_preds = np.array([crop_convnext_outputs[i][0] for i in range(len(crop_convnext_outputs))]).mean(axis=0)

In [19]:
preds = 0.15*eff_preds + 0.15*crop_preds + 0.7*convnext_preds

In [20]:
TARGET_COLS = sample_df.columns.tolist()
df = pd.DataFrame()
df['row_id'] = row_ids
df[['normal', 'mild', 'severe']] = preds
df.columns = TARGET_COLS
# df = df.sort_values("row_id").reset_index(drop=True)
df.to_csv('submission.csv', index=False, float_format='%.7f')

In [21]:
pd.read_csv("submission.csv")

Unnamed: 0,row_id,normal_mild,moderate,severe
0,44036939_spinal_canal_stenosis_l1_l2,0.223355,0.341723,0.434923
1,44036939_spinal_canal_stenosis_l2_l3,0.148317,0.368259,0.483424
2,44036939_spinal_canal_stenosis_l3_l4,0.14884,0.347878,0.503283
3,44036939_spinal_canal_stenosis_l4_l5,0.257308,0.183269,0.559423
4,44036939_spinal_canal_stenosis_l5_s1,0.746215,0.14534,0.108445
5,44036939_left_neural_foraminal_narrowing_l1_l2,0.438979,0.536824,0.024198
6,44036939_left_neural_foraminal_narrowing_l2_l3,0.230093,0.628808,0.141099
7,44036939_left_neural_foraminal_narrowing_l3_l4,0.160885,0.477462,0.361653
8,44036939_left_neural_foraminal_narrowing_l4_l5,0.081625,0.352821,0.565554
9,44036939_left_neural_foraminal_narrowing_l5_s1,0.076223,0.281847,0.64193


In [22]:
# # eff_outputs[0][0]
# pd.DataFrame(preds, columns=['normal', 'mild', 'severe'])