You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[burn]% git diff
diff --git a/crates/burn-import/src/pytorch/reader.rs b/crates/burn-import/src/pytorch/reader.rs
index 61e62e8a..6e575575 100644
--- a/crates/burn-import/src/pytorch/reader.rs+++ b/crates/burn-import/src/pytorch/reader.rs@@ -47,8 +47,18 @@ where
.map(|(key, tensor)| (key, CandleTensor(tensor)))
.collect();
+ println!("SOURCE KEYS");+ let mut keys: Vec<String> = tensors.keys().cloned().collect();+ keys.sort();+ keys.iter().for_each(|k| println!("{}", k));+
// Remap the keys (replace the keys in the map with the new keys)
let tensors = remap(tensors, key_remap);
+ println!("==================================================================================================");+ println!("REMAPPED KEYS");+ let mut keys: Vec<String> = tensors.keys().cloned().collect();+ keys.sort();+ keys.iter().for_each(|k| println!("{}", k));
// Convert the vector of Candle tensors to a nested value data structure
let nested_value = unflatten::<PS, _>(tensors)?;
[burn]%
Feature description
Ability to print tensor keys (before and after remap) for PyTorchFileRecorder. Something like this:
#1413 (comment)
Feature motivation
It's a blind exercise to try to remap the keys. It'd be helpful if it's see before and after remap.
Suggest a Solution
Pass a debug key to PyTorchFileRecorder.
The text was updated successfully, but these errors were encountered: