### Install the necessary libraries (if not already installed)

In [1]:
# install Hugging Face Datasets
!pip install datasets

# optional installs: NiBabel and PyTorch
!pip install nibabel
!pip install torch torchvision



Collecting nibabel
  Downloading nibabel-5.3.2-py3-none-any.whl.metadata (9.1 kB)
Downloading nibabel-5.3.2-py3-none-any.whl (3.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: nibabel
Successfully installed nibabel-5.3.2
Collecting torch
  Downloading torch-2.6.0-cp313-none-macosx_11_0_arm64.whl.metadata (28 kB)
Collecting torchvision
  Downloading torchvision-0.21.0-cp313-cp313-macosx_11_0_arm64.whl.metadata (6.1 kB)
Collecting networkx (from torch)
  Downloading networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Collecting sympy==1.13.1 (from torch)
  Downloading sympy-1.13.1-py3-none-any.whl.metadata (12 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy==1.13.1->torch)
  Downloading mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Downloading torch-2.6.0-cp313-none-macosx_11_0_arm64.whl (66.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.

### Import the libraries

In [2]:
import nibabel as nib
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_dataset, load_from_disk

### Function to process the scanned brain data

In [3]:
def preprocess_nifti(example):
    """
    Loads a .nii.gz file, crops, normalizes, and resamples to 96^3.
    Returns a numpy array (or tensor) in example["img"].
    """
    nii_path = example["nii_filepath"]
    # Load volume data
    vol = nib.load(nii_path).get_fdata()

    # Crop sub-volume
    vol = vol[7:105, 8:132, :108]  # shape: (98, 124, 108)

    # Shift intensities to be non-negative
    vol = vol + abs(vol.min())
    # Normalize to [0,1]
    vol = vol / vol.max()

    # Convert to torch.Tensor: (1,1,D,H,W)
    t_tensor = torch.from_numpy(vol).float().unsqueeze(0).unsqueeze(0)

    # Scale factor based on (124 -> 96) for the y-dimension
    scale_factor = 96 / 124
    downsampled = F.interpolate(
        t_tensor,
        scale_factor=(scale_factor, scale_factor, scale_factor),
        mode="trilinear",
        align_corners=False
    )

    # Now pad each dimension to exactly 96 (symmetric padding)
    _, _, d, h, w = downsampled.shape
    pad_d = 96 - d
    pad_h = 96 - h
    pad_w = 96 - w
    padding = (
        pad_w // 2, pad_w - pad_w // 2,
        pad_h // 2, pad_h - pad_h // 2,
        pad_d // 2, pad_d - pad_d // 2
    )
    final_img = F.pad(downsampled, padding)  # shape => (1, 1, 96, 96, 96)
    final_img = final_img.squeeze(0)

    # Store as numpy or keep as torch.Tensor
    example["img"] = final_img.numpy()
    return example

### Load the dataset from [huggingface](https://huggingface.co/datasets/radiata-ai/brain-structure)

In [4]:
ds_train = load_dataset("radiata-ai/brain-structure", split="train", trust_remote_code=True)
ds_val = load_dataset("radiata-ai/brain-structure", split="validation", trust_remote_code=True)
ds_test = load_dataset("radiata-ai/brain-structure", split="test", trust_remote_code=True)

### Process the data

In [5]:
# Apply the preprocessing to each split
ds_train = ds_train.map(preprocess_nifti)
ds_val   = ds_val.map(preprocess_nifti)
ds_test  = ds_test.map(preprocess_nifti)

# Set the dataset format to return PyTorch tensors for the 'img' column
ds_train.set_format(type='torch', columns=['img'])
ds_val.set_format(type='torch', columns=['img'])
ds_test.set_format(type='torch', columns=['img'])

Map:   0%|          | 0/3066 [00:00<?, ? examples/s]

Map:   0%|          | 0/364 [00:00<?, ? examples/s]

Map:   0%|          | 0/364 [00:00<?, ? examples/s]

### Save data to disk for uploading

In [6]:
ds_train.save_to_disk('exported_brain_images/train')
ds_val.save_to_disk('exported_brain_images/val')
ds_test.save_to_disk('exported_brain_images/test')

Saving the dataset (0/22 shards):   0%|          | 0/3066 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/364 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/364 [00:00<?, ? examples/s]

### Load data from disk (if needed)

In [7]:
# ds_train = load_from_disk('exported_brain_images/train')
# ds_val = load_from_disk('exported_brain_images/val')
# ds_test = load_from_disk('exported_brain_images/test')

In [8]:
# Set up data loaders for model training
train_loader = DataLoader(ds_train, batch_size=16, shuffle=True)
val_loader   = DataLoader(ds_val, batch_size=16, shuffle=False)
test_loader  = DataLoader(ds_test, batch_size=16, shuffle=False)