Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 11 additions & 18 deletions vllm/model_executor/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,6 @@ def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij")
h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij")
h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij")
h_floor_grid_idx = h_floor_grid * num_grid_per_side
h_ceil_grid_idx = h_ceil_grid * num_grid_per_side

# original computation of weights
# w00 = (1 - dh_grid) * (1 - dw_grid)
Expand All @@ -480,30 +478,25 @@ def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
w11 = dh_grid * dw_grid
w10 = dh_grid - w11
w01 = dw_grid - w11
w00 = 1 - dh_grid - dw_grid + w11
w00 = 1 - dh_grid - w01

idx00 = h_floor_grid_idx + w_floor_grid
idx01 = h_floor_grid_idx + w_ceil_grid
idx10 = h_ceil_grid_idx + w_floor_grid
idx11 = h_ceil_grid_idx + w_ceil_grid
h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid])
w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid])
h_grid_idx = h_grid * num_grid_per_side

indices = torch.stack([idx00, idx01, idx10, idx11], dim=0).reshape(4, -1)
indices = (h_grid_idx + w_grid).reshape(4, -1)
weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)
weights = weights.to(
dtype=self.dtype, device=self.device, non_blocking=True
)
weights = weights.to(dtype=self.dtype)
Comment on lines -492 to +489
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weights will already be on self.device so no need to copy it again.


embeds = self.pos_embed(indices)
weighted_embeds = embeds * weights
p0, p1, p2, p3 = weighted_embeds.unbind(dim=0)
combined = p0 + p1 + p2 + p3
combined = weighted_embeds.sum(dim=0)

combined = combined.view(h * w, hidden_dim)
repeated = combined.unsqueeze(0).expand(t, -1, -1).contiguous()
repeated = repeated.view(
t, h // m_size, m_size, w // m_size, m_size, hidden_dim
combined = combined.reshape(
h // m_size, m_size, w // m_size, m_size, hidden_dim
)
repeated = repeated.permute(0, 1, 3, 2, 4, 5).reshape(-1, hidden_dim)
combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim)
repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim)
outputs.append(repeated)

return torch.cat(outputs, dim=0)
Expand Down