Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MobileStyleGAN Checkpoint converted to ONNX generates grey images #44

Open
IvonaTau opened this issue Oct 5, 2022 · 1 comment
Open

Comments

@IvonaTau
Copy link

IvonaTau commented Oct 5, 2022

Hi!

Thank you for an amazing repository.
I successfully converted my StyleGAN2-ada rosinality checkpoint, by running the following line:
python convert_rosinality_ckpt.py --ckpt {path_to_rosinality_stylegan2_ckpt} --ckpt-mnet output/mnet.ckpt --ckpt-snet output/snet.ckpt --cfg-path output/config.json

I tested the checkpoint with demo.py and it produces images as expected.

I then converted it to ONNX by running
python train.py --cfg output/config.json --export-model onnx --export-dir onnx-2
and tried to use the converted checkpoint in MobileStyleGAN web demo (https://github.com/cyrildiagne/mobilestylegan-web-demo).
It produces uniform grey images for all seeds. The web demo works fine with the authors' ffhq checkpoint so it seems to be an issue with the converted model.

Do you have any thoughts on what might be causing this?

Screenshot 2022-10-05 at 17 57 19

@johndpope
Copy link

johndpope commented Oct 10, 2024

it seems the onnx has problem with fusedleakyrelu -

related - rosinality/stylegan2-pytorch#322

am using this same function in own project -

# TODO: fixed ONNX conversion

i have this class that spits out onnx -
the opsets are very important as lower versions dont have the support
https://github.com/johndpope/IMF/blob/main/onnxconv.py#L153

but this didn't resolve things for me yet.

still digging....

wait a second maybe this recent upstream fix on pytorch solves things -
pytorch/pytorch#125753
need to test

pip install onnxscript -U
pip install onnxruntime -U
pip install onnxconverter_common -U
pip install onnx  -U

N.B. - this library has a diagnosis for onnx
https://github.com/webonnx/wonnx

cargo install --git https://github.com/webonnx/wonnx.git wonnx-cli
nnx info ./data/models/opt-squeeze.onnx

[2024-10-10T21:07:06Z ERROR nnx::info] Node '/latent_token_encoder/linear_layers.2_1/Gemm' input '/latent_token_encoder/activation_7/LeakyRelu_output_0' has unknown shape
[2024-10-10T21:07:06Z ERROR nnx::info] Node '/latent_token_encoder/activation_8/LeakyRelu' input '/latent_token_encoder/linear_layers.2_1/Gemm_output_0' has unknown shape
[2024-10-10T21:07:06Z ERROR nnx::info] Node '/latent_token_encoder/linear_layers.3_1/Gemm' input '/latent_token_encoder/activation_8/LeakyRelu_output_0' has unknown shape
[2024-10-10T21:07:06Z ERROR nnx::info] Node '/latent_token_encoder/activation_9/LeakyRelu' input '/latent_token_encoder/linear_layers.3_1/Gemm_output_0' has unknown shape
[2024-10-10T21:07:06Z ERROR nnx::info] Node '/latent_token_encoder/final_linear_1/Gemm' input '/latent_token_encoder/activation_9/LeakyRelu_output_0' has unknown shape
+------------------+-------------------------------------------------------------------+
| Model version    | 0                                                                 |
+------------------+-------------------------------------------------------------------+
| IR version       | 8                                                                 |
+------------------+-------------------------------------------------------------------+
| Producer name    | pytorch                                                           |
+------------------+-------------------------------------------------------------------+
| Producer version | 2.4.0                                                             |
+------------------+-------------------------------------------------------------------+
| Opsets           | 15                                                                |
+------------------+-------------------------------------------------------------------+
| Inputs           | +-------------+-------------+----------------------------+------+ |
|                  | | Name        | Description | Shape                      | Type | |
|                  | +-------------+-------------+----------------------------+------+ |
|                  | | x_current   |             | batch_size x 3 x 256 x 256 | f32  | |
|                  | +-------------+-------------+----------------------------+------+ |
|                  | | x_reference |             | batch_size x 3 x 256 x 256 | f32  | |
|                  | +-------------+-------------+----------------------------+------+ |
+------------------+-------------------------------------------------------------------+
| Outputs          | +------+-------------+----------------------------+------+        |
|                  | | Name | Description | Shape                      | Type |        |
|                  | +------+-------------+----------------------------+------+        |
|                  | | f_r  |             | batch_size x 128 x 64 x 64 | f32  |        |
|                  | +------+-------------+----------------------------+------+        |
|                  | | t_r  |             | batch_size x 256 x 32 x 32 | f32  |        |
|                  | +------+-------------+----------------------------+------+        |
|                  | | t_c  |             | batch_size x 512 x 16 x 16 | f32  |        |
|                  | +------+-------------+----------------------------+------+        |
|                  | | x    |             | batch_size x 512 x 8 x 8   | f32  |        |
|                  | +------+-------------+----------------------------+------+        |
|                  | | 3032 |             | Gemm3032_dim_0 x 32        | f32  |        |
|                  | +------+-------------+----------------------------+------+        |
|                  | | 4799 |             | Gemm4799_dim_0 x 32        | f32  |        |
|                  | +------+-------------+----------------------------+------+        |
+------------------+-------------------------------------------------------------------+
| Ops used         | +-----------------+---------------------+                         |
|                  | | Op              | Attributes          |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Add             |                     |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Cast            | to=7                |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Concat          | axis=0              |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Constant        | value=<TENSOR>      |                         |
|                  | +-----------------+---------------------+                         |
|                  | | ConstantOfShape | value=<TENSOR>      |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Conv            | dilations=<INTS>    |                         |
|                  | |                 | group=1             |                         |
|                  | |                 | kernel_shape=<INTS> |                         |
|                  | |                 | pads=<INTS>         |                         |
|                  | |                 | strides=<INTS>      |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Div             |                     |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Gather          | axis=0              |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Gemm            | alpha=1             |                         |
|                  | |                 | beta=1              |                         |
|                  | |                 | transB=1            |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Identity        |                     |                         |
|                  | +-----------------+---------------------+                         |
|                  | | LeakyRelu       | alpha=0.2           |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Mul             |                     |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Pad             | mode=constant       |                         |
|                  | +-----------------+---------------------+                         |
|                  | | ReduceMean      | axes=<INTS>         |                         |
|                  | |                 | keepdims=0          |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Relu            |                     |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Reshape         | allowzero=0         |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Shape           |                     |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Slice           |                     |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Sub             |                     |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Transpose       | perm=<INTS>         |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Unsqueeze       |                     |                         |
|                  | +-----------------+---------------------+                         |
+------------------+-------------------------------------------------------------------+
| Memory usage     | +--------------+-----------+                                      |
|                  | | Inputs       |       0 B |                                      |
|                  | +--------------+-----------+                                      |
|                  | | Outputs      |       0 B |                                      |
|                  | +--------------+-----------+                                      |
|                  | | Intermediate |       0 B |                                      |
|                  | +--------------+-----------+                                      |
|                  | | Weights      | 204.4 MiB |                                      |
|                  | +--------------+-----------+                                      |
|                  | | Total        | 204.4 MiB |                                      |
|                  | +--------------+-----------+                                      |
+------------------+-------------------------------------------------------------------+

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants