Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BERT support #331

Closed
kali opened this issue Jul 22, 2020 · 22 comments
Closed

BERT support #331

kali opened this issue Jul 22, 2020 · 22 comments

Comments

@kali
Copy link
Collaborator

kali commented Jul 22, 2020

Not too sure what specific operators the BERT architecture will require but:

@bminixhofer
Copy link
Contributor

This looks really exciting. Running models from https://github.com/huggingface/transformers without the TF hassle would be neat. Happy to help with a test case if needed.

@kali
Copy link
Collaborator Author

kali commented Aug 22, 2020

@bminixhofer I tried making a test-case for this one https://github.com/onnx/models/tree/master/text/machine_comprehension/bert-squad, but I have little patience for debugging python notebooks that got broken months ago... If your offer of helping with test-cases still stand, please do :)

@kali
Copy link
Collaborator Author

kali commented Aug 23, 2020

[x] NonZeroWe now have a working test case in incorporated mode. Performance is of course awful.
[ ] some version of #313 is needed to run in optimised mode
[ ] ConstantOfShape: I do not have an example network that crash with this operator. @joverwey, any chance you can provide a test-case featuring this op ?

@joverwey
Copy link

ConstantOfShape

You can use the following script to generate a BERT model in ONNX format that uses the ConstantOfShape operator.
Before running the script, make sure to install the dependencies. The first is PyTorch, the second is the pretrained PyTorch model.
On my machine I did the following:

conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
pip3 install pytorch-pretrained-bert

export_bert.py.txt

Running the script will produce 2 files:

  • bertmodel.onnx
  • bertmodel_lm.onnx

Both use the ConstantOfShape operator. Let me know if you have trouble with this. I could upload the ONNX model somewhere but they are > 400 MB.

@bminixhofer
Copy link
Contributor

bminixhofer commented Aug 29, 2020

Hi, just a heads up: pytorch-pretrained-bert has long been deprecated in favor of huggingface/transformers. Maybe you can try updating to huggingface/transformers? The networks I exported with it did apparently not have this operator.

@kali You could also try changing the model_name in the test from my PR to bert-base-uncased, then it should be the exact same model.

@Hendler
Copy link

Hendler commented Oct 28, 2020

Wondering if there is any progress here. Perhaps it is limited by how well (or easily) huggingface is exporting to ONNX. e.g.
huggingface/transformers#5948

@kali
Copy link
Collaborator Author

kali commented Oct 29, 2020

There have been progress.

Then it should be a matter of just fixing a handful of operators. We'll get there.

@llenotre
Copy link
Contributor

Hello,

I was trying to load an ONNX file of a FlauBERT model (which is similar to BERT). and I ran into the following error:

Translating node #28 "SkipLayerNorm1" Unimplemented(SkipLayerNormalization) ToTypedTranslator

Caused by:
    0: translating op UnimplementedOp { outputs: 1, name: "SkipLayerNormalization", message: "NodeProto { input: [\"128\", \"131\", \"transformer.layer_norm_emb.weight\", \"transformer.layer_norm_emb.bias\"], output: [\"143\"], name: \"SkipLayerNorm1\", op_type: \"SkipLayerNormalization\", domain: \"com.microsoft\", attribute: [AttributeProto { name: \"epsilon\", ref_attr_name: \"\", doc_string: \"\", r#type: Float, f: 1e-6, i: 0, s: [], t: None, g: None, floats: [], ints: [], strings: [], tensors: [], graphs: [] }], doc_string: \"\" }" }
    1: Operator can not be made a TypedOp.

@Yevgnen
Copy link
Contributor

Yevgnen commented Feb 18, 2022

Politely ping...

@kali
Copy link
Collaborator Author

kali commented Feb 18, 2022

@Yevgnen thanks for the ping. Let's try to revive this, and see where we are.

Keeping up with huggingface models releases is a bit of challenge for me :) Does anybody knows the current state of affairs? What is their latest release format? Does it take the same form as what @bminixhofer put in the tests in 2020 in onnx/transformer-mlm ? What are we ultimately trying to achieve here ? My thinking right now is we want a working rust example to be put in examples alongside the others. I'm just not sure which model, and what kind of workflow (fully known input sizes before optim ? partially known input sizes ?). Could anyone help me define this ?

