Skip to content

[tfjs-react-native] tf.loadGraphModel with bundled files slow  #5475

@Caundy

Description

@Caundy

Packages installed

  • "react-native": "0.63.4",
  • "@tensorflow/tfjs": "^3.8.0",
  • "@tensorflow/tfjs-automl": "^1.2.0",
  • "@tensorflow/tfjs-react-native": "^0.6.0",
  • "expo-gl": "^10.4.2",
  • "expo-gl-cpp": "^10.4.1",
  • Tensorflow.js Converter Version: v3.6.0 (as reported by the model file)

Describe the current behavior
Loading a graph model (13MB) bundled in a react-native application (bare) takes an excessive amount of time on Android devices, taking upwards of 60s.

Model
We have trained an image classification model using Google Vision and exported it as a Tensorflow.js package using their dashboard. After downloading it from the Google Vision dashboard we haven't modified the model in any way.
The model.json file weighs 167KB. The weights are sharded into 4 files (each named group1-shardxof4.bin), three of which weigh 4.2MB and the last 400KB, totalling to 13MB.

The three top lines of the model.json file read:

"format": "graph-model",
"generatedBy": "2.7.0",
"convertedBy": "TensorFlow.js Converter v3.6.0",

Init code
As for the model initialization in the application, we make sure that tensorflow is ready by running and awaiting tf.ready() early on in the application's lifecycle and making sure it resolves successfully before loading the model.

In the file where classification happens, we then import the necessary libraries, require the bundled model files and loadGraphModel, such as:

import * as tf from '@tensorflow/tfjs';
import { bundleResourceIO } from '@tensorflow/tfjs-react-native';

const modelJson = require('../../../assets/model/model.json');
const modelWeights1 = require('../../../assets/model/group1-shard1of4.bin');
const modelWeights2 = require('../../../assets/model/group1-shard2of4.bin');
const modelWeights3 = require('../../../assets/model/group1-shard3of4.bin');
const modelWeights4 = require('../../../assets/model/group1-shard4of4.bin');

// Basic model setup, code wrapped in a functional component which I'm not including here
const setupModel = async () => {
  try {
    // Create GraphModel
    const ioHandler = bundleResourceIO(modelJson, [ modelWeights1, modelWeights2, modelWeights3, modelWeights4 ]);
    const graphModel = await tf.loadGraphModel(ioHandler);

    // Save model for further use
    // ...
  }

  catch (modelSetupError) {
    // ...
  }
};

useEffect(() => {
  setupModel();
}, []);

graphModel is later used to create and store an automl.ImageClassificationModel, which isn't relevant here.

The issue
The issue is the tf.loadGraphModel method, which takes upwards of 60s to resolve on slightly older Android devices - such as Nexus 5x - while making the app interface completely unresponsive in the meantime.

Running the application as a built release apk resulted in:
Xiaomi Redmi 7: taking ~24s to load the model,
Nexus 5x: taking ~60s to load the model,
Samsung A10: taking ~78s to load the model.

For comparison:
when ran locally on an iPhone 7 or 8 plus: ~3s to load the model,
when archived, downloaded from Testflight and ran on iPhone7 or 8 plus: ~9s,
when serving the model files from a locally ran node server and loading the model through http using automl.loadImageClassification(modelUrl): 17s to load the model on Nexus 5x.

Describe the expected behavior
I would expect the loadGraphModel method to resolve faster than the reported times.
Given that it's pretty simple, with only ~3k images used to train, but it's hard for me to judge whether the observed load times are reasonable and would love for someone to let me know what kind of load performance could be expected from a 13MB model.

Also, if anyone could point me to how the model load time could be optimized in react-native, any steps that could've been missed and would affect the load time or a better approach to using the Google Vision-trained model (automl) in react-native, I'd greatly appreciate that 🙂

Additionally, is there any proven way of loading the model without completely blocking the js thread while it happens?

Let me know if any additional information would be helpful to resolving the issue 🙂

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions