## Train a Pixel-level Classifier

In this notebook, we show a quick way to train a downstream classifier based on learned representations from a trained ddpm model.

In [None]:
import os
import sys

# Add braintumor_ddpm to path
sys.path.append(os.path.dirname(os.getcwd()))

# braintumor_ddpm imports for training a pixel-level classifeir
from braintumor_ddpm.insights.evaluator import NiftiEvaluator
from braintumor_ddpm.data.datasets import SegmentationDataset
from braintumor_ddpm.core.networks.pixel_classifier import PixelClassifier
from braintumor_ddpm.core.training.PixelRepresentationsTrainer import PixelRepresentationsTrainer

In [None]:
# Define some parameters
train_size = 10
images_dir = r"PATH TO SCANS"
labels_dir = r"PATH TO LABELS/MASKS"
labels ={
    'Background': 0,
    'Non-Enhancing Tumor': 1,
    'Tumor Core': 2,
    'Enhancing Tumor': 3
    }
output_folder = r"PATH POINTING TO OUTPUT OR CACHE FOLDER"
trained_ddpm_path = r"PATH TO THE TRAINED DIFFUSION MODEL"

# Create output folder
output_folder = os.path.join(output_folder, f"DenoiseNetwork Experiment - {train_size} Samples")
os.makedirs(output_folder, exist_ok=True)

In [None]:
# define a pixel-level classifier
pixel_classifier = PixelClassifier(ddpm_model_path=trained_ddpm_path,
                                   layers=[16, 17, 18],
                                   time_steps=[200])

# define the trainer class
trainer = PixelRepresentationsTrainer(network=pixel_classifier,
                                      output_folder=output_folder,
                                      images_dir=images_dir,
                                      labels_dir=labels_dir,
                                      train_size=train_size,
                                      seed=16,
                                      labels=labels
                                     )

In [None]:
# adjust some parameters for the trainer
trainer.set_maximum_epochs(8)
trainer.set_initial_lr(0.0001)
trainer.initialize()
trainer.run_training()

In [None]:
# start training
trainer.run_training()

In [None]:
pred_dir = r"PATH TO EXPORTED NIFTI PREDICTIONS"
gt_dir = r"PATH TO GT FILES"

# set-up evaluator and run evaluation
evaluator = NiftiEvaluator(predictions=pred_dir, references=gt_dir,labels=labels)
evaluator.evaluate_folders(output_folder)