Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 60 additions & 75 deletions vllm/model_executor/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def __init__(
self.temporal_patch_size = vision_config.temporal_patch_size
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
self.use_data_parallel = use_data_parallel
self.num_grid_per_side = int(self.num_position_embeddings**0.5)

# NOTE: This is used for creating empty tensor for all_gather for
# DP ViT. Here out_hidden_size is enlarged due to deepstack
Expand Down Expand Up @@ -377,82 +378,68 @@ def rot_pos_emb(self, grid_thw):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb

def fast_pos_embed_interpolate(self, grid_thw):
num_grid_per_side = int(self.num_position_embeddings**0.5)
def fast_pos_embed_interpolate(self,
grid_thw: list[list[int]]) -> torch.Tensor:

idx_list = [[] for _ in range(4)]
weight_list = [[] for _ in range(4)]
num_grid_per_side = self.num_grid_per_side
m_size = self.spatial_merge_size
hidden_dim = self.pos_embed.embedding_dim

outputs = []
for t, h, w in grid_thw:
h_idxs = torch.linspace(0,
num_grid_per_side - 1,
h,
dtype=torch.float32)
dtype=torch.float32,
device=self.device)
w_idxs = torch.linspace(0,
num_grid_per_side - 1,
w,
dtype=torch.float32)

h_idxs_floor = h_idxs.to(torch.long)
w_idxs_floor = w_idxs.to(torch.long)
h_idxs_ceil = torch.clamp(h_idxs.to(torch.long) + 1,
max=num_grid_per_side - 1)
w_idxs_ceil = torch.clamp(w_idxs.to(torch.long) + 1,
max=num_grid_per_side - 1)

dh = h_idxs - h_idxs_floor
dw = w_idxs - w_idxs_floor

idx_list[0].extend(((h_idxs_floor * num_grid_per_side)[None].T +
w_idxs_floor[None]).flatten().tolist() * t)
idx_list[1].extend(((h_idxs_floor * num_grid_per_side)[None].T +
w_idxs_ceil[None]).flatten().tolist() * t)
idx_list[2].extend(((h_idxs_ceil * num_grid_per_side)[None].T +
w_idxs_floor[None]).flatten().tolist() * t)
idx_list[3].extend(((h_idxs_ceil * num_grid_per_side)[None].T +
w_idxs_ceil[None]).flatten().tolist() * t)

weight_list[0].extend(
((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t)
weight_list[1].extend(
((1 - dh)[None].T * dw[None]).flatten().tolist() * t)
weight_list[2].extend(
(dh[None].T * (1 - dw)[None]).flatten().tolist() * t)
weight_list[3].extend(
(dh[None].T * dw[None]).flatten().tolist() * t)

device = self.pos_embed.weight.device
dtype = self.pos_embed.weight.dtype

p0 = self.pos_embed(
torch.tensor(
idx_list[0], dtype=torch.long, device=device)) * torch.tensor(
weight_list[0], dtype=dtype, device=device)[:, None]
p1 = self.pos_embed(
torch.tensor(
idx_list[1], dtype=torch.long, device=device)) * torch.tensor(
weight_list[1], dtype=dtype, device=device)[:, None]
p2 = self.pos_embed(
torch.tensor(
idx_list[2], dtype=torch.long, device=device)) * torch.tensor(
weight_list[2], dtype=dtype, device=device)[:, None]
p3 = self.pos_embed(
torch.tensor(
idx_list[3], dtype=torch.long, device=device)) * torch.tensor(
weight_list[3], dtype=dtype, device=device)[:, None]

patch_pos_embeds = p0 + p1 + p2 + p3
patch_pos_embeds = patch_pos_embeds.split(
[t * h * w for t, h, w in grid_thw])
patch_pos_embeds_permute = []
m_size = self.spatial_merge_size
for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw):
pos_embed = pos_embed.view(t, h // m_size, m_size, w // m_size,
m_size, -1).permute(0, 1, 3, 2, 4,
5).flatten(0, 4)
patch_pos_embeds_permute.append(pos_embed)
patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
return patch_pos_embeds
dtype=torch.float32,
device=self.device)

h_floor = h_idxs.to(torch.long)
w_floor = w_idxs.to(torch.long)
h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)

dh = h_idxs - h_floor
dw = w_idxs - w_floor

w00 = ((1 - dh)[:, None] * (1 - dw)[None, :]).reshape(-1)
w01 = ((1 - dh)[:, None] * dw[None, :]).reshape(-1)
w10 = (dh[:, None] * (1 - dw)[None, :]).reshape(-1)
w11 = (dh[:, None] * dw[None, :]).reshape(-1)

idx00 = (h_floor[:, None] * num_grid_per_side +
w_floor[None, :]).reshape(-1)
idx01 = (h_floor[:, None] * num_grid_per_side +
w_ceil[None, :]).reshape(-1)
idx10 = (h_ceil[:, None] * num_grid_per_side +
w_floor[None, :]).reshape(-1)
idx11 = (h_ceil[:, None] * num_grid_per_side +
w_ceil[None, :]).reshape(-1)

indices = torch.stack([idx00, idx01, idx10, idx11], dim=0)
weights = torch.stack([w00, w01, w10, w11],
dim=0).to(dtype=self.dtype,
device=self.device)
weights = weights.unsqueeze(-1)

embeds = self.pos_embed(indices)
weighted_embeds = embeds * weights
p0, p1, p2, p3 = weighted_embeds.unbind(dim=0)
combined = p0 + p1 + p2 + p3

combined = combined.view(h * w, hidden_dim)
repeated = combined.unsqueeze(0).expand(t, -1, -1).contiguous()
repeated = repeated.view(t, h // m_size, m_size, w // m_size,
m_size, hidden_dim)
repeated = repeated.permute(0, 1, 3, 2, 4,
5).reshape(-1, hidden_dim)
outputs.append(repeated)

return torch.cat(outputs, dim=0)

def compute_attn_mask_seqlen(
self,
Expand All @@ -477,12 +464,9 @@ def forward(
hidden_states = hidden_states + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)

if isinstance(grid_thw, list):
grid_thw_tensor = torch.tensor(grid_thw,
device=hidden_states.device,
dtype=torch.int32)
else:
grid_thw_tensor = grid_thw
grid_thw_tensor = torch.tensor(grid_thw,
device=self.device,
dtype=torch.int32)

cu_seqlens = torch.repeat_interleave(
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2],
Expand Down Expand Up @@ -1224,7 +1208,8 @@ def _process_image_input(
grid_thw_list,
rope_type="rope_3d")
else:
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
image_embeds = self.visual(pixel_values,
grid_thw=grid_thw_list)

# Split concatenated embeddings for each image item.
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
Expand Down Expand Up @@ -1526,4 +1511,4 @@ def get_mm_mapping(self) -> MultiModelKeys:
language_model="language_model",
connector="model.visual.merger",
tower_model="model.visual.",
)
)