In [2]:
import os
import sys
from pathlib import Path

current_dir = os.getcwd()
file = Path(current_dir).resolve()
sys.path.append(str(file.parents[0]))
sys.path.append(str(file.parents[1]))
sys.path.append(str(file.parents[2]))
base_dir = Path(current_dir).resolve().parents[0]
src_path = base_dir / 'src'

In [3]:
from data.bids_dataset import BidsDataset, BidsDataModule, contrasts
import torch

In [4]:
# Parameters
base_data_dir = '/home/student/farid_ma/dev/multiclass_softseg/MulticlassSoftSeg/data/external/ASNR-MICCAI-BraTS2023-GLI-Challenge/Sample-Subset'
train_data_dir = base_data_dir + '/train'
resize = (200, 200, 152) # resize the input images to this size
brats_keys = ['img', 'seg']
n_classes = 4
out_channels = n_classes    # as we don't have intermediate feature maps, our output are the final class predictions
img_key = brats_keys[0]
format = 'fnio'
do2D = False 
batch_size = 1

if n_classes == 2:
    binary = True
else:
    binary = False

### Testing Stacking Torch Tensors with different BidsDataset instances

In [5]:
contrasts

['t1c', 't1n', 't2f', 't2w', 'seg']

In [34]:
t1c_ds = BidsDataset(
    data_dir=train_data_dir,
    prefix='BraTS-GLI',
    contrast=contrasts[0],
    suffix = format,
    do2D=do2D,
    binary=binary,
    transform=None,
    resize=resize,
    )

t1n_ds = BidsDataset(
    data_dir=train_data_dir,
    prefix='BraTS-GLI',
    contrast=contrasts[1],
    suffix = format,
    do2D=do2D,
    binary=binary,
    transform=None,
    resize=resize,
    )

t2f_ds = BidsDataset(
    data_dir=train_data_dir,
    prefix='BraTS-GLI',
    contrast=contrasts[2],
    suffix = format,
    do2D=do2D,
    binary=binary,
    transform=None,
    resize=resize,
    )

t2w_ds = BidsDataset(
    data_dir=train_data_dir,
    prefix='BraTS-GLI',
    contrast=contrasts[3],
    suffix = format,
    do2D=do2D,
    binary=binary,
    transform=None,
    resize=resize,
    )

In [35]:
test_idx = 23

t1c_batch = t1c_ds[test_idx]
t1c_img = t1c_batch['img']

t1n_batch = t1n_ds[test_idx]
t1n_img = t1n_batch['img']

t2f_batch = t2f_ds[test_idx]
t2f_img = t2f_batch['img']

t2w_batch = t2w_ds[test_idx]
t2w_img = t2w_batch['img']

mask = t1c_batch['seg']

In [36]:
print(f"image shape: {t1c_img.shape}; mask shape: {mask.shape}")
print(f"image.dtype: {t1c_img.dtype}; mask.dtype: {mask.dtype}")

image shape: torch.Size([1, 200, 200, 152]); mask shape: torch.Size([1, 200, 200, 152])
image.dtype: torch.float32; mask.dtype: torch.int64


In [37]:
stacked_img_tensor = torch.cat([t1c_img, t1n_img, t2f_img, t2w_img], dim=0)

In [38]:
print(f"stacked_img_tensor shape: {stacked_img_tensor.shape}")
print(f"stacked_img_tensor dtype: {stacked_img_tensor.dtype}")

stacked_img_tensor shape: torch.Size([4, 200, 200, 152])
stacked_img_tensor dtype: torch.float32


### Testing stacked Torch tensor with Multimodal MRI straight as output from BidsDataset

In [6]:
test_multimodal_ds = BidsDataset(
    data_dir=train_data_dir,
    prefix='BraTS-GLI',
    contrast='multimodal',
    suffix = format,
    do2D=do2D,
    binary=binary,
    transform=None,
    resize=resize,
    )

In [7]:
test_idx = 23
sample_batch = test_multimodal_ds[test_idx]

sample_img = sample_batch['img']
sample_mask = sample_batch['seg']

In [8]:
print(f"sample_img shape: {sample_img.shape}; sample_mask shape: {sample_mask.shape}")

sample_img shape: torch.Size([4, 200, 200, 152]); sample_mask shape: torch.Size([1, 200, 200, 152])


### Testing Output of DataModule