# Installation
install torch for you GPU/MPS/CPU https://pytorch.org/get-started/locally/



In [None]:
try:
    import deepali
    import TPTBox
except Exception:
    %pip install TPTBox ruamel.yaml configargparse
    %pip install hf-deepali
    %pip install nnunetv2
from pathlib import Path
from typing import Literal

import torch
from TPTBox import NII, POI_Global, to_nii
from TPTBox.core.vert_constants import Full_Body_Instance


In [None]:
#############################################
################# Input #####################
#############################################

# New image
target_ct_file:str|Path|None = None
target_seg_file:str|Path = "data/sub-right_seg-TotalVibe-12_msk.nii.gz"
target_out_poi:str|Path = "data/seg-right_desc-atlas_poi.json"# output
target_out_subdivided:str|Path = "data/seg-right_desc-atlas_msk.nii.gz" #output
mirror = "right" in str(target_seg_file).lower() #False # Set True for mapping a left to a right structure. Also update mapping_mirror, if you have costume segmentation


# Atlas
atlas_seg_file:str|Path = "data/sub-atlas_seg-TotalVibe-12_msk.nii.gz" # default is a left leg
atlas_file:str|Path = "data/sub-atlas_seg-poi_poi.json"
atlas_seg_subdivided_file:str|Path|None =None

# Device settings

ddevice: Literal["cpu", "cuda", "mps"] = "cuda"
gpu=0 # Only used for cuda


In [None]:
def resolve_device(
    ddevice: Literal["cpu", "cuda", "mps"],
    gpu: int = 0,
) -> torch.device:
    if ddevice == "cuda":
        if torch.cuda.is_available():
            torch.cuda.set_device(gpu)
            return torch.device(f"cuda:{gpu}")
        else:
            print("⚠️ CUDA requested but not available → falling back to CPU")

    if ddevice == "mps":
        if torch.backends.mps.is_available() and torch.backends.mps.is_built():
            return torch.device("mps")
        else:
            print("⚠️ MPS requested but not available → falling back to CPU")

    return torch.device("cpu")


device = resolve_device(ddevice, gpu)
print(f"✅ Using device: {device}")

In [None]:
# Parameters
lr: float = 0.001
max_steps: int = 1500
min_delta: float = 0.000001
pyramid_levels: int = 4
coarsest_level: int = 3
finest_level: int = 0
weights: dict = {"be": 0.00001, "seg": 1, "Dice": 0.01, "Tether": 0.001}


# IDs from the segmentation model, for you own segmentation update it with the label ID you want to use.
leg_ids = [
    Full_Body_Instance.femur_left,#<- you can use integer for the labels here.
    Full_Body_Instance.patella_left,
    Full_Body_Instance.tibia_left,
    Full_Body_Instance.fibula_left,
]
# Mirrors the ids for left and right legs. Update for you own. If not needed, make an empty dict.
mapping_mirror = {
    Full_Body_Instance.femur_right.value: Full_Body_Instance.femur_left.value,
    Full_Body_Instance.patella_right.value: Full_Body_Instance.patella_left.value,
    Full_Body_Instance.tibia_right.value: Full_Body_Instance.tibia_left.value,
    Full_Body_Instance.fibula_right.value: Full_Body_Instance.fibula_left.value,
    Full_Body_Instance.femur_left.value: Full_Body_Instance.femur_right.value,
    Full_Body_Instance.patella_left.value: Full_Body_Instance.patella_right.value,
    Full_Body_Instance.tibia_left.value: Full_Body_Instance.tibia_right.value,
    Full_Body_Instance.fibula_left.value: Full_Body_Instance.fibula_right.value,
}

## Generate Segmentation


In [None]:
if not Path(target_seg_file).exists():
    assert target_ct_file is not None, "Provide either a CT or a segmentation."
    target_ct_file = Path(target_ct_file)
    assert target_ct_file.exists(), "Provide either a CT or a segmentation."
    from TPTBox.segmentation import run_vibeseg
    # FIXME Check if model 12 is now available
    run_vibeseg(target_ct_file,target_seg_file,gpu=0,ddevice="cuda",dataset_id=11)

In [None]:
# assert not mirror
from TPTBox.registration import Template_Registration
# load
moving_img = to_nii(atlas_seg_file, True)
target = to_nii(target_seg_file,True)

# change label for mirroring
if mirror:
    target = target.map_labels(mapping_mirror)
# Limit to only used labels
print(target.unique()      ,moving_img.unique())
seg = target.extract_label(leg_ids, True)
print("unique",seg.unique()      ,moving_img.unique())

# Run Template_Registration
reg = Template_Registration(
    seg,  # Target segmentation
    moving_img.extract_label(leg_ids, True),  # Starting Atlas Segmentation (not the split one)
    same_side=not mirror,
    lr=lr,
    max_steps=max_steps,
    min_delta=min_delta,
    pyramid_levels=pyramid_levels,
    coarsest_level=coarsest_level,
    finest_level=finest_level,
    # loss_terms=loss_terms,
    # poi_target_cms=None,
    # poi_cms=poi_atlas_cms,  # Can be None, than it will be computed automatically
    weights=weights,
    gpu=0,
    ddevice=ddevice
)
# Transfer atlas to target
print("Transfer atlas to target")
poi_in = POI_Global.load(atlas_file)
atlas_reg = reg.transform_poi(poi_in)  # Transferring the atlas points
atlas_reg.info = poi_in.info
atlas_reg.to_global().save_mrk(target_out_poi)
atlas_reg.save(target_out_poi)
if atlas_seg_subdivided_file is not None:
    n = reg.transform_nii(to_nii(atlas_seg_subdivided_file,True)) # Transferring the atlas subdivisions
    n.save(target_out_subdivided)


In [None]:
from treg.angle import compute_angles
print(compute_angles(atlas_reg.to_global(), "data/seg-right_desc-atlas_angle.json"))