In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud

from collections import Counter
import numpy as np
import random
import math

#import pandas as pd
#import scipy
#import sklearn
#from sklearn.metrics.pairwise import cosine_similarity

USE_CUDA = torch.cuda.is_available()

random.seed(1)
np.random.seed(1)
torch.manual_seed(1)

#几个超参数
C = 3
K = 100
NUM_EPOCHS = 2
MAX_VOCAB_SIZE = 30000
BATCH_SIZE = 129
LEARNING_RATE = 0.2
EMBEDDING_SIZE = 100

#用于切片，按照空格切片
def word_tokenize(text):
    return text.split()

In [9]:
with open("/Users/mengfanhui/Documents/GitR/Code-Practice/python/Pytorch practice/text8/text8.dev.txt", "r") as fin:
    text = fin.read()
text = text.split()
#拿到最常用的单词以及次数
vocab = dict(Counter(text).most_common(MAX_VOCAB_SIZE - 1))
vocab["<unk>"] = len(text) - np.sum(list(vocab.values()))
vocab

{'the': 51066,
 'of': 28974,
 'one': 26357,
 'and': 20366,
 'in': 18273,
 'nine': 17217,
 'zero': 15792,
 'a': 15661,
 'to': 14692,
 'two': 11324,
 'is': 8534,
 'eight': 7856,
 'five': 7126,
 'three': 6728,
 'four': 6635,
 'six': 6523,
 'seven': 6092,
 'as': 6027,
 's': 5724,
 'for': 5537,
 'by': 5369,
 'was': 5155,
 'that': 4913,
 'with': 4574,
 'on': 4469,
 'from': 3501,
 'are': 3379,
 'it': 3189,
 'or': 2962,
 'an': 2777,
 'this': 2684,
 'be': 2526,
 'which': 2489,
 'at': 2473,
 'his': 2340,
 'also': 2066,
 'he': 1940,
 'not': 1903,
 'has': 1885,
 'were': 1779,
 'b': 1748,
 'american': 1693,
 'have': 1687,
 'd': 1644,
 'its': 1583,
 'other': 1556,
 'but': 1529,
 'first': 1482,
 'their': 1430,
 'some': 1323,
 'most': 1251,
 'all': 1247,
 'had': 1243,
 'new': 1187,
 'they': 1163,
 'more': 1152,
 'many': 1114,
 'been': 1108,
 'can': 1100,
 'such': 1064,
 'who': 1058,
 'into': 1026,
 'there': 1009,
 'world': 997,
 'm': 985,
 'after': 972,
 'used': 945,
 'may': 939,
 'when': 888,
 'only'

In [11]:
idx_to_word = [word for word in vocab.keys()]
word_to_idx = {word:i for i, word in enumerate(idx_to_word)}

In [14]:
# idx_to_word[:100]
list(word_to_idx.items())[:100]

[('the', 0),
 ('of', 1),
 ('one', 2),
 ('and', 3),
 ('in', 4),
 ('nine', 5),
 ('zero', 6),
 ('a', 7),
 ('to', 8),
 ('two', 9),
 ('is', 10),
 ('eight', 11),
 ('five', 12),
 ('three', 13),
 ('four', 14),
 ('six', 15),
 ('seven', 16),
 ('as', 17),
 ('s', 18),
 ('for', 19),
 ('by', 20),
 ('was', 21),
 ('that', 22),
 ('with', 23),
 ('on', 24),
 ('from', 25),
 ('are', 26),
 ('it', 27),
 ('or', 28),
 ('an', 29),
 ('this', 30),
 ('be', 31),
 ('which', 32),
 ('at', 33),
 ('his', 34),
 ('also', 35),
 ('he', 36),
 ('not', 37),
 ('has', 38),
 ('were', 39),
 ('b', 40),
 ('american', 41),
 ('have', 42),
 ('d', 43),
 ('its', 44),
 ('other', 45),
 ('but', 46),
 ('first', 47),
 ('their', 48),
 ('some', 49),
 ('most', 50),
 ('all', 51),
 ('had', 52),
 ('new', 53),
 ('they', 54),
 ('more', 55),
 ('many', 56),
 ('been', 57),
 ('can', 58),
 ('such', 59),
 ('who', 60),
 ('into', 61),
 ('there', 62),
 ('world', 63),
 ('m', 64),
 ('after', 65),
 ('used', 66),
 ('may', 67),
 ('when', 68),
 ('only', 69),
 ('yea

In [16]:
word_counts = np.array([count for count in vocab.values()], dtype = np.float32)
word_freqs = word_counts / np.sum(word_counts)
word_freqs = word_freqs ** (3. / 4.)
word_freqs = word_counts / np.sum(word_counts)
VOCAB_SIZE = len(idx_to_word)
VOCAB_SIZE

30000

In [19]:
class WordEmbeddingDataset(tud.Dataset):
    def __init__(self, text, word_to_idx, idx_to_word, word_freqs, word_counts):
        super(WordEmbeedingDataset, self).__init()__
        self.text_encoded = [word_to_idx.get(word, word_to_idx["<unk>"]) for word in text]
        self.text_encoded = torch.LongTensor(self.text_encoded)
        self.word_to_idx = word_to_idx
        self.idx