# Data loading - check to correctly understand data / label format

- load data from conformal boundary example dataset
- watch how data changes formats / shapes throughout the training pipeline.

In [1]:
import os

import pytorch3dunet
import pytorch3dunet.unet3d as u3  # !
from pytorch3dunet import predict
from pytorch3dunet import train

from pytorch3dunet.unet3d import trainer as trainer
from pytorch3dunet.unet3d.config import _load_config_yaml

from torchsummary import summary

In [2]:
config_dir = './3DUnet_confocal_boundary'
config_filename = 'train_config.yml'
os.listdir(config_dir)

['test_config.yml', 'ovules_raw.png', 'ovules_pred.png', 'train_config.yml']

In [3]:
config_filepath = f'{config_dir}/{config_filename}'
config = _load_config_yaml(config_filepath)
config

{'model': {'name': 'UNet3D',
  'in_channels': 1,
  'out_channels': 1,
  'layer_order': 'gcr',
  'f_maps': 32,
  'num_groups': 8,
  'final_sigmoid': True},
 'loss': {'name': 'BCEDiceLoss',
  'ignore_index': None,
  'skip_last_target': True},
 'optimizer': {'learning_rate': 0.0002, 'weight_decay': 1e-05},
 'eval_metric': {'name': 'BoundaryAdaptedRandError',
  'threshold': 0.4,
  'use_last_target': True,
  'use_first_input': True},
 'lr_scheduler': {'name': 'ReduceLROnPlateau',
  'mode': 'min',
  'factor': 0.5,
  'patience': 30},
 'trainer': {'eval_score_higher_is_better': False,
  'checkpoint_dir': 'CHECKPOINT_DIR',
  'resume': None,
  'pre_trained': None,
  'validate_after_iters': 1000,
  'log_after_iters': 500,
  'max_num_epochs': 1000,
  'max_num_iterations': 150000},
 'loaders': {'num_workers': 8,
  'raw_internal_path': '/raw',
  'label_internal_path': '/label',
  'train': {'file_paths': ['/scratch/groups/jyeatman/samjohns-projects/unet3d/data/osfstorage-archive-train'],
   'slice_bu

In [None]:
loaders = trainer.get_train_loaders(config)

2023-05-16 23:43:31,585 [MainThread] INFO Dataset - Creating training and validation set loaders...
2023-05-16 23:43:32,477 [MainThread] INFO HDF5Dataset - Loading train set from: /scratch/groups/jyeatman/samjohns-projects/unet3d/data/osfstorage-archive-train/N_487_ds2x.h5...


In [None]:
# Create the model
model = u3.model.get_model(config['model'])

In [6]:
model

UNet3D(
  (encoders): ModuleList(
    (0): Encoder(
      (basic_module): DoubleConv(
        (SingleConv1): SingleConv(
          (groupnorm): GroupNorm(1, 1, eps=1e-05, affine=True)
          (conv): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (ReLU): ReLU(inplace=True)
        )
        (SingleConv2): SingleConv(
          (groupnorm): GroupNorm(8, 16, eps=1e-05, affine=True)
          (conv): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (ReLU): ReLU(inplace=True)
        )
      )
    )
    (1): Encoder(
      (pooling): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (basic_module): DoubleConv(
        (SingleConv1): SingleConv(
          (groupnorm): GroupNorm(8, 32, eps=1e-05, affine=True)
          (conv): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (ReLU): ReLU(inplace=True)
        )
     

In [None]:
summary(model, (256, 256, 256, 1))