Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to print tensor keys (before and after remap) for PyTorchFileRecorder #1416

Closed
antimora opened this issue Mar 5, 2024 · 2 comments · Fixed by #1425
Closed

Ability to print tensor keys (before and after remap) for PyTorchFileRecorder #1416

antimora opened this issue Mar 5, 2024 · 2 comments · Fixed by #1425
Assignees
Labels
enhancement Enhance existing features

Comments

@antimora
Copy link
Collaborator

antimora commented Mar 5, 2024

Feature description

Ability to print tensor keys (before and after remap) for PyTorchFileRecorder. Something like this:
#1413 (comment)

Feature motivation

It's a blind exercise to try to remap the keys. It'd be helpful if it's see before and after remap.

Suggest a Solution

Pass a debug key to PyTorchFileRecorder.

@antimora
Copy link
Collaborator Author

antimora commented Mar 5, 2024

Here is debug hack:

[burn]% git diff
diff --git a/crates/burn-import/src/pytorch/reader.rs b/crates/burn-import/src/pytorch/reader.rs
index 61e62e8a..6e575575 100644
--- a/crates/burn-import/src/pytorch/reader.rs
+++ b/crates/burn-import/src/pytorch/reader.rs
@@ -47,8 +47,18 @@ where
         .map(|(key, tensor)| (key, CandleTensor(tensor)))
         .collect();

+    println!("SOURCE KEYS");
+    let mut keys: Vec<String> = tensors.keys().cloned().collect();
+    keys.sort();
+    keys.iter().for_each(|k| println!("{}", k));
+
     // Remap the keys (replace the keys in the map with the new keys)
     let tensors = remap(tensors, key_remap);
+    println!("==================================================================================================");
+    println!("REMAPPED KEYS");
+    let mut keys: Vec<String> = tensors.keys().cloned().collect();
+    keys.sort();
+    keys.iter().for_each(|k| println!("{}", k));

     // Convert the vector of Candle tensors to a nested value data structure
     let nested_value = unflatten::<PS, _>(tensors)?;
[burn]%

@antimora antimora added the enhancement Enhance existing features label Mar 5, 2024
@antimora
Copy link
Collaborator Author

antimora commented Mar 6, 2024

I submitted a PR (#1425) to improve the usability and debugging. Now, you can print key/remapped keys/tensor shape and dtypes.

@antimora antimora self-assigned this Mar 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Enhance existing features
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant