Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent Results on torchvision pretrained models using Python script vs C++ API #29194

Closed
Con-Mi opened this issue Nov 5, 2019 · 1 comment
Labels
module: cpp Related to C++ API

Comments

@Con-Mi
Copy link

Con-Mi commented Nov 5, 2019

I have tried to benchmark and compare a Python script, using a pretrained classification model from the torchvision library with the C++ API implementation in order to make sure it will work with my application.
I have tried the following script in Python:

import torch
from torchvision import transforms
from PIL import Image


_IMAGE_FILENAME = "../images/goldfish.jpg"
_MODEL_JIT_FILENAME = "../jit_models_bin/traced_mnasnet0_5.pt"

model = torch.jit.load(_MODEL_JIT_FILENAME)

tfm = transforms.Compose([ 
    transforms.Resize([ 224, 224 ]), 
    transforms.ToTensor(), 
    transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) 
    ])

image = Image.open(_IMAGE_FILENAME)
image = tfm(image)
image = image.unsqueeze(dim=0)


output = model(image)
output = torch.softmax(output, 1)

prob_value, index = torch.max(output, 1)

print("Probability Value: ")
print(prob_value)
print("ImageNet Index: ")
print(index)

And I used the following code in C++:

#include <iostream>
#include <vector>
#include <string>

#include <torch/torch.h>
#include <torch/script.h>

#include <opencv2/core.hpp>
#include <opencv2/imgcodecs.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc.hpp>


int main() {

    const cv::String _IMAGE_FILENAME = "../images/goldfish.jpg";
    const std::string _MODEL_JIT_FILENAME= "../jit_models_bin/traced_mnasnet0_5.pt";

    cv::Mat img = cv::imread( _IMAGE_FILENAME, cv::IMREAD_UNCHANGED );
    cv::Size rsz = { 224, 224 };

    cv::resize( img, img, rsz, 0, 0, cv::INTER_LINEAR );
    img.convertTo( img, CV_32FC3, 1/255.0 );

    at::Tensor tensorImage = torch::from_blob(img.data, { 1, img.rows, img.cols, 3 }, at::kFloat);
    tensorImage = tensorImage.permute({0, 3, 1, 2});

    //  Normalize data
    tensorImage[0][0] = tensorImage[0][0].sub(0.485).div(0.229);
    tensorImage[0][1] = tensorImage[0][1].sub(0.456).div(0.224);
    tensorImage[0][2] = tensorImage[0][2].sub(0.406).div(0.225);

    std::vector<torch::jit::IValue> input;
    input.push_back(tensorImage);

    torch::jit::script::Module model = torch::jit::load( _MODEL_JIT_FILENAME );

    at::Tensor output = torch::softmax(model.forward(input).toTensor(), 1);

    std::tuple<at::Tensor, at::Tensor> result = torch::max(output, 1);

    std::cout << "Probability Value: " << std::endl;
    std::cout << std::get<0>(result) << std::endl;
    std::cout << "ImageNet Index" << std::endl;
    std::cout << std::get<1>(result) << std::endl;

    return 0;
}

The results I get in the Python script:
Probability: 0.99, Index: 1, Label: GoldFish

The results I get in C++:
Probability: 0.4839, Index: 584

I used PyTorch 1.3 version in Python 3.7 and the latest corresponding libtorch.

Any ideas?

cc @yf225

@VitalyFedyunin VitalyFedyunin added the module: cpp Related to C++ API label Nov 5, 2019
@Con-Mi
Copy link
Author

Con-Mi commented Nov 6, 2019

Ok after making some tests, regarding reading an image with OpenCV I have identified the issue.
The problem was, reading an image with OpenCV with cv::IMREAD_UNCHANGED gives you BGR channels. The image still needs to be converted to RGB.
Therefore I have removed using the PIL library from Python as you cant use it in C++.

The updated code that gives consistent results:

Python Script

import torch
from torchvision import transforms
import cv2


_IMAGE_FILENAME = "../images/goldfish.jpg"
_MODEL_JIT_FILENAME = "../jit_models_bin/traced_mnasnet0_5.pt"

model = torch.jit.load(_MODEL_JIT_FILENAME)

trsfm = transforms.Compose([ 
    transforms.ToTensor(),
    transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) 
    ])

image = cv2.imread(_IMAGE_FILENAME)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (224, 224))

image = trsfm(image)
image = image.unsqueeze(dim=0)


output = model(image)
output = torch.softmax(output, 1)

prob_value, index = torch.max(output, 1)

print("Probability Value: ")
print(prob_value)
print("ImageNet Index: ")
print(index)

C++ Code:

#include <iostream>
#include <vector>
#include <string>

#include <torch/torch.h>
#include <torch/script.h>

#include <opencv2/core.hpp>
#include <opencv2/imgcodecs.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc.hpp>


int main() {

    const cv::String _IMAGE_FILENAME = "../images/goldfish.jpg";
    const std::string _MODEL_JIT_FILENAME= "../jit_models_bin/traced_mnasnet0_5.pt";

    cv::Mat img = cv::imread( _IMAGE_FILENAME );
    cv::cvtColor( img, img, cv::COLOR_BGR2RGB );
    cv::Size rsz = { 224, 224 };

    cv::resize( img, img, rsz, 0, 0, cv::INTER_LINEAR );
    img.convertTo( img, CV_32FC3, 1/255.0 );

    at::Tensor tensorImage = torch::from_blob(img.data, { 1, img.rows, img.cols, 3 }, at::kFloat);
    tensorImage = tensorImage.permute({0, 3, 1, 2});

    //  Normalize data
    tensorImage[0][0] = tensorImage[0][0].sub(0.485).div(0.229);
    tensorImage[0][1] = tensorImage[0][1].sub(0.456).div(0.224);
    tensorImage[0][2] = tensorImage[0][2].sub(0.406).div(0.225);

    std::vector<torch::jit::IValue> input;
    input.push_back(tensorImage);

    torch::jit::script::Module model = torch::jit::load( _MODEL_JIT_FILENAME );

    at::Tensor output = torch::softmax(model.forward(input).toTensor(), 1);

    std::tuple<at::Tensor, at::Tensor> result = torch::max(output, 1);

    std::cout << "Probability Value: " << std::endl;
    std::cout << std::get<0>(result) << std::endl;
    std::cout << "ImageNet Index" << std::endl;
    std::cout << std::get<1>(result) << std::endl;

    return 0;
}

Results Expected:
Python: Probability Value: 0.991, ImageNet Index: 1
C++: Probability Value: 0.991, ImageNet Index: 1

@Con-Mi Con-Mi closed this as completed Nov 6, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cpp Related to C++ API
Projects
None yet
Development

No branches or pull requests

2 participants