In [1]:
import collections
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import numpy as np
import random
import math
import torch.nn as nn

device = torch.device('cuda')

  from .autonotebook import tqdm as notebook_tqdm


# load file

In [2]:
# data: Penn Tree Bank, sampled from Wall Street Journal articles.
# each line = a sentence

with open('/home/tian/Projects/d2l/data/ptb/ptb.train.txt', 'r') as f:
    raw_train_text = f.read()

raw_train_text = raw_train_text.split('\n') # num_line, num_words_per_line
raw_train_text = [line.split() for line in raw_train_text]

In [3]:
print(raw_train_text[:5])

[['aer', 'banknote', 'berlitz', 'calloway', 'centrust', 'cluett', 'fromstein', 'gitano', 'guterman', 'hydro-quebec', 'ipo', 'kia', 'memotec', 'mlx', 'nahb', 'punts', 'rake', 'regatta', 'rubens', 'sim', 'snack-food', 'ssangyong', 'swapo', 'wachter'], ['pierre', '<unk>', 'N', 'years', 'old', 'will', 'join', 'the', 'board', 'as', 'a', 'nonexecutive', 'director', 'nov.', 'N'], ['mr.', '<unk>', 'is', 'chairman', 'of', '<unk>', 'n.v.', 'the', 'dutch', 'publishing', 'group'], ['rudolph', '<unk>', 'N', 'years', 'old', 'and', 'former', 'chairman', 'of', 'consolidated', 'gold', 'fields', 'plc', 'was', 'named', 'a', 'nonexecutive', 'director', 'of', 'this', 'british', 'industrial', 'conglomerate'], ['a', 'form', 'of', 'asbestos', 'once', 'used', 'to', 'make', 'kent', 'cigarette', 'filters', 'has', 'caused', 'a', 'high', 'percentage', 'of', 'cancer', 'deaths', 'among', 'a', 'group', 'of', 'workers', 'exposed', 'to', 'it', 'more', 'than', 'N', 'years', 'ago', 'researchers', 'reported']]


# Subsampling, getting subsampled vocab and corpus

In [4]:
class ptbVocab:
    def __init__(self, raw_text, min_freq):
        self.n_sentences = len(raw_text)
        raw_text_flatten = [w for line in raw_text for w in line]
        counter = collections.Counter(raw_text_flatten)
        print(f'The original #vocab: {len(counter)}')

        # any word that appears less than 10 times is replaced by the “<unk>” token
        min_freq_text = []
        min_freq_text_flatten = []
        for line in raw_text:
            new_line = []
            for word in line:
                if counter[word] < min_freq:
                    new_line.append('<unk>')
                else:
                    new_line.append(word)
            min_freq_text.append(new_line)
            min_freq_text_flatten.extend(new_line)

        print(f'After discarded freq < {min_freq}, #vocab: {len(set(min_freq_text_flatten))}')

        counter_after_min_freq = collections.Counter(min_freq_text_flatten)
        
        # subsampling to possibly discard the words with high freq like "a", "the"... do not interfere the #vocab
        def keep(token):
            return(random.uniform(0, 1) < math.sqrt(1e-4 / counter_after_min_freq[token] * len(min_freq_text_flatten)))
        self.text_subsampled = [[token for token in line if keep(token)] for line in min_freq_text]
        self.text_subsampled_flatten = [w for line in self.text_subsampled for w in line]
        self.n_words_subsampled = len(self.text_subsampled_flatten)
        print(f'After subsampled, the #vocab: {len(set(self.text_subsampled_flatten))}')
        print(f'The original #words: {len(raw_text_flatten)}\nThe new #words: {len(self.text_subsampled_flatten)}')
        self.counter = counter_after_min_freq


In [5]:
train_vocab = ptbVocab(raw_train_text, 10)

The original #vocab: 9999
After discarded freq < 10, #vocab: 6719
After subsampled, the #vocab: 6719
The original #words: 887521
The new #words: 354506


In [6]:
vocab_c = train_vocab.counter
vocab_c = list(vocab_c.keys())

# make the <unk> to 0
if vocab_c[0] != '<unk>':
    replaced_word = vocab_c[0]
    for i, word in enumerate(vocab_c):
        if word == '<unk>':
            break
    
    vocab_c[i] = vocab_c[0]
    vocab_c[0] = '<unk>'


vocab_word_to_id = {}
vocab_id_to_word = {}
for idx, word in enumerate(vocab_c):
    vocab_word_to_id[word] = idx
    vocab_id_to_word[idx] = word

In [7]:
vocab_id_to_word[0]

'<unk>'

In [8]:
corpus = []
for line in train_vocab.text_subsampled:
    l = []
    for word in line:
        l.append(vocab_word_to_id[word])

    corpus.append(l)

In [9]:
len(corpus) == train_vocab.n_sentences

True

# Extracting center words and context words

In [10]:
def get_centers_and_contexts(corpus, max_window_size):
    """Return center words and context words in skip-gram."""
    centers, contexts = [], []
    for line in corpus:
        # To form a "center word--context word" pair, each sentence needs to
        # have at least 2 words
        if len(line) < 2:
            continue
        centers += line
        for i in range(len(line)):  # Context window centered at `i`
            window_size = random.randint(1, max_window_size)
            indices = list(range(max(0, i - window_size),
                                 min(len(line), i + 1 + window_size)))
            # Exclude the center word from the context words
            indices.remove(i)
            contexts.append([line[idx] for idx in indices])
    return centers, contexts

In [11]:
tiny_dataset = [list(range(7)), list(range(7, 10))]
print('dataset', tiny_dataset)
for center, context in zip(*get_centers_and_contexts(tiny_dataset, 4)):
    print('center', center, 'has contexts', context)

dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]
center 0 has contexts [1, 2]
center 1 has contexts [0, 2, 3]
center 2 has contexts [0, 1, 3, 4, 5]
center 3 has contexts [0, 1, 2, 4, 5, 6]
center 4 has contexts [0, 1, 2, 3, 5, 6]
center 5 has contexts [3, 4, 6]
center 6 has contexts [2, 3, 4, 5]
center 7 has contexts [8, 9]
center 8 has contexts [7, 9]
center 9 has contexts [8]


In [12]:
all_centers, all_contexts = get_centers_and_contexts(corpus, 5)
f'# center-context pairs: {sum([len(contexts) for contexts in all_contexts])}'

'# center-context pairs: 1562333'

In [13]:
len(all_centers) == len(all_contexts)

True

In [14]:
vocab_word_to_id['<unk>']

0

# Negative sampling

In [15]:
class RandomGenerator:
    """Randomly draw among {1, ..., n} according to n sampling weights."""
    def __init__(self, sampling_weights):
        # Exclude
        self.population = list(range(1, len(sampling_weights) + 1))
        self.sampling_weights = sampling_weights
        self.candidates = []
        self.i = 0

    def draw(self):
        if self.i == len(self.candidates):
            # Cache `k` random sampling results
            self.candidates = random.choices(
                self.population, self.sampling_weights, k=10000)
            self.i = 0
        self.i += 1
        return self.candidates[self.i - 1]

In [16]:
generator = RandomGenerator([2, 3, 4])
[generator.draw() for _ in range(10)]

[3, 3, 1, 3, 1, 2, 1, 1, 1, 2]

In [17]:
def get_negatives(all_contexts, id_to_word, counter, K):
    """Return noise words in negative sampling."""
    # Sampling weights for words with indices 1, 2, ... (index 0 is the excluded unknown token) in the vocabulary
    sampling_weights = [counter[id_to_word[i]]**0.75
                        for i in range(1, len(id_to_word))]
    all_negatives, generator = [], RandomGenerator(sampling_weights)
    for contexts in all_contexts:
        negatives = []
        while len(negatives) < len(contexts) * K:
            neg = generator.draw()
            # Noise words cannot be context words
            if neg not in contexts:
                negatives.append(neg)
        all_negatives.append(negatives)
    return all_negatives

all_negatives = get_negatives(all_contexts, vocab_id_to_word, train_vocab.counter, 5)

