In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
FOLDER_OUTPUT_PREFIX = "output/"
FOLDER_DATA_CHALLENGE = "../../data/czii/czii-cryo-et-object-identification"

In [None]:
!ls {FOLDER_DATA_CHALLENGE}

In [None]:
# Make a copick project
import os
import shutil

config_blob = """{
    "name": "czii_cryoet_mlchallenge_2024",
    "description": "2024 CZII CryoET ML Challenge training data.",
    "version": "1.0.0",

    "pickable_objects": [
        {
            "name": "apo-ferritin",
            "is_particle": true,
            "pdb_id": "4V1W",
            "label": 1,
            "color": [  0, 117, 220, 128],
            "radius": 60,
            "map_threshold": 0.0418
        },
        {
            "name": "beta-galactosidase",
            "is_particle": true,
            "pdb_id": "6X1Q",
            "label": 3,
            "color": [ 76,   0,  92, 128],
            "radius": 90,
            "map_threshold": 0.0578
        },
        {
            "name": "ribosome",
            "is_particle": true,
            "pdb_id": "6EK0",
            "label": 4,
            "color": [  0,  92,  49, 128],
            "radius": 150,
            "map_threshold": 0.0374
        },
        {
            "name": "thyroglobulin",
            "is_particle": true,
            "pdb_id": "6SCJ",
            "label": 5,
            "color": [ 43, 206,  72, 128],
            "radius": 130,
            "map_threshold": 0.0278
        },
        {
            "name": "virus-like-particle",
            "is_particle": true,
            "label": 6,
            "color": [255, 204, 153, 128],
            "radius": 135,
            "map_threshold": 0.201
        },
        {
            "name": "membrane",
            "is_particle": false,
            "label": 8,
            "color": [100, 100, 100, 128]
        },
        {
            "name": "background",
            "is_particle": false,
            "label": 9,
            "color": [10, 150, 200, 128]
        }
    ],
    "overlay_fs_args": {
        "auto_mkdir": true
    },
"""

config_blob_folders = f"""
    "overlay_root": "{FOLDER_OUTPUT_PREFIX}/overlay",
    "static_root": "{FOLDER_DATA_CHALLENGE}/train/static"
"""

config_blob_suffix = """
}
"""

config_blob = config_blob + config_blob_folders + config_blob_suffix

copick_config_path = f"{FOLDER_OUTPUT_PREFIX}/copick.config"
output_overlay = f"{FOLDER_OUTPUT_PREFIX}/overlay"

with open(copick_config_path, "w") as f:
    f.write(config_blob)

# Update the overlay
# Define source and destination directories
source_dir = f"{FOLDER_DATA_CHALLENGE}/train/overlay"
destination_dir = f"{FOLDER_OUTPUT_PREFIX}/overlay"

# Walk through the source directory
for root, dirs, files in os.walk(source_dir):
    # Create corresponding subdirectories in the destination
    relative_path = os.path.relpath(root, source_dir)
    target_dir = os.path.join(destination_dir, relative_path)
    os.makedirs(target_dir, exist_ok=True)

    # Copy and rename each file
    for file in files:
        if file.startswith("curation_0_"):
            new_filename = file
        else:
            new_filename = f"curation_0_{file}"

        # Define full paths for the source and destination files
        source_file = os.path.join(root, file)
        destination_file = os.path.join(target_dir, new_filename)

        # Copy the file with the new name
        shutil.copy2(source_file, destination_file)
        print(f"Copied {source_file} to {destination_file}")

In [None]:
import os
import numpy as np
from pathlib import Path
import torch
import torchinfo
import zarr, copick
from tqdm import tqdm
from typing import Optional, Union, Tuple, List
from monai.data import DataLoader, Dataset, CacheDataset, decollate_batch
from monai.transforms import (
    Compose,
    EnsureChannelFirstd,
    Orientationd,
    AsDiscrete,
    RandFlipd,
    RandRotate90d,
    NormalizeIntensityd,
    RandCropByLabelClassesd,
)
from monai.networks.nets import UNet
from monai.losses import DiceLoss, FocalLoss, TverskyLoss
from monai.metrics import DiceMetric, ConfusionMatrixMetric
import pytorch_lightning as pl
import torch.distributed as dist
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import MLFlowLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

### Generate the painted mask

#### Get copick root

In [None]:
!ls $copick_config_path

In [None]:
root = copick.from_file(copick_config_path)

copick_user_name = "copickUtils"
copick_segmentation_name = "paintedPicks"
voxel_size = 10
tomo_type = "denoised"

