In [41]:
import torch
from transformers import AutoTokenizer, AutoModel

In [44]:
# 1. Load the pre-trained tokenizer
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

In [62]:
# 2. Prepare the input text
text = "DistilBERT is a fast and light Transformer model."

inputs = tokenizer(text, return_tensors="pt")

inputs['input_ids'].shape, inputs['input_ids']

(torch.Size([1, 14]),
 tensor([[  101,  4487, 16643, 23373,  2003,  1037,  3435,  1998,  2422, 10938,
           2121,  2944,  1012,   102]]))

In [49]:
print(model)

DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0-5): 6 x TransformerBlock(
        (attention): DistilBertSdpaAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): L

In [52]:
# 3. Extract features (perform the forward pass)
with torch.no_grad():
    outputs = model(**inputs)


print(outputs.last_hidden_state.shape)

torch.Size([1, 14, 768])


In [53]:
# 4. Get the last hidden state
# Shape: [batch_size, sequence_length, hidden_size]
last_hidden_states = outputs.last_hidden_state

In [56]:
# 4. Extract the [CLS] token representation (the "feature")
# We take the first token (index 0) from the sequence
cls_feature = last_hidden_states[:, 0, :]
cls_feature.shape

torch.Size([1, 768])