You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
usecrate::model::Model;use burn::module::Module;use burn::module::State;use burn_ndarray::NdArrayBackend;pubtypeBackend = NdArrayBackend<f32>;constMODEL_STATE_FILE_NAME:&str = "model-6.json.gz";pubfnbuild_and_load_model() -> Model<Backend>{let model:Model<Backend> = Model::new();let state:State<f32> =
State::load(MODEL_STATE_FILE_NAME).expect(concat!("Model JSON file could not be loaded"));
model
.load(&state).expect("State could not be loaded into model")}
And here is the model:
use alloc::{format, vec::Vec};use burn::{
module::{Module,Param},
nn,
tensor::{backend::Backend,Tensor},};#[derive(Module,Debug)]pubstructModel<B:Backend>{conv1:Param<ConvBlock<B>>,conv2:Param<ConvBlock<B>>,conv3:Param<ConvBlock<B>>,dropout: nn::Dropout,fc1:Param<nn::Linear<B>>,fc2:Param<nn::Linear<B>>,activation: nn::GELU,}constNUM_CLASSES:usize = 10;impl<B:Backend>Model<B>{pubfnnew() -> Self{let conv1 = ConvBlock::new([1,8],[3,3]);// out: [Batch,1,26,26]let conv2 = ConvBlock::new([8,16],[3,3]);// out: [Batch,1,24x24]let conv3 = ConvBlock::new([16,24],[3,3]);// out: [Batch,1,22x22]let fc1 = nn::Linear::new(&nn::LinearConfig::new(24*22*22,32).with_bias(false));let fc2 = nn::Linear::new(&nn::LinearConfig::new(32,NUM_CLASSES).with_bias(false));let dropout = nn::Dropout::new(&nn::DropoutConfig::new(0.3));Self{conv1:Param::from(conv1),conv2:Param::from(conv2),conv3:Param::from(conv3),fc1:Param::from(fc1),fc2:Param::from(fc2),
dropout,activation: nn::GELU::new(),}}pubfnforward(&self,input:Tensor<B,3>) -> Tensor<B,2>{let[batch_size, heigth, width] = input.dims();let x = input.reshape([batch_size,1, heigth, width]).detach();let x = self.conv1.forward(x);let x = self.conv2.forward(x);let x = self.conv3.forward(x);let x = x.reshape([batch_size,24*22*22]);let x = self.fc1.forward(x);let x = self.activation.forward(x);let x = self.dropout.forward(x);let x = self.fc2.forward(x);
x
}}#[derive(Module,Debug)]pubstructConvBlock<B:Backend>{conv:Param<nn::conv::Conv2d<B>>,activation: nn::GELU,}impl<B:Backend>ConvBlock<B>{pubfnnew(channels:[usize;2],kernel_size:[usize;2]) -> Self{let conv = nn::conv::Conv2d::new(&nn::conv::Conv2dConfig::new(channels, kernel_size).with_bias(false),);Self{conv:Param::from(conv),activation: nn::GELU::new(),}}pubfnforward(&self,input:Tensor<B,4>) -> Tensor<B,4>{let x = self.conv.forward(input);self.activation.forward(x)}}
Expected behavior
No errors. It worked before recent changes.
Additional context
The text was updated successfully, but these errors were encountered:
Describe the bug
To Reproduce
Here is trained model:
model-6.json.gz
Here is the code:
State loading:
And here is the model:
Expected behavior
No errors. It worked before recent changes.
Additional context
The text was updated successfully, but these errors were encountered: