## iaaa-data 
output dir : iaaa-nii.
read dicom series, save as nii files and add dicom tags to file train.csv

In [None]:
from pathlib import Path


ROOT_DATA_DIR = Path('/kaggle/input/iaaa-mri-challenge').expanduser().absolute()
DATA_DIR = ROOT_DATA_DIR / 'data'
LABELS_PATH = ROOT_DATA_DIR / 'train.csv'
PREPARED_DATA_DIR = Path('/kaggle/working/').expanduser().absolute()

In [None]:
from typing import Optional

import numpy as np
import pydicom
import SimpleITK as sitk
from PIL import Image
from torchvision.transforms import v2
import torch

def read_dicom_series(study_path: Path, series_instance_uid: Optional[str] = None) -> np.ndarray:
    """Reads the dicom series and returns the rendered pixel-array, header and dicom file paths.

    Notes:
        - returned array is in haunsfield uints, you have to take care of windowing
    """

    if series_instance_uid is None:
        series_id = sitk.ImageSeriesReader.GetGDCMSeriesIDs(str(study_path))[0]
    else:
        series_id = series_instance_uid

    series_file_names = sitk.ImageSeriesReader.GetGDCMSeriesFileNames(str(study_path), series_id)

    headers = list()
    for fn in series_file_names:
        headers.append(pydicom.dcmread(str(fn), stop_before_pixels=True))

    volume = sitk.ReadImage(
        series_file_names, sitk.sitkInt32
    )
#     volume = np.array(sitk.GetArrayFromImage(volume), dtype=np.float32)

    if all([i.get('InstanceNumber') is not None for i in headers]):
        slice_number_tag = 'InstanceNumber'
    elif all([i.get('InstanceNumber') is not None for i in headers]):  # in earlier versions of Dicom
        slice_number_tag = 'ImageNumber'
    else:
        slice_number_tag = None

    if slice_number_tag is not None:
        sorted_headers = sorted(headers, key=lambda x: int(x.get(slice_number_tag)))
        file_name_to_index_mapper = {k: v for v, k in enumerate(series_file_names)}
        sorted_file_names = sorted(
            series_file_names,
            key=lambda x: int(headers[file_name_to_index_mapper[x]].get(slice_number_tag))
        )
    else:
        sorted_headers = headers
        sorted_file_names = series_file_names

    ret = {
        'array': volume,
        'headers': sorted_headers,
        'dcm_paths': sorted_file_names
    }
    return ret


def apply_windowing(series: np.ndarray,
                    window_center: int,
                    window_width: int) -> np.ndarray:
    """Returns an array for given window.

    Args:
        series: numpy array of shape (n_slices, h, w) or (h, w) in haunsfield units.
        window_center: for example, brain window's center is 40
        window_width: for example, brain window's width is 80

    Returns:
        numpy array of shape (n_sclies, h, w) or (h, w) in range(0, 1)

    """

    w_min = int(window_center - (window_width / 2))
    w_max = int(window_center + (window_width / 2))

    clipped = np.clip(series, w_min, w_max)
    windowed = (clipped - w_min) / (w_max - w_min)

    return windowed


def apply_windowing_using_header(arr: np.ndarray, header: pydicom.FileDataset) -> np.ndarray:
    """This function returns an array for windows found in windowing dicom tags.

    Args:
        arr: numpy array of shape (h, w) in haunsfield units.
        header: dicom header containing ``WindowCenter`` and ``WindowWidth``

    Returns:
        numpy array of shape (h, w) in range(0, 1)
    """

    window_center = header.get('WindowCenter')
    window_width = header.get('WindowWidth')

    return apply_windowing(arr, window_center, window_width)


def apply_windowing_using_header_on_series(series: np.ndarray, headers: list[pydicom.FileDataset]) -> np.ndarray:
    """This function returns an array for windows found in windowing dicom tags.

    Args:
        series: numpy array of shape (num_slices, h, w) in haunsfield units.
        headers: dicom header containing ``WindowCenter`` and ``WindowWidth``

    Returns:
        numpy array of shape (h, w) in range(0, 1)
    """

    windowed_series = list()
    for i, header in enumerate(headers):
        window_center = header.get('WindowCenter')
        window_width = header.get('WindowWidth')
        windowed_series.append(apply_windowing(series[i], window_center, window_width))

    return np.array(windowed_series)

In [None]:
from tqdm import tqdm
import pandas as pd


def prepare_data(df: pd.DataFrame, split: str):
    prepared_data_dir_for_split = PREPARED_DATA_DIR / split
    prepared_data_dir_for_split.mkdir(parents=True, exist_ok=True)
    
    prepared_data_dir_for_wcw = PREPARED_DATA_DIR / 'wcw'
    prepared_data_dir_for_wcw.mkdir(parents=True, exist_ok=True)
    rows_list = list()
    for ind, row in tqdm(df.iterrows()):
        siuid = row['SeriesInstanceUID']
        study_path = DATA_DIR / siuid
        series = read_dicom_series(study_path)

        prepared_path = prepared_data_dir_for_split / f'{siuid}.nii'
        sitk.WriteImage(series['array'], prepared_path)
        
        window_centers_widths = []
        for i, header in enumerate(series['headers']):
            window_center = header.get('WindowCenter')
            window_width = header.get('WindowWidth')
            window_centers_widths.append([window_center, window_width])
            
        prepared_path_wcw = prepared_data_dir_for_wcw / f'{siuid}.nii'
        with open(prepared_path_wcw, 'wb') as f:
            np.save(f, np.array(window_centers_widths))
        
        drow = list()
        for elem in list(series['headers'][0]):
            drow.append((elem.keyword, elem.value))
        rows_list.append(dict(drow))
    df = pd.concat([df, pd.DataFrame(rows_list)], axis=1)
    df.to_csv(PREPARED_DATA_DIR/"train.csv",index=False)
    
annotations = pd.read_csv(LABELS_PATH)
prepare_data(annotations, 'data')

## segnpy 
output dir : seg_npy_data. read segmentation labels, fix sizes and save as npy files

In [None]:
import numpy as np
from pathlib import Path
import SimpleITK as sitk
import torch
from torchvision.transforms import v2


seg_dir = Path('/kaggle/input/iaaa-seg/seg')
npy_dir = Path('seg')
npy_dir.mkdir(exist_ok=True)

spaths = [i for i in seg_dir.iterdir()]
for i in range(len(spaths)):
    if i%50==0:
        print(i)
        
    npy_seg = np.array(sitk.GetArrayFromImage(sitk.ReadImage(spaths[i])), dtype=np.float32)
    
    seg = torch.from_numpy(npy_seg)
    if seg.shape[1]==288:
        seg = v2.CenterCrop(size=256)(seg)
    else:
        seg = v2.Resize(size=256, antialias=True)(seg)
    prepared_path = npy_dir / f'{spaths[i].name[:-4]}.npy'
    with open(prepared_path, 'wb') as f:
        np.save(f, np.float32(seg.numpy()))

## iaaa-data-prep 
output dir: npy_256_fixed. create npy_data.npy file with shape (1035,16,4,256,256) with 4th element of dim 2 as segmentation labels

In [None]:
from pathlib import Path

ROOT_DATA_DIR = Path('/kaggle/input/iaaa-nii').expanduser().absolute()
DATA_DIR = ROOT_DATA_DIR / 'data'
WCW_DIR = '/kaggle/input/iaaa-nii/wcw/'
SEG_DIR = '/kaggle/input/seg-npy-data/seg/'
LABELS_PATH = ROOT_DATA_DIR / 'train.csv'
PREPARED_DATA_DIR = Path('/kaggle/working/').expanduser().absolute()

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import SimpleITK as sitk
from tqdm import tqdm
import os
from torchvision.transforms import v2
import torch

annotations = pd.read_csv(LABELS_PATH)

df = annotations.pivot(index='PatientID', columns='SeriesDescription', values=['prediction', 'SeriesInstanceUID','Rows'])
df.insert(3, 'label', df.prediction.any(axis=1).astype(np.int32))
df = df.reset_index()
df.columns = ['PatientID','T1_pred', 'FL_pred', 'T2_pred', 'label', 'T1_siuid','FL_siuid','T2_siuid','T1_rows','FL_rows','T2_rows']
df.head(2)

In [None]:
sag_list = [146624, 378524, 591219, 331150, 595574, 681963, 684731, 675668, 679925] #331150
df = df[~df.PatientID.isin(sag_list)]
df = df.reset_index(drop=True)

In [None]:
import numpy as np
# import SimpleITK as sitk
data = np.zeros((16, 256, 256), dtype=np.float32)
with open('zeros.npy', 'wb') as f:
    np.save(f, np.float32(data))

In [None]:
df['seg'] = ''
seg_fnames = [f.name for f in Path(SEG_DIR).iterdir() if f.is_file()]

seg_labels = []
for i in range(len(df)):
    if (df.iloc[i,6] + ".npy") in seg_fnames:
        df.iloc[i,11] = SEG_DIR + df.iloc[i,6] + ".npy"
        seg_npy = np.load(df.iloc[i,11])
        labels = np.unique(seg_npy.reshape(-1)).astype(np.int32)
        tmp = np.zeros(7)
        tmp[labels]=1
        seg_labels.append(tmp)
    elif (df.iloc[i,7] + ".npy") in seg_fnames:
        df.iloc[i,11] =  SEG_DIR + df.iloc[i,7] + ".npy"
        seg_npy = np.load(df.iloc[i,11])
        labels = np.unique(seg_npy.reshape(-1)).astype(np.int32)
        tmp = np.zeros(7)
        tmp[labels]=1
        seg_labels.append(tmp)
    elif (df.iloc[i,5] + ".npy") in seg_fnames:
        df.iloc[i,11] =  SEG_DIR + df.iloc[i,7] + ".npy"
        seg_npy = np.load(df.iloc[i,11])
        labels = np.unique(seg_npy.reshape(-1)).astype(np.int32)
        tmp = np.zeros(7)
        tmp[labels]=1
        seg_labels.append(tmp)
    else:
        df.iloc[i,11] = '/kaggle/working/zeros.npy'
        tmp = np.zeros(7)
        seg_labels.append(tmp)

df_seg_label = pd.DataFrame(np.array(seg_labels))
df_seg_label.columns = [f'L{i}' for i in range(7)]

df = pd.concat([df,df_seg_label], axis=1)

In [None]:
def read_npy_file(item):
    data = np.load(item)
    return data.astype(np.float32)

def apply_windowing(series: np.ndarray,
                    window_center: int,
                    window_width: int) -> np.ndarray:
    w_min = int(window_center - (window_width / 2))
    w_max = int(window_center + (window_width / 2))

    clipped = np.clip(series, w_min, w_max)
    windowed = (clipped - w_min) / (w_max - w_min)

    return windowed

In [None]:
# 256,288,256 -> crop center 288 to 256

# 240,288,288 -> resize 240 to 288 -> crop center all to 256
# 256,288,288 -> resize 256 to 288 -> crop center all to 256
# 320,288,288 -> resize 320 to 288 -> crop center all to 256
# 288,288,288 -> crop to 256

# 256,256,384 -> just resize 384 to 256
# 256,256,288 -> just resize 288 to 256

In [None]:
prepared_data_dir_for_split = PREPARED_DATA_DIR / 'data'
prepared_data_dir_for_split.mkdir(parents=True, exist_ok=True)

prepared_seg_dir_for_split = PREPARED_DATA_DIR / 'seg'
prepared_seg_dir_for_split.mkdir(parents=True, exist_ok=True)

