Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regression: cannot load state.json.gz: "Loading error: Can't load Option<Tensor> from NamedState" #213

Closed
antimora opened this issue Mar 8, 2023 · 0 comments · Fixed by #218

Comments

@antimora
Copy link
Collaborator

antimora commented Mar 8, 2023

Describe the bug

thread 'web::tests::inference' panicked at '
State could not be loaded into model: LoadingError { message: "Can't load module conv1: 
Loading error: Can't load module conv: Loading error: Can't load module bias:
 Loading error: Can't load Option<Tensor> from NamedState" }', 
examples/mnist-inference-web/src/state.rs:37:10

To Reproduce
Here is trained model:
model-6.json.gz

Here is the code:

State loading:

use crate::model::Model;

use burn::module::Module;
use burn::module::State;
use burn_ndarray::NdArrayBackend;

pub type Backend = NdArrayBackend<f32>;

const MODEL_STATE_FILE_NAME: &str = "model-6.json.gz";

pub fn build_and_load_model() -> Model<Backend> {

    let model: Model<Backend> = Model::new();

    let state: State<f32> =
        State::load(MODEL_STATE_FILE_NAME).expect(concat!("Model JSON file could not be loaded"));

    model
        .load(&state)
        .expect("State could not be loaded into model")
}

And here is the model:

use alloc::{format, vec::Vec};

use burn::{
    module::{Module, Param},
    nn,
    tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
    conv1: Param<ConvBlock<B>>,
    conv2: Param<ConvBlock<B>>,
    conv3: Param<ConvBlock<B>>,
    dropout: nn::Dropout,
    fc1: Param<nn::Linear<B>>,
    fc2: Param<nn::Linear<B>>,
    activation: nn::GELU,
}

const NUM_CLASSES: usize = 10;

impl<B: Backend> Model<B> {
    pub fn new() -> Self {
        let conv1 = ConvBlock::new([1, 8], [3, 3]); // out: [Batch,1,26,26]
        let conv2 = ConvBlock::new([8, 16], [3, 3]); // out: [Batch,1,24x24]
        let conv3 = ConvBlock::new([16, 24], [3, 3]); // out: [Batch,1,22x22]

        let fc1 = nn::Linear::new(&nn::LinearConfig::new(24 * 22 * 22, 32).with_bias(false));
        let fc2 = nn::Linear::new(&nn::LinearConfig::new(32, NUM_CLASSES).with_bias(false));

        let dropout = nn::Dropout::new(&nn::DropoutConfig::new(0.3));

        Self {
            conv1: Param::from(conv1),
            conv2: Param::from(conv2),
            conv3: Param::from(conv3),
            fc1: Param::from(fc1),
            fc2: Param::from(fc2),
            dropout,
            activation: nn::GELU::new(),
        }
    }

    pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 2> {
        let [batch_size, heigth, width] = input.dims();

        let x = input.reshape([batch_size, 1, heigth, width]).detach();
        let x = self.conv1.forward(x);
        let x = self.conv2.forward(x);
        let x = self.conv3.forward(x);

        let x = x.reshape([batch_size, 24 * 22 * 22]);

        let x = self.fc1.forward(x);
        let x = self.activation.forward(x);
        let x = self.dropout.forward(x);

        let x = self.fc2.forward(x);
        x
    }
}

#[derive(Module, Debug)]
pub struct ConvBlock<B: Backend> {
    conv: Param<nn::conv::Conv2d<B>>,
    activation: nn::GELU,
}

impl<B: Backend> ConvBlock<B> {
    pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self {
        let conv = nn::conv::Conv2d::new(
            &nn::conv::Conv2dConfig::new(channels, kernel_size).with_bias(false),
        );

        Self {
            conv: Param::from(conv),
            activation: nn::GELU::new(),
        }
    }

    pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
        let x = self.conv.forward(input);
        self.activation.forward(x)
    }
}

Expected behavior
No errors. It worked before recent changes.

Additional context

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant