-
Notifications
You must be signed in to change notification settings - Fork 383
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
92 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
24 changes: 24 additions & 0 deletions
24
burn-import/pytorch-tests/tests/missing_module_field/export_weights.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
class Model(nn.Module): | ||
def __init__(self): | ||
super(Model, self).__init__() | ||
self.conv1 = nn.Conv2d(2, 2, (2,2)) | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
return x | ||
|
||
|
||
def main(): | ||
torch.set_printoptions(precision=8) | ||
torch.manual_seed(1) | ||
model = Model().to(torch.device("cpu")) | ||
torch.save(model.state_dict(), "missing_module_field.pt") | ||
|
||
if __name__ == '__main__': | ||
main() |
Binary file added
BIN
+1.66 KB
burn-import/pytorch-tests/tests/missing_module_field/missing_module_field.pt
Binary file not shown.
30 changes: 30 additions & 0 deletions
30
burn-import/pytorch-tests/tests/missing_module_field/mod.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
use burn::{module::Module, nn::conv::Conv2d, tensor::backend::Backend}; | ||
|
||
#[derive(Module, Debug)] | ||
pub struct Net<B: Backend> { | ||
do_not_exist_in_pt: Conv2d<B>, | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
type Backend = burn_ndarray::NdArray<f32>; | ||
|
||
use burn::record::{FullPrecisionSettings, Recorder}; | ||
use burn_import::pytorch::PyTorchFileRecorder; | ||
|
||
use super::*; | ||
|
||
#[test] | ||
#[should_panic( | ||
expected = "Missing source values for the 'do_not_exist_in_pt' field of type 'Conv2dRecordItem'. Please verify the source data and ensure the field name is correct" | ||
)] | ||
fn should_fail_if_struct_field_is_missing() { | ||
let device = Default::default(); | ||
let _record: NetRecord<Backend> = PyTorchFileRecorder::<FullPrecisionSettings>::default() | ||
.load( | ||
"tests/missing_module_field/missing_module_field.pt".into(), | ||
&device, | ||
) | ||
.expect("Should decode state successfully"); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,5 +17,6 @@ cfg_if::cfg_if! { | |
mod key_remap_chained; | ||
mod layer_norm; | ||
mod linear; | ||
mod missing_module_field; | ||
} | ||
} |