## PyTorch: nn.embedding

In [11]:
import torch
import torch.nn as nn

## Without nn.Embedding()

In [6]:
train_data = 'you need to know how to code'
word_set = set(train_data.split()) # 중복을 제거한 단어들의 집합인 단어 집합 생성.
vocab = {word: i+2 for i, word in enumerate(word_set)}  # 단어 집합의 각 단어에 고유한 정수 맵핑.
vocab['<unk>'] = 0
vocab['<pad>'] = 1
print(vocab)

{'to': 2, 'need': 3, 'you': 4, 'know': 5, 'code': 6, 'how': 7, '<unk>': 0, '<pad>': 1}


In [7]:
embedding_table = torch.FloatTensor([
                               [ 0.0,  0.0,  0.0],
                               [ 0.0,  0.0,  0.0],
                               [ 0.2,  0.9,  0.3],
                               [ 0.1,  0.5,  0.7],
                               [ 0.2,  0.1,  0.8],
                               [ 0.4,  0.1,  0.1],
                               [ 0.1,  0.8,  0.9],
                               [ 0.6,  0.1,  0.1]])

In [9]:
sample = 'you need to run'.split()
idxes=[]
for word in sample:
  try:
    idxes.append(vocab[word])
  except KeyError: # Replace to <unk> if the word does not exist
    idxes.append(vocab['<unk>'])
idxes = torch.LongTensor(idxes)

lookup_result = embedding_table[idxes, :] # bring the index from the embedding table
print(lookup_result)

tensor([[0.2000, 0.1000, 0.8000],
        [0.1000, 0.5000, 0.7000],
        [0.2000, 0.9000, 0.3000],
        [0.0000, 0.0000, 0.0000]])


## Using nn.Embedding()

In [10]:
train_data = 'you need to know how to code'
word_set = set(train_data.split()) # Remove the overlaps
vocab = {tkn: i+2 for i, tkn in enumerate(word_set)}  # 단어 집합의 각 단어에 고유한 정수 맵핑.
vocab['<unk>'] = 0
vocab['<pad>'] = 1

In [12]:
embedding_layer = nn.Embedding(num_embeddings = len(vocab), 
                               embedding_dim = 3,
                               padding_idx = 1)

In [13]:
print(embedding_layer.weight)

Parameter containing:
tensor([[ 1.0043,  0.0257,  0.3820],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.9507,  0.7400, -0.5485],
        [ 0.7646, -2.1381,  0.2586],
        [-0.0106, -0.0385,  1.8724],
        [-0.1384, -0.6268, -0.5356],
        [-1.4032,  0.0541, -2.5304],
        [-0.4924, -0.1660, -0.2354]], requires_grad=True)
