Skip to content

Commit

Permalink
Add blueprint to depth guided stable diffusion example (#5582)
Browse files Browse the repository at this point in the history
### What

This was a tough one! But I'm somewhat happy with how this looks now.


![image](https://github.com/rerun-io/rerun/assets/1220815/c8813994-0cbe-4b11-8c39-62e507828b81)

For comparision, this is a social media post we did with this demo at
some point. It had a different aspect ratio though which is why I
diverged from this quite a bit

https://www.linkedin.com/posts/nikolauswest_visualizing-depth-guided-stable-diffusion-activity-7011344379873234945-IFS7/?trk=public_profile_like_view


### Checklist
* [x] I have read and agree to [Contributor
Guide](https://github.com/rerun-io/rerun/blob/main/CONTRIBUTING.md) and
the [Code of
Conduct](https://github.com/rerun-io/rerun/blob/main/CODE_OF_CONDUCT.md)
* [x] I've included a screenshot or gif (if applicable)
* [x] I have tested the web demo (if applicable):
* Using newly built examples:
[app.rerun.io](https://app.rerun.io/pr/5582/index.html)
* Using examples from latest `main` build:
[app.rerun.io](https://app.rerun.io/pr/5582/index.html?manifest_url=https://app.rerun.io/version/main/examples_manifest.json)
* Using full set of examples from `nightly` build:
[app.rerun.io](https://app.rerun.io/pr/5582/index.html?manifest_url=https://app.rerun.io/version/nightly/examples_manifest.json)
* [x] The PR title and labels are set such as to maximize their
usefulness for the next release's CHANGELOG
* [x] If applicable, add a new check to the [release
checklist](https://github.com/rerun-io/rerun/blob/main/tests/python/release_checklist)!

- [PR Build Summary](https://build.rerun.io/pr/5582)
- [Docs
preview](https://rerun.io/preview/f3c0af0de4cf5cb54526428cf4aaaa6c23f2c91b/docs)
<!--DOCS-PREVIEW-->
- [Examples
preview](https://rerun.io/preview/f3c0af0de4cf5cb54526428cf4aaaa6c23f2c91b/examples)
<!--EXAMPLES-PREVIEW-->
- [Recent benchmark results](https://build.rerun.io/graphs/crates.html)
- [Wasm size tracking](https://build.rerun.io/graphs/sizes.html)

---------

Co-authored-by: Clement Rey <cr.rey.clement@gmail.com>
  • Loading branch information
Wumpf and teh-cmc committed Mar 20, 2024
1 parent 676b6ef commit ce7eab6
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
rr.log("prompt/text_input/ids", rr.Tensor(text_input_ids))
rr.log("prompt/text_input/ids", rr.BarChart(text_input_ids))
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids

if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
Expand All @@ -229,7 +229,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
)

if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
rr.log("prompt/text_input/attention_mask", rr.Tensor(text_inputs.attention_mask))
rr.log("prompt/text_input/attention_mask", rr.BarChart(text_inputs.attention_mask))
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
Expand Down
53 changes: 52 additions & 1 deletion examples/python/depth_guided_stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import requests
import rerun as rr # pip install rerun-sdk
import rerun.blueprint as rrb
import torch
from huggingface_pipeline import StableDiffusionDepth2ImgPipeline
from PIL import Image
Expand Down Expand Up @@ -112,7 +113,57 @@ def main() -> None:
rr.script_add_args(parser)
args = parser.parse_args()

rr.script_setup(args, "rerun_example_depth_guided_stable_diffusion")
rr.script_setup(
args,
"rerun_example_depth_guided_stable_diffusion",
# This example is very complex, making it too hard for the Viewer to infer a good layout.
# Therefore, we specify everything explicitly:
# We set up three columns using a `Horizontal` layout, one each for
# * inputs
# * depth & initializations
# * diffusion outputs
blueprint=rrb.Blueprint(
rrb.Horizontal(
rrb.Vertical(
rrb.Tabs(
rrb.Spatial2DView(name="Image original", origin="image/original"),
rrb.TensorView(name="Image preprocessed", origin="input_image/preprocessed"),
),
rrb.Vertical(
rrb.TextLogView(name="Prompt", contents=["prompt/text", "prompt/text_negative"]),
rrb.Tabs(
rrb.TensorView(name="Text embeddings", origin="prompt/text_embeddings"),
rrb.TensorView(name="Unconditional embeddings", origin="prompt/uncond_embeddings"),
),
rrb.BarChartView(name="Prompt ids", origin="prompt/text_input"),
),
),
rrb.Vertical(
rrb.Tabs(
rrb.Spatial2DView(name="Depth estimated", origin="depth/estimated"),
rrb.Spatial2DView(name="Depth interpolated", origin="depth/interpolated"),
rrb.Spatial2DView(name="Depth normalized", origin="depth/normalized"),
rrb.TensorView(name="Depth input pre-processed", origin="depth/input_preprocessed"),
active_tab="Depth interpolated",
),
rrb.Tabs(
rrb.TensorView(name="Encoded input", origin="encoded_input_image"),
rrb.TensorView(name="Decoded init latents", origin="decoded_init_latents"),
),
),
rrb.Vertical(
rrb.Spatial2DView(name="Image diffused", origin="image/diffused"),
rrb.Horizontal(
rrb.TensorView(name="Latent Model Input", origin="diffusion/latent_model_input"),
rrb.TensorView(name="Diffusion latents", origin="diffusion/latents"),
# rrb.TensorView(name="Noise Prediction", origin="diffusion/noise_pred"),
),
),
),
rrb.SelectionPanel(expanded=False),
rrb.TimePanel(expanded=False),
),
)

image_path = args.image_path # type: str
if not image_path:
Expand Down
2 changes: 1 addition & 1 deletion rerun_py/rerun_sdk/rerun/archetypes/bar_chart_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def values__field_converter_override(data: TensorDataArrayLike) -> TensorDataBat
# once we coerce to a canonical non-arrow type.
shape_dims = tensor_data.as_arrow_array()[0].value["shape"].values.field(0).to_numpy()

if len(shape_dims) != 1:
if len([d for d in shape_dims if d != 1]) != 1:
_send_warning_or_raise(
f"Bar chart data should only be 1D. Got values with shape: {shape_dims}",
2,
Expand Down

0 comments on commit ce7eab6

Please sign in to comment.