Skip to content

Commit

Permalink
Make T2I-Adapter downscale padding match the UNet (huggingface#5435)
Browse files Browse the repository at this point in the history
* Update get_dummy_inputs(...) in T2I-Adapter tests to take image height and width as params.

* Update the T2I-Adapter unit tests to run with the standard number of UNet down blocks so that all T2I-Adapter down blocks get exercised.

* Update the T2I-Adapter down blocks to better match the padding behavior of the UNet.

* Revert "Update the T2I-Adapter unit tests to run with the standard number of UNet down blocks so that all T2I-Adapter down blocks get exercised."

This reverts commit 6d4a060.

* Create  utility functions for testing the T2I-Adapter downscaling bahevior.

* (minor) Improve readability with an intermediate named variable.

* Statically parameterize  T2I-Adapter test dimensions rather than generating them dynamically.

* Fix static checks.

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
  • Loading branch information
RyanJDick and sayakpaul committed Oct 23, 2023
1 parent 9dbed6a commit 96f42a4
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 24 deletions.
42 changes: 26 additions & 16 deletions models/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from .modeling_utils import ModelMixin
from .resnet import Downsample2D


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -51,24 +50,28 @@ def __init__(self, adapters: List["T2IAdapter"]):
if len(adapters) == 1:
raise ValueError("For a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`")

# The outputs from each adapter are added together with a weight
# This means that the change in dimenstions from downsampling must
# be the same for all adapters. Inductively, it also means the total
# downscale factor must also be the same for all adapters.

# The outputs from each adapter are added together with a weight.
# This means that the change in dimensions from downsampling must
# be the same for all adapters. Inductively, it also means the
# downscale_factor and total_downscale_factor must be the same for all
# adapters.
first_adapter_total_downscale_factor = adapters[0].total_downscale_factor

first_adapter_downscale_factor = adapters[0].downscale_factor
for idx in range(1, len(adapters)):
adapter_idx_total_downscale_factor = adapters[idx].total_downscale_factor

if adapter_idx_total_downscale_factor != first_adapter_total_downscale_factor:
if (
adapters[idx].total_downscale_factor != first_adapter_total_downscale_factor
or adapters[idx].downscale_factor != first_adapter_downscale_factor
):
raise ValueError(
f"Expecting all adapters to have the same total_downscale_factor, "
f"but got adapters[0].total_downscale_factor={first_adapter_total_downscale_factor} and "
f"adapter[`{idx}`]={adapter_idx_total_downscale_factor}"
f"Expecting all adapters to have the same downscaling behavior, but got:\n"
f"adapters[0].total_downscale_factor={first_adapter_total_downscale_factor}\n"
f"adapters[0].downscale_factor={first_adapter_downscale_factor}\n"
f"adapter[`{idx}`].total_downscale_factor={adapters[idx].total_downscale_factor}\n"
f"adapter[`{idx}`].downscale_factor={adapters[idx].downscale_factor}"
)

self.total_downscale_factor = adapters[0].total_downscale_factor
self.total_downscale_factor = first_adapter_total_downscale_factor
self.downscale_factor = first_adapter_downscale_factor

def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
r"""
Expand Down Expand Up @@ -274,6 +277,13 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
def total_downscale_factor(self):
return self.adapter.total_downscale_factor

@property
def downscale_factor(self):
"""The downscale factor applied in the T2I-Adapter's initial pixel unshuffle operation. If an input image's dimensions are
not evenly divisible by the downscale_factor then an exception will be raised.
"""
return self.adapter.unshuffle.downscale_factor


# full adapter

Expand Down Expand Up @@ -399,7 +409,7 @@ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, dow

self.downsample = None
if down:
self.downsample = Downsample2D(in_channels)
self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)

self.in_conv = None
if in_channels != out_channels:
Expand Down Expand Up @@ -526,7 +536,7 @@ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, dow

self.downsample = None
if down:
self.downsample = Downsample2D(in_channels)
self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)

self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)])
Expand Down
8 changes: 4 additions & 4 deletions pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,17 +568,17 @@ def _default_height_width(self, height, width, image):
elif isinstance(image, torch.Tensor):
height = image.shape[-2]

# round down to nearest multiple of `self.adapter.total_downscale_factor`
height = (height // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
# round down to nearest multiple of `self.adapter.downscale_factor`
height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor

if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[-1]

# round down to nearest multiple of `self.adapter.total_downscale_factor`
width = (width // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
# round down to nearest multiple of `self.adapter.downscale_factor`
width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor

return height, width

Expand Down
8 changes: 4 additions & 4 deletions pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,17 +622,17 @@ def _default_height_width(self, height, width, image):
elif isinstance(image, torch.Tensor):
height = image.shape[-2]

# round down to nearest multiple of `self.adapter.total_downscale_factor`
height = (height // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
# round down to nearest multiple of `self.adapter.downscale_factor`
height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor

if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[-1]

# round down to nearest multiple of `self.adapter.total_downscale_factor`
width = (width // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
# round down to nearest multiple of `self.adapter.downscale_factor`
width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor

return height, width

Expand Down

0 comments on commit 96f42a4

Please sign in to comment.