In [1]:
# Requires transformers>=4.51.0

import torch
import torch.nn.functional as F

from torch import Tensor
from transformers import AutoTokenizer, AutoModel


def last_token_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]


def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'Instruct: {task_description}\nQuery:{query}'


tokenizer = AutoTokenizer.from_pretrained('/ytech_m2v5_hdd/workspace/kling_mm/yangsihan05/models/Qwen/Qwen3-Embedding-8B', padding_side='left')
model = AutoModel.from_pretrained('/ytech_m2v5_hdd/workspace/kling_mm/yangsihan05/models/Qwen/Qwen3-Embedding-8B')

# We recommend enabling flash_attention_2 for better acceleration and memory saving.
# model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-8B', attn_implementation="flash_attention_2", torch_dtype=torch.float16).cuda()

max_length = 8192





  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 4/4 [00:14<00:00,  3.58s/it]


In [3]:
model.to("cuda")

Qwen3Model(
  (embed_tokens): Embedding(151665, 4096)
  (layers): ModuleList(
    (0-35): 36 x Qwen3DecoderLayer(
      (self_attn): Qwen3Attention(
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
        (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
        (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
        (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
      )
      (mlp): Qwen3MLP(
        (gate_proj): Linear(in_features=4096, out_features=12288, bias=False)
        (up_proj): Linear(in_features=4096, out_features=12288, bias=False)
        (down_proj): Linear(in_features=12288, out_features=4096, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): Qwen3RMSNorm((4096,), eps=1e-06)
      (post_attention_layernorm): Qwen3RMSNorm((4096,), eps=1e-06)
    )
  )
  (norm): Qwen3RMSNorm((

In [4]:
# Each query must come with a one-sentence instruction that describes the task
task = 'Given a web search query, retrieve relevant passages that answer the query'

queries = [
    get_detailed_instruct(task, 'What is the capital of China?'),
    get_detailed_instruct(task, 'Explain gravity')
]
# No need to add instruction for retrieval documents
documents = [
    "The capital of China is Beijing.",
    "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
]
input_texts = queries + documents





In [5]:
input_texts

['Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery:What is the capital of China?',
 'Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery:Explain gravity',
 'The capital of China is Beijing.',
 'Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.']

In [6]:
# Tokenize the input texts
batch_dict = tokenizer(
    input_texts,
    padding=True,
    truncation=True,
    max_length=max_length,
    return_tensors="pt",
)
batch_dict.to(model.device)



{'input_ids': tensor([[151643, 151643, 151643, 151643,    641,   1235,     25,  16246,    264,
           3482,   2711,   3239,     11,  17179,   9760,  46769,    429,   4226,
            279,   3239,    198,   2859,     25,   3838,    374,    279,   6722,
            315,   5616,     30, 151643],
        [151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,    641,
           1235,     25,  16246,    264,   3482,   2711,   3239,     11,  17179,
           9760,  46769,    429,   4226,    279,   3239,    198,   2859,     25,
            840,  20772,  23249, 151643],
        [151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151643, 151643, 151643,    785,   6722,    315,   5616,
            374,  26549,     13, 151643],
        [ 38409,    374,    264,   5344,    429,  60091,   1378,  12866,   6974,
           1817,   1008,     13,   1084,   6696,  

In [18]:
outputs = model(**batch_dict, output_hidden_states=True)


In [22]:
outputs["hidden_states"].__len__()


37

In [23]:
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])



In [25]:
embeddings.shape

torch.Size([4, 4096])

In [26]:
# normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)
scores = (embeddings[:2] @ embeddings[2:].T)
print(scores.tolist())
# [[0.7493016123771667, 0.0750647559762001], [0.08795969933271408, 0.6318399906158447]]

[[0.7493014931678772, 0.07506485283374786], [0.08795969188213348, 0.6318402886390686]]
