diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 6a7d2eaeab3b..9f7a77338452 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -467,8 +467,6 @@ def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij") h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij") h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij") - h_floor_grid_idx = h_floor_grid * num_grid_per_side - h_ceil_grid_idx = h_ceil_grid * num_grid_per_side # original computation of weights # w00 = (1 - dh_grid) * (1 - dw_grid) @@ -480,30 +478,25 @@ def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: w11 = dh_grid * dw_grid w10 = dh_grid - w11 w01 = dw_grid - w11 - w00 = 1 - dh_grid - dw_grid + w11 + w00 = 1 - dh_grid - w01 - idx00 = h_floor_grid_idx + w_floor_grid - idx01 = h_floor_grid_idx + w_ceil_grid - idx10 = h_ceil_grid_idx + w_floor_grid - idx11 = h_ceil_grid_idx + w_ceil_grid + h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid]) + w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid]) + h_grid_idx = h_grid * num_grid_per_side - indices = torch.stack([idx00, idx01, idx10, idx11], dim=0).reshape(4, -1) + indices = (h_grid_idx + w_grid).reshape(4, -1) weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1) - weights = weights.to( - dtype=self.dtype, device=self.device, non_blocking=True - ) + weights = weights.to(dtype=self.dtype) embeds = self.pos_embed(indices) weighted_embeds = embeds * weights - p0, p1, p2, p3 = weighted_embeds.unbind(dim=0) - combined = p0 + p1 + p2 + p3 + combined = weighted_embeds.sum(dim=0) - combined = combined.view(h * w, hidden_dim) - repeated = combined.unsqueeze(0).expand(t, -1, -1).contiguous() - repeated = repeated.view( - t, h // m_size, m_size, w // m_size, m_size, hidden_dim + combined = combined.reshape( + h // m_size, m_size, w // m_size, m_size, hidden_dim ) - repeated = repeated.permute(0, 1, 3, 2, 4, 5).reshape(-1, hidden_dim) + combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim) + repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim) outputs.append(repeated) return torch.cat(outputs, dim=0)