In [1]:
import os
import collections
import torch
import torch.utils.data.dataset as Dataset
import torch.utils.data.dataloader as DataLoader
import matplotlib.pyplot as plt
import numpy as np
import random
import math

  from .autonotebook import tqdm as notebook_tqdm


# load file

In [2]:
with open('/home/tian/Projects/d2l/data/ptb/ptb.train.txt', 'r') as f:
    raw_train_text = f.read()

raw_train_text = raw_train_text.split('\n') # num_line, num_words_per_line
raw_train_text = [line.split() for line in raw_train_text]

In [3]:
print(raw_train_text[:5])

[['aer', 'banknote', 'berlitz', 'calloway', 'centrust', 'cluett', 'fromstein', 'gitano', 'guterman', 'hydro-quebec', 'ipo', 'kia', 'memotec', 'mlx', 'nahb', 'punts', 'rake', 'regatta', 'rubens', 'sim', 'snack-food', 'ssangyong', 'swapo', 'wachter'], ['pierre', '<unk>', 'N', 'years', 'old', 'will', 'join', 'the', 'board', 'as', 'a', 'nonexecutive', 'director', 'nov.', 'N'], ['mr.', '<unk>', 'is', 'chairman', 'of', '<unk>', 'n.v.', 'the', 'dutch', 'publishing', 'group'], ['rudolph', '<unk>', 'N', 'years', 'old', 'and', 'former', 'chairman', 'of', 'consolidated', 'gold', 'fields', 'plc', 'was', 'named', 'a', 'nonexecutive', 'director', 'of', 'this', 'british', 'industrial', 'conglomerate'], ['a', 'form', 'of', 'asbestos', 'once', 'used', 'to', 'make', 'kent', 'cigarette', 'filters', 'has', 'caused', 'a', 'high', 'percentage', 'of', 'cancer', 'deaths', 'among', 'a', 'group', 'of', 'workers', 'exposed', 'to', 'it', 'more', 'than', 'N', 'years', 'ago', 'researchers', 'reported']]


# Subsampling, getting subsampled vocab and corpus

In [26]:
class ptbVocab:
    def __init__(self, raw_text):
        self.sentence_length = len(raw_text)
        raw_text_flatten = [w for line in raw_text for w in line]
        
        self.counter = collections.Counter(raw_text_flatten)

        self.n_words_original= sum(self.counter.values())
        
        #subsampling to discard the words with high freq like "a", "the"..
        def keep(token):
            return(random.uniform(0, 1) < math.sqrt(1e-4 / self.counter[token] * self.n_words_original))
        self.text_subsampled = [[token for token in line if keep(token)] for line in raw_text]
        self.text_subsampled_flatten = [w for line in self.text_subsampled for w in line]
        self.n_words_subsampled = len(self.text_subsampled_flatten)

In [27]:
train_vocab = ptbVocab(raw_train_text)

In [30]:
# compare original and subsampled text
train_vocab.n_words_original, train_vocab.n_words_subsampled

(887521, 376236)

In [31]:
vocab_c = collections.Counter(train_vocab.text_subsampled_flatten)
vocab_c = vocab_c.keys()
vocab = {b:a for a, b in enumerate(vocab_c)}

In [33]:
corpus = []
for line in train_vocab.text_subsampled:
    l = []
    for word in line:
        l.append(vocab[word])

    corpus.append(l)

# Extracting center words and context words

In [42]:
def get_centers_and_contexts(corpus, max_window_size):
    """Return center words and context words in skip-gram."""
    centers, contexts = [], []
    for line in corpus:
        # To form a "center word--context word" pair, each sentence needs to
        # have at least 2 words
        if len(line) < 2:
            continue
        centers += line
        for i in range(len(line)):  # Context window centered at `i`
            window_size = random.randint(1, max_window_size)
            indices = list(range(max(0, i - window_size),
                                 min(len(line), i + 1 + window_size)))
            # Exclude the center word from the context words
            indices.remove(i)
            contexts.append([line[idx] for idx in indices])
    return centers, contexts

In [45]:
tiny_dataset = [list(range(7)), list(range(7, 10))]
print('dataset', tiny_dataset)
for center, context in zip(*get_centers_and_contexts(tiny_dataset, 4)):
    print('center', center, 'has contexts', context)

dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]
center 0 has contexts [1, 2]
center 1 has contexts [0, 2, 3]
center 2 has contexts [0, 1, 3, 4, 5, 6]
center 3 has contexts [0, 1, 2, 4, 5, 6]
center 4 has contexts [0, 1, 2, 3, 5, 6]
center 5 has contexts [2, 3, 4, 6]
center 6 has contexts [5]
center 7 has contexts [8, 9]
center 8 has contexts [7, 9]
center 9 has contexts [7, 8]


In [46]:
all_centers, all_contexts = get_centers_and_contexts(corpus, 5)
f'# center-context pairs: {sum([len(contexts) for contexts in all_contexts])}'

'# center-context pairs: 1689571'

In [50]:
len(all_centers) == len(all_contexts)

True

# Negative sampling

In [None]:
class RandomGenerator:
    """Randomly draw among {1, ..., n} according to n sampling weights."""
    def __init__(self, sampling_weights):
        # Exclude
        self.population = list(range(1, len(sampling_weights) + 1))
        self.sampling_weights = sampling_weights
        self.candidates = []
        self.i = 0

    def draw(self):
        if self.i == len(self.candidates):
            # Cache `k` random sampling results
            self.candidates = random.choices(
                self.population, self.sampling_weights, k=10000)
            self.i = 0
        self.i += 1
        return self.candidates[self.i - 1]