# Dataset and DataLoader

In [18]:
class PTBDataset(Dataset):
    def __init__(self, centers, contexts, negatives):
        assert len(centers) == len(contexts) == len(negatives)
        self.centers = centers
        # self.contexts = [line+[0]*(max_window_size-len(line)) for line in contexts if len(line)<max_window_size]
        self.contexts = contexts
        # self.negatives = [line+[0]*(max_window_size-len(line)) for line in contexts if len(line)<max_window_size]
        self.negatives = negatives

    def __getitem__(self, index):
        # center = torch.tensor(self.centers[index])
        
        return (self.centers[index], self.contexts[index], self.negatives[index])
    
    def __len__(self):
        return len(self.centers)
    

In [19]:
# for every batch, do these

def batchify(data):
    """Return a minibatch of examples for skip-gram with negative sampling."""
    max_len = max(len(c) + len(n) for _, c, n in data)
    centers, contexts_negatives, masks, labels = [], [], [], []
    for center, context, negative in data:
        cur_len = len(context) + len(negative)
        centers += [center]
        contexts_negatives += [context + negative + [0] * (max_len - cur_len)]
        masks += [[1] * cur_len + [0] * (max_len - cur_len)]
        labels += [[1] * len(context) + [0] * (max_len - len(context))]
    return (torch.tensor(centers).reshape((-1, 1)), torch.tensor(
        contexts_negatives), torch.tensor(masks), torch.tensor(labels))

In [20]:
x_1 = (1, [2, 2], [3, 3, 3, 3])
x_2 = (1, [2, 2, 2], [3, 3, 3, 3, 3, 3])
batch = batchify((x_1, x_2))

names = ['centers', 'contexts_negatives', 'masks', 'labels']
for name, data in zip(names, batch):
    print(name, '=', data)

centers = tensor([[1],
        [1]])
contexts_negatives = tensor([[2, 2, 3, 3, 3, 3, 0, 0, 0],
        [2, 2, 2, 3, 3, 3, 3, 3, 3]])
masks = tensor([[1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1]])
labels = tensor([[1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0]])


In [21]:
# from raw_text to training_dataset

train_vocab = ptbVocab(raw_train_text, 10)

vocab_c = train_vocab.counter
vocab_c = list(vocab_c.keys())

# make the <unk> to 0
if vocab_c[0] != '<unk>':
    replaced_word = vocab_c[0]
    for i, word in enumerate(vocab_c):
        if word == '<unk>':
            break
    
    vocab_c[i] = vocab_c[0]
    vocab_c[0] = '<unk>'

vocab_word_to_id = {}
vocab_id_to_word = {}
for a, b in enumerate(vocab_c):
    vocab_word_to_id[b] = a
    vocab_id_to_word[a] = b


corpus = []
for line in train_vocab.text_subsampled:
    l = []
    for word in line:
        l.append(vocab_word_to_id[word])

    corpus.append(l)


all_centers, all_contexts = get_centers_and_contexts(corpus, 5)

all_negatives = get_negatives(all_contexts, vocab_id_to_word, train_vocab.counter, 5)

The original #vocab: 9999
After discarded freq < 10, #vocab: 6719
After subsampled, the #vocab: 6719
The original #words: 887521
The new #words: 353932


In [22]:
batch_size = 512

train_data = PTBDataset(centers=all_centers, contexts=all_contexts, negatives=all_negatives)
train_dataloader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, collate_fn=batchify, num_workers=4)

# For every batch_size: (0: center_word, 1: context+negative_sample, 2: mask, 3: label)
# label: if it is the context

In [23]:
a = iter(train_dataloader)
b = next(a)
print(f'{b[1][0]}\n{b[2][0]}\n{b[3][0]}')

tensor([3050,   63,  288,  533, 2915, 2967, 2791,  448,  764,   57,  337,  997,
         477,  290,    9,  127,  642,  637,  684, 2882,   78,  135,   99, 6419,
         640, 3418,  244,  669, 6605, 4468, 1045,  167, 4385, 3531, 2677, 4183,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0])
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])


# Model

Skipgram: input: context, output: center_word

In [24]:
# forward() can have only one input

class skip_gram(nn.Module):
      def __init__(self, vocab_size, hidden_size=100):
            super(skip_gram, self).__init__()
            self.embed1= nn.Embedding(num_embeddings=vocab_size, embedding_dim=hidden_size)
            def init_weights(module):
                  if type(module) == nn.Embedding:
                        nn.init.xavier_uniform_(module.weight)
            self.embed1.apply(init_weights)

      def forward(self, x):
            x = self.embed1(x)
            return x

In [25]:
center_model = skip_gram(vocab_size=len(vocab_id_to_word)).to(device)
context_model = skip_gram(vocab_size=len(vocab_id_to_word)).to(device)

In [26]:
a = iter(train_dataloader)
center, context_negaive, mask, label = next(a)
center = center.to(device)
context_negaive = context_negaive.to(device)

In [27]:
context_negaive.shape, center.shape

(torch.Size([512, 60]), torch.Size([512, 1]))

In [28]:
center_model(center).shape

torch.Size([512, 1, 100])

In [29]:
context_model(context_negaive).permute(0,2,1).shape

torch.Size([512, 100, 60])

In [30]:
torch.bmm(center_model(center), context_model(context_negaive).permute(0,2,1))[0][0]
# 512, 60

tensor([ 6.4093e-04, -4.9274e-03,  5.8413e-03, -4.5686e-03, -5.5685e-03,
        -1.2810e-04, -3.8573e-03, -4.5488e-04,  6.0015e-03, -4.9525e-04,
         2.5522e-03, -5.7410e-03,  1.6082e-03, -6.8115e-03,  4.3481e-03,
        -4.8595e-03, -3.6823e-03,  4.7506e-03, -4.6159e-03,  7.0750e-04,
        -1.3938e-04, -1.0319e-03,  1.5741e-03, -1.5019e-03,  2.9298e-04,
        -1.6903e-05, -6.2466e-04, -8.9285e-05,  2.5837e-03,  7.9055e-04,
        -2.1683e-03, -2.1683e-03, -2.1683e-03, -2.1683e-03, -2.1683e-03,
        -2.1683e-03, -2.1683e-03, -2.1683e-03, -2.1683e-03, -2.1683e-03,
        -2.1683e-03, -2.1683e-03, -2.1683e-03, -2.1683e-03, -2.1683e-03,
        -2.1683e-03, -2.1683e-03, -2.1683e-03, -2.1683e-03, -2.1683e-03,
        -2.1683e-03, -2.1683e-03, -2.1683e-03, -2.1683e-03, -2.1683e-03,
        -2.1683e-03, -2.1683e-03, -2.1683e-03, -2.1683e-03, -2.1683e-03],
       device='cuda:0', grad_fn=<SelectBackward0>)

# Training

In [31]:
# hyperparameter (batch_size in Dataset and DataLoader)
num_epochs = 50
learning_rate = 0.005

In [32]:
#loss function
class SigmoidBCELoss(nn.Module):
    # Binary cross-entropy loss with masking
    def __init__(self):
        super().__init__()

    def forward(self, inputs, target, mask=None):
        inputs = inputs.float() # the nn.functional.binary_cross_entropy_with_logits() expects a float input
        target = target.float() # the nn.functional.binary_cross_entropy_with_logits() expects a float input
        out = nn.functional.binary_cross_entropy_with_logits(
            inputs, target, weight=mask, reduction="none")
        out = out.mean(dim=1)
        out = out * mask.shape[1] / mask.sum(axis=1)
        return out

criterion = SigmoidBCELoss()    # This loss combines a Sigmoid layer and the BCELoss in one single class.
optimizer = torch.optim.Adam(list(center_model.parameters()) + list(context_model.parameters()), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)

In [33]:
pred = torch.tensor([[1.1, -2.2, 3.3, -4.4]] * 2)
label = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]])
mask = torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0]])
criterion(pred, label, mask=mask)

tensor([0.9352, 1.8462])

In [34]:
# an uneffective way to implement BCELoss
def sigmd(x):
    return -math.log(1 / (1 + math.exp(-x)))

print(f'{(sigmd(1.1) + sigmd(2.2) + sigmd(-3.3) + sigmd(4.4)) / 4:.4f}')
print(f'{(sigmd(-1.1) + sigmd(-2.2)) / 2:.4f}')

0.9352
1.8462


In [35]:
n_iter = len(train_dataloader)

In [36]:
for epoch in range(num_epochs):
    for i, (center, context_negaive, mask, label) in enumerate(train_dataloader):
        center = center.to(device)
        context_negaive = context_negaive.to(device)
        mask = mask.to(device)
        label = label.to(device)

        v = center_model(center)
        u = context_model(context_negaive)
        pred = torch.bmm(v, u.permute(0, 2, 1))


        loss = criterion(pred.reshape(label.shape), label, mask)
        optimizer.zero_grad()
        loss.sum().backward()
        optimizer.step()

        if (i+1) % (batch_size // 50) == 0 or (i+1) == batch_size:
            print(f'Epoch {epoch+1} / {num_epochs}: {i+1} / {n_iter} loss = {loss.sum().item()}')

    scheduler.step()


Epoch 1 / 50: 10 / 690 loss = 354.70025634765625
Epoch 1 / 50: 20 / 690 loss = 352.1670837402344
Epoch 1 / 50: 30 / 690 loss = 338.9978332519531
Epoch 1 / 50: 40 / 690 loss = 312.15032958984375
Epoch 1 / 50: 50 / 690 loss = 277.7599792480469
Epoch 1 / 50: 60 / 690 loss = 262.46710205078125
Epoch 1 / 50: 70 / 690 loss = 252.3597412109375
Epoch 1 / 50: 80 / 690 loss = 244.75291442871094
Epoch 1 / 50: 90 / 690 loss = 239.13592529296875
Epoch 1 / 50: 100 / 690 loss = 236.80908203125
Epoch 1 / 50: 110 / 690 loss = 236.852783203125
Epoch 1 / 50: 120 / 690 loss = 231.2360076904297
Epoch 1 / 50: 130 / 690 loss = 234.6048126220703
Epoch 1 / 50: 140 / 690 loss = 232.0992431640625
Epoch 1 / 50: 150 / 690 loss = 231.26754760742188
Epoch 1 / 50: 160 / 690 loss = 229.1293182373047
Epoch 1 / 50: 170 / 690 loss = 229.7235870361328
Epoch 1 / 50: 180 / 690 loss = 228.3643035888672
Epoch 1 / 50: 190 / 690 loss = 227.9276885986328
Epoch 1 / 50: 200 / 690 loss = 228.53836059570312
Epoch 1 / 50: 210 / 690 l

# Inference

In [37]:
center_model.eval()
context_model.eval()


skip_gram(
  (embed1): Embedding(6719, 100)
)

In [38]:
def get_similar_tokens(query_token, k, model):
    W = next(iter(model.state_dict().items()))[1]
    print(vocab_word_to_id[query_token])
    x = W[vocab_word_to_id[query_token]]
    # Compute the cosine similarity. Add 1e-9 for numerical stability
    # torch.mv : Performs a matrix-vector product of the matrix input and the vector vec.
    cos = torch.mv(W, x) / torch.sqrt(torch.sum(W * W, dim=1) *
                                      torch.sum(x * x) + 1e-9)
    topk = torch.topk(cos, k=k+1)[1].cpu().numpy().astype('int32')
    for i in topk[1:]:  # Remove the input words
        print(f'cosine sim={float(cos[i]):.3f}: {vocab_id_to_word[i]}')


get_similar_tokens('computer', 3, center_model)

994
cosine sim=0.681: microprocessor
cosine sim=0.662: hardware
cosine sim=0.637: computers


In [39]:
get_similar_tokens('computer', 3, context_model)

994
cosine sim=0.772: microprocessor
cosine sim=0.728: bugs
cosine sim=0.728: hardware
