Skip to content

Commit

Permalink
BUG: Make dict_from_transform more consistent with other dict represe…
Browse files Browse the repository at this point in the history
…ntations

Encapsulate the "transformParameterization", "parametersValueType",
"inputDimension", "outputDimension", which are static, into a
"transformType" member, similar to "imageType", "meshType" for Images,
Meshes. "transformParameterization" is the string name of the ITK transform
class, less the trailing "Transform".

Remove unused numpy import in __setstate__.

Support both a list of transforms and a single transform in
dict_from_transform.

Add more smoke tests in extras.py.
  • Loading branch information
thewtex committed May 1, 2024
1 parent 126e2ec commit 5bb32f4
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 42 deletions.
Expand Up @@ -35,16 +35,13 @@

keys_to_test1 = [
"name",
"parametersValueType",
"transformType",
"inputDimension",
"outputDimension",
"inputSpaceName",
"outputSpaceName",
"numberOfParameters",
"numberOfFixedParameters",
]
keys_to_test2 = ["parameters", "fixedParameters"]
keys_to_test3 = ["transformParameterization", "parametersValueType", "inputDimension", "outputDimension"]

transform_object_list = []
for i, transform_type in enumerate(transforms_to_test):
Expand All @@ -60,6 +57,8 @@
# Test all the parameters
for k in keys_to_test2:
assert np.array_equal(serialize_deserialize[k], transform[k])
for k in keys_to_test3:
assert serialize_deserialize["transformType"][k], transform["transformType"][k]
transform_object_list.append(transform)

print("Individual Transforms Test Done")
Expand Down Expand Up @@ -93,6 +92,9 @@
for k in keys_to_test2:
assert np.array_equal(transform_obj[k], transform_object_list[i][k])

for k in keys_to_test3:
assert serialize_deserialize["transformType"][k], transform["transformType"][k]


# Test for transformation using de-serialized BSpline Transform
ImageDimension = 2
Expand Down
3 changes: 1 addition & 2 deletions Wrapping/Generators/Python/PyBase/pyBase.i
Expand Up @@ -430,7 +430,7 @@ str = str
Return keys related to the transform's metadata.
These keys are used in the dictionary resulting from dict(transform).
"""
result = ['name', 'inputDimension', 'outputDimension', 'inputSpaceName', 'outputSpaceName', 'numberOfParameters', 'numberOfFixedParameters', 'parameters', 'fixedParameters']
result = ['transformType', 'name', 'inputSpaceName', 'outputSpaceName', 'numberOfParameters', 'numberOfFixedParameters', 'parameters', 'fixedParameters']
return result
def __getitem__(self, key):
Expand Down Expand Up @@ -474,7 +474,6 @@ str = str
def __setstate__(self, state):
"""Set object state, necessary for serialization with pickle."""
import itk
import numpy as np
deserialized = itk.transform_from_dict(state)
self.__dict__['this'] = deserialized
%}
Expand Down
5 changes: 5 additions & 0 deletions Wrapping/Generators/Python/Tests/extras.py
Expand Up @@ -343,6 +343,11 @@ def custom_callback(name, progress):
parameters = np.asarray(transforms[0].GetParameters())
assert np.allclose(parameters, np.array(baseline_additional_transform_params))

transform_dict = itk.dict_from_transform(transforms[0])
transform_back = itk.transform_from_dict(transform_dict)
transform_dict = itk.dict_from_transform(transforms)
transform_back = itk.transform_from_dict(transform_dict)

# pipeline, auto_pipeline and templated class are tested in other files

# BridgeNumPy
Expand Down
84 changes: 48 additions & 36 deletions Wrapping/Generators/Python/itk/support/extras.py
Expand Up @@ -981,57 +981,66 @@ def dict_from_pointset(pointset: "itkt.PointSet") -> Dict:
)


def dict_from_transform(transform: "itkt.TransformBase") -> Dict:
def dict_from_transform(transform: Union["itkt.TransformBase", List["itkt.TransformBase"]]) -> List[Dict]:
import itk
datatype_dict = {"double": itk.D, "float": itk.F}

def update_transform_dict(current_transform):
current_transform_type = current_transform.GetTransformTypeAsString()
current_transform_type_split = current_transform_type.split("_")
component = itk.template(current_transform)

in_transform_dict = dict()
in_transform_dict["name"] = current_transform.GetObjectName()
transform_type = dict()
transform_parameterization = current_transform_type_split[0].replace("Transform", "")
transform_type["transformParameterization"] = transform_parameterization

datatype_dict = {"double": itk.D, "float": itk.F}
in_transform_dict["parametersValueType"] = python_to_js(
transform_type["parametersValueType"] = python_to_js(
datatype_dict[current_transform_type_split[1]]
)
in_transform_dict["inputDimension"] = int(current_transform_type_split[2])
in_transform_dict["outputDimension"] = int(current_transform_type_split[3])
in_transform_dict["transformType"] = current_transform_type_split[0]
transform_type["inputDimension"] = int(current_transform_type_split[2])
transform_type["outputDimension"] = int(current_transform_type_split[3])

in_transform_dict["inputSpaceName"] = current_transform.GetInputSpaceName()
in_transform_dict["outputSpaceName"] = current_transform.GetOutputSpaceName()
transform_dict = dict()
transform_dict['transformType'] = transform_type
transform_dict["name"] = current_transform.GetObjectName()

transform_dict["inputSpaceName"] = current_transform.GetInputSpaceName()
transform_dict["outputSpaceName"] = current_transform.GetOutputSpaceName()

# To avoid copying the parameters for the Composite Transform
# as it is a copy of child transforms.
if "Composite" not in current_transform_type_split[0]:
p = np.array(current_transform.GetParameters())
in_transform_dict["parameters"] = p
transform_dict["parameters"] = p

fp = np.array(current_transform.GetFixedParameters())
in_transform_dict["fixedParameters"] = fp
transform_dict["fixedParameters"] = fp

in_transform_dict["numberOfParameters"] = p.shape[0]
in_transform_dict["numberOfFixedParameters"] = fp.shape[0]
transform_dict["numberOfParameters"] = p.shape[0]
transform_dict["numberOfFixedParameters"] = fp.shape[0]

return in_transform_dict
return transform_dict

dict_array = []
transform_type = transform.GetTransformTypeAsString()
if "CompositeTransform" in transform_type:
# Add the transforms inside the composite transform
# range is over-ridden so using this hack to create a list
for i, _ in enumerate([0] * transform.GetNumberOfTransforms()):
current_transform = transform.GetNthTransform(i)
dict_array.append(update_transform_dict(current_transform))
def add_transform_dict(transform):
transform_type = transform.GetTransformTypeAsString()
if "CompositeTransform" in transform_type:
# Add the transforms inside the composite transform
# range is over-ridden so using this hack to create a list
for i, _ in enumerate([0] * transform.GetNumberOfTransforms()):
current_transform = transform.GetNthTransform(i)
dict_array.append(update_transform_dict(current_transform))
else:
dict_array.append(update_transform_dict(transform))
if isinstance(transform, list):
for t in transform:
add_transform_dict(t)
else:
dict_array.append(update_transform_dict(transform))
add_transform_dict(transform)

return dict_array


def transform_from_dict(transform_dict: Dict) -> "itkt.TransformBase":
def transform_from_dict(transform_dict: List[Dict]) -> "itkt.TransformBase":
import itk

def set_parameters(transform, transform_parameters, transform_fixed_parameters, data_type):
Expand All @@ -1058,32 +1067,35 @@ def special_transform_check(transform_name):
# Loop over all the transforms in the dictionary
transforms_list = []
for i, _ in enumerate(transform_dict):
data_type = parametersValueType_dict[transform_dict[i]["parametersValueType"]]
transform_type = transform_dict[i]["transformType"]
data_type = parametersValueType_dict[transform_type["parametersValueType"]]

transform_parameterization = transform_type["transformParameterization"] + 'Transform'

# No template parameter needed for transforms having 2D or 3D name
# Also for some selected transforms
if special_transform_check(transform_dict[i]["transformType"]):
transform_template = getattr(itk, transform_dict[i]["transformType"])
if special_transform_check(transform_parameterization):
transform_template = getattr(itk, transform_parameterization)
transform = transform_template[data_type].New()
# Currently only BSpline Transform has 3 template parameters
# For future extensions the information will have to be encoded in
# the transformType variable. The transform object once added in a
# composite transform lose the information for other template parameters ex. BSpline.
# The Spline order is fixed as 3 here.
elif transform_dict[i]["transformType"] == "BSplineTransform":
transform_template = getattr(itk, transform_dict[i]["transformType"])
elif transform_parameterization == "BSplineTransform":
transform_template = getattr(itk, transform_parameterization)
transform = transform_template[
data_type, transform_dict[i]["inputDimension"], 3
data_type, transform_type["inputDimension"], 3
].New()
else:
transform_template = getattr(itk, transform_dict[i]["transformType"])
transform_template = getattr(itk, transform_parameterization)
if len(transform_template.items()[0][0]) > 2:
transform = transform_template[
data_type, transform_dict[i]["inputDimension"], transform_dict[i]["outputDimension"]
data_type, transform_type["inputDimension"], transform_type["outputDimension"]
].New()
else:
transform = transform_template[
data_type, transform_dict[i]["inputDimension"]
data_type, transform_type["inputDimension"]
].New()

transform.SetObjectName(transform_dict[i]["name"])
Expand All @@ -1102,8 +1114,8 @@ def special_transform_check(transform_name):
if len(transforms_list) > 1:
# Create a Composite Transform object
# and add all the transforms in it.
data_type = parametersValueType_dict[transform_dict[0]["parametersValueType"]]
transform = itk.CompositeTransform[data_type, transforms_list[0]['inputDimension']].New()
data_type = parametersValueType_dict[transform_dict[0]["transformType"]["parametersValueType"]]
transform = itk.CompositeTransform[data_type, transforms_list[0]["transformType"]['inputDimension']].New()
for current_transform in transforms_list:
transform.AddTransform(current_transform)
else:
Expand Down

0 comments on commit 5bb32f4

Please sign in to comment.