In [15]:
import torch
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import IMDB

In [6]:
corpus = [
    "I love deep learning",
    "I love NLP",
    "PyTorch is amazing for NLP"
]

tokenizer = get_tokenizer("basic_english")

tokenized_corpus = [tokenizer(sentence) for sentence in corpus]
print(tokenized_corpus)

[['i', 'love', 'deep', 'learning'], ['i', 'love', 'nlp'], ['pytorch', 'is', 'amazing', 'for', 'nlp']]


In [8]:
vocab = build_vocab_from_iterator( tokenized_corpus, specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

In [9]:
print("Vocab size:", len(vocab))
print("Word -> Index mapping:", vocab.get_stoi())

Vocab size: 10
Word -> Index mapping: {'pytorch': 9, 'learning': 8, 'is': 7, 'for': 6, 'deep': 5, 'amazing': 4, 'nlp': 3, 'i': 1, 'love': 2, '<unk>': 0}


In [10]:
numericalized = [vocab(tokens) for tokens in tokenized_corpus]
print("Numericalized corpus:", numericalized)

Numericalized corpus: [[1, 2, 5, 8], [1, 2, 3], [9, 7, 4, 6, 3]]


In [14]:
embedding_dim = 50
embedding = torch.nn.Embedding(num_embeddings=len(vocab), embedding_dim=embedding_dim)

# Convert one sentence to tensor
sentence_tensor = torch.tensor(numericalized[0])  # "i love deep learning"
embedded_sentence = embedding(sentence_tensor)

print("Embedded shape:", embedded_sentence.shape)  # [sentence_length, embedding_dim]
print(embedded_sentence)


Embedded shape: torch.Size([4, 50])
tensor([[-0.0500,  0.6587,  1.1648, -1.0693,  0.3712, -0.2256, -1.1257,  0.3748,
          0.4174,  0.1633,  1.6521,  0.6422,  0.3396,  0.6409, -0.6120, -1.5655,
          0.1120,  0.8527, -0.8310,  0.2124, -0.8584,  0.7755,  1.0168, -0.5937,
         -1.2423, -1.2420,  0.6951, -0.1970, -1.3777,  0.1003, -0.0777, -1.9960,
          0.2769,  0.0621,  0.2846,  1.0130, -1.5226, -0.0774,  0.5942,  0.7252,
         -0.1557, -1.1899,  0.8966,  0.5493, -0.0288,  1.4236, -0.3089, -2.8745,
          0.4513, -0.2730],
        [ 1.8075, -0.2100,  1.6299, -1.9432, -1.4332, -0.5559, -0.1015, -0.0674,
          0.5600,  0.0114,  0.4599,  0.0794,  0.4144,  0.7365,  2.0062,  0.4038,
          0.5327, -0.3421, -0.1681,  0.7718, -0.8312, -0.3992,  0.5420, -1.0114,
          1.4819, -0.3148, -0.1770,  1.0229, -0.3627, -0.8726, -1.0074, -0.9971,
         -0.7400, -0.0162,  0.3356,  0.9134,  0.3517, -1.8829,  1.1414,  0.7288,
         -0.9678, -0.1095,  0.4044, -1.9440, 