Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resurrect JIT functionality and loading of models #146

Closed
dsyme opened this issue Jun 2, 2020 · 16 comments
Closed

Resurrect JIT functionality and loading of models #146

dsyme opened this issue Jun 2, 2020 · 16 comments

Comments

@dsyme
Copy link
Contributor

dsyme commented Jun 2, 2020

The PyTorch C++ guide shows that loading of modules means using the JIT loading: https://pytorch.org/tutorials/advanced/cpp_export.html

The JIT functionality was removed after the extensive churn in Pytorch C++ API in v1.01 --> 1.50.
It needs to be resurrected.

The more direct Load/Save that was present on modules should probably not be resurrected directly.

Also need to add ONNX load/save support

@dsyme
Copy link
Contributor Author

dsyme commented Nov 4, 2021

@NiklasGustafsson What's your thinking on this for the roadmap for TorchSharp? It feels like it should be there?

@NiklasGustafsson
Copy link
Contributor

I think this is one area that requires very careful consideration. I'm not convinced that TorchSharp needs to follow the PyTorch recipe for externalizing models exactly.

While we have spent a lot of effort making the eager evaluation of models as close as possible to what PyTorch developers are doing, we do not have to make the externalized model experience the same. For example, there may be a role for Roslyn analyzers for the C# solution, rather than a clone of the Python solution. Likewise for F# -- there may be a F#-specific solution.

I think this should come post-v1.0, and we also need to consider how we will support ONNX export from TorchSharp. To me, they are two sides of the same coin.

@GeorgeS2019
Copy link

GeorgeS2019 commented Nov 4, 2021

@dsyme @NiklasGustafsson

Currently, I could be wrong, it seems the continuity from pytorch saved model to Torchsharp is through ONNX.

There is a need for a proof of concept that the state dictionary exported (using TorshSharp exportsd.py script) from pyTorch can be loaded back to TorchSharp by having the requirement that both the TorchSharp and PyTorch models are compatible.

As Nikas has implied in a few occasions, one possible option is the TorchSharp support for ONNX import of PyTorch exported ONNX.

Once the model has been exported to ONNX, we can (theoretically) use that to recreate the model code in C# (or F#), which can then be used to load the model weights.

It seems TorchSharp ONNX import is equally and perhaps more important than TorchSharp ONNX export.

Again, I hope this issue is evaluated post-v1.0

Congratulation of rapid progress towards v1.0!

@dsyme
Copy link
Contributor Author

dsyme commented Nov 4, 2021

@NiklasGustafsson I see for loading (though I'd still think we need to support loading of the pytorch format, if Java or Rust or Scala bindings do, we should check)

What about JIT? That's really what I was asking about, I forgot this issue referred to loading too.

@dsyme
Copy link
Contributor Author

dsyme commented Nov 4, 2021

Here are relevant links from Rust PyTorch bindings

https://github.com/LaurentMazare/tch-rs#using-some-pre-trained-model

https://github.com/LaurentMazare/tch-rs/tree/main/examples/jit

I haven't checked how the ".ot" files are extracted. At least, we should be matching this

@dsyme
Copy link
Contributor Author

dsyme commented Nov 4, 2021

Looks like they have load:

let model = tch::CModule::load(model_file)?;

@GeorgeS2019
Copy link

GeorgeS2019 commented Nov 4, 2021

@dsyme @NiklasGustafsson

Based on new intiative from @dsyme, I recall the challenges of Netron dealing with PyTorch support.

PyTorch support is experimental. The recommended approach is to save the model to ONNX.

a .pth file might contain only tensor data or a partial module hierarchy without graph connections.

Rust binding approach

Then I look into the Rust binding approach as suggested by @dsyme

Rust seems to work only with TorchScript instead of PyTorch model

Example (tracing an existing module):

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)

n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)

# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)

# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)

The trained PyTorch model can then be serialized into a Torch Script file. The Torch Script file contains a description of the model architecture as well as trained weights. This file can be loaded from Rust to run inference for the saved model.

Given a "PyTorch" trained and save model, (e.g. pth or pt) the extension itself does not reveal if it is a PyTorch model or a TorchScript serialized model. (lots of discussion on that)

To address that in Rust.

  • Rust controls that ONLY saved TorchScript model is used and converts that into a format (with extension "ot"), loadable by Rust.

The rust CModule is described here

However, it seems in practice, Rust goes from a save Jit model (e.g. "pt"), through an intermediate format ".nz", before ending up with the "ot" format.

This is what I could gather and I am not 100% sure if my understanding is correct.

@NiklasGustafsson
Copy link
Contributor

@dsyme, @GeorgeS2019 -- we can and should look at JIT post-v1.0. In my opinion, we should decide on what is right in the context of .NET and the tooling options we have at our disposal. Tracing is simple, it's the scripting option of TorchScript that is going to take some thought and consideration.

I think ONNX is a good solution for exchange between languages (that's what it's for), but the reason for JIT support is primarily about performance and being able to deploy models for inference on a wide variety of platforms, so the JIT question is distinct from loading, IMO.

I'm working on an example (for the examples repo) with a simple ONNX -> C# converter that relies on the .NET model binary format. It should be ready in a week or so.

@GeorgeS2019
Copy link

GeorgeS2019 commented Nov 4, 2021

@NiklasGustafsson @dsyme

In some way, I am thinking along what Niklas has said.

We may need to learn from the PyTorch experience faced by both Netron and Rust.

To avoid not knowing the origin of PyTorch "model" and confusion caused by that, Netron recommends ONNX and not officially supporting PyTorch.

For Rust, it requires a Controled workflow of migrating a TorchScript to a proprietary Rust "ot" format that is non standard and only used within Rust, or perhaps a subset of Rust projects dealing with PyTorch.

Recommendation

  • Stay dealing with ONLY ONNX import and export. There is no more proprietary TorchSharp format which is currently the case. If this is the case, immediately TorchSharp will have access to ONNX model coming from not just PyTorch, but other ML frameworks.

If there is a need to deal with TorchScript, provides a tool to detect if a given file is a PyTorch Model or TorchScript. Provide a brief explanation why only deals with TorchScript. Instead of following the path of Rust, TorchScript will be converted to ONNX within TorchSharp. There is no need to deal with a proprietary non standard TorchSharp format which is currently the case.

@dsyme
Copy link
Contributor Author

dsyme commented Nov 4, 2021

The core design philosophy for TorchSharp seems to be "if it's in LibTorch, and stable, then TorchSharp should expose it". That is from the perspective of the coherence of the TorchSharp design: TorchSharp is the one and only wrapper for LibTorch in the .NET ecosystem. As a .NET Foundation project I think that's the right approach.

So for any LibTorch functionality (incl JIT and loading) I'd say "if it's there, if it's stable, if it's feasible, if other languages bind to loading, then TorchSharp should in theory do to". That means I believe we should say we would accept community PRs for these. It's not saying we need to put it into v1.0 ourselves, but it's hard to argue against accepting these things if someone contributes them.

I guess there's chance that @gbaydin and I may like to use the JIT functionality from DiffSharp to allow quotation-based compilation. Or I might try to build a different toolchain that compiles models written in F# statically using FSharp.Compiler.Service and JIT.

@GeorgeS2019
Copy link

GeorgeS2019 commented Feb 19, 2022

What I learn from this recent use case suggest the following challenges when using save pytorch model ( more like saving dict_states) in torchsharp.

if __name__ == '__main__':
    # Create model
    model = BasicConv1d(1, 32)
    
    #Export model to .dat file for ingestion into TorchSharp
    f = open("bug510.dat", "wb")
    exportsd.save_state_dict(model.to("cpu").state_dict(), f)
    f.close()

One has to examine if all the save dict states have been implemented internally in torchsharp for the torchsharp model.load to work.

I discuss here so more users check to cover (if any) more missing internal implementation (if possible submitting PRs) to make mode.load more reliable when loading saved pytorch dict_states.

Once we have a reliable model.load for almost all scenarios, then it will be less challenging implement torchsharp codes to load ONNX in torchsharp.

Any comment??

How can we coordinate with the community to write unit tests to increase code coverage for achieving reliable model.load when exporting using the provided torchsharp python code.

@NiklasGustafsson
Copy link
Contributor

Once we have a reliable model.load for almost all scenarios, then it will be less challenging implement torchsharp codes to load ONNX in torchsharp.

Model.load() has nothing to do with loading ONNX models in TorchSharp, it relates only to loading model state_dicts() generated from either TorchSharp or PyTorch.

@GeorgeS2019
Copy link

@NiklasGustafsson

Model.load() has nothing to do with loading ONNX models in TorchSharp, it relates only to loading model state_dicts() generated from either TorchSharp or PyTorch.

I got the part on ONNX and state_dict() and nothing to do with Model.Load().

Model.Load() ONLY works if

  • torchsharp csharp module corresponding to that of pytorch is available
  • torchsharp python script was used to export the pytorch state_dict.

@GeorgeS2019
Copy link

GeorgeS2019 commented Feb 22, 2022

@NiklasGustafsson

We have discussed the limitation of pytorch pth model

Here @fwaris addresses that by creating a parser going through TensorFlow saved model.

  • instead of exporting ONNX from pytorch, use the provided saved Tensorflow checkpoint model.
  • instead of going through pytorch saved state_dict and the need of manually creating a corresponding torchsharp csharp module, perhaps using the graph structure provided by TfCheckpoint to half manual create the corresponding torchsharp csharp module.

Since @michaelgsharp ML.NET can integrate Tensorflow save model. I wonder if the necessary parser needed to extract the graph structure similar to that of TfCheckpoint already available?

@fwaris
Copy link
Contributor

fwaris commented Feb 23, 2022

@GeorgeS2019 at this time, TfCheckpoint only reads the tensor data (variables) - not the model structure. TfCheckpoint has minimal dependencies (really just IronSnappy and the required protobuf definitions).

However, Tensorflow.NET seems to have the required functionality but the documentation is sparse. The "Restoring the Model" part is not written yet.

@NiklasGustafsson
Copy link
Contributor

This has now been implemented -- TorchSharp can load and saved TorchScript files, but not create them from scratch. See PR #644

@dsyme dsyme closed this as completed Jul 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants