Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions trtx/src/cuda_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,60 @@ impl<'engine> CudaEngine<'engine> {
let name_cstr = std::ffi::CString::new(name)?;
Ok(unsafe { self.inner.isShapeInferenceIO(name_cstr.as_ptr()) })
}

/// Returns an iterator over all IO tensor names.
pub fn io_tensor_names(&self) -> Result<CudaEngineIoTensorNamesIter<'_>> {
Ok(CudaEngineIoTensorNamesIter {
engine: self,
index: 0,
count: self.nb_io_tensors()?,
})
}

/// Returns an iterator over input tensor names.
pub fn input_tensor_names(&self) -> Result<impl Iterator<Item = String> + '_> {
Ok(self
.io_tensor_names()?
.filter(|name| self.tensor_io_mode(name).ok() == Some(TensorIOMode::kINPUT)))
}

/// Returns an iterator over output tensor names.
pub fn output_tensor_names(&self) -> Result<impl Iterator<Item = String> + '_> {
Ok(self
.io_tensor_names()?
.filter(|name| self.tensor_io_mode(name).ok() == Some(TensorIOMode::kOUTPUT)))
}
}

/// Iterator over [`CudaEngine`] IO tensor names. Created by [`CudaEngine::io_tensor_names`].
pub struct CudaEngineIoTensorNamesIter<'a> {
engine: &'a CudaEngine<'a>,
index: i32,
count: i32,
}

impl Iterator for CudaEngineIoTensorNamesIter<'_> {
type Item = String;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.count {
return None;
}
let name = self
.engine
.io_tensor_name(self.index)
.expect("valid tensor index");
self.index += 1;
Some(name)
}

fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = (self.count - self.index).max(0) as usize;
(remaining, Some(remaining))
}
}

impl ExactSizeIterator for CudaEngineIoTensorNamesIter<'_> {}

#[cfg(test)]
#[cfg(not(feature = "mock_runtime"))]
mod tests {
Expand All @@ -405,6 +457,7 @@ mod tests {
.unwrap()
.output(&network, 0)
.unwrap();
tensor.set_name(&mut network, "output").unwrap();
network.mark_output(&tensor);

let mut config = builder.create_config()?;
Expand All @@ -415,6 +468,48 @@ mod tests {
Ok(engine_data.to_vec())
}

#[test]
fn input_output_tensor_names_iter() {
let logger = Logger::stderr().expect("logger");
let engine_data =
build_minimal_engine_with_verbose_profiling(&logger).expect("build engine");
let mut runtime = Runtime::new(&logger).expect("runtime");
let engine = runtime
.deserialize_cuda_engine(&engine_data)
.expect("deserialize");

let inputs: Vec<_> = engine.input_tensor_names().unwrap().collect();
let outputs: Vec<_> = engine.output_tensor_names().unwrap().collect();

assert_eq!(inputs, ["input"]);
assert_eq!(outputs, ["output"]);
assert_eq!(
inputs.len() + outputs.len(),
engine.nb_io_tensors().unwrap() as usize
);
}

#[test]
fn io_tensor_names_iter() {
let logger = Logger::stderr().expect("logger");
let engine_data =
build_minimal_engine_with_verbose_profiling(&logger).expect("build engine");
let mut runtime = Runtime::new(&logger).expect("runtime");
let engine = runtime
.deserialize_cuda_engine(&engine_data)
.expect("deserialize");

let names: Vec<_> = engine.io_tensor_names().unwrap().collect();
assert_eq!(engine.io_tensor_names().unwrap().len(), names.len());

// equivalent to the old loop pattern
let mut old_style = Vec::new();
for i in 0..engine.nb_io_tensors().unwrap() {
old_style.push(engine.io_tensor_name(i).unwrap());
}
assert_eq!(names, old_style);
}

#[test]
fn engine_inspector_json_verbose_profiling() {
let logger = Logger::stderr().expect("logger");
Expand Down
172 changes: 172 additions & 0 deletions trtx/src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1357,6 +1357,24 @@ impl<'network> NetworkDefinition<'network> {
self.output(index)
}

/// Returns an iterator over the network's input tensors.
pub fn inputs(&self) -> NetworkInputIter<'_, 'network> {
NetworkInputIter {
network: self,
index: 0,
count: self.nb_inputs(),
}
}

/// Returns an iterator over the network's output tensors.
pub fn outputs(&self) -> NetworkOutputIter<'_, 'network> {
NetworkOutputIter {
network: self,
index: 0,
count: self.nb_outputs(),
}
}

/// Number of layers in the network (for introspection/dumping).
/// See [INetworkDefinition::getNbLayers]
pub fn nb_layers(&self) -> i32 {
Expand All @@ -1373,6 +1391,15 @@ impl<'network> NetworkDefinition<'network> {
DynLayer::new_dyn(self.inner.as_ptr(), layer_ptr)
}

/// Returns an iterator over the network's layers.
pub fn layers(&self) -> NetworkLayerIter<'_, 'network> {
NetworkLayerIter {
network: self,
index: 0,
count: self.nb_layers(),
}
}

#[deprecated = "use layer instead"]
pub fn get_layer(&self, layer_index: i32) -> Result<DynLayer<'network>> {
self.layer(layer_index)
Expand Down Expand Up @@ -2386,6 +2413,86 @@ impl<'network> NetworkDefinition<'network> {
}
}

// --- Network input/output iterators ---

/// Iterator over a [`NetworkDefinition`]'s input tensors. Created by [`NetworkDefinition::inputs`].
pub struct NetworkInputIter<'a, 'network> {
network: &'a NetworkDefinition<'network>,
index: i32,
count: i32,
}

impl<'network> Iterator for NetworkInputIter<'_, 'network> {
type Item = Tensor<'network>;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.count {
return None;
}
let tensor = self.network.input(self.index).expect("valid input index");
self.index += 1;
Some(tensor)
}

fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = (self.count - self.index).max(0) as usize;
(remaining, Some(remaining))
}
}

impl ExactSizeIterator for NetworkInputIter<'_, '_> {}

/// Iterator over a [`NetworkDefinition`]'s output tensors. Created by [`NetworkDefinition::outputs`].
pub struct NetworkOutputIter<'a, 'network> {
network: &'a NetworkDefinition<'network>,
index: i32,
count: i32,
}

impl<'network> Iterator for NetworkOutputIter<'_, 'network> {
type Item = Tensor<'network>;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.count {
return None;
}
let tensor = self.network.output(self.index).expect("valid output index");
self.index += 1;
Some(tensor)
}

fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = (self.count - self.index).max(0) as usize;
(remaining, Some(remaining))
}
}

impl ExactSizeIterator for NetworkOutputIter<'_, '_> {}

/// Iterator over a [`NetworkDefinition`]'s layers. Created by [`NetworkDefinition::layers`].
pub struct NetworkLayerIter<'a, 'network> {
network: &'a NetworkDefinition<'network>,
index: i32,
count: i32,
}

impl<'network> Iterator for NetworkLayerIter<'_, 'network> {
type Item = DynLayer<'network>;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.count {
return None;
}
let layer = self.network.layer(self.index).expect("valid layer index");
self.index += 1;
Some(layer)
}

fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = (self.count - self.index).max(0) as usize;
(remaining, Some(remaining))
}
}

impl ExactSizeIterator for NetworkLayerIter<'_, '_> {}

// --- IAttention ---

impl<'network> Attention<'network> {
Expand Down Expand Up @@ -3074,4 +3181,69 @@ mod test {
);
assert_eq!(&network.layer(1).unwrap().name(&network), "Eva");
}

#[test]
#[cfg(not(feature = "mock"))]
fn test_inputs_outputs_iter() {
let logger = Logger::stderr().unwrap();
let mut builder = Builder::new(&logger).unwrap();
let mut network = builder.create_network(0).unwrap();
let a = network
.add_input("input_a", trtx_sys::DataType::kFLOAT, &[1])
.unwrap();
let b = network
.add_input("input_b", trtx_sys::DataType::kFLOAT, &[1])
.unwrap();
let out = network
.add_elementwise(&a, &b, trtx_sys::ElementWiseOperation::kSUM)
.unwrap()
.output(&network, 0)
.unwrap();
out.set_name(&mut network, "output_c").unwrap();
network.mark_output(&out);

let input_names: Vec<_> = network
.inputs()
.map(|t| t.name(&network).unwrap())
.collect();
assert_eq!(input_names, ["input_a", "input_b"]);
assert_eq!(network.inputs().len(), 2);

let output_names: Vec<_> = network
.outputs()
.map(|t| t.name(&network).unwrap())
.collect();
assert_eq!(output_names, ["output_c"]);
assert_eq!(network.outputs().len(), 1);

// equivalent to the old loop pattern
let mut old_style = Vec::new();
for i in 0..network.nb_inputs() {
old_style.push(network.input(i).unwrap().name(&network).unwrap());
}
assert_eq!(input_names, old_style);
}

#[test]
#[cfg(not(feature = "mock"))]
fn test_layers_iter() {
use trtx_sys::LayerType;

let logger = Logger::stderr().unwrap();
let mut builder = Builder::new(&logger).unwrap();
let mut network = builder.create_network(0).unwrap();
let input = network
.add_input("a", trtx_sys::DataType::kFLOAT, &[1])
.unwrap();
network
.add_activation(&input, trtx_sys::ActivationType::kRELU)
.unwrap();
network
.add_activation(&input, trtx_sys::ActivationType::kSIGMOID)
.unwrap();

assert_eq!(network.layers().len(), 2);
let types: Vec<_> = network.layers().map(|l| l.layer_type_dynamic()).collect();
assert_eq!(types, [LayerType::kACTIVATION, LayerType::kACTIVATION]);
}
}
Loading