From aa6bd8ddba4ae96bc7720da3c48b6f6589970fcb Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sun, 21 Sep 2025 06:03:15 +0000 Subject: [PATCH 1/4] add Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen3_vl.py | 147 +++++++++++++------------ 1 file changed, 75 insertions(+), 72 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 17375ff0959d..bba1281deddf 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -276,6 +276,10 @@ def __init__( self.out_hidden_size = (vision_config.out_hidden_size * (1 + len(self.deepstack_visual_indexes))) + # Cache per-resolution positional embeddings to avoid recomputing when + # the same spatial size appears across frames/batches. + self._cached_hw_pos_embeds: dict[tuple[int, int], torch.Tensor] = {} + self.patch_embed = Qwen3_VisionPatchEmbed( patch_size=self.patch_size, temporal_patch_size=self.temporal_patch_size, @@ -378,81 +382,80 @@ def rot_pos_emb(self, grid_thw): return rotary_pos_emb def fast_pos_embed_interpolate(self, grid_thw): - num_grid_per_side = int(self.num_position_embeddings**0.5) - - idx_list = [[] for _ in range(4)] - weight_list = [[] for _ in range(4)] - - for t, h, w in grid_thw: - h_idxs = torch.linspace(0, - num_grid_per_side - 1, - h, - dtype=torch.float32) - w_idxs = torch.linspace(0, - num_grid_per_side - 1, - w, - dtype=torch.float32) - - h_idxs_floor = h_idxs.to(torch.long) - w_idxs_floor = w_idxs.to(torch.long) - h_idxs_ceil = torch.clamp(h_idxs.to(torch.long) + 1, - max=num_grid_per_side - 1) - w_idxs_ceil = torch.clamp(w_idxs.to(torch.long) + 1, - max=num_grid_per_side - 1) - - dh = h_idxs - h_idxs_floor - dw = w_idxs - w_idxs_floor - - idx_list[0].extend(((h_idxs_floor * num_grid_per_side)[None].T + - w_idxs_floor[None]).flatten().tolist() * t) - idx_list[1].extend(((h_idxs_floor * num_grid_per_side)[None].T + - w_idxs_ceil[None]).flatten().tolist() * t) - idx_list[2].extend(((h_idxs_ceil * num_grid_per_side)[None].T + - w_idxs_floor[None]).flatten().tolist() * t) - idx_list[3].extend(((h_idxs_ceil * num_grid_per_side)[None].T + - w_idxs_ceil[None]).flatten().tolist() * t) - - weight_list[0].extend( - ((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t) - weight_list[1].extend( - ((1 - dh)[None].T * dw[None]).flatten().tolist() * t) - weight_list[2].extend( - (dh[None].T * (1 - dw)[None]).flatten().tolist() * t) - weight_list[3].extend( - (dh[None].T * dw[None]).flatten().tolist() * t) + if isinstance(grid_thw, torch.Tensor): + grid_list = [ + tuple(int(v) for v in grid) for grid in grid_thw.tolist() + ] + else: + grid_list = [tuple(int(v) for v in grid) for grid in grid_thw] + num_grid_per_side = int(self.num_position_embeddings**0.5) device = self.pos_embed.weight.device dtype = self.pos_embed.weight.dtype - - p0 = self.pos_embed( - torch.tensor( - idx_list[0], dtype=torch.long, device=device)) * torch.tensor( - weight_list[0], dtype=dtype, device=device)[:, None] - p1 = self.pos_embed( - torch.tensor( - idx_list[1], dtype=torch.long, device=device)) * torch.tensor( - weight_list[1], dtype=dtype, device=device)[:, None] - p2 = self.pos_embed( - torch.tensor( - idx_list[2], dtype=torch.long, device=device)) * torch.tensor( - weight_list[2], dtype=dtype, device=device)[:, None] - p3 = self.pos_embed( - torch.tensor( - idx_list[3], dtype=torch.long, device=device)) * torch.tensor( - weight_list[3], dtype=dtype, device=device)[:, None] - - patch_pos_embeds = p0 + p1 + p2 + p3 - patch_pos_embeds = patch_pos_embeds.split( - [t * h * w for t, h, w in grid_thw]) - patch_pos_embeds_permute = [] m_size = self.spatial_merge_size - for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw): - pos_embed = pos_embed.view(t, h // m_size, m_size, w // m_size, - m_size, -1).permute(0, 1, 3, 2, 4, - 5).flatten(0, 4) - patch_pos_embeds_permute.append(pos_embed) - patch_pos_embeds = torch.cat(patch_pos_embeds_permute) - return patch_pos_embeds + hidden_dim = self.pos_embed.embedding_dim + + outputs = [] + for t, h, w in grid_list: + key = (h, w) + cached = self._cached_hw_pos_embeds.get(key) + if (cached is None or cached.device != device + or cached.dtype != dtype): + h_idxs = torch.linspace(0, + num_grid_per_side - 1, + h, + dtype=torch.float32, + device=device) + w_idxs = torch.linspace(0, + num_grid_per_side - 1, + w, + dtype=torch.float32, + device=device) + + h_floor = h_idxs.to(torch.long) + w_floor = w_idxs.to(torch.long) + h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1) + w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1) + + dh = h_idxs - h_floor + dw = w_idxs - w_floor + + w00 = ((1 - dh)[:, None] * (1 - dw)[None, :]).reshape(-1) + w01 = ((1 - dh)[:, None] * dw[None, :]).reshape(-1) + w10 = (dh[:, None] * (1 - dw)[None, :]).reshape(-1) + w11 = (dh[:, None] * dw[None, :]).reshape(-1) + + idx00 = (h_floor[:, None] * num_grid_per_side + + w_floor[None, :]).reshape(-1) + idx01 = (h_floor[:, None] * num_grid_per_side + + w_ceil[None, :]).reshape(-1) + idx10 = (h_ceil[:, None] * num_grid_per_side + + w_floor[None, :]).reshape(-1) + idx11 = (h_ceil[:, None] * num_grid_per_side + + w_ceil[None, :]).reshape(-1) + + indices = torch.stack([idx00, idx01, idx10, idx11], dim=0) + weights = torch.stack([w00, w01, w10, w11], + dim=0).to(dtype=dtype, device=device) + weights = weights.unsqueeze(-1) + + embeds = F.embedding(indices, self.pos_embed.weight) + weighted_embeds = embeds * weights + p0, p1, p2, p3 = weighted_embeds.unbind(dim=0) + combined = ((p0 + p1) + p2) + p3 + + frame = combined.view(1, h // m_size, m_size, w // m_size, + m_size, hidden_dim) + frame = frame.permute(0, 1, 3, 2, 4, 5).reshape(-1, hidden_dim) + cached = frame.contiguous() + self._cached_hw_pos_embeds[key] = cached + + if t == 1: + outputs.append(cached) + else: + outputs.append(cached.repeat(t, 1)) + + return torch.cat(outputs, dim=0) def compute_attn_mask_seqlen( self, @@ -1526,4 +1529,4 @@ def get_mm_mapping(self) -> MultiModelKeys: language_model="language_model", connector="model.visual.merger", tower_model="model.visual.", - ) \ No newline at end of file + ) From 51ae2dcade52a5c665b2b9b1da406484f175b568 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sun, 21 Sep 2025 06:16:17 +0000 Subject: [PATCH 2/4] update Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen3_vl.py | 111 +++++++++++-------------- 1 file changed, 50 insertions(+), 61 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index bba1281deddf..1524b96ff121 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -276,10 +276,6 @@ def __init__( self.out_hidden_size = (vision_config.out_hidden_size * (1 + len(self.deepstack_visual_indexes))) - # Cache per-resolution positional embeddings to avoid recomputing when - # the same spatial size appears across frames/batches. - self._cached_hw_pos_embeds: dict[tuple[int, int], torch.Tensor] = {} - self.patch_embed = Qwen3_VisionPatchEmbed( patch_size=self.patch_size, temporal_patch_size=self.temporal_patch_size, @@ -397,63 +393,56 @@ def fast_pos_embed_interpolate(self, grid_thw): outputs = [] for t, h, w in grid_list: - key = (h, w) - cached = self._cached_hw_pos_embeds.get(key) - if (cached is None or cached.device != device - or cached.dtype != dtype): - h_idxs = torch.linspace(0, - num_grid_per_side - 1, - h, - dtype=torch.float32, - device=device) - w_idxs = torch.linspace(0, - num_grid_per_side - 1, - w, - dtype=torch.float32, - device=device) - - h_floor = h_idxs.to(torch.long) - w_floor = w_idxs.to(torch.long) - h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1) - w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1) - - dh = h_idxs - h_floor - dw = w_idxs - w_floor - - w00 = ((1 - dh)[:, None] * (1 - dw)[None, :]).reshape(-1) - w01 = ((1 - dh)[:, None] * dw[None, :]).reshape(-1) - w10 = (dh[:, None] * (1 - dw)[None, :]).reshape(-1) - w11 = (dh[:, None] * dw[None, :]).reshape(-1) - - idx00 = (h_floor[:, None] * num_grid_per_side + - w_floor[None, :]).reshape(-1) - idx01 = (h_floor[:, None] * num_grid_per_side + - w_ceil[None, :]).reshape(-1) - idx10 = (h_ceil[:, None] * num_grid_per_side + - w_floor[None, :]).reshape(-1) - idx11 = (h_ceil[:, None] * num_grid_per_side + - w_ceil[None, :]).reshape(-1) - - indices = torch.stack([idx00, idx01, idx10, idx11], dim=0) - weights = torch.stack([w00, w01, w10, w11], - dim=0).to(dtype=dtype, device=device) - weights = weights.unsqueeze(-1) - - embeds = F.embedding(indices, self.pos_embed.weight) - weighted_embeds = embeds * weights - p0, p1, p2, p3 = weighted_embeds.unbind(dim=0) - combined = ((p0 + p1) + p2) + p3 - - frame = combined.view(1, h // m_size, m_size, w // m_size, - m_size, hidden_dim) - frame = frame.permute(0, 1, 3, 2, 4, 5).reshape(-1, hidden_dim) - cached = frame.contiguous() - self._cached_hw_pos_embeds[key] = cached - - if t == 1: - outputs.append(cached) - else: - outputs.append(cached.repeat(t, 1)) + h_idxs = torch.linspace(0, + num_grid_per_side - 1, + h, + dtype=torch.float32, + device=device) + w_idxs = torch.linspace(0, + num_grid_per_side - 1, + w, + dtype=torch.float32, + device=device) + + h_floor = h_idxs.to(torch.long) + w_floor = w_idxs.to(torch.long) + h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1) + w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1) + + dh = h_idxs - h_floor + dw = w_idxs - w_floor + + w00 = ((1 - dh)[:, None] * (1 - dw)[None, :]).reshape(-1) + w01 = ((1 - dh)[:, None] * dw[None, :]).reshape(-1) + w10 = (dh[:, None] * (1 - dw)[None, :]).reshape(-1) + w11 = (dh[:, None] * dw[None, :]).reshape(-1) + + idx00 = (h_floor[:, None] * num_grid_per_side + + w_floor[None, :]).reshape(-1) + idx01 = (h_floor[:, None] * num_grid_per_side + + w_ceil[None, :]).reshape(-1) + idx10 = (h_ceil[:, None] * num_grid_per_side + + w_floor[None, :]).reshape(-1) + idx11 = (h_ceil[:, None] * num_grid_per_side + + w_ceil[None, :]).reshape(-1) + + indices = torch.stack([idx00, idx01, idx10, idx11], dim=0) + weights = torch.stack([w00, w01, w10, w11], + dim=0).to(dtype=dtype, device=device) + weights = weights.unsqueeze(-1) + + embeds = F.embedding(indices, self.pos_embed.weight) + weighted_embeds = embeds * weights + p0, p1, p2, p3 = weighted_embeds.unbind(dim=0) + combined = ((p0 + p1) + p2) + p3 + + combined = combined.view(h * w, hidden_dim) + repeated = combined.unsqueeze(0).expand(t, -1, -1).contiguous() + repeated = repeated.view(t, h // m_size, m_size, w // m_size, + m_size, hidden_dim) + repeated = repeated.permute(0, 1, 3, 2, 4, + 5).reshape(-1, hidden_dim) + outputs.append(repeated) return torch.cat(outputs, dim=0) From cfb879e7337c8a59ca4237a2da4fabe53b86c4e5 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sun, 21 Sep 2025 06:45:00 +0000 Subject: [PATCH 3/4] update Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen3_vl.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 1524b96ff121..745bf5178c77 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -270,6 +270,7 @@ def __init__( self.temporal_patch_size = vision_config.temporal_patch_size self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes self.use_data_parallel = use_data_parallel + self.num_grid_per_side = int(self.num_position_embeddings**0.5) # NOTE: This is used for creating empty tensor for all_gather for # DP ViT. Here out_hidden_size is enlarged due to deepstack @@ -385,9 +386,7 @@ def fast_pos_embed_interpolate(self, grid_thw): else: grid_list = [tuple(int(v) for v in grid) for grid in grid_thw] - num_grid_per_side = int(self.num_position_embeddings**0.5) - device = self.pos_embed.weight.device - dtype = self.pos_embed.weight.dtype + num_grid_per_side = self.num_grid_per_side m_size = self.spatial_merge_size hidden_dim = self.pos_embed.embedding_dim @@ -397,12 +396,12 @@ def fast_pos_embed_interpolate(self, grid_thw): num_grid_per_side - 1, h, dtype=torch.float32, - device=device) + device=self.device) w_idxs = torch.linspace(0, num_grid_per_side - 1, w, dtype=torch.float32, - device=device) + device=self.device) h_floor = h_idxs.to(torch.long) w_floor = w_idxs.to(torch.long) @@ -428,13 +427,14 @@ def fast_pos_embed_interpolate(self, grid_thw): indices = torch.stack([idx00, idx01, idx10, idx11], dim=0) weights = torch.stack([w00, w01, w10, w11], - dim=0).to(dtype=dtype, device=device) + dim=0).to(dtype=self.dtype, + device=self.device) weights = weights.unsqueeze(-1) - embeds = F.embedding(indices, self.pos_embed.weight) + embeds = self.pos_embed(indices) weighted_embeds = embeds * weights p0, p1, p2, p3 = weighted_embeds.unbind(dim=0) - combined = ((p0 + p1) + p2) + p3 + combined = p0 + p1 + p2 + p3 combined = combined.view(h * w, hidden_dim) repeated = combined.unsqueeze(0).expand(t, -1, -1).contiguous() From dde249f36aefbb59561a762add21d0d9d0f31d2c Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sun, 21 Sep 2025 08:39:19 +0000 Subject: [PATCH 4/4] cleanup Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen3_vl.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 745bf5178c77..ca232e03767b 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -378,20 +378,15 @@ def rot_pos_emb(self, grid_thw): rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb - def fast_pos_embed_interpolate(self, grid_thw): - if isinstance(grid_thw, torch.Tensor): - grid_list = [ - tuple(int(v) for v in grid) for grid in grid_thw.tolist() - ] - else: - grid_list = [tuple(int(v) for v in grid) for grid in grid_thw] + def fast_pos_embed_interpolate(self, + grid_thw: list[list[int]]) -> torch.Tensor: num_grid_per_side = self.num_grid_per_side m_size = self.spatial_merge_size hidden_dim = self.pos_embed.embedding_dim outputs = [] - for t, h, w in grid_list: + for t, h, w in grid_thw: h_idxs = torch.linspace(0, num_grid_per_side - 1, h, @@ -469,12 +464,9 @@ def forward( hidden_states = hidden_states + pos_embeds rotary_pos_emb = self.rot_pos_emb(grid_thw) - if isinstance(grid_thw, list): - grid_thw_tensor = torch.tensor(grid_thw, - device=hidden_states.device, - dtype=torch.int32) - else: - grid_thw_tensor = grid_thw + grid_thw_tensor = torch.tensor(grid_thw, + device=self.device, + dtype=torch.int32) cu_seqlens = torch.repeat_interleave( grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], @@ -1216,7 +1208,8 @@ def _process_image_input( grid_thw_list, rope_type="rope_3d") else: - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + image_embeds = self.visual(pixel_values, + grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync