Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Apple Silicon / MPS #15

Closed
j-f1 opened this issue Dec 16, 2022 · 5 comments
Closed

Support for Apple Silicon / MPS #15

j-f1 opened this issue Dec 16, 2022 · 5 comments

Comments

@j-f1
Copy link

j-f1 commented Dec 16, 2022

I’m not sure if current gen Apple Silicon GPUs are capable of doing the computation fast enough (probably not tbh) but it would be great to get it working so folks can at least try it out. I tried changing all the mentions of cuda in the project to mps, but I’m getting an error in TensorScript which suggests some changes need to be made to the model to not assume CUDA. Is there a way to fix/patch this?

NotImplementedError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/diffusers/models/unet_2d_condition/___torch_mangle_4939.py", line 44, in forward
    _4 = ops.prim.NumToTensor(torch.size(sample, 0))
    timesteps = torch.expand(timestep, [int(_4)])
    input0 = torch.to((time_proj).forward(timesteps, ), 5)
                       ~~~~~~~~~~~~~~~~~~ <--- HERE
    _5 = (time_embedding).forward(input0, )
    _6 = (conv_in).forward(sample, )
  File "code/__torch__/diffusers/models/embeddings/___torch_mangle_4232.py", line 8, in forward
  def forward(self: __torch__.diffusers.models.embeddings.___torch_mangle_4232.Timesteps,
    timesteps: Tensor) -> Tensor:
    _0 = torch.arange(0, 160, dtype=6, layout=None, device=torch.device("cuda:0"), pin_memory=False)
         ~~~~~~~~~~~~ <--- HERE
    exponent = torch.mul(_0, CONSTANTS.c0)
    exponent0 = torch.div(exponent, CONSTANTS.c1)
@hmartiro
Copy link
Member

Potentially you could try not tracing the UNet, if that is causing the issue: https://github.com/riffusion/riffusion-inference/blob/main/riffusion/server.py#L99

I do expect it to be too slow for real time on MPS, and not tracing will slow it further

@ashwal
Copy link

ashwal commented Dec 19, 2022

Can confirm removing the trace works. I also did my own trace, but see ~negligible difference. Perhaps because aten::repeat_interleave.self_int is not supported on MPS.

I get ~2it/s on a Apple M1 Max (18sec image gen).

The real bottleneck appears to be wav_bytes_from_spectrogram_image @ ~100sec. Doesn't look like MPS supports GriffinLim. It seg faults for me, but I only gave it a brief look

@hmartiro
Copy link
Member

hmartiro commented Dec 19, 2022

Interesting about GriffinLim. There's been some talk of finding a neural vocoder that has better quality, could be something to track.

@hmartiro
Copy link
Member

Follow up, a lot of fourier operations are not supported on MPS yet, in particular the ComplexFloat data type: pytorch/pytorch#78044

Until that is resolved, generation could work but the entire stack will not

@hmartiro
Copy link
Member

MPS is now supported as a device with CPU fallback for some operations. See the README for a description!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants