Skip to content

Commit

Permalink
Make traced unet optional and support custom checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
hmartiro committed Dec 23, 2022
1 parent 40e1e51 commit 8349ccf
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 23 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ numpy
pillow
pydub
scipy
soundfile
torch
torchaudio
transformers
46 changes: 23 additions & 23 deletions riffusion/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import base64
import dataclasses
import functools
import logging
import io
import json
Expand Down Expand Up @@ -45,6 +44,7 @@
def run_app(
*,
checkpoint: str = "riffusion/riffusion-model-v1",
no_traced_unet: bool = False,
host: str = "127.0.0.1",
port: int = 3000,
debug: bool = False,
Expand All @@ -56,7 +56,7 @@ def run_app(
"""
# Initialize the model
global MODEL
MODEL = load_model(checkpoint=checkpoint)
MODEL = load_model(checkpoint=checkpoint, traced_unet=not no_traced_unet)

args = dict(
debug=debug,
Expand All @@ -72,7 +72,7 @@ def run_app(
app.run(**args)


def load_model(checkpoint: str):
def load_model(checkpoint: str, traced_unet: bool = True):
"""
Load the riffusion model pipeline.
"""
Expand All @@ -86,28 +86,30 @@ def load_model(checkpoint: str):
safety_checker=lambda images, **kwargs: (images, False),
).to("cuda")

@dataclasses.dataclass
class UNet2DConditionOutput:
sample: torch.FloatTensor
# Set the traced unet if desired
if checkpoint == "riffusion/riffusion-model-v1" and traced_unet:
@dataclasses.dataclass
class UNet2DConditionOutput:
sample: torch.FloatTensor

# Using traced unet from hf hub
unet_file = hf_hub_download(
"riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced"
)
unet_traced = torch.jit.load(unet_file)
# Using traced unet from hf hub
unet_file = hf_hub_download(
checkpoint, filename="unet_traced.pt", subfolder="unet_traced"
)
unet_traced = torch.jit.load(unet_file)

class TracedUNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.in_channels = model.unet.in_channels
self.device = model.unet.device
self.dtype = torch.float16
class TracedUNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.in_channels = model.unet.in_channels
self.device = model.unet.device
self.dtype = torch.float16

def forward(self, latent_model_input, t, encoder_hidden_states):
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
return UNet2DConditionOutput(sample=sample)
def forward(self, latent_model_input, t, encoder_hidden_states):
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
return UNet2DConditionOutput(sample=sample)

model.unet = TracedUNet()
model.unet = TracedUNet()

model = model.to("cuda")

Expand Down Expand Up @@ -151,8 +153,6 @@ def run_inference():
return response


# TODO(hayk): Enable cache here.
# @functools.lru_cache()
def compute(inputs: InferenceInput) -> str:
"""
Does all the heavy lifting of the request.
Expand Down

0 comments on commit 8349ccf

Please sign in to comment.