From b5b0d4153798861e8536807b477155a1877862be Mon Sep 17 00:00:00 2001 From: Stephan Seitz Date: Thu, 30 Apr 2026 16:41:45 +0200 Subject: [PATCH 1/3] feat: add iterators for NetworkDefinition::{inputs,outputs} --- trtx/src/network.rs | 114 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/trtx/src/network.rs b/trtx/src/network.rs index dc0f6bb..0ab3864 100644 --- a/trtx/src/network.rs +++ b/trtx/src/network.rs @@ -1357,6 +1357,24 @@ impl<'network> NetworkDefinition<'network> { self.output(index) } + /// Returns an iterator over the network's input tensors. + pub fn inputs(&self) -> NetworkInputIter<'_, 'network> { + NetworkInputIter { + network: self, + index: 0, + count: self.nb_inputs(), + } + } + + /// Returns an iterator over the network's output tensors. + pub fn outputs(&self) -> NetworkOutputIter<'_, 'network> { + NetworkOutputIter { + network: self, + index: 0, + count: self.nb_outputs(), + } + } + /// Number of layers in the network (for introspection/dumping). /// See [INetworkDefinition::getNbLayers] pub fn nb_layers(&self) -> i32 { @@ -2386,6 +2404,60 @@ impl<'network> NetworkDefinition<'network> { } } +// --- Network input/output iterators --- + +/// Iterator over a [`NetworkDefinition`]'s input tensors. Created by [`NetworkDefinition::inputs`]. +pub struct NetworkInputIter<'a, 'network> { + network: &'a NetworkDefinition<'network>, + index: i32, + count: i32, +} + +impl<'network> Iterator for NetworkInputIter<'_, 'network> { + type Item = Tensor<'network>; + fn next(&mut self) -> Option { + if self.index >= self.count { + return None; + } + let tensor = self.network.input(self.index).expect("valid input index"); + self.index += 1; + Some(tensor) + } + + fn size_hint(&self) -> (usize, Option) { + let remaining = (self.count - self.index).max(0) as usize; + (remaining, Some(remaining)) + } +} + +impl ExactSizeIterator for NetworkInputIter<'_, '_> {} + +/// Iterator over a [`NetworkDefinition`]'s output tensors. Created by [`NetworkDefinition::outputs`]. +pub struct NetworkOutputIter<'a, 'network> { + network: &'a NetworkDefinition<'network>, + index: i32, + count: i32, +} + +impl<'network> Iterator for NetworkOutputIter<'_, 'network> { + type Item = Tensor<'network>; + fn next(&mut self) -> Option { + if self.index >= self.count { + return None; + } + let tensor = self.network.output(self.index).expect("valid output index"); + self.index += 1; + Some(tensor) + } + + fn size_hint(&self) -> (usize, Option) { + let remaining = (self.count - self.index).max(0) as usize; + (remaining, Some(remaining)) + } +} + +impl ExactSizeIterator for NetworkOutputIter<'_, '_> {} + // --- IAttention --- impl<'network> Attention<'network> { @@ -3074,4 +3146,46 @@ mod test { ); assert_eq!(&network.layer(1).unwrap().name(&network), "Eva"); } + + #[test] + #[cfg(not(feature = "mock"))] + fn test_inputs_outputs_iter() { + let logger = Logger::stderr().unwrap(); + let mut builder = Builder::new(&logger).unwrap(); + let mut network = builder.create_network(0).unwrap(); + let a = network + .add_input("input_a", trtx_sys::DataType::kFLOAT, &[1]) + .unwrap(); + let b = network + .add_input("input_b", trtx_sys::DataType::kFLOAT, &[1]) + .unwrap(); + let out = network + .add_elementwise(&a, &b, trtx_sys::ElementWiseOperation::kSUM) + .unwrap() + .output(&network, 0) + .unwrap(); + out.set_name(&mut network, "output_c").unwrap(); + network.mark_output(&out); + + let input_names: Vec<_> = network + .inputs() + .map(|t| t.name(&network).unwrap()) + .collect(); + assert_eq!(input_names, ["input_a", "input_b"]); + assert_eq!(network.inputs().len(), 2); + + let output_names: Vec<_> = network + .outputs() + .map(|t| t.name(&network).unwrap()) + .collect(); + assert_eq!(output_names, ["output_c"]); + assert_eq!(network.outputs().len(), 1); + + // equivalent to the old loop pattern + let mut old_style = Vec::new(); + for i in 0..network.nb_inputs() { + old_style.push(network.input(i).unwrap().name(&network).unwrap()); + } + assert_eq!(input_names, old_style); + } } From 018612a9b319586d076a079f90e07db4a39dbb9f Mon Sep 17 00:00:00 2001 From: Stephan Seitz Date: Thu, 30 Apr 2026 16:44:09 +0200 Subject: [PATCH 2/3] feat: add iterator for NetworkDefinition::layers --- trtx/src/network.rs | 58 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/trtx/src/network.rs b/trtx/src/network.rs index 0ab3864..e431fda 100644 --- a/trtx/src/network.rs +++ b/trtx/src/network.rs @@ -1391,6 +1391,15 @@ impl<'network> NetworkDefinition<'network> { DynLayer::new_dyn(self.inner.as_ptr(), layer_ptr) } + /// Returns an iterator over the network's layers. + pub fn layers(&self) -> NetworkLayerIter<'_, 'network> { + NetworkLayerIter { + network: self, + index: 0, + count: self.nb_layers(), + } + } + #[deprecated = "use layer instead"] pub fn get_layer(&self, layer_index: i32) -> Result> { self.layer(layer_index) @@ -2458,6 +2467,32 @@ impl<'network> Iterator for NetworkOutputIter<'_, 'network> { impl ExactSizeIterator for NetworkOutputIter<'_, '_> {} +/// Iterator over a [`NetworkDefinition`]'s layers. Created by [`NetworkDefinition::layers`]. +pub struct NetworkLayerIter<'a, 'network> { + network: &'a NetworkDefinition<'network>, + index: i32, + count: i32, +} + +impl<'network> Iterator for NetworkLayerIter<'_, 'network> { + type Item = DynLayer<'network>; + fn next(&mut self) -> Option { + if self.index >= self.count { + return None; + } + let layer = self.network.layer(self.index).expect("valid layer index"); + self.index += 1; + Some(layer) + } + + fn size_hint(&self) -> (usize, Option) { + let remaining = (self.count - self.index).max(0) as usize; + (remaining, Some(remaining)) + } +} + +impl ExactSizeIterator for NetworkLayerIter<'_, '_> {} + // --- IAttention --- impl<'network> Attention<'network> { @@ -3188,4 +3223,27 @@ mod test { } assert_eq!(input_names, old_style); } + + #[test] + #[cfg(not(feature = "mock"))] + fn test_layers_iter() { + use trtx_sys::LayerType; + + let logger = Logger::stderr().unwrap(); + let mut builder = Builder::new(&logger).unwrap(); + let mut network = builder.create_network(0).unwrap(); + let input = network + .add_input("a", trtx_sys::DataType::kFLOAT, &[1]) + .unwrap(); + network + .add_activation(&input, trtx_sys::ActivationType::kRELU) + .unwrap(); + network + .add_activation(&input, trtx_sys::ActivationType::kSIGMOID) + .unwrap(); + + assert_eq!(network.layers().len(), 2); + let types: Vec<_> = network.layers().map(|l| l.layer_type_dynamic()).collect(); + assert_eq!(types, [LayerType::kACTIVATION, LayerType::kACTIVATION]); + } } From d9db95fa47bbe850f034369ff0c60093a3ac667a Mon Sep 17 00:00:00 2001 From: Stephan Seitz Date: Thu, 30 Apr 2026 16:53:10 +0200 Subject: [PATCH 3/3] feat: iterators for `CudaEngine::{input,output,io}_tensor_names` --- trtx/src/cuda_engine.rs | 95 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/trtx/src/cuda_engine.rs b/trtx/src/cuda_engine.rs index d9a908e..ba7333f 100644 --- a/trtx/src/cuda_engine.rs +++ b/trtx/src/cuda_engine.rs @@ -378,8 +378,60 @@ impl<'engine> CudaEngine<'engine> { let name_cstr = std::ffi::CString::new(name)?; Ok(unsafe { self.inner.isShapeInferenceIO(name_cstr.as_ptr()) }) } + + /// Returns an iterator over all IO tensor names. + pub fn io_tensor_names(&self) -> Result> { + Ok(CudaEngineIoTensorNamesIter { + engine: self, + index: 0, + count: self.nb_io_tensors()?, + }) + } + + /// Returns an iterator over input tensor names. + pub fn input_tensor_names(&self) -> Result + '_> { + Ok(self + .io_tensor_names()? + .filter(|name| self.tensor_io_mode(name).ok() == Some(TensorIOMode::kINPUT))) + } + + /// Returns an iterator over output tensor names. + pub fn output_tensor_names(&self) -> Result + '_> { + Ok(self + .io_tensor_names()? + .filter(|name| self.tensor_io_mode(name).ok() == Some(TensorIOMode::kOUTPUT))) + } +} + +/// Iterator over [`CudaEngine`] IO tensor names. Created by [`CudaEngine::io_tensor_names`]. +pub struct CudaEngineIoTensorNamesIter<'a> { + engine: &'a CudaEngine<'a>, + index: i32, + count: i32, +} + +impl Iterator for CudaEngineIoTensorNamesIter<'_> { + type Item = String; + fn next(&mut self) -> Option { + if self.index >= self.count { + return None; + } + let name = self + .engine + .io_tensor_name(self.index) + .expect("valid tensor index"); + self.index += 1; + Some(name) + } + + fn size_hint(&self) -> (usize, Option) { + let remaining = (self.count - self.index).max(0) as usize; + (remaining, Some(remaining)) + } } +impl ExactSizeIterator for CudaEngineIoTensorNamesIter<'_> {} + #[cfg(test)] #[cfg(not(feature = "mock_runtime"))] mod tests { @@ -405,6 +457,7 @@ mod tests { .unwrap() .output(&network, 0) .unwrap(); + tensor.set_name(&mut network, "output").unwrap(); network.mark_output(&tensor); let mut config = builder.create_config()?; @@ -415,6 +468,48 @@ mod tests { Ok(engine_data.to_vec()) } + #[test] + fn input_output_tensor_names_iter() { + let logger = Logger::stderr().expect("logger"); + let engine_data = + build_minimal_engine_with_verbose_profiling(&logger).expect("build engine"); + let mut runtime = Runtime::new(&logger).expect("runtime"); + let engine = runtime + .deserialize_cuda_engine(&engine_data) + .expect("deserialize"); + + let inputs: Vec<_> = engine.input_tensor_names().unwrap().collect(); + let outputs: Vec<_> = engine.output_tensor_names().unwrap().collect(); + + assert_eq!(inputs, ["input"]); + assert_eq!(outputs, ["output"]); + assert_eq!( + inputs.len() + outputs.len(), + engine.nb_io_tensors().unwrap() as usize + ); + } + + #[test] + fn io_tensor_names_iter() { + let logger = Logger::stderr().expect("logger"); + let engine_data = + build_minimal_engine_with_verbose_profiling(&logger).expect("build engine"); + let mut runtime = Runtime::new(&logger).expect("runtime"); + let engine = runtime + .deserialize_cuda_engine(&engine_data) + .expect("deserialize"); + + let names: Vec<_> = engine.io_tensor_names().unwrap().collect(); + assert_eq!(engine.io_tensor_names().unwrap().len(), names.len()); + + // equivalent to the old loop pattern + let mut old_style = Vec::new(); + for i in 0..engine.nb_io_tensors().unwrap() { + old_style.push(engine.io_tensor_name(i).unwrap()); + } + assert_eq!(names, old_style); + } + #[test] fn engine_inspector_json_verbose_profiling() { let logger = Logger::stderr().expect("logger");