@llenotre can you provide me or tell me where to find an instance of Flaubert so i can have a look at the missing op ?

@Yevgnen
Copy link
Contributor

Yevgnen commented Feb 19, 2022

@kali

Hi, thanks for your quick response. I'm quite new to Rust and ONNX. It would be great to have examples related to the BERT models.

At least related early model architectures does not change for last two years. What are we missing to get it work? I got similar error like @llenotre when trying to load an ONNX file which uses Albert as an encoder. Are we missing ops only or there are some more difficulties like #533 to get it work?

I think models like BERT or Albert would be the simple and good to start with. The typical inputs are input_ids, attention_mask and token_type_ids and all of them have sizes [batch_size, max_length] (See #533).

@bminixhofer
Copy link
Contributor

I'm a bit surprised that this is not covered by the transformer-mlm test. I'll check what exactly is missing for the models you mentioned here.

@bminixhofer
Copy link
Contributor

bminixhofer commented Feb 24, 2022

I tried slightly adapted code from the transformer-mlm test:

Code
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForMaskedLM
import onnxruntime

model_name = "distilbert-base-uncased"
model_path = "model.onnx"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)

text = "tract is a machine " + tokenizer.mask_token + " library."
print("MODEL:", model_name)
print("TEXT:", text)

encoded = tokenizer.encode_plus(text)
mask_idx = encoded["input_ids"].index(tokenizer.mask_token_id)

input_ids = torch.tensor([encoded["input_ids"]], dtype=torch.long)
attention_mask = torch.tensor([encoded["attention_mask"]], dtype=torch.long)

torch.onnx.export(
    model,
    (input_ids, attention_mask),
    model_path,
    input_names=["input_ids", "attention_mask"],
    output_names=["output"],
    dynamic_axes={
        "input_ids": {0: "batch", 1: "seq"},
        "attention_mask": {0: "batch", 1: "seq"},
        "output": {0: "batch", 1: "seq"},
    },
    opset_version=13,
)

sess = onnxruntime.InferenceSession(model_path)

outputs = sess.run(
    None, {"input_ids": input_ids.numpy(), "attention_mask": attention_mask.numpy()}
)[0]

np.savez_compressed(
    open("io.npz", "wb"),
    input_ids=input_ids.numpy(),
    attention_mask=attention_mask.numpy(),
    output=outputs,
)

(note that I had to add opset_version=13 to make it work with recent versions of PyTorch & Transformers)

and: cargo run -p tract --release -- model.onnx --input-bundle io.npz run --assert-output-bundle io.npz

This works with model_name = "distilbert-base-uncased" (as already tested in CI), but it also works with albert-base-v2 and flaubert/flaubert_base_cased (although I get Mismatch at [0, 1, 14426] -0.645134 != -0.64454913 from FlauBERT but I suppose that's just tract being rather strict about precision).

I hope this helps in the discussion, from my POV the models which were mentioned here work.


Maybe it would be good to add an example with code similar to the one above + inference in Rust.

@Yevgnen
Copy link
Contributor

Yevgnen commented Feb 25, 2022

@bminixhofer

Thanks for you example. I tried to get the model run in Rust code, but the result did not seem to equal to the Python output. Do you have any suggestions?

Python code

# -*- coding: utf-8 -*-

import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
import onnxruntime

model_name = "voidful/albert_chinese_tiny"
model_path = "model.onnx"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model.eval()

text = "hey there!"
print("MODEL:", model_name)
print("TEXT:", text)

encoded = tokenizer(text, return_tensors="pt")
print(encoded)

input_ids = encoded["input_ids"]
attention_mask = encoded["attention_mask"]
token_type_ids = encoded["token_type_ids"]

torch.onnx.export(
    model,
    (input_ids, attention_mask, token_type_ids),
    model_path,
    input_names=["input_ids", "attention_mask", "token_type_ids"],
    output_names=["output"],
    dynamic_axes={
        "input_ids": {0: "batch", 1: "seq"},
        "attention_mask": {0: "batch", 1: "seq"},
        "token_type_ids": {0: "batch", 1: "seq"},
        "output": {0: "batch", 1: "seq"},
    },
    opset_version=13,
)

sess = onnxruntime.InferenceSession(model_path)

outputs = sess.run(
    None,
    {
        "input_ids": input_ids.numpy(),
        "attention_mask": attention_mask.numpy(),
        "token_type_ids": token_type_ids.numpy(),
    },
)
print(outputs[0])

np.savez_compressed(
    open("io.npz", "wb"),
    input_ids=input_ids.numpy(),
    attention_mask=attention_mask.numpy(),
    token_type_ids=token_type_ids.numpy(),
    output=outputs[0],
)

Rust Code

use std::{
    path::{Path, PathBuf},
    str::FromStr,
};
use tokenizers::tokenizer::{Result, Tokenizer};
use tract_onnx::prelude::*;

fn main() -> Result<()> {
    // model here: https://huggingface.co/voidful/albert_chinese_tiny/tree/main
    // the model does have a tokenizer.json so I save the downloaded tokenizer manually in Python
    let model_dir = PathBuf::from_str("./albert_chinese_tiny/")?;
    let tokenizer = Tokenizer::from_file(Path::join(&model_dir, "tokenizer.json"))?;

    let text = "hey there!";

    let tokenizer_output = tokenizer.encode(text, true)?;
    println!("{:?}", tokenizer_output);

    let input_ids = tokenizer_output.get_ids();
    let attention_mask = tokenizer_output.get_attention_mask();
    let token_type_ids = tokenizer_output.get_type_ids();
    let length = input_ids.len();

    let input_ids: Tensor = tract_ndarray::Array2::from_shape_vec(
        (1, length),
        input_ids.iter().map(|&x| x as i64).collect(),
    )?
    .into();
    let attention_mask: Tensor = tract_ndarray::Array2::from_shape_vec(
        (1, length),
        attention_mask.iter().map(|&x| x as i64).collect(),
    )?
    .into();
    let token_type_ids: Tensor = tract_ndarray::Array2::from_shape_vec(
        (1, length),
        token_type_ids.iter().map(|&x| x as i64).collect(),
    )?
    .into();

    let model = tract_onnx::onnx()
        .model_for_path(Path::join(&model_dir, "model.onnx"))?
        .with_input_fact(
            0,
            InferenceFact::dt_shape(i64::datum_type(), tvec!(1, length)),
        )?
        .with_input_fact(
            1,
            InferenceFact::dt_shape(i64::datum_type(), tvec!(1, length)),
        )?
        .with_input_fact(
            2,
            InferenceFact::dt_shape(i64::datum_type(), tvec!(1, length)),
        )?
        .into_optimized()?
        .into_runnable()?;

    let result = model.run(tvec!(input_ids, attention_mask, token_type_ids))?;
    println!("{:?}", result);

    Ok(())
}

Python input:

{'input_ids': tensor([[  101, 13153, 11136,   106,   102]]), 
'token_type_ids': tensor([[0, 0, 0, 0, 0]]), 
'attention_mask': tensor([[1, 1, 1, 1, 1]])}

Python output:

array([[[ 0.3897,  0.4885,  0.4034, ..., -0.4318,  0.414 , -1.3471],
        [-0.44  , -0.0851, -1.3209, ...,  0.6185,  0.1637, -0.8835],
        [-0.3721, -0.5651,  0.4615, ...,  0.6273, -0.235 , -1.6752],
        [ 0.3228, -0.2446,  0.4693, ...,  0.8158, -1.6703, -1.3914],
        [ 0.6923, -0.4026, -0.1817, ..., -1.2022, -0.1027, -1.1065]]],
      dtype=float32)

Rust input:

Encoding { ids: [101, 13153, 11136, 106, 102], 
type_ids: [0, 0, 0, 0, 0], 
tokens: ["[CLS]", "hey", "there", "!", "[SEP]"], 
words: [None, Some(0), Some(1), Some(2), None], 
offsets: [(0, 0), (0, 3), (4, 9), (9, 10), (0, 0)], 
special_tokens_mask: [1, 0, 0, 0, 1], 
attention_mask: [1, 1, 1, 1, 1], 
overflowing: [], 
sequence_ranges: {0: 1..4} }

Rust output:

[1,5,312,F32 -0.0455257, -0.48773792, -0.23926008, 0.63138044, 0.5035107, -0.23176816, 1.3015757, -0.9572776, -1.2335656, 1.1044103, -0.16797636, -0.6322866..., 1,312,F32 -0.83731806, -0.38727227, 0.098052114, 0.99925756, 0.185268, -0.93537277, -0.041956298, -0.14867271, 0.98756856, -0.9998903, -0.99749595, 0.16180465...]

@kali
Copy link
Collaborator Author

kali commented Feb 25, 2022

tract output looks really wrong: it does not even have the correct shape. As you do not set the model output in your tract code, tract makes a guess, and is probably just wrong. You can setup the output node for your model with setOutputName()

@Yevgnen
Copy link
Contributor

Yevgnen commented Feb 25, 2022

@kali Thanks for the hint. I'm new to Rust and try to follow tract/examples to make the snippet. I've thought the [1, 5, 312, F32.. part is correct as it's the same as the Python output (1, 5, 312). I'll try to fix this!

@kali
Copy link
Collaborator Author

kali commented Feb 25, 2022

Ha sorry, no, you were right, forget my comment. I misread the python output, did not see it was truncated. So I don't know what's wrong. Can you bundle the model.onnx and io.npz you are using in one single tidy piece I can download ? I don't want to sound like a diva, but the less I have to touch any python, the more likely I'll get motivated to look into the problem.

@Yevgnen
Copy link
Contributor

Yevgnen commented Feb 25, 2022

@kali

Sure, here you go:

@kali
Copy link
Collaborator Author

kali commented Feb 27, 2022

@Yevgnen Good news, this is what I get with tract top of tree:

Encoding { ids: [101, 13153, 11136, 106, 102], type_ids: [0, 0, 0, 0, 0], tokens: ["[CLS]", "hey", "there", "!", "[SEP]"], words: [None, Some(0), Some(1), Some(2), None], offsets: [(0, 0), (0, 3), (4, 9), (9, 10), (0, 0)], special_tokens_mask: [1, 0, 0, 0, 1], attention_mask: [1, 1, 1, 1, 1], overflowing: [], sequence_ranges: {0: 1..4} }
[1,5,312,F32 0.3896918, 0.48845768, 0.40343055, -0.5043659, -0.74386686, -0.04551193, 1.7171062, -0.5664656, -0.40217492, 0.40226883, 0.419717, -0.09059957..., 1,312,F32 0.99874693, 0.13921316, -0.10930423, 0.9993504, -0.22610563, -0.84864765, -0.15000838, -0.069201134, 0.9946854, -0.9996393, -0.96849227, 0.066608444...]

So I guess you can just pull the main branch for now. Anyway, I'm gonna do a release very soon (I was actually waiting for rust 1.59, it was release a few days ago).

@kali
Copy link
Collaborator Author

kali commented Feb 27, 2022

I'm also going to close this issue (as soon as the release is done). It would be nice to have an example, maybe something one of you nice folks would be nice enough to contribute (created #634 to track this).

But as there is nothing structural in either tract or BERT models that is compatible, we can safely say tract supports BERT. If we have remaining issues with on or the other BERT incarnation around, we will deal with them as bugs.

@Yevgnen
Copy link
Contributor

Yevgnen commented Feb 28, 2022

@kali Hi, that's really a good news. Thanks for you work! I confirm using the main branch I can get the correct output. I also made a small example and hope that help!

@kali
Copy link
Collaborator Author

kali commented Feb 28, 2022

@Yevgnen It does help ! thanks you.

@kali kali closed this as completed Feb 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants