diff --git a/models/adapter.py b/models/adapter.py index 64d64d07bf77..388915e7c02d 100644 --- a/models/adapter.py +++ b/models/adapter.py @@ -20,7 +20,6 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import logging from .modeling_utils import ModelMixin -from .resnet import Downsample2D logger = logging.get_logger(__name__) @@ -51,24 +50,28 @@ def __init__(self, adapters: List["T2IAdapter"]): if len(adapters) == 1: raise ValueError("For a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`") - # The outputs from each adapter are added together with a weight - # This means that the change in dimenstions from downsampling must - # be the same for all adapters. Inductively, it also means the total - # downscale factor must also be the same for all adapters. - + # The outputs from each adapter are added together with a weight. + # This means that the change in dimensions from downsampling must + # be the same for all adapters. Inductively, it also means the + # downscale_factor and total_downscale_factor must be the same for all + # adapters. first_adapter_total_downscale_factor = adapters[0].total_downscale_factor - + first_adapter_downscale_factor = adapters[0].downscale_factor for idx in range(1, len(adapters)): - adapter_idx_total_downscale_factor = adapters[idx].total_downscale_factor - - if adapter_idx_total_downscale_factor != first_adapter_total_downscale_factor: + if ( + adapters[idx].total_downscale_factor != first_adapter_total_downscale_factor + or adapters[idx].downscale_factor != first_adapter_downscale_factor + ): raise ValueError( - f"Expecting all adapters to have the same total_downscale_factor, " - f"but got adapters[0].total_downscale_factor={first_adapter_total_downscale_factor} and " - f"adapter[`{idx}`]={adapter_idx_total_downscale_factor}" + f"Expecting all adapters to have the same downscaling behavior, but got:\n" + f"adapters[0].total_downscale_factor={first_adapter_total_downscale_factor}\n" + f"adapters[0].downscale_factor={first_adapter_downscale_factor}\n" + f"adapter[`{idx}`].total_downscale_factor={adapters[idx].total_downscale_factor}\n" + f"adapter[`{idx}`].downscale_factor={adapters[idx].downscale_factor}" ) - self.total_downscale_factor = adapters[0].total_downscale_factor + self.total_downscale_factor = first_adapter_total_downscale_factor + self.downscale_factor = first_adapter_downscale_factor def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]: r""" @@ -274,6 +277,13 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]: def total_downscale_factor(self): return self.adapter.total_downscale_factor + @property + def downscale_factor(self): + """The downscale factor applied in the T2I-Adapter's initial pixel unshuffle operation. If an input image's dimensions are + not evenly divisible by the downscale_factor then an exception will be raised. + """ + return self.adapter.unshuffle.downscale_factor + # full adapter @@ -399,7 +409,7 @@ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, dow self.downsample = None if down: - self.downsample = Downsample2D(in_channels) + self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True) self.in_conv = None if in_channels != out_channels: @@ -526,7 +536,7 @@ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, dow self.downsample = None if down: - self.downsample = Downsample2D(in_channels) + self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True) self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1) self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)]) diff --git a/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index b0f20199b48b..d2b9bfb00d6c 100644 --- a/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -568,8 +568,8 @@ def _default_height_width(self, height, width, image): elif isinstance(image, torch.Tensor): height = image.shape[-2] - # round down to nearest multiple of `self.adapter.total_downscale_factor` - height = (height // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor + # round down to nearest multiple of `self.adapter.downscale_factor` + height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor if width is None: if isinstance(image, PIL.Image.Image): @@ -577,8 +577,8 @@ def _default_height_width(self, height, width, image): elif isinstance(image, torch.Tensor): width = image.shape[-1] - # round down to nearest multiple of `self.adapter.total_downscale_factor` - width = (width // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor + # round down to nearest multiple of `self.adapter.downscale_factor` + width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor return height, width diff --git a/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index 4e8f6a9d834b..2a3fca7f4603 100644 --- a/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -622,8 +622,8 @@ def _default_height_width(self, height, width, image): elif isinstance(image, torch.Tensor): height = image.shape[-2] - # round down to nearest multiple of `self.adapter.total_downscale_factor` - height = (height // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor + # round down to nearest multiple of `self.adapter.downscale_factor` + height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor if width is None: if isinstance(image, PIL.Image.Image): @@ -631,8 +631,8 @@ def _default_height_width(self, height, width, image): elif isinstance(image, torch.Tensor): width = image.shape[-1] - # round down to nearest multiple of `self.adapter.total_downscale_factor` - width = (width // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor + # round down to nearest multiple of `self.adapter.downscale_factor` + width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor return height, width