nslices=16
print(df.shape)
npy_data = np.zeros((1035,16,4,256,256),dtype=np.float32)
for i in tqdm(range(len(df))):
    if i%50==0:
        print(i)

    t1 = sitk.ReadImage((DATA_DIR / f'{df.iloc[i,5]}.nii'), sitk.sitkInt32)
    fl = sitk.ReadImage((DATA_DIR / f'{df.iloc[i,6]}.nii'), sitk.sitkInt32)
    t2 = sitk.ReadImage((DATA_DIR / f'{df.iloc[i,7]}.nii'), sitk.sitkInt32)

    npy_t1 = sitk.GetArrayFromImage(t1)
    npy_fl = sitk.GetArrayFromImage(fl)
    npy_t2 = sitk.GetArrayFromImage(t2)
    
    npy_seg = np.load(df.iloc[i,11])
    #============================================================================
    t1w_path = WCW_DIR + df.iloc[i,5] + '.nii'
    flw_path = WCW_DIR + df.iloc[i,6] + '.nii'
    t2w_path = WCW_DIR + df.iloc[i,7] + '.nii'

    npy_t1w = read_npy_file(t1w_path)
    npy_flw = read_npy_file(flw_path)
    npy_t2w = read_npy_file(t2w_path)
    #==========================================================
    wnpy_t1 = []
    for j in range(len(npy_t1w)):
        wnpy_t1.append(apply_windowing(npy_t1[j],npy_t1w[j,0], npy_t1w[j,1]))
    npy_t1 = np.array(wnpy_t1)

    wnpy_fl = []
    for j in range(len(npy_flw)):
        wnpy_fl.append(apply_windowing(npy_fl[j],npy_flw[j,0], npy_flw[j,1]))
    npy_fl = np.array(wnpy_fl)

    wnpy_t2 = []
    for j in range(len(npy_t2w)):
        wnpy_t2.append(apply_windowing(npy_t2[j],npy_t2w[j,0], npy_t2w[j,1]))
    npy_t2 = np.array(wnpy_t2)
    #============================================================================
    
    d_t1 = npy_t1.shape[1]
    d_fl = npy_fl.shape[1]
    d_t2 = npy_t2.shape[1]
    
    t_t1 = torch.from_numpy(npy_t1)
    t_fl = torch.from_numpy(npy_fl)
    t_t2 = torch.from_numpy(npy_t2)
    
    t_seg = torch.from_numpy(npy_seg)
    
    
    if d_t1==256 and d_fl==288 and d_t2==256:
        t_fl = v2.CenterCrop(size=256)(t_fl)
        
    elif d_fl==288 and d_t2==288:
        transform = v2.Compose([v2.Resize(size=288, antialias=True), v2.CenterCrop(size=256)])
        t_t1 = transform(t_t1)
        t_fl = v2.CenterCrop(size=256)(t_fl)
        t_t2 = v2.CenterCrop(size=256)(t_t2)
        
    elif  d_t1==256 and d_fl==256:
        t_t2 = v2.Resize(size=256, antialias=True)(t_t2)
        
    else:
        t_t1 = v2.Resize(size=256, antialias=True)(t_t1)
        t_fl = v2.Resize(size=256, antialias=True)(t_fl)
        t_t2 = v2.Resize(size=256, antialias=True)(t_t2)

    assert ((t_t1.shape[1]==256) and (t_fl.shape[1]==256) and (t_t2.shape[1]==256)), "shapes must be 256"
#==========================================================================================================        
    if t_t1.shape[0]!=8:
        sd = t_t1.shape[0] - nslices
        selected_indices = torch.clamp(torch.arange(sd//2, nslices+sd//2), 0, len(t_t1)-1)
        t_t1 = t_t1[selected_indices]
    else:
        selected_indices = torch.arange(8).repeat_interleave(2)
        t_t1 = t_t1[selected_indices]
#==========================================================================================================        
    if t_fl.shape[0]!=8:
        sd = t_fl.shape[0] - nslices
        selected_indices = torch.clamp(torch.arange(sd//2, nslices+sd//2), 0, len(t_fl)-1)
        t_fl = t_fl[selected_indices]
    else:
        selected_indices = torch.arange(8).repeat_interleave(2)
        t_fl = t_fl[selected_indices]
#==========================================================================================================
    if t_t2.shape[0]!=8:
        sd = t_t2.shape[0] - nslices
        selected_indices = torch.clamp(torch.arange(sd//2, nslices+sd//2), 0, len(t_t2)-1)
        t_t2 = t_t2[selected_indices]
    else:
        selected_indices = torch.arange(8).repeat_interleave(2)
        t_t2 = t_t2[selected_indices]
#==========================================================================================================
    if t_seg.shape[0]!=8:
        sd = t_seg.shape[0] - nslices
        selected_indices = torch.clamp(torch.arange(sd//2, nslices+sd//2), 0, len(t_seg)-1)
        t_seg = t_seg[selected_indices]
    else:
        selected_indices = torch.arange(8).repeat_interleave(2)
        t_seg = t_seg[selected_indices]
#==========================================================================================================
    npy_data[i,:,0,:,:] = np.float32(t_t1.numpy())
    npy_data[i,:,1,:,:] = np.float32(t_fl.numpy())
    npy_data[i,:,2,:,:] = np.float32(t_t2.numpy())
    npy_data[i,:,3,:,:] = np.float32(t_seg.numpy())

npy_data_dir = PREPARED_DATA_DIR / 'npy_data.npy'    
with open(npy_data_dir, 'wb') as f:
    np.save(f, npy_data)

## iaaa-model-3channel (training script)

In [None]:
!pip install iterative-stratification

In [None]:
from typing import Optional
# import pydicom
# import SimpleITK as sitk
from PIL import Image
import click
import gc
import os
import numpy as np
import pandas as pd
from torchvision.io import read_image
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.models as models
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from torchvision.transforms import InterpolationMode
import torch.optim as optim
from sklearn.metrics import roc_auc_score, precision_score, recall_score, average_precision_score
from sklearn.metrics import PrecisionRecallDisplay, precision_recall_curve
from sklearn.metrics import RocCurveDisplay, roc_curve
from sklearn.model_selection import StratifiedShuffleSplit, StratifiedKFold, GroupShuffleSplit, StratifiedGroupKFold
from sklearn.metrics import auc, roc_curve, precision_recall_curve
from torch.utils.data import  WeightedRandomSampler
from iterstrat.ml_stratifiers import RepeatedMultilabelStratifiedKFold, MultilabelStratifiedKFold, MultilabelStratifiedShuffleSplit


import time

def read_npy_file(item):
    data = np.load(item)
    return data.astype(np.float32)

In [None]:
npy_data = np.load('/kaggle/input/npy-256-fixed/npy_data.npy', mmap_mode='r')
# npy_data = np.load('/kaggle/input/npy-256-fixed/npy_data.npy')

print(npy_data.shape)

In [None]:
from pathlib import Path

LABELS_PATH = '/kaggle/input/iaaa-nii/train.csv'

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

annotations = pd.read_csv(LABELS_PATH)
bins = [0,.6,1.2]
df = annotations.pivot(index='PatientID', columns='SeriesDescription', values=['prediction', 'SeriesInstanceUID', 'PatientBirthDate','PatientWeight', 'PatientSex'])
df = df.reset_index()
df.insert(4, 'label', df.prediction.any(axis=1).astype(np.int32))
df.iloc[:,8] = (2024-df.iloc[:,8]//10000)
df = df.iloc[:,[0,1,2,3,4,5,6,7,8,11,14]]
df.columns = ['PatientID','T1_pred','FLAIR_pred','T2_pred','label','T1_SIUID','FLAIR_SIUID','T2_SIUID','AGE','WEIGHT', 'SEX']
df.loc[df.AGE>=140,'AGE'] -= 100
df['AGE'] = df['AGE'].apply(lambda x: round(x/10)/10)
df['SEX'] = df['SEX'].astype('category')
df['SEX'] = df['SEX'].cat.codes
df['strat'] = df.label.astype(str)+pd.cut(df.AGE,bins=bins,labels=False).astype(str)# + df.SEX.astype(str)

In [None]:
sag_list = [146624, 378524, 591219, 331150, 595574, 681963, 684731, 675668, 679925] #331150
df = df[~df.PatientID.isin(sag_list)]
df = df.reset_index(drop=True)

In [None]:
df['seg'] = ''
# 1035,16,4,256,256
seg_labels = []
for i in range(len(df)):
    seg_npy = npy_data[i,:,3,:,:]
    labels = np.unique(seg_npy.reshape(-1)).astype(np.int32)
    tmp = np.zeros(7)
    tmp[labels]=1
    seg_labels.append(tmp)

df_seg_label = pd.DataFrame(np.array(seg_labels))
df_seg_label.columns = [f'L{i}' for i in range(7)]

df = pd.concat([df,df_seg_label], axis=1)

In [None]:
class MRIDataset(Dataset):
    def __init__(self, annotations, data, transform, training):
        self.labels = annotations
        self.data = data
        self.transform = transform
        self.training = training

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx): 
        ix = self.labels.loc[idx, 'original_index']
        age = self.labels.loc[idx, ['AGE','WEIGHT','SEX']].values.astype(np.float32)#         arr = np.copy(self.data[self.labels.index[idx]])
        lbl = self.labels.loc[idx, 'label']
        label = torch.from_numpy(self.labels.loc[idx,['T1_pred','FLAIR_pred','T2_pred']].values.astype(np.float32))
        arr = np.copy(self.data[ix])

        images = torch.from_numpy(arr).to(torch.float32)

#==========================================================================================================
        if self.training:
            images = torch.stack([self.transform(img) for img in images])
        else:
            images = self.transform(images)
        if self.training:
            images = images[torch.randperm(len(images))]
        
        segmentation_labels = torch.round(images[:,3,:,:])
        one_hot_encoded_labels = torch.zeros((16, 7), dtype=torch.float32)
        for i in range(segmentation_labels.shape[0]):
            unique_values = segmentation_labels[i].unique()  # Get unique values in the mask
            one_hot_encoded_labels[i, unique_values.to(torch.int32)] = 1  # Set the corresponding positions to 1
        
        images = images[:,:3]        
        return images, torch.tensor(age,dtype=torch.float32), label, one_hot_encoded_labels[:,1:]

In [None]:
transforms = v2.Compose([
    v2.RandomAffine(degrees=180,
                    translate=(0.1, 0.1),
                    scale=(0.95, 1.05),
                    shear=(-10,10,-10,10)
                   ),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32),
])
transformsv = v2.Compose([
    v2.ToDtype(torch.float32),    
])

In [None]:
class effb0_meta(nn.Module):
    def __init__(self):
        super(effb0_meta, self).__init__()
        self.model = models.efficientnet_b0(weights='DEFAULT')
        self.model2 = models.efficientnet_b0(weights='DEFAULT')
        self.model.classifier[1] = nn.Identity()
        self.model2.classifier[1] = nn.Identity()
        self.fc1 = nn.Linear(1280,6)
        self.fc2 = nn.Linear(1280,6)
        self.fc3 = nn.Linear(1280,6)
        self.fc4 = nn.Linear(1280,3)

    def forward(self, x, y):
        bs,sl,ch,h,w = x.size()
        x = x.contiguous()
        x1 = x.clone()
        
        x = x.view(bs*sl,ch,h,w)
        x = self.model(x)
        x = x.view(bs,sl,1280)
        o1 = torch.stack((self.fc1(x), self.fc2(x), self.fc3(x)), dim=2)
        x1 = x1.permute(0,2,1,3,4).view(bs, 3, 4, 4, 256, 256).permute(0, 1, 2, 4, 3, 5).reshape(bs, 3, 4 * 256, 4 * 256)
        x1 = self.model2(x1)
        
        o2 = self.fc4(x1)
        
        return o1, o2

In [None]:
def save_checkpoint(model, optimizer, epoch, loss, filename='checkpoint.pth'):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, filename)

# Function to load a checkpoint
def load_checkpoint(model, optimizer, filename='checkpoint.pth'):
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch'], checkpoint['loss']

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
figs_dir = Path('figs')
figs_dir.mkdir(exist_ok=True)
num_epochs=30
bs=4
sss = MultilabelStratifiedKFold(n_splits=5, shuffle=True, random_state=13)
for i, (train_index, val_index) in enumerate(sss.split(df.PatientID, df.iloc[:,-6:])):
    dftr = df.iloc[train_index]
    dftr = dftr.reset_index().rename(columns={'index': 'original_index'})
    
    dfte = df.iloc[val_index]
    dfte = dfte.reset_index().rename(columns={'index': 'original_index'})
    
    training_data = MRIDataset(dftr, npy_data, transform = transforms, training = True)
    validation_data = MRIDataset(dfte, npy_data, transform = transformsv, training = False)
    train_loader = DataLoader(training_data, batch_size=bs, shuffle=True, num_workers=4,pin_memory=True)
    val_loader = DataLoader(validation_data, batch_size=bs, shuffle=False, num_workers=4,pin_memory=True)
    
    model = effb0_meta()
    model.to(device)
    criterion = nn.BCEWithLogitsLoss().to(device)
    criterion2 = nn.MSELoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.0001, max_lr=0.0003,
                                                  step_size_up=int(len(train_loader)*0.42),
                                                  cycle_momentum=False)
    best_ap = 0.0
    learning_rates = [0.00001] + [0.0001 for i in range(16)] + [0.00001 for i in range(3)]
    # Training loop
    for epoch in range(num_epochs):
        
        start_time = time.time()
        # Training
        model.train()
        train_loss = 0.0
        train_loss2 = 0.0
        train_mseloss = 0.0
        train_preds = []
        train_preds2 = []
        train_preds3 = []
        train_targets = []
        for inputs, ages, targets, auxt in train_loader:
            # Move the inputs and targets to the GPU
            inputs = inputs.to(device, non_blocking=True)     #(bs,16,3,256,256)
            ages = ages.to(device, non_blocking=True)         # (bs,3)
            targets = targets.to(device, non_blocking=True)   #(bs,3)
            auxt = auxt.to(device, non_blocking=True)         #(bs,16,256,256)
            
            optimizer.zero_grad()
            outputs, outputs2 = model(inputs,ages)
            bs,sl,ch,lbl = outputs.size()
            newTargets = torch.zeros((bs, sl, ch, lbl), dtype=auxt.dtype)
            targets_expanded = targets.unsqueeze(1).expand(-1, sl, -1)
            mask = targets_expanded == 1  # Shape (4, 16, 3)
            newTargets = auxt.unsqueeze(2) * mask.unsqueeze(-1).to(auxt.dtype) #same shape as outputs (4,16,3,6)
            losses = criterion(outputs[:,:,0], newTargets[:,:,0].float())
            losses += criterion(outputs[:,:,1], newTargets[:,:,1].float())
            losses += criterion(outputs[:,:,2], newTargets[:,:,2].float())

            losses2 = criterion(outputs2, targets)
            
            losses.backward()
            losses2.backward()
            optimizer.step()
            
            train_loss += losses.item()
            train_loss2 += losses2.item()
            pro1 = outputs.permute(0,1,3,2)[:,:,torch.tensor([1,2,3,4]),:].reshape(-1,16*4,3).max(-2).values.view(-1)
            pro2 = outputs2.view(-1)
            pro3 = (pro1+pro2)/2.0
            train_preds.extend(torch.sigmoid(pro1).tolist())
            train_preds2.extend(torch.sigmoid(pro2).tolist())
            train_preds3.extend(torch.sigmoid(pro3).tolist())

            train_targets.extend(targets.view(-1).tolist())
        train_loss /= len(train_loader)
        train_mseloss /= len(train_loader)

        # Validation
        model.eval()
        val_loss = 0.0
        val_loss2 = 0.0
        val_mseloss = 0.0
        val_preds = []
        val_preds2 = []
        val_preds3 = []
        val_targets = []
        with torch.no_grad():
            for inputs, ages, targets, auxt in val_loader:
                # Move the inputs and targets to the GPU
                inputs = inputs.to(device, non_blocking=True)
                ages = ages.to(device, non_blocking=True)
                targets = targets.to(device, non_blocking=True)
                auxt = auxt.to(device, non_blocking=True)

                outputs, outputs2 = model(inputs,ages)
                bs,sl,ch,lbl = outputs.size()
                newTargets = torch.zeros((bs, sl, ch, lbl), dtype=auxt.dtype)
                targets_expanded = targets.unsqueeze(1).expand(-1, sl, -1)
                mask = targets_expanded == 1  # Shape (bs, 16, 3)
                newTargets = auxt.unsqueeze(2) * mask.unsqueeze(-1).to(auxt.dtype) #same shape as outputs (bs,16,3,6)
                losses = criterion(outputs[:,:,0], newTargets[:,:,0].float())
                losses += criterion(outputs[:,:,1], newTargets[:,:,1].float())
                losses += criterion(outputs[:,:,2], newTargets[:,:,2].float())

                losses2 = criterion(outputs2, targets)

                val_loss += losses.item()
                val_loss2 += losses2.item()
        
                pro1 = outputs.permute(0,1,3,2)[:,:,torch.tensor([1,2,3,4]),:].reshape(-1,16*4,3).max(-2).values.view(-1)
                pro2 = outputs2.view(-1)
                pro3 = (pro1+pro2)/2.0
                val_preds.extend(torch.sigmoid(pro1).tolist())
                val_preds2.extend(torch.sigmoid(pro2).tolist())
                val_preds3.extend(torch.sigmoid(pro3).tolist())
                val_targets.extend(targets.view(-1).tolist())
                
        scheduler.step()        
        val_loss /= len(val_loader)
        val_mseloss /= len(val_loader)
        val_auc = roc_auc_score(val_targets, val_preds)
        val_ap = average_precision_score(val_targets, val_preds)
        val_ap2 = average_precision_score(val_targets, val_preds2)
        val_ap3 = average_precision_score(val_targets, val_preds3)
        val_precision = precision_score(val_targets, [1 if p > 0.5 else 0 for p in val_preds])
        val_recall = recall_score(val_targets, [1 if p > 0.5 else 0 for p in val_preds])
        
        fpr, tpr, thresholds_roc = roc_curve(val_targets, val_preds)
        val_roc_auc = auc(fpr, tpr)
        precision, recall, thresholds_pr = precision_recall_curve(val_targets, val_preds)
        val_pr_auc = auc(recall, precision)
        
        train_ap = average_precision_score(train_targets, train_preds)
        train_ap2 = average_precision_score(train_targets, train_preds2)
        train_ap3 = average_precision_score(train_targets, train_preds3)
        
        end_time = time.time()  # End time for the epoch
        epoch_duration = end_time - start_time  # Calculate duration
        print(f"Fold [{i+1}], ({epoch_duration:.2f}s) Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f},    Train Loss2: {train_loss2:.4f},    Train AP: {train_ap:.4f},    Train AP2: {train_ap2:.4f},    Train AP3: {train_ap3:.4f},    Val Loss: {val_loss:.4f},    Val Loss2: {val_loss2:.4f},    Val AUC: {val_auc:.4f},    Val AP: {val_ap:.4f},    Val AP2: {val_ap2:.4f},    Val AP3: {val_ap3:.4f},    Val Precision: {val_precision:.4f},    Val Recall: {val_recall:.4f},    Val roc_auc: {val_roc_auc:.4f},    Val pr_auc: {val_pr_auc:.4f}")
        
        # Create a figure with subplots
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        RocCurveDisplay.from_predictions(val_targets, val_preds, ax=ax1)
        ax1.set_title('ROC Curve')
        ax1.set_xlabel('False Positive Rate')
        ax1.set_ylabel('True Positive Rate')
        PrecisionRecallDisplay.from_predictions(val_targets, val_preds, ax=ax2)
        ax2.set_title('Precision-Recall Curve')
        ax2.set_xlabel('Recall')
        ax2.set_ylabel('Precision')
        plt.tight_layout()
        plt.savefig(f"figs/fold{i+1}_epoch{epoch+1}", dpi=300, bbox_inches='tight')
        plt.close()
        
        if val_pr_auc > best_ap:
            best_ap = val_pr_auc
        save_checkpoint(model, optimizer, epoch, val_loss, filename= f"chechpoint_fold{i+1}_epoch{epoch+1}.pth")