Skip to content

Shape mismatch in a DenseNet121 converted from pytorch through keras. #824

@kidzik

Description

@kidzik

TensorFlow.js version

0.13.2

Browser version

Firefox 62.0.3 64-bit
Chrome Version 68.0.3440.84 (Official Build) (64-bit)
Ubuntu 18.04

Describe the problem or feature request

DenseNet converted from pytorch to keras and then from keras to tfjs throws an error on load in tensorflow-js:
Error: Shape mismatch: [null,1000] vs. [1024,1000]

Code to reproduce the bug / link to feature request

Loading the model on the js side:
import * as tf from '@tensorflow/tfjs';
mobilenet = await tf.loadModel('https://s3-eu-west-1.amazonaws.com/kidzinski/densenet121/model.json');

With package.json:

{
  "name": "tfjs-examples-mobilenet",
  "version": "0.1.0",
  "description": "",
  "main": "index.js",
  "license": "Apache-2.0",
  "private": true,
  "engines": {
    "node": ">=8.9.0"
  },
  "dependencies": {
    "@tensorflow/tfjs": "^0.13.2",
    "express": "^4.16.4"
  },
  "scripts": {
    "watch": "cross-env NODE_ENV=development parcel index.html --no-hmr --open",
    "build": "cross-env NODE_ENV=production parcel build index.html  --no-minify --public-url ./",
    "link-local": "yalc link",
    "assets": "cp -R public/ dist/public/",
    "postinstall": "yarn upgrade --pattern @tensorflow"
  },
  "devDependencies": {
    "babel-plugin-transform-runtime": "~6.23.0",
    "babel-polyfill": "~6.26.0",
    "babel-preset-env": "~1.6.1",
    "clang-format": "~1.2.2",
    "cross-env": "^5.1.6",
    "parcel-bundler": "~1.8.1",
    "yalc": "~1.0.0-pre.22"
  }
}
Conversion from pytorch to keras (how I created the model under the link above)
import torchvision
import numpy as np
import tensorflowjs as tfjs
import torch
from torch.autograd import Variable
from pytorch2keras.converter import pytorch_to_keras

pretrained_model = torchvision.models.densenet121(pretrained = True)
pretrained_model = pretrained_model.float()
img_size = 224
input_np = np.random.uniform(0, 1, (1, 3, img_size, img_size))
input_dummy = Variable(torch.FloatTensor(input_np))
k_model = pytorch_to_keras(pretrained_model, input_dummy, [(3, img_size, img_size,)], verbose=True, names="short")
tfjs.converters.save_keras_model(k_model, "DenseNet121")

With conda environment

# packages in environment at /home/kidzik/anaconda3/envs/tensorflow:
#
# Name                    Version                   Build  Channel
blas                      1.0                         mkl  
ca-certificates           2018.03.07                    0  
certifi                   2018.10.15               py36_0  
cffi                      1.11.5           py36he75722e_1  
cuda92                    1.0                           0    pytorch
cycler                    0.10.0                    <pip>
dill                      0.2.8.2                   <pip>
freetype                  2.9.1                h8a8886c_1  
h5py                      2.8.0                     <pip>
intel-openmp              2019.0                      118  
jpeg                      9b                   h024ee3a_2  
jupyter                   1.0.0                     <pip>
Keras                     2.2.2                     <pip>
Keras-Applications        1.0.4                     <pip>
Keras-Preprocessing       1.0.2                     <pip>
kiwisolver                1.0.1                     <pip>
libedit                   3.1.20170329         h6b74fdf_2  
libffi                    3.2.1                hd88cf55_4  
libgcc-ng                 8.2.0                hdf63c60_1  
libgfortran-ng            7.3.0                hdf63c60_0  
libpng                    1.6.35               hbc83047_0  
libstdcxx-ng              8.2.0                hdf63c60_1  
libtiff                   4.0.9                he85c1e1_2  
matplotlib                3.0.0                     <pip>
mkl                       2019.0                      118  
mkl_fft                   1.0.6            py36h7dd41cf_0  
mkl_random                1.0.1            py36h4414c95_1  
ncurses                   6.1                  hf484d3e_0  
ninja                     1.8.2            py36h6bb024c_1  
numpy                     1.15.3           py36h1d66e8a_0  
numpy                     1.14.1                    <pip>
numpy-base                1.15.3           py36h81de0dd_0  
olefile                   0.46                     py36_0  
onnx                      1.3.0                     <pip>
onnx-tf                   1.2.0                     <pip>
opencv-python             3.4.3.18                  <pip>
openssl                   1.1.1                h7b6447c_0  
pillow                    5.3.0            py36h34e0f95_0  
pip                       10.0.1                   py36_0  
pycparser                 2.19                     py36_0  
pyparsing                 2.2.2                     <pip>
python                    3.6.7                h0371630_0  
pytorch                   0.4.1           py36_cuda9.2.148_cudnn7.1.4_1  [cuda92]  pytorch
pytorch2keras             0.1.5                     <pip>
PyYAML                    3.13                      <pip>
readline                  7.0                  h7b6447c_5  
scikit-learn              0.20.0                    <pip>
scipy                     1.1.0                     <pip>
setuptools                39.1.0                    <pip>
setuptools                40.4.3                   py36_0  
six                       1.11.0                   py36_1  
sqlite                    3.25.2               h7b6447c_0  
tensorboard               1.9.0                     <pip>
tensorflow                1.9.0                     <pip>
tensorflow-hub            0.1.1                     <pip>
tensorflowjs              0.6.4                     <pip>
tk                        8.6.8                hbc83047_0  
torch                     0.4.0                     <pip>
torchvision               0.2.1                    py36_1    pytorch
typing                    3.6.6                     <pip>
typing-extensions         3.6.6                     <pip>
wheel                     0.32.2                   py36_0  
xz                        5.2.4                h14c3975_4  
zlib                      1.2.11               ha838bed_2  

Metadata

Metadata

Assignees

Labels

type:bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions