In [1]:
import numpy as np
import torch
from collections import Counter
from torch.utils.data import Dataset, DataLoader
import random

In [2]:
class DataPipeline(Dataset):
    def __init__(self, filename,window_size = 4,min_freq=1,vocab=None,neg_words=5):
        self.data = self.read_data(filename)
        self.neg_words = neg_words
        self.window_size = window_size
        if vocab is None:
            self.vocab, self.ind2vocab,self.word_count = self.build_vocab(self.data)
        else:
            self.vocab = vocab
            self.ind2vocab = {v: k for k, v in vocab.items()}
            self.word_count = self.get_word_count(vocab,self.data)
        self.neg_sampling_table = self.__create_neg_sampling_table()
        self.sub_sampling_table = self.__create_sub_sampling_table()

    def get_vocab(self):
        return self.vocab

    def read_data(self, filename):
        data = []
        with open(filename, 'r') as f:
            for line in f.readlines():
                e = line.strip()
                data.append(e.split())
        return data
    
    def get_word_count(self,vocab,data):
        word_count = {0: 0}
        for line in data:
            for word in line:
                if word in vocab:
                    word_count[vocab[word]] += 1
                else:
                    word_count[0] += 1
        return word_count
    
    def most_common(self,n):
        counter = Counter(self.word_count)
        common = counter.most_common(n)
        ind_freq = dict(common)
        # convert to word frequency
        word_freq = {}
        for ind in ind_freq:
            word_freq[self.ind2vocab[ind]] = ind_freq[ind]
        return word_freq

    def build_vocab(self, data,min_freq=1):
        word_set = {}
        for line in data:
            for word in line:
                if word not in word_set:
                    word_set[word]=1
                else:
                    word_set[word]+=1
        # sort the vocab
        word_list = sorted(list(word_set))
        word_count = {0: 1}
        vocab_dict = {"<unk>": 0}
        i=1
        for word in word_list:
            if word_set[word] >= min_freq:
                vocab_dict[word] = i
                word_count[i] = word_set[word]
                i+=1
            else:
                word_count[0] += word_set[word]
        ind2word = {v: k for k, v in vocab_dict.items()}
        return vocab_dict, ind2word, word_count

    def total_count(self):
        return sum(self.word_count.values())

    
    def __create_sub_sampling_table(self, threshold=1e-5):
        word_freq = np.array(list(self.word_count.values()))
        word_freq = word_freq / np.sum(word_freq)
        sub_sampling_table = ((np.sqrt(word_freq / threshold) + 1) * (threshold / word_freq))
        return sub_sampling_table
    
    def is_sample_selected(self, idx):
        # return True if the word is selected
        return random.random() < self.sub_sampling_table[idx]
    
    def __create_neg_sampling_table(self, power=0.75, table_size =1e8):
        vocab_size = len(self.vocab)
        word_freq = np.array(list(self.word_count.values())) ** power
        word_freq = word_freq / np.sum(word_freq)
        count = np.round(word_freq * table_size)
        neg_sampling_table = []
        for i in range(vocab_size):
            neg_sampling_table += [i] * int(count[i])
        neg_sampling_table = np.array(neg_sampling_table)
        np.random.shuffle(neg_sampling_table)
        return neg_sampling_table.tolist()
    
    def get_negative_samples(self, target, k):
        delta = random.sample(self.neg_sampling_table, k)
        while target in delta:
            delta = random.sample(self.neg_sampling_table, k)
        return delta   

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        words = self.data[idx]
        if len(words) < self.window_size:
            raise Exception("Sentence length is less than window size")
        data = []
        start = self.window_size // 2
        for i in range(start, len(words) - start):
            target = self.vocab[words[i]]
            if not self.is_sample_selected(target):
                continue
            context = words[i - start: i] + words[i + 1: i + start + 1]
            context = [self.vocab[word] for word in context]
            neg_samples = self.get_negative_samples(target, self.neg_words)
            data.append((target, context, neg_samples))
        return data
    
    def __collate_fn(self,batches):
        target = []
        context = []
        neg_samples = []
        for sentence in batches:
            for t,c,n in sentence:
                target.append(t)
                context.append(c)
                neg_samples.append(n)
        return torch.LongTensor(target),torch.LongTensor(context),torch.LongTensor(neg_samples)

    def get_batches(self, batch_size):
        return DataLoader(self, batch_size=batch_size, shuffle=False,collate_fn=self.__collate_fn ,drop_last=True)


In [3]:
dataset = DataPipeline("../data/corpus_cleaned.txt")

In [7]:
dataset.word_count[dataset.vocab['k']]

742

In [4]:
dataset.most_common(20)

{'the': 568205,
 'and': 298569,
 'a': 268778,
 'of': 257422,
 'to': 225717,
 'is': 203966,
 'in': 155513,
 'this': 150864,
 'it': 150259,
 'i': 146230,
 'that': 113076,
 'movie': 94233,
 'for': 82555,
 'as': 79279,
 'was': 78510,
 'with': 75268,
 'but': 65319,
 'you': 64244,
 'film': 60895,
 'on': 60665}

In [5]:
len(dataset.neg_sampling_table)

100008668

In [7]:
dataset.word_count[dataset.vocab["noodles"]]

5

In [8]:
dataset.sub_sampling_table[dataset.vocab["of"]]

0.020088256634534823

In [9]:
words=[ dataset.ind2vocab[idx] for idx in dataset.get_negative_samples(dataset.vocab["the"], 6)]
print('Negative samples = freq :')
for word in words :
    print('{} = {}'.format(word,dataset.word_count[dataset.vocab[word]]))

Negative samples = freq :
get = 15939
plays = 4213
actually = 5119
book = 6687
not = 52474
master = 1127


In [14]:
window_size = 5
words = "a b c d e f g h i j k l m n o p q r s t u v w x y z".split()
start = window_size // 2

In [15]:
for i in range(start, len(words) - start):
    context = words[i - start: i] + words[i + 1: i + start + 1]
    target = words[i]
    print('Target: {} \t Context: {}'.format(target, context))

Target: c 	 Context: ['a', 'b', 'd', 'e']
Target: d 	 Context: ['b', 'c', 'e', 'f']
Target: e 	 Context: ['c', 'd', 'f', 'g']
Target: f 	 Context: ['d', 'e', 'g', 'h']
Target: g 	 Context: ['e', 'f', 'h', 'i']
Target: h 	 Context: ['f', 'g', 'i', 'j']
Target: i 	 Context: ['g', 'h', 'j', 'k']
Target: j 	 Context: ['h', 'i', 'k', 'l']
Target: k 	 Context: ['i', 'j', 'l', 'm']
Target: l 	 Context: ['j', 'k', 'm', 'n']
Target: m 	 Context: ['k', 'l', 'n', 'o']
Target: n 	 Context: ['l', 'm', 'o', 'p']
Target: o 	 Context: ['m', 'n', 'p', 'q']
Target: p 	 Context: ['n', 'o', 'q', 'r']
Target: q 	 Context: ['o', 'p', 'r', 's']
Target: r 	 Context: ['p', 'q', 's', 't']
Target: s 	 Context: ['q', 'r', 't', 'u']
Target: t 	 Context: ['r', 's', 'u', 'v']
Target: u 	 Context: ['s', 't', 'v', 'w']
Target: v 	 Context: ['t', 'u', 'w', 'x']
Target: w 	 Context: ['u', 'v', 'x', 'y']
Target: x 	 Context: ['v', 'w', 'y', 'z']


In [5]:
sample0=dataset.__getitem__(0)
for var in sample0:
    t,c,n = var
    t=dataset.ind2vocab[t]
    c=[dataset.ind2vocab[idx] for idx in c]
    n=[dataset.ind2vocab[idx] for idx in n]
    print('Target: {} \t Context: {} \t Negative Samples: {}'.format(t,c,n))

Target: tips 	 Context: ['some', 'great', 'as', 'always'] 	 Negative Samples: ['others', 'fun', 'life', 'tableaux', 'album']
Target: helping 	 Context: ['and', 'is', 'me', 'to'] 	 Negative Samples: ['bruce', 'insert', 'or', 'effective', 'day']
Target: eats 	 Context: ['my', 'good', 'collection', 'i'] 	 Negative Samples: ['but', 'are', 'gorgeousness', 'take', 'least']
Target: collection 	 Context: ['good', 'eats', 'i', 'havent'] 	 Negative Samples: ['for', 'im', 'will', 'of', 'drunk']
Target: any 	 Context: ['havent', 'tried', 'of', 'the'] 	 Negative Samples: ['kimer', 'i', 'fi', 'myself', 'sad']
Target: recipes 	 Context: ['of', 'the', 'yet', 'but'] 	 Negative Samples: ['portrayed', 'recalls', 'cd', 'and', 'bissetdirector']
Target: i 	 Context: ['yet', 'but', 'will', 'soon'] 	 Negative Samples: ['base', 'new', 'the', 'scene', 'ie']
Target: lovely 	 Context: ['its', 'just', 'to', 'let'] 	 Negative Samples: ['this', 'hasnt', 'fiction', 'on', 'italthough']
Target: alton 	 Context: ['to', 

In [6]:
batches = dataset.get_batches(128)

In [7]:
for batch in batches:
    x = 0
    target,context,neg_samples = batch
    print(target.shape,context.shape,neg_samples.shape)
    break

torch.Size([2088]) torch.Size([2088, 4]) torch.Size([2088, 5])