#### Generate multi-class segmentation masks from picks, and saved them to the copick overlay directory

In [None]:
root

In [None]:
from copick_utils.segmentation import segmentation_from_picks
import copick_utils.writers.write as write
from collections import defaultdict

# Just do this once
generate_masks = True

if generate_masks:
    target_objects = defaultdict(dict)
    for object in root.pickable_objects:
        if object.is_particle:
            target_objects[object.name]["label"] = object.label
            target_objects[object.name]["radius"] = object.radius
    for run in tqdm(root.runs):
        tomo = run.get_voxel_spacing(10)
        tomo = tomo.get_tomogram(tomo_type).numpy()
        target = np.zeros(tomo.shape, dtype=np.uint8)
        for pickable_object in root.pickable_objects:
            pick = run.get_picks(object_name=pickable_object.name, user_id="curation")
            if len(pick):
                target = segmentation_from_picks.from_picks(
                    pick[0],
                    target,
                    target_objects[pickable_object.name]["radius"] * 0.8,
                    target_objects[pickable_object.name]["label"],
                )
        write.segmentation(run, target, copick_user_name, name=copick_segmentation_name)

In [None]:
class Model(pl.LightningModule):
    def __init__(
        self,
        spatial_dims: int = 3,
        in_channels: int = 1,
        out_channels: int = 8,
        channels: Union[Tuple[int, ...], List[int]] = (48, 64, 80, 80),
        strides: Union[Tuple[int, ...], List[int]] = (2, 2, 1),
        num_res_units: int = 1,
        lr: float = 1e-3,
    ):

        super().__init__()
        self.save_hyperparameters()
        self.model = UNet(
            spatial_dims=self.hparams.spatial_dims,
            in_channels=self.hparams.in_channels,
            out_channels=self.hparams.out_channels,
            channels=self.hparams.channels,
            strides=self.hparams.strides,
            num_res_units=self.hparams.num_res_units,
        )
        self.loss_fn = TverskyLoss(
            include_background=True, to_onehot_y=True, softmax=True
        )  # softmax=True for multiclass
        self.metric_fn = DiceMetric(
            include_background=False, reduction="mean", ignore_empty=True
        )

        self.train_loss = 0
        self.val_metric = 0
        self.num_train_batch = 0
        self.num_val_batch = 0

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch["image"], batch["label"]
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.train_loss += loss
        self.num_train_batch += 1
        torch.cuda.empty_cache()
        return loss

    def on_train_epoch_end(self):
        loss_per_epoch = self.train_loss / self.num_train_batch
        print(f"Epoch {self.current_epoch} - Average Train Loss: {loss_per_epoch:.4f}")
        self.log("train_loss", loss_per_epoch, prog_bar=True)
        self.train_loss = 0
        self.num_train_batch = 0

    def validation_step(self, batch, batch_idx):
        with torch.no_grad():  # This ensures that gradients are not stored in memory
            x, y = batch["image"], batch["label"]
            y_hat = self(x)
            metric_val_outputs = [
                AsDiscrete(argmax=True, to_onehot=self.hparams.out_channels)(i)
                for i in decollate_batch(y_hat)
            ]
            metric_val_labels = [
                AsDiscrete(to_onehot=self.hparams.out_channels)(i)
                for i in decollate_batch(y)
            ]

            # compute metric for current iteration
            self.metric_fn(y_pred=metric_val_outputs, y=metric_val_labels)
            metrics = self.metric_fn.aggregate(reduction="mean_batch")
            val_metric = torch.mean(
                metrics
            )  # I used mean over all particle species as the metric. This can be explored.
            self.val_metric += val_metric
            self.num_val_batch += 1
        torch.cuda.empty_cache()
        return {"val_metric": val_metric}

    def on_validation_epoch_end(self):
        metric_per_epoch = self.val_metric / self.num_val_batch
        print(
            f"Epoch {self.current_epoch} - Average Val Metric: {metric_per_epoch:.4f}"
        )
        self.log(
            "val_metric", metric_per_epoch, prog_bar=True, sync_dist=False
        )  # sync_dist=True for distributed training
        self.val_metric = 0
        self.num_val_batch = 0

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)


