Skip to content

Commit

Permalink
add requirements.txt, annotate the script and add reference to index.rst
Browse files Browse the repository at this point in the history
  • Loading branch information
bowang007 committed May 7, 2024
1 parent 62d8773 commit 4bc05b7
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 1 deletion.
3 changes: 3 additions & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ Tutorials
tutorials/_rendered_examples/dynamo/torch_compile_transformers_example
tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage
tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion
tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2
tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion


Python API Documenation
------------------------
Expand Down
25 changes: 25 additions & 0 deletions examples/distributed_inference/data_parallel_gpt2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
"""
.. _data_parallel_gpt2:
Torch-TensorRT Distributed Inference
======================================================
This interactive script is intended as a sample of distributed inference using data
parallelism using Accelerate
library with the Torch-TensorRT workflow on GPT2 model.
"""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import torch
from accelerate import PartialState
from transformers import AutoTokenizer, GPT2LMHeadModel
Expand All @@ -6,6 +22,7 @@

tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Set input prompts for different devices
prompt1 = "GPT2 is a model developed by."
prompt2 = "Llama is a model developed by "

Expand All @@ -14,8 +31,11 @@

distributed_state = PartialState()

# Import GPT2 model and load to distributed devices
model = GPT2LMHeadModel.from_pretrained("gpt2").eval().to(distributed_state.device)


# Instantiate model with Torch-TensorRT backend
model.forward = torch.compile(
model.forward,
backend="torch_tensorrt",
Expand All @@ -27,6 +47,11 @@
dynamic=False,
)

# %%
# Inference
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Assume there are 2 processes (2 devices)
with distributed_state.split_between_processes([input_id1, input_id2]) as prompt:
cur_input = torch.clone(prompt[0]).to(distributed_state.device)

Expand Down
26 changes: 25 additions & 1 deletion examples/distributed_inference/data_parallel_stable_diffusion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
"""
.. _data_parallel_stable_diffusion:
Torch-TensorRT Distributed Inference
======================================================
This interactive script is intended as a sample of distributed inference using data
parallelism using Accelerate
library with the Torch-TensorRT workflow on Stable Diffusion model.
"""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
import torch
from accelerate import PartialState
from diffusers import DiffusionPipeline
Expand All @@ -17,7 +32,10 @@
backend = "torch_tensorrt"

# Optimize the UNet portion with Torch-TensorRT
pipe.unet = torch.compile(
pipe.unet = torch.compile( # %%
# Inference
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Assume there are 2 processes (2 devices)
pipe.unet,
backend=backend,
options={
Expand All @@ -30,6 +48,12 @@
)
torch_tensorrt.runtime.set_multi_device_safe_mode(True)


# %%
# Inference
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Assume there are 2 processes (2 devices)
with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
print("before \n")
result = pipe(prompt).images[0]
Expand Down
3 changes: 3 additions & 0 deletions examples/distributed_inference/requirement.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
accelerate
transformers
diffusers

0 comments on commit 4bc05b7

Please sign in to comment.