In [1]:
import json
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator, Vocab
from nltk.stem import PorterStemmer
from typing import List
from torch import nn
import torch
from einops import rearrange
import random
from tqdm import tqdm
from nltk.stem import WordNetLemmatizer
from nltk import pos_tag
from nltk.corpus import wordnet
import re

In [2]:
wnl = WordNetLemmatizer()
tokenizer = get_tokenizer('spacy')
ps = PorterStemmer()

In [3]:
def get_wordnet_pos(tag):
    if tag.startswith('J'):
        return wordnet.ADJ
    elif tag.startswith('V'):
        return wordnet.VERB
    elif tag.startswith('N'):
        return wordnet.NOUN
    elif tag.startswith('R'):
        return wordnet.ADV
    else:
        return None

In [4]:
with open('docs.json') as f:
    contents = f.read()
    docs = json.loads(contents)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gram = 9
bz = 256
epochs = 100
emb_dim = 128
dims = [256,]

In [6]:
def yield_tokens(docs):
    for doc in docs:
        if "abstract" in doc:
            for abstract in doc["abstract"]:
                if "text" in abstract:
                    # sen = re.sub('[^a-z]+', " ", abstract["text"].lower())
                    # sen = re.sub('[-]+', " ", abstract["text"].lower())
                    sen = abstract["text"].lower()
                    # yield [ps.stem(w) for w in tokenizer(sen)]
                    lemmas_sent = []
                    for tag in pos_tag(tokenizer(sen)):
                        wordnet_pos = get_wordnet_pos(tag[1]) or wordnet.NOUN
                        lemmas_sent.append(wnl.lemmatize(tag[0], pos=wordnet_pos)) 
                    yield lemmas_sent

vocab = build_vocab_from_iterator(yield_tokens(docs), specials=["<unk>"], min_freq=10)
vocab.set_default_index(vocab["<unk>"])

vocab.__len__()

5436

