From 26574f5853e4b5cbb98ea719a3ca2a992d437638 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Fri, 24 Apr 2026 12:52:58 -0500 Subject: [PATCH 01/11] Fix SeedVR2 native VAE execution --- comfy/ldm/seedvr/vae.py | 73 +++++++++++++++++++++++++++++++++++- comfy/model_detection.py | 11 ++++++ comfy/sd.py | 60 +++++++++++++++++++++++++++-- comfy_extras/nodes_seedvr.py | 61 +++++++++++++++++++++++++++--- nodes.py | 20 ++++++++++ 5 files changed, 215 insertions(+), 10 deletions(-) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 9eae4bc52b67..a79d9a06aa42 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -1,6 +1,8 @@ from contextlib import nullcontext from typing import Literal, Optional, Tuple import gc +import os +import time import torch import torch.nn as nn import torch.nn.functional as F @@ -21,6 +23,11 @@ ops = comfy.ops.disable_weight_init +def issue130_trace(message): + if os.environ.get("ISSUE130_SEEDVR2_TRACE") == "1": + print(f"ISSUE130_TRACE {time.monotonic():.6f} seedvr_vae {message}", flush=True) + + @torch.inference_mode() def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, encode=True, **kwargs): @@ -2120,23 +2127,41 @@ def decode_( def _encode( self, x, memory_state = MemoryState.DISABLED ) -> torch.Tensor: + issue130_trace( + "VideoAutoencoderKL._encode.start " + f"x_shape={tuple(x.shape)} x_device={x.device} target_device={self.device} memory_state={memory_state}" + ) + start_time = time.monotonic() _x = x.to(self.device) h = self.encoder(_x, memory_state=memory_state) if self.quant_conv is not None: output = self.quant_conv(h, memory_state=memory_state) else: output = h + issue130_trace( + "VideoAutoencoderKL._encode.end " + f"output_shape={tuple(output.shape)} output_device={output.device} elapsed={time.monotonic() - start_time:.3f}" + ) return output.to(x.device) def _decode( self, z, memory_state = MemoryState.DISABLED ) -> torch.Tensor: + issue130_trace( + "VideoAutoencoderKL._decode.start " + f"z_shape={tuple(z.shape)} z_device={z.device} target_device={self.device} memory_state={memory_state}" + ) + start_time = time.monotonic() _z = z.to(self.device) if self.post_quant_conv is not None: _z = self.post_quant_conv(_z, memory_state=memory_state) output = self.decoder(_z, memory_state=memory_state) + issue130_trace( + "VideoAutoencoderKL._decode.end " + f"output_shape={tuple(output.shape)} output_device={output.device} elapsed={time.monotonic() - start_time:.3f}" + ) return output.to(z.device) def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: @@ -2230,19 +2255,42 @@ def forward(self, x: torch.FloatTensor): return x, z, p def encode(self, x, orig_dims=None): + start_time = time.monotonic() + issue130_trace( + "VideoAutoencoderKLWrapper.encode.start " + f"x_shape={tuple(x.shape)} x_device={x.device} x_dtype={x.dtype} orig_dims={orig_dims}" + ) # we need to keep a reference to the image/video so we later can do a colour fix later #self.original_image_video = x if orig_dims is not None: self.img_dims = orig_dims if x.ndim == 4: x = x.unsqueeze(2) - x = x.to(next(self.parameters()).dtype) - x = x.to(next(self.parameters()).device) + issue130_trace(f"VideoAutoencoderKLWrapper.encode.after_unsqueeze x_shape={tuple(x.shape)}") + x = x.to(dtype=next(self.parameters()).dtype) + self.device = x.device + issue130_trace( + "VideoAutoencoderKLWrapper.encode.before_super " + f"x_shape={tuple(x.shape)} x_device={x.device} x_dtype={x.dtype}" + ) p = super().encode(x) + issue130_trace( + "VideoAutoencoderKLWrapper.encode.after_super " + f"latent_dist_shape={tuple(p.shape)} elapsed={time.monotonic() - start_time:.3f}" + ) z = p.squeeze(2) + issue130_trace( + "VideoAutoencoderKLWrapper.encode.end " + f"z_shape={tuple(z.shape)} elapsed={time.monotonic() - start_time:.3f}" + ) return z, p def decode(self, z): + start_time = time.monotonic() + issue130_trace( + "VideoAutoencoderKLWrapper.decode.start " + f"z_shape={tuple(z.shape)} z_device={z.device} z_dtype={z.dtype}" + ) b, tc, h, w = z.shape latent = z.view(b, 16, -1, h, w) scale = 0.9152 @@ -2252,12 +2300,25 @@ def decode(self, z): if latent.ndim == 4: latent = latent.unsqueeze(2) + self.device = latent.device self.enable_tiling = self.tiled_args.get("enable_tiling", False) if self.enable_tiling: + issue130_trace( + "VideoAutoencoderKLWrapper.decode.before_tiled_decode " + f"latent_shape={tuple(latent.shape)} tiled_args={self.tiled_args}" + ) x = tiled_vae(latent, self, **self.tiled_args, encode=False).squeeze(2) else: + issue130_trace( + "VideoAutoencoderKLWrapper.decode.before_super_decode " + f"latent_shape={tuple(latent.shape)}" + ) x = super().decode_(latent).squeeze(2) + issue130_trace( + "VideoAutoencoderKLWrapper.decode.after_model_decode " + f"x_shape={tuple(x.shape)} elapsed={time.monotonic() - start_time:.3f}" + ) input = rearrange(self.original_image_video, "b c t h w -> (b t) c h w") if x.ndim == 4: @@ -2281,6 +2342,10 @@ def decode(self, z): x = x[..., :o_h, :o_w] input = input[..., :o_h, :o_w ] x = lab_color_transfer(x, input) + issue130_trace( + "VideoAutoencoderKLWrapper.decode.after_lab_color_transfer " + f"x_shape={tuple(x.shape)} elapsed={time.monotonic() - start_time:.3f}" + ) x = x.unsqueeze(0) x = rearrange(x, "b t c h w -> b c t h w") @@ -2291,6 +2356,10 @@ def decode(self, z): h2 = h - (h % 2) x = x[..., :h2, :w2] + issue130_trace( + "VideoAutoencoderKLWrapper.decode.end " + f"x_shape={tuple(x.shape)} elapsed={time.monotonic() - start_time:.3f}" + ) return x def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float], memory_device = "same"): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 499fc71eebda..bd2a2d3944e6 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -490,6 +490,17 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): return dit_config + elif "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) in state_dict_keys and state_dict["{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)].shape[1] == 3072: # seedvr2 7b + dit_config = {} + dit_config["image_model"] = "seedvr2" + dit_config["vid_dim"] = 3072 + dit_config["heads"] = 24 + dit_config["num_layers"] = 36 + dit_config["norm_eps"] = 1e-5 + dit_config["qk_rope"] = None + dit_config["rope_type"] = None + dit_config["mlp_type"] = "normal" + return dit_config elif "{}blocks.36.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b dit_config = {} dit_config["image_model"] = "seedvr2" diff --git a/comfy/sd.py b/comfy/sd.py index bf58a8cc5ff7..dffcc08fd696 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -26,9 +26,15 @@ import yaml import math import os +import time import comfy.utils + +def issue130_trace(message): + if os.environ.get("ISSUE130_SEEDVR2_TRACE") == "1": + print(f"ISSUE130_TRACE {time.monotonic():.6f} sd {message}", flush=True) + from . import clip_vision from . import gligen from . import diffusers_convert @@ -439,7 +445,8 @@ def decode(self, token_ids, skip_special_tokens=True): class VAE: def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): - if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format + is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd + if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format if metadata is None or metadata.get("keep_diffusers_format") != "true": sd = diffusers_convert.convert_vae_state_dict(sd) @@ -512,8 +519,11 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None) self.latent_channels = 16 elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2 self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper() - self.memory_used_decode = lambda shape, dtype: (shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype) - self.memory_used_encode = lambda shape, dtype: (max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.latent_channels = 16 + self.latent_dim = 3 + self.disable_offload = True + self.memory_used_decode = lambda shape, dtype: (shape[1] * shape[-2] * shape[-1] * (4 * 8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (max(shape[2], 5) * shape[3] * shape[4] * 64) * model_management.dtype_size(dtype) self.working_dtypes = [torch.bfloat16, torch.float32] self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) self.downscale_index_formula = (4, 8, 8) @@ -976,16 +986,24 @@ def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap= def decode(self, samples_in, vae_options={}): self.throw_exception_if_invalid() + trace_start = time.monotonic() + issue130_trace( + "VAE.decode.start " + f"samples_shape={tuple(samples_in.shape)} samples_device={samples_in.device} " + f"samples_dtype={samples_in.dtype} vae_dtype={self.vae_dtype}" + ) pixel_samples = None do_tile = False if self.latent_dim == 2 and samples_in.ndim == 5: samples_in = samples_in[:, :, 0] try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) + issue130_trace(f"VAE.decode.memory_required={memory_used}") model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) free_memory = self.patcher.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) + issue130_trace(f"VAE.decode.batch_number={batch_number} free_memory={free_memory}") # Pre-allocate output for VAEs that support direct buffer writes preallocated = False @@ -994,6 +1012,7 @@ def decode(self, samples_in, vae_options={}): preallocated = True for x in range(0, samples_in.shape[0], batch_number): + issue130_trace(f"VAE.decode.batch_start start={x} stop={x + batch_number}") samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype) if preallocated: self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options) @@ -1004,6 +1023,7 @@ def decode(self, samples_in, vae_options={}): pixel_samples[x:x+batch_number].copy_(out) del out self.process_output(pixel_samples[x:x+batch_number]) + issue130_trace(f"VAE.decode.batch_end start={x} elapsed={time.monotonic() - trace_start:.3f}") except Exception as e: model_management.raise_non_oom(e) logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") @@ -1014,6 +1034,7 @@ def decode(self, samples_in, vae_options={}): do_tile = True if do_tile: + issue130_trace("VAE.decode.tiled_fallback_start") comfy.model_management.soft_empty_cache() dims = samples_in.ndim - 2 if dims == 1 or self.extra_1d_channel is not None: @@ -1024,12 +1045,23 @@ def decode(self, samples_in, vae_options={}): tile = 256 // self.spacial_compression_decode() overlap = tile // 4 pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + issue130_trace(f"VAE.decode.tiled_fallback_end elapsed={time.monotonic() - trace_start:.3f}") pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) + issue130_trace( + "VAE.decode.end " + f"pixel_shape={tuple(pixel_samples.shape)} elapsed={time.monotonic() - trace_start:.3f}" + ) return pixel_samples def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): self.throw_exception_if_invalid() + trace_start = time.monotonic() + issue130_trace( + "VAE.decode_tiled.start " + f"samples_shape={tuple(samples.shape)} tile_x={tile_x} tile_y={tile_y} " + f"overlap={overlap} tile_t={tile_t} overlap_t={overlap_t}" + ) memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) dims = samples.ndim - 2 @@ -1055,10 +1087,20 @@ def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=N args["tile_t"] = max(2, tile_t) output = self.decode_tiled_3d(samples, **args) + issue130_trace( + "VAE.decode_tiled.end " + f"output_shape={tuple(output.shape)} elapsed={time.monotonic() - trace_start:.3f}" + ) return output.movedim(1, -1) def encode(self, pixel_samples): self.throw_exception_if_invalid() + trace_start = time.monotonic() + issue130_trace( + "VAE.encode.start " + f"pixel_shape={tuple(pixel_samples.shape)} pixel_device={pixel_samples.device} " + f"pixel_dtype={pixel_samples.dtype} latent_dim={self.latent_dim}" + ) pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = pixel_samples.movedim(-1, 1) do_tile = False @@ -1069,22 +1111,28 @@ def encode(self, pixel_samples): pixel_samples = pixel_samples.unsqueeze(2) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) + issue130_trace(f"VAE.encode.memory_required={memory_used}") model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) free_memory = self.patcher.get_free_memory(self.device) batch_number = int(free_memory / max(1, memory_used)) batch_number = max(1, batch_number) + issue130_trace(f"VAE.encode.batch_number={batch_number} free_memory={free_memory}") samples = None for x in range(0, pixel_samples.shape[0], batch_number): + issue130_trace(f"VAE.encode.batch_start start={x} stop={x + batch_number}") pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype) if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): out = self.first_stage_model.encode(pixels_in, device=self.device) else: pixels_in = pixels_in.to(self.device) out = self.first_stage_model.encode(pixels_in) + if isinstance(out, tuple): + out = out[0] out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) if samples is None: samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) samples[x:x + batch_number] = out + issue130_trace(f"VAE.encode.batch_end start={x} elapsed={time.monotonic() - trace_start:.3f}") except Exception as e: model_management.raise_non_oom(e) @@ -1096,6 +1144,7 @@ def encode(self, pixel_samples): do_tile = True if do_tile: + issue130_trace("VAE.encode.tiled_fallback_start") comfy.model_management.soft_empty_cache() if self.latent_dim == 3: tile = 256 @@ -1105,7 +1154,12 @@ def encode(self, pixel_samples): samples = self.encode_tiled_1d(pixel_samples) else: samples = self.encode_tiled_(pixel_samples) + issue130_trace(f"VAE.encode.tiled_fallback_end elapsed={time.monotonic() - trace_start:.3f}") + issue130_trace( + "VAE.encode.end " + f"samples_shape={tuple(samples.shape)} elapsed={time.monotonic() - trace_start:.3f}" + ) return samples def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 8aa166c48d4e..ae6545865a37 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -2,6 +2,8 @@ from comfy_api.latest import ComfyExtension, io import torch import math +import os +import time from einops import rearrange import gc @@ -13,6 +15,11 @@ from torchvision.transforms.functional import InterpolationMode from comfy.ldm.seedvr.vae import tiled_vae + +def issue130_trace(message): + if os.environ.get("ISSUE130_SEEDVR2_TRACE") == "1": + print(f"ISSUE130_TRACE {time.monotonic():.6f} nodes_seedvr {message}", flush=True) + def clear_vae_memory(vae_model): for module in vae_model.modules(): if hasattr(module, "memory"): @@ -143,7 +150,15 @@ def define_schema(cls): @classmethod def execute(cls, images, vae, resolution, spatial_tile_size, spatial_overlap, temporal_tile_size, enable_tiling): + issue130_trace( + "SeedVR2InputProcessing.start " + f"images_shape={tuple(images.shape)} resolution={resolution} " + f"spatial_tile_size={spatial_tile_size} spatial_overlap={spatial_overlap} " + f"temporal_tile_size={temporal_tile_size} enable_tiling={enable_tiling}" + ) + start_time = time.monotonic() comfy.model_management.load_models_gpu([vae.patcher]) + issue130_trace("SeedVR2InputProcessing.after_load_models_gpu") vae_model = vae.first_stage_model scale = 0.9152 shift = 0 @@ -166,8 +181,8 @@ def execute(cls, images, vae, resolution, spatial_tile_size, spatial_overlap, te images = images.reshape(b, t, c, new_h, new_w) images = cut_videos(images) - - images = rearrange(images, "b t c h w -> b c t h w") + images_bcthw = rearrange(images, "b t c h w -> b c t h w") + images_bthwc = rearrange(images, "b t c h w -> b t h w c") # in case users a non-compatiable number for tiling def make_divisible(val, divisor): @@ -182,24 +197,51 @@ def make_divisible(val, divisor): args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap), "temporal_size":temporal_tile_size} if enable_tiling: - latent = tiled_vae(images, vae_model, encode=True, **args) + issue130_trace( + "SeedVR2InputProcessing.before_tiled_encode " + f"images_shape={tuple(images_bthwc.shape)} args={args}" + ) + vae_model.img_dims = [o_h, o_w] + vae_model.original_image_video = images_bcthw + latent = vae.encode_tiled( + images_bthwc, + tile_x=spatial_tile_size, + tile_y=spatial_tile_size, + overlap=spatial_overlap, + tile_t=temporal_tile_size, + ) else: - latent = vae_model.encode(images, orig_dims = [o_h, o_w])[0] + issue130_trace( + "SeedVR2InputProcessing.before_direct_encode " + f"images_shape={tuple(images_bthwc.shape)} orig_dims={[o_h, o_w]}" + ) + vae_model.img_dims = [o_h, o_w] + vae_model.original_image_video = images_bcthw + latent = vae.encode(images_bthwc) + issue130_trace( + "SeedVR2InputProcessing.after_encode " + f"latent_shape={tuple(latent.shape)} elapsed={time.monotonic() - start_time:.3f}" + ) clear_vae_memory(vae_model) + issue130_trace("SeedVR2InputProcessing.after_clear_vae_memory") #images = images.to(offload_device) #vae_model = vae_model.to(offload_device) vae_model.img_dims = [o_h, o_w] args["enable_tiling"] = enable_tiling vae_model.tiled_args = args - vae_model.original_image_video = images + vae_model.original_image_video = images_bcthw latent = latent.unsqueeze(2) if latent.ndim == 4 else latent latent = rearrange(latent, "b c ... -> b ... c") latent = (latent - shift) * scale + issue130_trace( + "SeedVR2InputProcessing.end " + f"output_shape={tuple(latent.shape)} elapsed={time.monotonic() - start_time:.3f}" + ) return io.NodeOutput({"samples": latent}) class SeedVR2Conditioning(io.ComfyNode): @@ -221,7 +263,12 @@ def define_schema(cls): @classmethod def execute(cls, vae_conditioning, model, latent_noise_scale) -> io.NodeOutput: + start_time = time.monotonic() vae_conditioning = vae_conditioning["samples"] + issue130_trace( + "SeedVR2Conditioning.start " + f"vae_conditioning_shape={tuple(vae_conditioning.shape)} latent_noise_scale={latent_noise_scale}" + ) device = vae_conditioning.device model = model.model.diffusion_model pos_cond = model.positive_conditioning @@ -260,6 +307,10 @@ def execute(cls, vae_conditioning, model, latent_noise_scale) -> io.NodeOutput: negative = [[neg_cond.unsqueeze(0), {"condition": condition}]] positive = [[pos_cond.unsqueeze(0), {"condition": condition}]] + issue130_trace( + "SeedVR2Conditioning.end " + f"latent_shape={tuple(noises.shape)} elapsed={time.monotonic() - start_time:.3f}" + ) return io.NodeOutput(positive, negative, {"samples": noises}) class SeedVRExtension(ComfyExtension): diff --git a/nodes.py b/nodes.py index 9059ed639d92..20ec946e151d 100644 --- a/nodes.py +++ b/nodes.py @@ -313,9 +313,29 @@ def decode(self, vae, samples): if latent.is_nested: latent = latent.unbind()[0] + trace_enabled = os.environ.get("ISSUE130_SEEDVR2_TRACE") == "1" + start_time = time.monotonic() + if trace_enabled: + print( + f"ISSUE130_TRACE {start_time:.6f} nodes VAEDecode.start " + f"latent_shape={tuple(latent.shape)}", + flush=True, + ) images = vae.decode(latent) + if trace_enabled: + print( + f"ISSUE130_TRACE {time.monotonic():.6f} nodes VAEDecode.after_decode " + f"images_shape={tuple(images.shape)} elapsed={time.monotonic() - start_time:.3f}", + flush=True, + ) if len(images.shape) == 5: #Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) + if trace_enabled: + print( + f"ISSUE130_TRACE {time.monotonic():.6f} nodes VAEDecode.end " + f"images_shape={tuple(images.shape)} elapsed={time.monotonic() - start_time:.3f}", + flush=True, + ) return (images, ) class VAEDecodeTiled: From 1558a81e937fb07ba05dbcf5ca85f12cebdf08e1 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Fri, 24 Apr 2026 14:08:02 -0500 Subject: [PATCH 02/11] Remove SeedVR2 diagnostic tracing --- comfy/ldm/seedvr/vae.py | 68 ------------------------------------ comfy/sd.py | 48 ------------------------- comfy_extras/nodes_seedvr.py | 40 --------------------- nodes.py | 20 ----------- 4 files changed, 176 deletions(-) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index a79d9a06aa42..d211558f6ae3 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -1,8 +1,6 @@ from contextlib import nullcontext from typing import Literal, Optional, Tuple import gc -import os -import time import torch import torch.nn as nn import torch.nn.functional as F @@ -23,11 +21,6 @@ ops = comfy.ops.disable_weight_init -def issue130_trace(message): - if os.environ.get("ISSUE130_SEEDVR2_TRACE") == "1": - print(f"ISSUE130_TRACE {time.monotonic():.6f} seedvr_vae {message}", flush=True) - - @torch.inference_mode() def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, encode=True, **kwargs): @@ -2127,41 +2120,23 @@ def decode_( def _encode( self, x, memory_state = MemoryState.DISABLED ) -> torch.Tensor: - issue130_trace( - "VideoAutoencoderKL._encode.start " - f"x_shape={tuple(x.shape)} x_device={x.device} target_device={self.device} memory_state={memory_state}" - ) - start_time = time.monotonic() _x = x.to(self.device) h = self.encoder(_x, memory_state=memory_state) if self.quant_conv is not None: output = self.quant_conv(h, memory_state=memory_state) else: output = h - issue130_trace( - "VideoAutoencoderKL._encode.end " - f"output_shape={tuple(output.shape)} output_device={output.device} elapsed={time.monotonic() - start_time:.3f}" - ) return output.to(x.device) def _decode( self, z, memory_state = MemoryState.DISABLED ) -> torch.Tensor: - issue130_trace( - "VideoAutoencoderKL._decode.start " - f"z_shape={tuple(z.shape)} z_device={z.device} target_device={self.device} memory_state={memory_state}" - ) - start_time = time.monotonic() _z = z.to(self.device) if self.post_quant_conv is not None: _z = self.post_quant_conv(_z, memory_state=memory_state) output = self.decoder(_z, memory_state=memory_state) - issue130_trace( - "VideoAutoencoderKL._decode.end " - f"output_shape={tuple(output.shape)} output_device={output.device} elapsed={time.monotonic() - start_time:.3f}" - ) return output.to(z.device) def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: @@ -2255,42 +2230,19 @@ def forward(self, x: torch.FloatTensor): return x, z, p def encode(self, x, orig_dims=None): - start_time = time.monotonic() - issue130_trace( - "VideoAutoencoderKLWrapper.encode.start " - f"x_shape={tuple(x.shape)} x_device={x.device} x_dtype={x.dtype} orig_dims={orig_dims}" - ) # we need to keep a reference to the image/video so we later can do a colour fix later #self.original_image_video = x if orig_dims is not None: self.img_dims = orig_dims if x.ndim == 4: x = x.unsqueeze(2) - issue130_trace(f"VideoAutoencoderKLWrapper.encode.after_unsqueeze x_shape={tuple(x.shape)}") x = x.to(dtype=next(self.parameters()).dtype) self.device = x.device - issue130_trace( - "VideoAutoencoderKLWrapper.encode.before_super " - f"x_shape={tuple(x.shape)} x_device={x.device} x_dtype={x.dtype}" - ) p = super().encode(x) - issue130_trace( - "VideoAutoencoderKLWrapper.encode.after_super " - f"latent_dist_shape={tuple(p.shape)} elapsed={time.monotonic() - start_time:.3f}" - ) z = p.squeeze(2) - issue130_trace( - "VideoAutoencoderKLWrapper.encode.end " - f"z_shape={tuple(z.shape)} elapsed={time.monotonic() - start_time:.3f}" - ) return z, p def decode(self, z): - start_time = time.monotonic() - issue130_trace( - "VideoAutoencoderKLWrapper.decode.start " - f"z_shape={tuple(z.shape)} z_device={z.device} z_dtype={z.dtype}" - ) b, tc, h, w = z.shape latent = z.view(b, 16, -1, h, w) scale = 0.9152 @@ -2304,21 +2256,9 @@ def decode(self, z): self.enable_tiling = self.tiled_args.get("enable_tiling", False) if self.enable_tiling: - issue130_trace( - "VideoAutoencoderKLWrapper.decode.before_tiled_decode " - f"latent_shape={tuple(latent.shape)} tiled_args={self.tiled_args}" - ) x = tiled_vae(latent, self, **self.tiled_args, encode=False).squeeze(2) else: - issue130_trace( - "VideoAutoencoderKLWrapper.decode.before_super_decode " - f"latent_shape={tuple(latent.shape)}" - ) x = super().decode_(latent).squeeze(2) - issue130_trace( - "VideoAutoencoderKLWrapper.decode.after_model_decode " - f"x_shape={tuple(x.shape)} elapsed={time.monotonic() - start_time:.3f}" - ) input = rearrange(self.original_image_video, "b c t h w -> (b t) c h w") if x.ndim == 4: @@ -2342,10 +2282,6 @@ def decode(self, z): x = x[..., :o_h, :o_w] input = input[..., :o_h, :o_w ] x = lab_color_transfer(x, input) - issue130_trace( - "VideoAutoencoderKLWrapper.decode.after_lab_color_transfer " - f"x_shape={tuple(x.shape)} elapsed={time.monotonic() - start_time:.3f}" - ) x = x.unsqueeze(0) x = rearrange(x, "b t c h w -> b c t h w") @@ -2356,10 +2292,6 @@ def decode(self, z): h2 = h - (h % 2) x = x[..., :h2, :w2] - issue130_trace( - "VideoAutoencoderKLWrapper.decode.end " - f"x_shape={tuple(x.shape)} elapsed={time.monotonic() - start_time:.3f}" - ) return x def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float], memory_device = "same"): diff --git a/comfy/sd.py b/comfy/sd.py index dffcc08fd696..4f2b7cf75341 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -26,15 +26,9 @@ import yaml import math import os -import time import comfy.utils - -def issue130_trace(message): - if os.environ.get("ISSUE130_SEEDVR2_TRACE") == "1": - print(f"ISSUE130_TRACE {time.monotonic():.6f} sd {message}", flush=True) - from . import clip_vision from . import gligen from . import diffusers_convert @@ -986,24 +980,16 @@ def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap= def decode(self, samples_in, vae_options={}): self.throw_exception_if_invalid() - trace_start = time.monotonic() - issue130_trace( - "VAE.decode.start " - f"samples_shape={tuple(samples_in.shape)} samples_device={samples_in.device} " - f"samples_dtype={samples_in.dtype} vae_dtype={self.vae_dtype}" - ) pixel_samples = None do_tile = False if self.latent_dim == 2 and samples_in.ndim == 5: samples_in = samples_in[:, :, 0] try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) - issue130_trace(f"VAE.decode.memory_required={memory_used}") model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) free_memory = self.patcher.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) - issue130_trace(f"VAE.decode.batch_number={batch_number} free_memory={free_memory}") # Pre-allocate output for VAEs that support direct buffer writes preallocated = False @@ -1012,7 +998,6 @@ def decode(self, samples_in, vae_options={}): preallocated = True for x in range(0, samples_in.shape[0], batch_number): - issue130_trace(f"VAE.decode.batch_start start={x} stop={x + batch_number}") samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype) if preallocated: self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options) @@ -1023,7 +1008,6 @@ def decode(self, samples_in, vae_options={}): pixel_samples[x:x+batch_number].copy_(out) del out self.process_output(pixel_samples[x:x+batch_number]) - issue130_trace(f"VAE.decode.batch_end start={x} elapsed={time.monotonic() - trace_start:.3f}") except Exception as e: model_management.raise_non_oom(e) logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") @@ -1034,7 +1018,6 @@ def decode(self, samples_in, vae_options={}): do_tile = True if do_tile: - issue130_trace("VAE.decode.tiled_fallback_start") comfy.model_management.soft_empty_cache() dims = samples_in.ndim - 2 if dims == 1 or self.extra_1d_channel is not None: @@ -1045,23 +1028,12 @@ def decode(self, samples_in, vae_options={}): tile = 256 // self.spacial_compression_decode() overlap = tile // 4 pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) - issue130_trace(f"VAE.decode.tiled_fallback_end elapsed={time.monotonic() - trace_start:.3f}") pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) - issue130_trace( - "VAE.decode.end " - f"pixel_shape={tuple(pixel_samples.shape)} elapsed={time.monotonic() - trace_start:.3f}" - ) return pixel_samples def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): self.throw_exception_if_invalid() - trace_start = time.monotonic() - issue130_trace( - "VAE.decode_tiled.start " - f"samples_shape={tuple(samples.shape)} tile_x={tile_x} tile_y={tile_y} " - f"overlap={overlap} tile_t={tile_t} overlap_t={overlap_t}" - ) memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) dims = samples.ndim - 2 @@ -1087,20 +1059,10 @@ def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=N args["tile_t"] = max(2, tile_t) output = self.decode_tiled_3d(samples, **args) - issue130_trace( - "VAE.decode_tiled.end " - f"output_shape={tuple(output.shape)} elapsed={time.monotonic() - trace_start:.3f}" - ) return output.movedim(1, -1) def encode(self, pixel_samples): self.throw_exception_if_invalid() - trace_start = time.monotonic() - issue130_trace( - "VAE.encode.start " - f"pixel_shape={tuple(pixel_samples.shape)} pixel_device={pixel_samples.device} " - f"pixel_dtype={pixel_samples.dtype} latent_dim={self.latent_dim}" - ) pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = pixel_samples.movedim(-1, 1) do_tile = False @@ -1111,15 +1073,12 @@ def encode(self, pixel_samples): pixel_samples = pixel_samples.unsqueeze(2) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) - issue130_trace(f"VAE.encode.memory_required={memory_used}") model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) free_memory = self.patcher.get_free_memory(self.device) batch_number = int(free_memory / max(1, memory_used)) batch_number = max(1, batch_number) - issue130_trace(f"VAE.encode.batch_number={batch_number} free_memory={free_memory}") samples = None for x in range(0, pixel_samples.shape[0], batch_number): - issue130_trace(f"VAE.encode.batch_start start={x} stop={x + batch_number}") pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype) if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): out = self.first_stage_model.encode(pixels_in, device=self.device) @@ -1132,7 +1091,6 @@ def encode(self, pixel_samples): if samples is None: samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) samples[x:x + batch_number] = out - issue130_trace(f"VAE.encode.batch_end start={x} elapsed={time.monotonic() - trace_start:.3f}") except Exception as e: model_management.raise_non_oom(e) @@ -1144,7 +1102,6 @@ def encode(self, pixel_samples): do_tile = True if do_tile: - issue130_trace("VAE.encode.tiled_fallback_start") comfy.model_management.soft_empty_cache() if self.latent_dim == 3: tile = 256 @@ -1154,12 +1111,7 @@ def encode(self, pixel_samples): samples = self.encode_tiled_1d(pixel_samples) else: samples = self.encode_tiled_(pixel_samples) - issue130_trace(f"VAE.encode.tiled_fallback_end elapsed={time.monotonic() - trace_start:.3f}") - issue130_trace( - "VAE.encode.end " - f"samples_shape={tuple(samples.shape)} elapsed={time.monotonic() - trace_start:.3f}" - ) return samples def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index ae6545865a37..0125cac60f52 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -2,8 +2,6 @@ from comfy_api.latest import ComfyExtension, io import torch import math -import os -import time from einops import rearrange import gc @@ -16,10 +14,6 @@ from comfy.ldm.seedvr.vae import tiled_vae -def issue130_trace(message): - if os.environ.get("ISSUE130_SEEDVR2_TRACE") == "1": - print(f"ISSUE130_TRACE {time.monotonic():.6f} nodes_seedvr {message}", flush=True) - def clear_vae_memory(vae_model): for module in vae_model.modules(): if hasattr(module, "memory"): @@ -150,15 +144,7 @@ def define_schema(cls): @classmethod def execute(cls, images, vae, resolution, spatial_tile_size, spatial_overlap, temporal_tile_size, enable_tiling): - issue130_trace( - "SeedVR2InputProcessing.start " - f"images_shape={tuple(images.shape)} resolution={resolution} " - f"spatial_tile_size={spatial_tile_size} spatial_overlap={spatial_overlap} " - f"temporal_tile_size={temporal_tile_size} enable_tiling={enable_tiling}" - ) - start_time = time.monotonic() comfy.model_management.load_models_gpu([vae.patcher]) - issue130_trace("SeedVR2InputProcessing.after_load_models_gpu") vae_model = vae.first_stage_model scale = 0.9152 shift = 0 @@ -197,10 +183,6 @@ def make_divisible(val, divisor): args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap), "temporal_size":temporal_tile_size} if enable_tiling: - issue130_trace( - "SeedVR2InputProcessing.before_tiled_encode " - f"images_shape={tuple(images_bthwc.shape)} args={args}" - ) vae_model.img_dims = [o_h, o_w] vae_model.original_image_video = images_bcthw latent = vae.encode_tiled( @@ -211,20 +193,11 @@ def make_divisible(val, divisor): tile_t=temporal_tile_size, ) else: - issue130_trace( - "SeedVR2InputProcessing.before_direct_encode " - f"images_shape={tuple(images_bthwc.shape)} orig_dims={[o_h, o_w]}" - ) vae_model.img_dims = [o_h, o_w] vae_model.original_image_video = images_bcthw latent = vae.encode(images_bthwc) - issue130_trace( - "SeedVR2InputProcessing.after_encode " - f"latent_shape={tuple(latent.shape)} elapsed={time.monotonic() - start_time:.3f}" - ) clear_vae_memory(vae_model) - issue130_trace("SeedVR2InputProcessing.after_clear_vae_memory") #images = images.to(offload_device) #vae_model = vae_model.to(offload_device) @@ -238,10 +211,6 @@ def make_divisible(val, divisor): latent = (latent - shift) * scale - issue130_trace( - "SeedVR2InputProcessing.end " - f"output_shape={tuple(latent.shape)} elapsed={time.monotonic() - start_time:.3f}" - ) return io.NodeOutput({"samples": latent}) class SeedVR2Conditioning(io.ComfyNode): @@ -263,12 +232,7 @@ def define_schema(cls): @classmethod def execute(cls, vae_conditioning, model, latent_noise_scale) -> io.NodeOutput: - start_time = time.monotonic() vae_conditioning = vae_conditioning["samples"] - issue130_trace( - "SeedVR2Conditioning.start " - f"vae_conditioning_shape={tuple(vae_conditioning.shape)} latent_noise_scale={latent_noise_scale}" - ) device = vae_conditioning.device model = model.model.diffusion_model pos_cond = model.positive_conditioning @@ -307,10 +271,6 @@ def execute(cls, vae_conditioning, model, latent_noise_scale) -> io.NodeOutput: negative = [[neg_cond.unsqueeze(0), {"condition": condition}]] positive = [[pos_cond.unsqueeze(0), {"condition": condition}]] - issue130_trace( - "SeedVR2Conditioning.end " - f"latent_shape={tuple(noises.shape)} elapsed={time.monotonic() - start_time:.3f}" - ) return io.NodeOutput(positive, negative, {"samples": noises}) class SeedVRExtension(ComfyExtension): diff --git a/nodes.py b/nodes.py index 20ec946e151d..9059ed639d92 100644 --- a/nodes.py +++ b/nodes.py @@ -313,29 +313,9 @@ def decode(self, vae, samples): if latent.is_nested: latent = latent.unbind()[0] - trace_enabled = os.environ.get("ISSUE130_SEEDVR2_TRACE") == "1" - start_time = time.monotonic() - if trace_enabled: - print( - f"ISSUE130_TRACE {start_time:.6f} nodes VAEDecode.start " - f"latent_shape={tuple(latent.shape)}", - flush=True, - ) images = vae.decode(latent) - if trace_enabled: - print( - f"ISSUE130_TRACE {time.monotonic():.6f} nodes VAEDecode.after_decode " - f"images_shape={tuple(images.shape)} elapsed={time.monotonic() - start_time:.3f}", - flush=True, - ) if len(images.shape) == 5: #Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) - if trace_enabled: - print( - f"ISSUE130_TRACE {time.monotonic():.6f} nodes VAEDecode.end " - f"images_shape={tuple(images.shape)} elapsed={time.monotonic() - start_time:.3f}", - flush=True, - ) return (images, ) class VAEDecodeTiled: From 8ab6eb074f24898dd0530fac71e62f562481040f Mon Sep 17 00:00:00 2001 From: John Pollock Date: Fri, 24 Apr 2026 14:18:28 -0500 Subject: [PATCH 03/11] Fix SeedVR2 Ruff unused import --- comfy_extras/nodes_seedvr.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 0125cac60f52..ecde9a69e415 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -11,7 +11,6 @@ from torchvision.transforms import functional as TVF from torchvision.transforms import Lambda, Normalize from torchvision.transforms.functional import InterpolationMode -from comfy.ldm.seedvr.vae import tiled_vae def clear_vae_memory(vae_model): From 302060442726283dcf192841b958bfcf33ea2045 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 27 Apr 2026 14:28:08 -0500 Subject: [PATCH 04/11] Add SeedVR2 apply_rotary_emb -> apply_rope1 delegation regression test (issue #120) --- .../comfy_test/test_seedvr_rope_delegation.py | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 tests-unit/comfy_test/test_seedvr_rope_delegation.py diff --git a/tests-unit/comfy_test/test_seedvr_rope_delegation.py b/tests-unit/comfy_test/test_seedvr_rope_delegation.py new file mode 100644 index 000000000000..97b0ee9cde5f --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_rope_delegation.py @@ -0,0 +1,64 @@ +import inspect +import json + +import pytest +import torch + +from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.seedvr.model import apply_rotary_emb + + +@pytest.fixture(scope="session", autouse=True) +def _print_apply_rotary_emb_params(pytestconfig): + names = list(inspect.signature(apply_rotary_emb).parameters) + line = f"params: {json.dumps(names)}" + tr = pytestconfig.pluginmanager.get_plugin("terminalreporter") + if tr is not None: + tr.write_line(line) + else: + print(line) + yield + + +_TOL = { + torch.float32: 1e-6, + torch.float16: 1e-3, + torch.bfloat16: 1e-2, +} + + +_CASES = [ + pytest.param("cpu", torch.float32, (1, 8, 16), id="cpu-float32-1x8x16"), + pytest.param("cpu", torch.float16, (1, 8, 16), id="cpu-float16-1x8x16"), + pytest.param("cpu", torch.bfloat16, (1, 8, 16), id="cpu-bfloat16-1x8x16"), + pytest.param("cpu", torch.float32, (2, 16, 32), id="cpu-float32-2x16x32"), + pytest.param( + "cuda", + torch.float16, + (1, 8, 16), + id="cuda-float16-1x8x16", + marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda"), + ), +] + + +@pytest.mark.parametrize("device,dtype,shape", _CASES) +def test_apply_rotary_emb_delegates_to_apply_rope1(device, dtype, shape): + torch.manual_seed(0) + t = torch.randn(*shape, dtype=dtype, device=device) + freqs = torch.randn(shape[-2], shape[-1], dtype=dtype, device=device) + + wrapper_out = apply_rotary_emb(freqs, t) + + rot_feats = freqs.shape[-1] + t_middle = t[..., 0:rot_feats] + angles = freqs.to(t_middle.device)[..., ::2] + cos = torch.cos(angles) * 1.0 + sin = torch.sin(angles) * 1.0 + col0 = torch.stack([cos, sin], dim=-1) + col1 = torch.stack([-sin, cos], dim=-1) + freqs_mat = torch.stack([col0, col1], dim=-1) + direct_out = apply_rope1(t_middle, freqs_mat) + + tol = _TOL[dtype] + assert torch.allclose(wrapper_out, direct_out, atol=tol) From c51764a148b479737a7d39acece5765a9e651860 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 27 Apr 2026 15:39:17 -0500 Subject: [PATCH 05/11] Address Copilot review on PR #21 (issue #120): assert_close rtol=0/atol=0, local Generator, non-default start_index/scale coverage, drop autouse params print --- .../comfy_test/test_seedvr_rope_delegation.py | 125 +++++++++++------- 1 file changed, 80 insertions(+), 45 deletions(-) diff --git a/tests-unit/comfy_test/test_seedvr_rope_delegation.py b/tests-unit/comfy_test/test_seedvr_rope_delegation.py index 97b0ee9cde5f..31fd500a5982 100644 --- a/tests-unit/comfy_test/test_seedvr_rope_delegation.py +++ b/tests-unit/comfy_test/test_seedvr_rope_delegation.py @@ -1,64 +1,99 @@ -import inspect -import json +"""Regression test: comfy.ldm.seedvr.model.apply_rotary_emb must delegate to +comfy.ldm.flux.math.apply_rope1 with byte-exact equality across the wrapper's +slicing, scaling, and concatenation logic. Drift between the wrapper and the +delegate would silently corrupt SeedVR2's RoPE; this test fails loudly on any +future drift. + +Imports are taken at module level. Heavy-import stubbing of +``comfy.model_management`` was attempted but is insufficient on the live import +chain (``comfy.ldm.seedvr.model`` pulls +``comfy.ldm.modules.diffusionmodules.model -> comfy.ops -> +comfy.memory_management -> comfy.quant_ops -> comfy_kitchen.tensor -> +torch._dynamo``), so every layer would have to be stubbed in lock-step; +running the test against the real modules instead is the fail-loud-from-real- +state approach this repo's tests follow. + +The test uses a local ``torch.Generator`` so global RNG state is not mutated +(Copilot review on PR #21, finding #4) and ``torch.testing.assert_close`` with +``rtol=0, atol=0`` so any future kernel-precision drift is caught (PR #21, +finding #2; live ``max_abs_delta`` on ``issue_101`` HEAD is 0.0 across every +case). Parametrization covers non-default ``start_index`` and ``scale`` so the +wrapper's slicing/concatenation and scale-propagation logic are exercised, not +just the trivial ``rot_feats == t.shape[-1]`` happy path (PR #21, finding #3). +The previous version's session-scoped ``params: [...]`` print fixture was +removed (PR #21, finding #1). +""" import pytest import torch +import torch.testing from comfy.ldm.flux.math import apply_rope1 from comfy.ldm.seedvr.model import apply_rotary_emb -@pytest.fixture(scope="session", autouse=True) -def _print_apply_rotary_emb_params(pytestconfig): - names = list(inspect.signature(apply_rotary_emb).parameters) - line = f"params: {json.dumps(names)}" - tr = pytestconfig.pluginmanager.get_plugin("terminalreporter") - if tr is not None: - tr.write_line(line) - else: - print(line) - yield - - -_TOL = { - torch.float32: 1e-6, - torch.float16: 1e-3, - torch.bfloat16: 1e-2, -} +def _direct_reproduction(freqs, t, start_index=0, scale=1.0): + """Byte-for-byte reproduction of comfy/ldm/seedvr/model.py:471-505 + apply_rotary_emb body, calling apply_rope1 directly on the middle slice. + """ + rot_feats = freqs.shape[-1] + end_index = start_index + rot_feats + t_left = t[..., :start_index] + t_middle = t[..., start_index:end_index] + t_right = t[..., end_index:] + angles = freqs.to(t_middle.device)[..., ::2] + cos = torch.cos(angles) * scale + sin = torch.sin(angles) * scale + col0 = torch.stack([cos, sin], dim=-1) + col1 = torch.stack([-sin, cos], dim=-1) + freqs_mat = torch.stack([col0, col1], dim=-1) + t_middle_out = apply_rope1(t_middle, freqs_mat) + return torch.cat((t_left, t_middle_out, t_right), dim=-1).type(t.dtype) +# (device, dtype, t_shape, freqs_shape, start_index, scale) _CASES = [ - pytest.param("cpu", torch.float32, (1, 8, 16), id="cpu-float32-1x8x16"), - pytest.param("cpu", torch.float16, (1, 8, 16), id="cpu-float16-1x8x16"), - pytest.param("cpu", torch.bfloat16, (1, 8, 16), id="cpu-bfloat16-1x8x16"), - pytest.param("cpu", torch.float32, (2, 16, 32), id="cpu-float32-2x16x32"), + pytest.param("cpu", torch.float32, (1, 8, 16), (8, 16), 0, 1.0, + id="cpu-float32-base"), + pytest.param("cpu", torch.float16, (1, 8, 16), (8, 16), 0, 1.0, + id="cpu-float16-base"), + pytest.param("cpu", torch.bfloat16, (1, 8, 16), (8, 16), 0, 1.0, + id="cpu-bfloat16-base"), + pytest.param("cpu", torch.float32, (2, 16, 32), (16, 32), 0, 1.0, + id="cpu-float32-larger"), + pytest.param("cpu", torch.float32, (1, 8, 24), (8, 16), 4, 1.0, + id="cpu-float32-non-empty-left-and-right-slices"), + pytest.param("cpu", torch.float32, (1, 8, 16), (8, 16), 0, 0.5, + id="cpu-float32-non-default-scale"), pytest.param( - "cuda", - torch.float16, - (1, 8, 16), - id="cuda-float16-1x8x16", + "cuda", torch.float16, (1, 8, 16), (8, 16), 0, 1.0, + id="cuda-float16-base", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda"), ), ] -@pytest.mark.parametrize("device,dtype,shape", _CASES) -def test_apply_rotary_emb_delegates_to_apply_rope1(device, dtype, shape): - torch.manual_seed(0) - t = torch.randn(*shape, dtype=dtype, device=device) - freqs = torch.randn(shape[-2], shape[-1], dtype=dtype, device=device) - - wrapper_out = apply_rotary_emb(freqs, t) +@pytest.mark.parametrize("device,dtype,t_shape,freqs_shape,start_index,scale", _CASES) +def test_apply_rotary_emb_delegates_to_apply_rope1( + device, dtype, t_shape, freqs_shape, start_index, scale +): + generator = torch.Generator(device=device).manual_seed(0) + t = torch.randn(*t_shape, dtype=dtype, device=device, generator=generator) + freqs = torch.randn(*freqs_shape, dtype=dtype, device=device, generator=generator) - rot_feats = freqs.shape[-1] - t_middle = t[..., 0:rot_feats] - angles = freqs.to(t_middle.device)[..., ::2] - cos = torch.cos(angles) * 1.0 - sin = torch.sin(angles) * 1.0 - col0 = torch.stack([cos, sin], dim=-1) - col1 = torch.stack([-sin, cos], dim=-1) - freqs_mat = torch.stack([col0, col1], dim=-1) - direct_out = apply_rope1(t_middle, freqs_mat) + wrapper_out = apply_rotary_emb(freqs, t, start_index=start_index, scale=scale) + direct_out = _direct_reproduction( + freqs, t, start_index=start_index, scale=scale + ) - tol = _TOL[dtype] - assert torch.allclose(wrapper_out, direct_out, atol=tol) + torch.testing.assert_close( + wrapper_out, + direct_out, + rtol=0, + atol=0, + msg=lambda m: ( + f"apply_rotary_emb does not byte-match direct apply_rope1 reproduction " + f"(device={device}, dtype={dtype}, t_shape={t_shape}, " + f"freqs_shape={freqs_shape}, start_index={start_index}, scale={scale}): {m}" + ), + ) From 5be4c76db7e35a1452a2d27a513691c8c7dcf8ae Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 27 Apr 2026 19:27:58 -0500 Subject: [PATCH 06/11] =?UTF-8?q?address=20codex=20P2=20+=20Copilot=203150?= =?UTF-8?q?100205/3150100217/3150100225/3150100175/3150100250=20=E2=80=94?= =?UTF-8?q?=20apply=5Frope1=20spy,=20freqs-longer-than-seq=20case,=20exact?= =?UTF-8?q?-equality=20wording,=20eager=20msg=20string?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../comfy_test/test_seedvr_rope_delegation.py | 97 +++++++++++++------ 1 file changed, 68 insertions(+), 29 deletions(-) diff --git a/tests-unit/comfy_test/test_seedvr_rope_delegation.py b/tests-unit/comfy_test/test_seedvr_rope_delegation.py index 31fd500a5982..1f87d13f4f3c 100644 --- a/tests-unit/comfy_test/test_seedvr_rope_delegation.py +++ b/tests-unit/comfy_test/test_seedvr_rope_delegation.py @@ -1,41 +1,63 @@ -"""Regression test: comfy.ldm.seedvr.model.apply_rotary_emb must delegate to -comfy.ldm.flux.math.apply_rope1 with byte-exact equality across the wrapper's -slicing, scaling, and concatenation logic. Drift between the wrapper and the -delegate would silently corrupt SeedVR2's RoPE; this test fails loudly on any -future drift. +"""Regression test: ``comfy.ldm.seedvr.model.apply_rotary_emb`` must delegate +to ``comfy.ldm.flux.math.apply_rope1`` and produce exact-equality output +across the wrapper's slicing, scaling, and concatenation logic. Drift between +the wrapper and the delegate would silently corrupt SeedVR2's RoPE; this test +fails loudly on any future drift. + +Each parametrized case both: + +1. Patches ``comfy.ldm.seedvr.model.apply_rope1`` with a ``wraps``-style spy + and asserts ``spy.call_count >= 1`` so a future change that inlines the + math and stops calling ``apply_rope1`` fails the test (Copilot review on + PR #21 comment 3150100205; codex P2). +2. Compares the wrapper's output against a hand-rolled reproduction using + ``torch.testing.assert_close(rtol=0, atol=0)`` -- exact tensor equality, + not bit-equality (``+0.0`` vs ``-0.0`` and NaN payloads can still match); + the assertion catches any future kernel-precision drift in the + ``apply_rope1`` dispatch (Copilot review on PR #21 comments 3149914528 and + 3150100175). + +The test uses a local ``torch.Generator`` so global RNG state is not mutated +(Copilot review on PR #21 comment 3149914599). Parametrization covers +non-default ``start_index`` and ``scale`` and a case where +``freqs.shape[0] > t.shape[seq_dim]`` so the wrapper's +``slice_at_dim(freqs, slice(-seq_len, None), dim=0)`` path is exercised +(Copilot review on PR #21 comments 3149914553, 3150100217, 3150100225). Imports are taken at module level. Heavy-import stubbing of -``comfy.model_management`` was attempted but is insufficient on the live import -chain (``comfy.ldm.seedvr.model`` pulls +``comfy.model_management`` was attempted but is insufficient on the live +import chain (``comfy.ldm.seedvr.model`` pulls ``comfy.ldm.modules.diffusionmodules.model -> comfy.ops -> comfy.memory_management -> comfy.quant_ops -> comfy_kitchen.tensor -> -torch._dynamo``), so every layer would have to be stubbed in lock-step; -running the test against the real modules instead is the fail-loud-from-real- -state approach this repo's tests follow. - -The test uses a local ``torch.Generator`` so global RNG state is not mutated -(Copilot review on PR #21, finding #4) and ``torch.testing.assert_close`` with -``rtol=0, atol=0`` so any future kernel-precision drift is caught (PR #21, -finding #2; live ``max_abs_delta`` on ``issue_101`` HEAD is 0.0 across every -case). Parametrization covers non-default ``start_index`` and ``scale`` so the -wrapper's slicing/concatenation and scale-propagation logic are exercised, not -just the trivial ``rot_feats == t.shape[-1]`` happy path (PR #21, finding #3). -The previous version's session-scoped ``params: [...]`` print fixture was -removed (PR #21, finding #1). +torch._dynamo``), so every layer would have to be stubbed in lock-step. +Running the test against the real modules is the fail-loud-from-real-state +approach this repo's tests follow. """ +from unittest.mock import patch + import pytest import torch import torch.testing +import comfy.ldm.seedvr.model as seedvr_model from comfy.ldm.flux.math import apply_rope1 from comfy.ldm.seedvr.model import apply_rotary_emb -def _direct_reproduction(freqs, t, start_index=0, scale=1.0): - """Byte-for-byte reproduction of comfy/ldm/seedvr/model.py:471-505 - apply_rotary_emb body, calling apply_rope1 directly on the middle slice. +def _direct_reproduction(freqs, t, start_index=0, scale=1.0, seq_dim=-2): + """Reproduce the body of ``apply_rotary_emb`` for the default case where + ``freqs.ndim == 2`` and ``t.ndim == 3`` (implicit ``freqs_seq_dim=0``). + Mirrors the wrapper's ``slice_at_dim(freqs, slice(-seq_len, None), dim=0)`` + step when freqs is longer than ``t`` along ``seq_dim``. Calls the real + ``apply_rope1`` via the test module's import (the test patches the + ``seedvr_model.apply_rope1`` attribute; this call uses the unpatched + ``flux.math`` symbol). """ + if freqs.ndim == 2 and t.ndim == 3: + seq_len = t.shape[seq_dim] + freqs = freqs[-seq_len:] + rot_feats = freqs.shape[-1] end_index = start_index + rot_feats t_left = t[..., :start_index] @@ -65,6 +87,8 @@ def _direct_reproduction(freqs, t, start_index=0, scale=1.0): id="cpu-float32-non-empty-left-and-right-slices"), pytest.param("cpu", torch.float32, (1, 8, 16), (8, 16), 0, 0.5, id="cpu-float32-non-default-scale"), + pytest.param("cpu", torch.float32, (1, 8, 16), (12, 16), 0, 1.0, + id="cpu-float32-freqs-longer-than-seq"), pytest.param( "cuda", torch.float16, (1, 8, 16), (8, 16), 0, 1.0, id="cuda-float16-base", @@ -81,19 +105,34 @@ def test_apply_rotary_emb_delegates_to_apply_rope1( t = torch.randn(*t_shape, dtype=dtype, device=device, generator=generator) freqs = torch.randn(*freqs_shape, dtype=dtype, device=device, generator=generator) - wrapper_out = apply_rotary_emb(freqs, t, start_index=start_index, scale=scale) + # Patch the apply_rope1 symbol as imported into seedvr.model with a wraps + # spy: a future change that inlines the math and stops calling the + # imported apply_rope1 makes spy.call_count == 0 and fails the test. + with patch.object( + seedvr_model, "apply_rope1", wraps=seedvr_model.apply_rope1 + ) as spy: + wrapper_out = apply_rotary_emb( + freqs, t, start_index=start_index, scale=scale + ) + + assert spy.call_count >= 1, ( + "apply_rotary_emb did not call comfy.ldm.seedvr.model.apply_rope1; " + "the delegation invariant is broken" + ) + direct_out = _direct_reproduction( freqs, t, start_index=start_index, scale=scale ) + msg = ( + f"apply_rotary_emb output does not match direct apply_rope1 " + f"reproduction (device={device}, dtype={dtype}, t_shape={t_shape}, " + f"freqs_shape={freqs_shape}, start_index={start_index}, scale={scale})" + ) torch.testing.assert_close( wrapper_out, direct_out, rtol=0, atol=0, - msg=lambda m: ( - f"apply_rotary_emb does not byte-match direct apply_rope1 reproduction " - f"(device={device}, dtype={dtype}, t_shape={t_shape}, " - f"freqs_shape={freqs_shape}, start_index={start_index}, scale={scale}): {m}" - ), + msg=msg, ) From ef0d3166290293147a8b6f35b7ea912f6a73477f Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 27 Apr 2026 19:43:56 -0500 Subject: [PATCH 07/11] =?UTF-8?q?address=20Copilot=203150913301/3150913316?= =?UTF-8?q?=20=E2=80=94=20drop=20redundant=20'import=20torch.testing',=20r?= =?UTF-8?q?eplace=20per-comment-id=20docstring=20refs=20with=20stable=20is?= =?UTF-8?q?sue=20link?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../comfy_test/test_seedvr_rope_delegation.py | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/tests-unit/comfy_test/test_seedvr_rope_delegation.py b/tests-unit/comfy_test/test_seedvr_rope_delegation.py index 1f87d13f4f3c..9933e299f08f 100644 --- a/tests-unit/comfy_test/test_seedvr_rope_delegation.py +++ b/tests-unit/comfy_test/test_seedvr_rope_delegation.py @@ -8,37 +8,33 @@ 1. Patches ``comfy.ldm.seedvr.model.apply_rope1`` with a ``wraps``-style spy and asserts ``spy.call_count >= 1`` so a future change that inlines the - math and stops calling ``apply_rope1`` fails the test (Copilot review on - PR #21 comment 3150100205; codex P2). + math and stops calling ``apply_rope1`` fails the test. 2. Compares the wrapper's output against a hand-rolled reproduction using ``torch.testing.assert_close(rtol=0, atol=0)`` -- exact tensor equality, not bit-equality (``+0.0`` vs ``-0.0`` and NaN payloads can still match); the assertion catches any future kernel-precision drift in the - ``apply_rope1`` dispatch (Copilot review on PR #21 comments 3149914528 and - 3150100175). + ``apply_rope1`` dispatch. -The test uses a local ``torch.Generator`` so global RNG state is not mutated -(Copilot review on PR #21 comment 3149914599). Parametrization covers -non-default ``start_index`` and ``scale`` and a case where -``freqs.shape[0] > t.shape[seq_dim]`` so the wrapper's -``slice_at_dim(freqs, slice(-seq_len, None), dim=0)`` path is exercised -(Copilot review on PR #21 comments 3149914553, 3150100217, 3150100225). - -Imports are taken at module level. Heavy-import stubbing of +The test uses a local ``torch.Generator`` so global RNG state is not mutated. +Parametrization covers non-default ``start_index`` and ``scale`` and a case +where ``freqs.shape[0] > t.shape[seq_dim]`` so the wrapper's +``slice_at_dim(freqs, slice(-seq_len, None), dim=0)`` path is exercised. +Imports are taken at module level; heavy-import stubbing of ``comfy.model_management`` was attempted but is insufficient on the live import chain (``comfy.ldm.seedvr.model`` pulls ``comfy.ldm.modules.diffusionmodules.model -> comfy.ops -> comfy.memory_management -> comfy.quant_ops -> comfy_kitchen.tensor -> -torch._dynamo``), so every layer would have to be stubbed in lock-step. -Running the test against the real modules is the fail-loud-from-real-state -approach this repo's tests follow. +torch._dynamo``), so running the test against the real modules is the +fail-loud-from-real-state approach this repo's tests follow. + +Test design rationale and per-decision review trail are recorded on the +tracking issue: https://github.com/pollockjj/mydevelopment/issues/120 """ from unittest.mock import patch import pytest import torch -import torch.testing import comfy.ldm.seedvr.model as seedvr_model from comfy.ldm.flux.math import apply_rope1 From 04ab2a8dc1dc25a60dd0a51b9ab818302263d7a2 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 27 Apr 2026 19:59:03 -0500 Subject: [PATCH 08/11] =?UTF-8?q?address=20Copilot=203150961536=20?= =?UTF-8?q?=E2=80=94=20flip=20comfy.cli=5Fargs.args.cpu=3DTrue=20before=20?= =?UTF-8?q?importing=20comfy.ldm.*=20on=20CPU-only=20hosts=20so=20import-t?= =?UTF-8?q?ime=20get=5Ftorch=5Fdevice()=20does=20not=20probe=20a=20missing?= =?UTF-8?q?=20CUDA=20device=20(matches=20tests-unit/comfy=5Fquant/test=5Fm?= =?UTF-8?q?ixed=5Fprecision.py=20pattern)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../comfy_test/test_seedvr_rope_delegation.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/tests-unit/comfy_test/test_seedvr_rope_delegation.py b/tests-unit/comfy_test/test_seedvr_rope_delegation.py index 9933e299f08f..5a1c45058206 100644 --- a/tests-unit/comfy_test/test_seedvr_rope_delegation.py +++ b/tests-unit/comfy_test/test_seedvr_rope_delegation.py @@ -36,9 +36,21 @@ import pytest import torch -import comfy.ldm.seedvr.model as seedvr_model -from comfy.ldm.flux.math import apply_rope1 -from comfy.ldm.seedvr.model import apply_rotary_emb +# CPU-only CI fix: ``comfy.ldm.seedvr.model`` transitively imports +# ``comfy.model_management``, whose import-time ``get_torch_device()`` call +# probes ``torch.cuda.current_device()`` unless ``comfy.cli_args.args.cpu`` is +# set. On a CPU-only build that probe can raise during test collection before +# the ``cuda`` case has had a chance to be skipped. Match the pattern used by +# ``tests-unit/comfy_quant/test_mixed_precision.py``: flip ``args.cpu`` before +# importing any ``comfy.ldm.*`` symbol. +from comfy.cli_args import args + +if not torch.cuda.is_available(): + args.cpu = True + +import comfy.ldm.seedvr.model as seedvr_model # noqa: E402 +from comfy.ldm.flux.math import apply_rope1 # noqa: E402 +from comfy.ldm.seedvr.model import apply_rotary_emb # noqa: E402 def _direct_reproduction(freqs, t, start_index=0, scale=1.0, seq_dim=-2): From acfeb68fbbcb5a5504f7ef7bf0590c731dfac144 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 27 Apr 2026 20:11:02 -0500 Subject: [PATCH 09/11] =?UTF-8?q?address=20Copilot=203151003974=20?= =?UTF-8?q?=E2=80=94=20narrow=20docstring=20claim=20from=20repo-wide=20con?= =?UTF-8?q?vention=20to=20test-local=20choice;=20cite=20image=5Fstitch=5Ft?= =?UTF-8?q?est.py=20as=20a=20counter-example=20of=20stub=20usage=20in=20th?= =?UTF-8?q?is=20repo?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests-unit/comfy_test/test_seedvr_rope_delegation.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests-unit/comfy_test/test_seedvr_rope_delegation.py b/tests-unit/comfy_test/test_seedvr_rope_delegation.py index 5a1c45058206..65484164413d 100644 --- a/tests-unit/comfy_test/test_seedvr_rope_delegation.py +++ b/tests-unit/comfy_test/test_seedvr_rope_delegation.py @@ -19,13 +19,16 @@ Parametrization covers non-default ``start_index`` and ``scale`` and a case where ``freqs.shape[0] > t.shape[seq_dim]`` so the wrapper's ``slice_at_dim(freqs, slice(-seq_len, None), dim=0)`` path is exercised. -Imports are taken at module level; heavy-import stubbing of -``comfy.model_management`` was attempted but is insufficient on the live +Imports are taken at module level. Heavy-import stubbing of +``comfy.model_management`` was attempted but is insufficient on this live import chain (``comfy.ldm.seedvr.model`` pulls ``comfy.ldm.modules.diffusionmodules.model -> comfy.ops -> comfy.memory_management -> comfy.quant_ops -> comfy_kitchen.tensor -> -torch._dynamo``), so running the test against the real modules is the -fail-loud-from-real-state approach this repo's tests follow. +torch._dynamo``), so this test intentionally runs against the real modules +to fail loudly if that import path or runtime state drifts. Other tests in +this repo (e.g. ``tests-unit/comfy_extras_test/image_stitch_test.py``) do +stub via ``patch.dict(sys.modules, ...)`` for narrower targets; the choice +here is local to this regression and not a repo-wide convention. Test design rationale and per-decision review trail are recorded on the tracking issue: https://github.com/pollockjj/mydevelopment/issues/120 From f80a8727fde96ce8ed80fc959b7afb63271dcf47 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 27 Apr 2026 20:18:38 -0500 Subject: [PATCH 10/11] =?UTF-8?q?address=20Copilot=203151024280=20?= =?UTF-8?q?=E2=80=94=20guard=20cpu-float16/cpu-bfloat16=20cases=20with=20p?= =?UTF-8?q?ytest.mark.skipif(not=20=5Fcpu=5Ftrig=5Fsupported(dtype))=20so?= =?UTF-8?q?=20PyTorch=20CPU=20wheels=20that=20don't=20implement=20torch.co?= =?UTF-8?q?s/sin=20for=20those=20dtypes=20skip=20cleanly=20instead=20of=20?= =?UTF-8?q?failing=20CI;=20cases=20preserved=20per=20plan=20#120=20Slice?= =?UTF-8?q?=202=20AC-2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../comfy_test/test_seedvr_rope_delegation.py | 38 +++++++++++++++++-- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/tests-unit/comfy_test/test_seedvr_rope_delegation.py b/tests-unit/comfy_test/test_seedvr_rope_delegation.py index 65484164413d..f2dfafca98af 100644 --- a/tests-unit/comfy_test/test_seedvr_rope_delegation.py +++ b/tests-unit/comfy_test/test_seedvr_rope_delegation.py @@ -84,14 +84,44 @@ def _direct_reproduction(freqs, t, start_index=0, scale=1.0, seq_dim=-2): return torch.cat((t_left, t_middle_out, t_right), dim=-1).type(t.dtype) +def _cpu_trig_supported(dtype): + """Return whether ``torch.cos`` (and by symmetry ``torch.sin``) is + implemented for the given dtype on CPU on the current runtime. Some + PyTorch CPU wheels don't implement trig ops for ``float16`` / ``bfloat16`` + and raise at runtime; the parametrized cases for those dtypes are skipped + when that's the case so CI remains stable across PyTorch builds. + """ + try: + torch.cos(torch.zeros(1, dtype=dtype)) + except (RuntimeError, TypeError): + return False + return True + + +_CPU_FP16_TRIG_OK = _cpu_trig_supported(torch.float16) +_CPU_BF16_TRIG_OK = _cpu_trig_supported(torch.bfloat16) + + # (device, dtype, t_shape, freqs_shape, start_index, scale) _CASES = [ pytest.param("cpu", torch.float32, (1, 8, 16), (8, 16), 0, 1.0, id="cpu-float32-base"), - pytest.param("cpu", torch.float16, (1, 8, 16), (8, 16), 0, 1.0, - id="cpu-float16-base"), - pytest.param("cpu", torch.bfloat16, (1, 8, 16), (8, 16), 0, 1.0, - id="cpu-bfloat16-base"), + pytest.param( + "cpu", torch.float16, (1, 8, 16), (8, 16), 0, 1.0, + id="cpu-float16-base", + marks=pytest.mark.skipif( + not _CPU_FP16_TRIG_OK, + reason="torch.cos/torch.sin unsupported for float16 tensors on CPU", + ), + ), + pytest.param( + "cpu", torch.bfloat16, (1, 8, 16), (8, 16), 0, 1.0, + id="cpu-bfloat16-base", + marks=pytest.mark.skipif( + not _CPU_BF16_TRIG_OK, + reason="torch.cos/torch.sin unsupported for bfloat16 tensors on CPU", + ), + ), pytest.param("cpu", torch.float32, (2, 16, 32), (16, 32), 0, 1.0, id="cpu-float32-larger"), pytest.param("cpu", torch.float32, (1, 8, 24), (8, 16), 4, 1.0, From 960d0ce5111b3f944b16acd5d17d4ca1fa5a0eb6 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 27 Apr 2026 20:32:01 -0500 Subject: [PATCH 11/11] =?UTF-8?q?address=20Copilot=203151043880=20?= =?UTF-8?q?=E2=80=94=20docstring=20grammar:=20'each=20parametrized=20case?= =?UTF-8?q?=20both:'=20->=20'each=20parametrized=20case=20does=20both:'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests-unit/comfy_test/test_seedvr_rope_delegation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests-unit/comfy_test/test_seedvr_rope_delegation.py b/tests-unit/comfy_test/test_seedvr_rope_delegation.py index f2dfafca98af..89ad5fa2b59c 100644 --- a/tests-unit/comfy_test/test_seedvr_rope_delegation.py +++ b/tests-unit/comfy_test/test_seedvr_rope_delegation.py @@ -4,7 +4,7 @@ the wrapper and the delegate would silently corrupt SeedVR2's RoPE; this test fails loudly on any future drift. -Each parametrized case both: +Each parametrized case does both: 1. Patches ``comfy.ldm.seedvr.model.apply_rope1`` with a ``wraps``-style spy and asserts ``spy.call_count >= 1`` so a future change that inlines the