class CopickDataModule(pl.LightningDataModule):
    def __init__(
        self,
        copick_config_path: str,
        train_batch_size: int,
        val_batch_size: int,
        num_random_samples_per_batch: int,
        num_training_dataset: int = 5,  # the rest of the dataset is used for validation
    ):

        super().__init__()
        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size

        self.data_dicts, self.nclasses = self.data_from_copick(copick_config_path)
        self.train_files = self.data_dicts[:num_training_dataset]
        self.val_files = self.data_dicts[num_training_dataset:]
        print(f"Number of training samples: {len(self.train_files)}")
        print(f"Number of validation samples: {len(self.val_files)}")

        # Non-random transforms to be cached
        self.non_random_transforms = Compose(
            [
                EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
                NormalizeIntensityd(keys="image"),
                Orientationd(keys=["image", "label"], axcodes="RAS"),
            ]
        )

        # Random transforms to be applied during training
        self.random_transforms = Compose(
            [
                RandCropByLabelClassesd(
                    keys=["image", "label"],
                    label_key="label",
                    spatial_size=[96, 96, 96],
                    num_classes=self.nclasses,
                    num_samples=num_random_samples_per_batch,
                ),
                RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
                RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
            ]
        )

    def setup(self, stage: Optional[str] = None) -> None:
        self.train_ds = CacheDataset(
            data=self.train_files, transform=self.non_random_transforms, cache_rate=1.0
        )
        self.train_ds = Dataset(data=self.train_ds, transform=self.random_transforms)
        self.val_ds = CacheDataset(
            data=self.val_files, transform=self.non_random_transforms, cache_rate=1.0
        )
        self.val_ds = Dataset(data=self.val_ds, transform=self.random_transforms)

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_ds,
            batch_size=self.train_batch_size,
            shuffle=True,
            num_workers=1,
            persistent_workers=False,
            # pin_memory=torch.cuda.is_available(),
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_ds,
            batch_size=self.val_batch_size,
            shuffle=False,  # Ensure the data order remains consistent
            num_workers=1,
            persistent_workers=False,
            # pin_memory=torch.cuda.is_available(),
        )

    @staticmethod
    def data_from_copick(
        copick_config_path,
        copick_user_name="copickUtils",
        copick_segmentation_name="paintedPicks",
        tomo_type="denoised",
    ):
        root = copick.from_file(copick_config_path)
        nclasses = len(root.pickable_objects) + 1
        data_dicts = []
        target_objects = defaultdict(dict)
        for object in root.pickable_objects:
            if object.is_particle:
                target_objects[object.name]["label"] = object.label
                target_objects[object.name]["radius"] = object.radius

        data_dicts = []
        for run in tqdm(root.runs):
            tomogram = run.get_voxel_spacing(10).get_tomogram(tomo_type).numpy()
            segmentation = run.get_segmentations(
                user_id=copick_user_name,
                name=copick_segmentation_name,
                voxel_size=10,
                is_multilabel=True,
            )[0].numpy()
            # membrane_seg = run.get_segmentations(name=copick_segmentation_name, user_id="data-portal")[0].numpy()
            # segmentation[membrane_seg==1] = 1
            data_dicts.append({"image": tomogram, "label": segmentation})

        return data_dicts, nclasses

In [None]:
mlf_logger = MLFlowLogger("training-3D-UNet-model-for-the-cryoET-ML-Challenge")
# Trainer callbacks
checkpoint_callback = ModelCheckpoint(monitor="val_metric", save_top_k=1, mode="max")
lr_monitor = LearningRateMonitor(logging_interval="epoch")

# Check if CUDA is available and then count the GPUs
if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs available: {num_gpus}")
else:
    print("No GPU available. Running on CPU.")
devices = list(range(num_gpus))
print(devices)

In [None]:
import monai

torch.serialization.add_safe_globals([monai.data.meta_tensor.MetaTensor])

In [None]:
channels = (48, 64, 80, 80)
strides_pattern = (2, 2, 1)
num_res_units = 1
learning_rate = 1e-4
num_epochs = 10

model = Model(
    channels=channels,
    strides=strides_pattern,
    num_res_units=num_res_units,
    lr=learning_rate,
)
datamodule = CopickDataModule(copick_config_path, 1, 1, 16)


# Priotize performace over precision
torch.set_float32_matmul_precision(
    "medium"
)  # or torch.set_float32_matmul_precision('high')

# Trainer for distributed training with DDP
trainer = Trainer(
    max_epochs=num_epochs,
    logger=mlf_logger,
    callbacks=[checkpoint_callback, lr_monitor],
    strategy="ddp_notebook",
    accelerator="gpu",
    devices=devices,
    num_nodes=1,
    log_every_n_steps=1,
    enable_progress_bar=True,
)

trainer.fit(model, datamodule=datamodule)