In [7]:
datas = []
for ws in yield_tokens(docs):
    for i in range(n_gram//2, len(ws) - n_gram//2):
        ctx = [*ws[i-n_gram//2:i], *ws[i+1:i+1+n_gram//2],]
        tar = ws[i]
        datas.append((ctx, tar))

In [8]:
class CBOW(torch.nn.Module):
    def __init__(self, n_gram, emb_dim, dims, vocab,):
        super(CBOW, self).__init__()
        self.vocab = vocab
        self.n_gram = n_gram
        
        self.emb_book = nn.Embedding(vocab.__len__(), emb_dim)
        
        # dims = [emb_dim*(n_gram-1), *dims]
        # dims = [emb_dim*2, *dims]
        dims = [emb_dim, *dims]
        net = []
        for i in range(len(dims)-1):
            net.append(nn.Linear(dims[i], dims[i+1]))
            net.append(nn.ReLU(),)
            
        net.append(nn.Linear(dims[-1], vocab.__len__()),)
        net.append(nn.LogSoftmax(dim = -1),)
            
        self.net = nn.ModuleList(net)

    def forward(self, ctx_idx):
        inp = self.emb_book(ctx_idx)
        out = inp.mean(1)
        # out = rearrange(inp, "b n d -> b (n d)")
        # f, b = torch.chunk(inp, 2, 1)
        # out = torch.cat([f.mean(1), b.mean(1)], 1)
        for layer in self.net:
            out = layer(out)
        
        return out

In [9]:
model = CBOW(n_gram, emb_dim, dims, vocab).to(device)
loss_fn = nn.NLLLoss()

optim = torch.optim.Adamax([*model.parameters(),], lr=0.01)

In [10]:
for epoch in range(epochs):
    total_loss = 0
    random.shuffle(datas)
    bnum = len(datas)//bz + (1 if len(datas)%bz!=0 else 0)
    for bidx in tqdm(range(bnum)):
        ctx = []
        tar = []
        for _ctx, _tar in datas[bidx*bz:(bidx+1)*bz]:
            ctx.append(vocab(_ctx))
            tar.append(_tar)
        
        ctx = torch.tensor(ctx, dtype=torch.int64, device=device)
        tar = torch.tensor(vocab(tar), dtype=torch.int64, device=device)
        
        out = model(ctx)
        loss = loss_fn(out, tar)
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        total_loss += loss.item()*len(tar)
        
    print("epoch: {}, loss: {}".format(epoch, total_loss/len(datas)))
    torch.save({
        'epoch': epoch,
        'net_state_dict': model.state_dict(),
        'opt_state_dict': optim.state_dict(),
        }, "./cbow-ex7.pt")

100%|██████████| 5721/5721 [00:16<00:00, 353.06it/s]


epoch: 0, loss: 5.348953632015232


100%|██████████| 5721/5721 [00:17<00:00, 325.56it/s]


epoch: 1, loss: 4.814663781334909


100%|██████████| 5721/5721 [00:15<00:00, 359.93it/s]


epoch: 2, loss: 4.576898402275445


100%|██████████| 5721/5721 [00:15<00:00, 362.19it/s]


epoch: 3, loss: 4.4377324101586515


100%|██████████| 5721/5721 [00:15<00:00, 361.85it/s]


epoch: 4, loss: 4.343841675562327


100%|██████████| 5721/5721 [00:15<00:00, 361.15it/s]


epoch: 5, loss: 4.275128244890957


100%|██████████| 5721/5721 [00:15<00:00, 360.57it/s]


epoch: 6, loss: 4.220163804706619


100%|██████████| 5721/5721 [00:16<00:00, 353.79it/s]


epoch: 7, loss: 4.174401545160704


100%|██████████| 5721/5721 [00:16<00:00, 350.29it/s]


epoch: 8, loss: 4.135870290434208


100%|██████████| 5721/5721 [00:15<00:00, 365.40it/s]


epoch: 9, loss: 4.102689714281856


100%|██████████| 5721/5721 [00:15<00:00, 364.89it/s]


epoch: 10, loss: 4.072431338695374


100%|██████████| 5721/5721 [00:15<00:00, 362.79it/s]


epoch: 11, loss: 4.046421098595867


100%|██████████| 5721/5721 [00:15<00:00, 360.87it/s]


epoch: 12, loss: 4.023006759180928


100%|██████████| 5721/5721 [00:15<00:00, 363.09it/s]


epoch: 13, loss: 4.001461083573043


100%|██████████| 5721/5721 [00:15<00:00, 367.42it/s]


epoch: 14, loss: 3.981839149533297


100%|██████████| 5721/5721 [00:15<00:00, 360.19it/s]


epoch: 15, loss: 3.963749521711268


100%|██████████| 5721/5721 [00:15<00:00, 364.24it/s]


epoch: 16, loss: 3.9470831931903936


100%|██████████| 5721/5721 [00:15<00:00, 359.47it/s]


epoch: 17, loss: 3.9323333183859077


100%|██████████| 5721/5721 [00:15<00:00, 362.77it/s]


epoch: 18, loss: 3.9178892173667967


100%|██████████| 5721/5721 [00:16<00:00, 355.44it/s]


epoch: 19, loss: 3.9043280507658302


100%|██████████| 5721/5721 [00:16<00:00, 341.92it/s]


epoch: 20, loss: 3.892010215413391


100%|██████████| 5721/5721 [00:16<00:00, 348.50it/s]


epoch: 21, loss: 3.8799874689432783


100%|██████████| 5721/5721 [00:16<00:00, 340.25it/s]


epoch: 22, loss: 3.8696276128215206


100%|██████████| 5721/5721 [00:17<00:00, 333.04it/s]


epoch: 23, loss: 3.8594869110546446


100%|██████████| 5721/5721 [00:17<00:00, 318.91it/s]


epoch: 24, loss: 3.8497627250972455


100%|██████████| 5721/5721 [00:17<00:00, 319.10it/s]


epoch: 25, loss: 3.840381315954353


100%|██████████| 5721/5721 [00:18<00:00, 316.47it/s]


epoch: 26, loss: 3.8313350547255163


100%|██████████| 5721/5721 [00:18<00:00, 311.43it/s]


epoch: 27, loss: 3.823241015197463


100%|██████████| 5721/5721 [00:16<00:00, 344.40it/s]


epoch: 28, loss: 3.8154992051588774


100%|██████████| 5721/5721 [00:16<00:00, 349.31it/s]


epoch: 29, loss: 3.807963698285075


100%|██████████| 5721/5721 [00:16<00:00, 340.59it/s]


epoch: 30, loss: 3.8014579019084804


100%|██████████| 5721/5721 [00:16<00:00, 357.12it/s]


epoch: 31, loss: 3.795020915285168


100%|██████████| 5721/5721 [00:15<00:00, 363.39it/s]


epoch: 32, loss: 3.7880459284578643


100%|██████████| 5721/5721 [00:16<00:00, 356.21it/s]


epoch: 33, loss: 3.781581576854865


100%|██████████| 5721/5721 [00:15<00:00, 360.64it/s]


epoch: 34, loss: 3.7765495696848346


100%|██████████| 5721/5721 [00:14<00:00, 384.50it/s]


epoch: 35, loss: 3.7709128931944575


100%|██████████| 5721/5721 [00:15<00:00, 359.34it/s]


epoch: 36, loss: 3.765351944519325


100%|██████████| 5721/5721 [00:18<00:00, 303.47it/s]


epoch: 37, loss: 3.7601921204088637


100%|██████████| 5721/5721 [00:19<00:00, 292.30it/s]


epoch: 38, loss: 3.755871976642444


100%|██████████| 5721/5721 [00:18<00:00, 312.06it/s]


epoch: 39, loss: 3.7505404363613906


100%|██████████| 5721/5721 [00:18<00:00, 311.12it/s]


epoch: 40, loss: 3.7466218509974123


100%|██████████| 5721/5721 [00:17<00:00, 325.51it/s]


epoch: 41, loss: 3.7421603315508976


100%|██████████| 5721/5721 [00:19<00:00, 299.61it/s]


epoch: 42, loss: 3.737739383801427


100%|██████████| 5721/5721 [00:18<00:00, 304.97it/s]


epoch: 43, loss: 3.733703624447117


100%|██████████| 5721/5721 [00:19<00:00, 295.99it/s]


epoch: 44, loss: 3.7299892450828107


100%|██████████| 5721/5721 [00:18<00:00, 301.33it/s]


epoch: 45, loss: 3.726507551095121


100%|██████████| 5721/5721 [00:19<00:00, 297.32it/s]


epoch: 46, loss: 3.723501945380433


100%|██████████| 5721/5721 [00:19<00:00, 295.75it/s]


epoch: 47, loss: 3.720066530812107


100%|██████████| 5721/5721 [00:18<00:00, 307.54it/s]


epoch: 48, loss: 3.7168070926751224


100%|██████████| 5721/5721 [00:18<00:00, 303.52it/s]


epoch: 49, loss: 3.713564472317365


100%|██████████| 5721/5721 [00:19<00:00, 291.19it/s]


epoch: 50, loss: 3.7097837373820357


100%|██████████| 5721/5721 [00:19<00:00, 297.16it/s]


epoch: 51, loss: 3.7073317192885833


100%|██████████| 5721/5721 [00:20<00:00, 285.40it/s]


epoch: 52, loss: 3.7042382791968573


100%|██████████| 5721/5721 [00:19<00:00, 299.54it/s]


epoch: 53, loss: 3.701914160431845


100%|██████████| 5721/5721 [00:19<00:00, 290.96it/s]


epoch: 54, loss: 3.6986967163171283


100%|██████████| 5721/5721 [00:19<00:00, 291.61it/s]


epoch: 55, loss: 3.696828457063187


100%|██████████| 5721/5721 [00:18<00:00, 302.58it/s]


epoch: 56, loss: 3.694504180558595


100%|██████████| 5721/5721 [00:18<00:00, 313.08it/s]


epoch: 57, loss: 3.692340897544747


100%|██████████| 5721/5721 [00:18<00:00, 306.30it/s]


epoch: 58, loss: 3.6900582866486413


100%|██████████| 5721/5721 [00:18<00:00, 308.52it/s]


epoch: 59, loss: 3.6880674603475647


100%|██████████| 5721/5721 [00:18<00:00, 301.63it/s]


epoch: 60, loss: 3.685937945709804


100%|██████████| 5721/5721 [00:18<00:00, 310.55it/s]


epoch: 61, loss: 3.6836826688809503


100%|██████████| 5721/5721 [00:15<00:00, 358.15it/s]


epoch: 62, loss: 3.682346861808462


100%|██████████| 5721/5721 [00:16<00:00, 348.54it/s]


epoch: 63, loss: 3.679861130511872


100%|██████████| 5721/5721 [00:16<00:00, 349.16it/s]


epoch: 64, loss: 3.6786120517737455


100%|██████████| 5721/5721 [00:15<00:00, 360.59it/s]


epoch: 65, loss: 3.6762323112012134


100%|██████████| 5721/5721 [00:18<00:00, 315.73it/s]


epoch: 66, loss: 3.6747506681022033


100%|██████████| 5721/5721 [00:16<00:00, 355.91it/s]


epoch: 67, loss: 3.6734959956593074


100%|██████████| 5721/5721 [00:17<00:00, 323.59it/s]


epoch: 68, loss: 3.671547212316652


100%|██████████| 5721/5721 [00:18<00:00, 317.47it/s]


epoch: 69, loss: 3.670130497201945


100%|██████████| 5721/5721 [00:17<00:00, 326.66it/s]


epoch: 70, loss: 3.6684637745692172


100%|██████████| 5721/5721 [00:17<00:00, 318.26it/s]


epoch: 71, loss: 3.6673024248040025


100%|██████████| 5721/5721 [00:19<00:00, 292.50it/s]


epoch: 72, loss: 3.665493994462236


100%|██████████| 5721/5721 [00:20<00:00, 277.39it/s]


epoch: 73, loss: 3.6641359910116007


100%|██████████| 5721/5721 [00:19<00:00, 290.23it/s]


epoch: 74, loss: 3.663228160795364


100%|██████████| 5721/5721 [00:19<00:00, 289.51it/s]


epoch: 75, loss: 3.6617797338789906


100%|██████████| 5721/5721 [00:20<00:00, 282.64it/s]


epoch: 76, loss: 3.6594135938955348


100%|██████████| 5721/5721 [00:21<00:00, 270.17it/s]


epoch: 77, loss: 3.658961443718951


100%|██████████| 5721/5721 [00:20<00:00, 280.85it/s]


epoch: 78, loss: 3.658246751303358


100%|██████████| 5721/5721 [00:18<00:00, 305.14it/s]


epoch: 79, loss: 3.6572782190933486


100%|██████████| 5721/5721 [00:19<00:00, 292.71it/s]


epoch: 80, loss: 3.6557021189548533


100%|██████████| 5721/5721 [00:20<00:00, 282.99it/s]


epoch: 81, loss: 3.6542675999146037


100%|██████████| 5721/5721 [00:19<00:00, 293.86it/s]


epoch: 82, loss: 3.6533647130020235


100%|██████████| 5721/5721 [00:19<00:00, 297.38it/s]


epoch: 83, loss: 3.6531384977827477


100%|██████████| 5721/5721 [00:19<00:00, 297.48it/s]


epoch: 84, loss: 3.6516797967291383


100%|██████████| 5721/5721 [00:18<00:00, 315.59it/s]


epoch: 85, loss: 3.6510847999256


100%|██████████| 5721/5721 [00:17<00:00, 318.96it/s]


epoch: 86, loss: 3.6502814730500814


100%|██████████| 5721/5721 [00:15<00:00, 363.65it/s]


epoch: 87, loss: 3.649917668119259


100%|██████████| 5721/5721 [00:16<00:00, 356.98it/s]


epoch: 88, loss: 3.649205558323101


100%|██████████| 5721/5721 [00:15<00:00, 362.35it/s]


epoch: 89, loss: 3.6478470688674096


100%|██████████| 5721/5721 [00:15<00:00, 364.39it/s]


epoch: 90, loss: 3.647121413969154


100%|██████████| 5721/5721 [00:15<00:00, 367.25it/s]


epoch: 91, loss: 3.646779708064133


100%|██████████| 5721/5721 [00:15<00:00, 366.27it/s]


epoch: 92, loss: 3.6461130573559917


100%|██████████| 5721/5721 [00:15<00:00, 364.44it/s]


epoch: 93, loss: 3.645498676483153


100%|██████████| 5721/5721 [00:16<00:00, 352.71it/s]


epoch: 94, loss: 3.6441335691072


100%|██████████| 5721/5721 [00:17<00:00, 334.03it/s]


epoch: 95, loss: 3.6440565353333967


100%|██████████| 5721/5721 [00:18<00:00, 313.38it/s]


epoch: 96, loss: 3.6437468829217776


100%|██████████| 5721/5721 [00:17<00:00, 328.33it/s]


epoch: 97, loss: 3.643131605773126


100%|██████████| 5721/5721 [00:16<00:00, 357.32it/s]


epoch: 98, loss: 3.6424955940593136


100%|██████████| 5721/5721 [00:17<00:00, 330.37it/s]


epoch: 99, loss: 3.6431357230354418


In [11]:
import io

out_v = io.open('vectors-ex7.tsv', 'w', encoding='utf-8')
out_m = io.open('metadata-ex7.tsv', 'w', encoding='utf-8')

emb_book = model.emb_book.weight.detach().cpu().numpy()
for idx, w in enumerate(vocab.get_itos()):
    emb = emb_book[idx]
    out_v.write('\t'.join([str(x) for x in emb]) + "\n")
    out_m.write(w + "\n")
out_v.close()
out_m.close()