In [1]:
import json
import spacy
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator, Vocab
from typing import List
from torch import nn
import torch
from einops import rearrange
import random
from tqdm import tqdm

In [2]:
with open('docs.json') as f:
    contents = f.read()
    docs = json.loads(contents)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gram = 7
bz = 256
epochs = 100
emb_dim = 64
dims = [256,]

In [4]:
def yield_tokens(docs):
    nlp = spacy.load("en_core_web_sm")
    for doc in docs:
        if "abstract" in doc:
            for abstract in doc["abstract"]:
                if "text" in abstract:
                    doc = nlp(abstract["text"])
                    assert doc.has_annotation("SENT_START")
                    for sent in doc.sents:
                        yield [token.lemma_.lower() for token in sent]

vocab = build_vocab_from_iterator(yield_tokens(docs), specials=["<unk>"], min_freq=10)
vocab.set_default_index(vocab["<unk>"])

vocab.__len__()

562

In [5]:
for ws in yield_tokens(docs):
    print(ws)
    break

['covid-19', ',', 'cause', 'by', 'sars', '-', 'cov-2', 'infection', ',', 'be', 'mild', 'to', 'moderate', 'in', 'the', 'majority', 'of', 'previously', 'healthy', 'individual', ',', 'but', 'can', 'cause', 'life', '-', 'threaten', 'disease', 'or', 'persistent', 'debilitate', 'symptom', 'in', 'some', 'case', '.']


In [6]:
datas = []
for ws in yield_tokens(docs):
    for i in range(n_gram//2, len(ws) - n_gram//2):
        ctx = [*ws[i-n_gram//2:i], *ws[i+1:i+1+n_gram//2],]
        tar = ws[i]
        datas.append((ctx, tar))

In [7]:
class CBOW(torch.nn.Module):
    def __init__(self, n_gram, emb_dim, dims, vocab,):
        super(CBOW, self).__init__()
        self.vocab = vocab
        self.n_gram = n_gram
        
        self.emb_book = nn.Embedding(vocab.__len__(), emb_dim)
        
        # dims = [emb_dim*(n_gram-1), *dims]
        # dims = [emb_dim*2, *dims]
        dims = [emb_dim, *dims]
        net = []
        for i in range(len(dims)-1):
            net.append(nn.Linear(dims[i], dims[i+1]))
            net.append(nn.ReLU(),)
            
        net.append(nn.Linear(dims[-1], vocab.__len__()),)
        net.append(nn.LogSoftmax(dim = -1),)
            
        self.net = nn.ModuleList(net)

    def forward(self, ctx_idx):
        inp = self.emb_book(ctx_idx)
        out = inp.mean(1)
        # out = rearrange(inp, "b n d -> b (n d)")
        # f, b = torch.chunk(inp, 2, 1)
        # out = torch.cat([f.mean(1), b.mean(1)], 1)
        for layer in self.net:
            out = layer(out)
        
        return out

In [8]:
model = CBOW(n_gram, emb_dim, dims, vocab).to(device)
loss_fn = nn.NLLLoss()

optim = torch.optim.Adamax([*model.parameters(),], lr=0.01)

In [9]:
for epoch in range(epochs):
    total_loss = 0
    random.shuffle(datas)
    bnum = len(datas)//bz + (1 if len(datas)%bz!=0 else 0)
    for bidx in tqdm(range(bnum)):
        ctx = []
        tar = []
        for _ctx, _tar in datas[bidx*bz:(bidx+1)*bz]:
            ctx.append(vocab(_ctx))
            tar.append(_tar)
        
        ctx = torch.tensor(ctx, dtype=torch.int64, device=device)
        tar = torch.tensor(vocab(tar), dtype=torch.int64, device=device)
        
        out = model(ctx)
        loss = loss_fn(out, tar)
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        total_loss += loss.item()*len(tar)
        
    print("epoch: {}, loss: {}".format(epoch, total_loss/len(datas)))
    torch.save({
        'epoch': epoch,
        'net_state_dict': model.state_dict(),
        'opt_state_dict': optim.state_dict(),
        }, "./cbow.pt")

100%|██████████| 114/114 [00:00<00:00, 290.94it/s]
 36%|███▌      | 41/114 [00:00<00:00, 408.84it/s]

epoch: 0, loss: 4.544087686358349


100%|██████████| 114/114 [00:00<00:00, 405.10it/s]
 35%|███▌      | 40/114 [00:00<00:00, 390.12it/s]

epoch: 1, loss: 4.041873703165209


100%|██████████| 114/114 [00:00<00:00, 394.26it/s]
 35%|███▌      | 40/114 [00:00<00:00, 391.37it/s]

epoch: 2, loss: 3.795629492892058


100%|██████████| 114/114 [00:00<00:00, 395.93it/s]
 36%|███▌      | 41/114 [00:00<00:00, 408.16it/s]

epoch: 3, loss: 3.595419913334275


100%|██████████| 114/114 [00:00<00:00, 404.05it/s]
 36%|███▌      | 41/114 [00:00<00:00, 404.97it/s]

epoch: 4, loss: 3.420881489602645


100%|██████████| 114/114 [00:00<00:00, 405.54it/s]
 36%|███▌      | 41/114 [00:00<00:00, 407.42it/s]

epoch: 5, loss: 3.26740734923995


100%|██████████| 114/114 [00:00<00:00, 398.67it/s]
 33%|███▎      | 38/114 [00:00<00:00, 372.65it/s]

epoch: 6, loss: 3.1304025680903607


100%|██████████| 114/114 [00:00<00:00, 383.24it/s]
 36%|███▌      | 41/114 [00:00<00:00, 408.13it/s]

epoch: 7, loss: 3.0030337451250086


100%|██████████| 114/114 [00:00<00:00, 401.69it/s]
 36%|███▌      | 41/114 [00:00<00:00, 407.90it/s]

epoch: 8, loss: 2.885787609011055


100%|██████████| 114/114 [00:00<00:00, 405.43it/s]
 34%|███▍      | 39/114 [00:00<00:00, 388.02it/s]

epoch: 9, loss: 2.7786259388464565


100%|██████████| 114/114 [00:00<00:00, 395.04it/s]
 36%|███▌      | 41/114 [00:00<00:00, 405.12it/s]

epoch: 10, loss: 2.676075854675452


100%|██████████| 114/114 [00:00<00:00, 382.06it/s]
 36%|███▌      | 41/114 [00:00<00:00, 402.32it/s]

epoch: 11, loss: 2.579493502367254


100%|██████████| 114/114 [00:00<00:00, 403.03it/s]
 35%|███▌      | 40/114 [00:00<00:00, 399.75it/s]

epoch: 12, loss: 2.4888816453879117


100%|██████████| 114/114 [00:00<00:00, 400.89it/s]
 35%|███▌      | 40/114 [00:00<00:00, 394.14it/s]

epoch: 13, loss: 2.4004221909677965


100%|██████████| 114/114 [00:00<00:00, 396.25it/s]
 36%|███▌      | 41/114 [00:00<00:00, 401.12it/s]

epoch: 14, loss: 2.3187749439543826


100%|██████████| 114/114 [00:00<00:00, 403.89it/s]
 36%|███▌      | 41/114 [00:00<00:00, 406.78it/s]

epoch: 15, loss: 2.2369396727206086


100%|██████████| 114/114 [00:00<00:00, 402.56it/s]
 37%|███▋      | 42/114 [00:00<00:00, 410.16it/s]

epoch: 16, loss: 2.1622982975420806


100%|██████████| 114/114 [00:00<00:00, 400.36it/s]
 35%|███▌      | 40/114 [00:00<00:00, 394.38it/s]

epoch: 17, loss: 2.0862851915684386


100%|██████████| 114/114 [00:00<00:00, 402.43it/s]
 36%|███▌      | 41/114 [00:00<00:00, 401.65it/s]

epoch: 18, loss: 2.0165004193534375


100%|██████████| 114/114 [00:00<00:00, 391.88it/s]
 36%|███▌      | 41/114 [00:00<00:00, 404.82it/s]

epoch: 19, loss: 1.9491170697854276


100%|██████████| 114/114 [00:00<00:00, 405.50it/s]
 35%|███▌      | 40/114 [00:00<00:00, 399.08it/s]

epoch: 20, loss: 1.8814393453209404


100%|██████████| 114/114 [00:00<00:00, 397.06it/s]
 36%|███▌      | 41/114 [00:00<00:00, 400.12it/s]

epoch: 21, loss: 1.8237640772327062


100%|██████████| 114/114 [00:00<00:00, 397.41it/s]
 35%|███▌      | 40/114 [00:00<00:00, 399.72it/s]

epoch: 22, loss: 1.764983522223086


100%|██████████| 114/114 [00:00<00:00, 402.06it/s]
 36%|███▌      | 41/114 [00:00<00:00, 403.94it/s]

epoch: 23, loss: 1.7066384937928205


100%|██████████| 114/114 [00:00<00:00, 402.61it/s]
 35%|███▌      | 40/114 [00:00<00:00, 392.18it/s]

epoch: 24, loss: 1.6509211499266174


100%|██████████| 114/114 [00:00<00:00, 398.44it/s]
 35%|███▌      | 40/114 [00:00<00:00, 398.41it/s]

epoch: 25, loss: 1.6021339170759732


100%|██████████| 114/114 [00:00<00:00, 399.27it/s]
 33%|███▎      | 38/114 [00:00<00:00, 372.86it/s]

epoch: 26, loss: 1.5530603676813666


100%|██████████| 114/114 [00:00<00:00, 385.72it/s]
 36%|███▌      | 41/114 [00:00<00:00, 402.92it/s]

epoch: 27, loss: 1.5031314147113706


100%|██████████| 114/114 [00:00<00:00, 396.01it/s]
 35%|███▌      | 40/114 [00:00<00:00, 391.24it/s]

epoch: 28, loss: 1.4564759402667482


100%|██████████| 114/114 [00:00<00:00, 393.91it/s]
 35%|███▌      | 40/114 [00:00<00:00, 399.44it/s]

epoch: 29, loss: 1.4163498937937875


100%|██████████| 114/114 [00:00<00:00, 398.48it/s]
 35%|███▌      | 40/114 [00:00<00:00, 399.01it/s]

epoch: 30, loss: 1.3735962229794039


100%|██████████| 114/114 [00:00<00:00, 395.80it/s]
 36%|███▌      | 41/114 [00:00<00:00, 402.05it/s]

epoch: 31, loss: 1.3377111350700384


100%|██████████| 114/114 [00:00<00:00, 390.79it/s]
 37%|███▋      | 42/114 [00:00<00:00, 411.00it/s]

epoch: 32, loss: 1.296961115212319


100%|██████████| 114/114 [00:00<00:00, 360.18it/s]
 36%|███▌      | 41/114 [00:00<00:00, 402.28it/s]

epoch: 33, loss: 1.2617170135750928


100%|██████████| 114/114 [00:00<00:00, 389.09it/s]
 35%|███▌      | 40/114 [00:00<00:00, 398.75it/s]

epoch: 34, loss: 1.223242855791441


100%|██████████| 114/114 [00:00<00:00, 386.65it/s]
 34%|███▍      | 39/114 [00:00<00:00, 389.04it/s]

epoch: 35, loss: 1.1969334763009887


100%|██████████| 114/114 [00:00<00:00, 384.59it/s]
 35%|███▌      | 40/114 [00:00<00:00, 399.53it/s]

epoch: 36, loss: 1.1614529800392044


100%|██████████| 114/114 [00:00<00:00, 390.74it/s]
 35%|███▌      | 40/114 [00:00<00:00, 399.04it/s]

epoch: 37, loss: 1.1363795178563418


100%|██████████| 114/114 [00:00<00:00, 371.44it/s]
 27%|██▋       | 31/114 [00:00<00:00, 309.17it/s]

epoch: 38, loss: 1.104283966439471


100%|██████████| 114/114 [00:00<00:00, 352.75it/s]
 35%|███▌      | 40/114 [00:00<00:00, 399.57it/s]

epoch: 39, loss: 1.0760371574281822


100%|██████████| 114/114 [00:00<00:00, 391.56it/s]
 33%|███▎      | 38/114 [00:00<00:00, 374.59it/s]

epoch: 40, loss: 1.0506739661357762


100%|██████████| 114/114 [00:00<00:00, 374.07it/s]
 36%|███▌      | 41/114 [00:00<00:00, 408.37it/s]

epoch: 41, loss: 1.0220311430638758


100%|██████████| 114/114 [00:00<00:00, 394.86it/s]
 35%|███▌      | 40/114 [00:00<00:00, 395.20it/s]

epoch: 42, loss: 0.9965362751575487


100%|██████████| 114/114 [00:00<00:00, 393.57it/s]
 36%|███▌      | 41/114 [00:00<00:00, 404.24it/s]

epoch: 43, loss: 0.9698903217271655


100%|██████████| 114/114 [00:00<00:00, 389.23it/s]
 36%|███▌      | 41/114 [00:00<00:00, 407.83it/s]

epoch: 44, loss: 0.9501978450434786


100%|██████████| 114/114 [00:00<00:00, 397.68it/s]
 36%|███▌      | 41/114 [00:00<00:00, 405.00it/s]

epoch: 45, loss: 0.9245525796448489


100%|██████████| 114/114 [00:00<00:00, 389.95it/s]
 35%|███▌      | 40/114 [00:00<00:00, 398.02it/s]

epoch: 46, loss: 0.8996769480394663


100%|██████████| 114/114 [00:00<00:00, 393.48it/s]
 31%|███       | 35/114 [00:00<00:00, 342.71it/s]

epoch: 47, loss: 0.8838046036940668


100%|██████████| 114/114 [00:00<00:00, 376.09it/s]
 35%|███▌      | 40/114 [00:00<00:00, 399.18it/s]

epoch: 48, loss: 0.8614261997484782


100%|██████████| 114/114 [00:00<00:00, 389.68it/s]
 35%|███▌      | 40/114 [00:00<00:00, 393.40it/s]

epoch: 49, loss: 0.8446740892426934


100%|██████████| 114/114 [00:00<00:00, 388.24it/s]
 35%|███▌      | 40/114 [00:00<00:00, 397.86it/s]

epoch: 50, loss: 0.8227870030661175


100%|██████████| 114/114 [00:00<00:00, 388.61it/s]
 36%|███▌      | 41/114 [00:00<00:00, 409.49it/s]

epoch: 51, loss: 0.8062937216270738


100%|██████████| 114/114 [00:00<00:00, 395.14it/s]
 36%|███▌      | 41/114 [00:00<00:00, 399.97it/s]

epoch: 52, loss: 0.7854082065429981


100%|██████████| 114/114 [00:00<00:00, 391.37it/s]
 34%|███▍      | 39/114 [00:00<00:00, 389.38it/s]

epoch: 53, loss: 0.7674620170039079


100%|██████████| 114/114 [00:00<00:00, 392.66it/s]
 36%|███▌      | 41/114 [00:00<00:00, 404.48it/s]

epoch: 54, loss: 0.7525552575355237


100%|██████████| 114/114 [00:00<00:00, 395.05it/s]
 36%|███▌      | 41/114 [00:00<00:00, 402.15it/s]

epoch: 55, loss: 0.7362828363351762


100%|██████████| 114/114 [00:00<00:00, 393.49it/s]
 36%|███▌      | 41/114 [00:00<00:00, 402.32it/s]

epoch: 56, loss: 0.7230854772916979


100%|██████████| 114/114 [00:00<00:00, 387.53it/s]
 35%|███▌      | 40/114 [00:00<00:00, 398.96it/s]

epoch: 57, loss: 0.7060646013943205


100%|██████████| 114/114 [00:00<00:00, 356.46it/s]
 36%|███▌      | 41/114 [00:00<00:00, 401.78it/s]

epoch: 58, loss: 0.6887535703325529


100%|██████████| 114/114 [00:00<00:00, 359.42it/s]
 36%|███▌      | 41/114 [00:00<00:00, 403.16it/s]

epoch: 59, loss: 0.6772300252731781


100%|██████████| 114/114 [00:00<00:00, 389.51it/s]
 28%|██▊       | 32/114 [00:00<00:00, 319.35it/s]

epoch: 60, loss: 0.6633447496427259


100%|██████████| 114/114 [00:00<00:00, 328.35it/s]
 32%|███▏      | 37/114 [00:00<00:00, 367.70it/s]

epoch: 61, loss: 0.64884153403998


100%|██████████| 114/114 [00:00<00:00, 364.62it/s]
 32%|███▏      | 37/114 [00:00<00:00, 366.77it/s]

epoch: 62, loss: 0.6323014854181688


100%|██████████| 114/114 [00:00<00:00, 364.53it/s]
 32%|███▏      | 36/114 [00:00<00:00, 358.15it/s]

epoch: 63, loss: 0.6195278958121085


100%|██████████| 114/114 [00:00<00:00, 356.36it/s]
 33%|███▎      | 38/114 [00:00<00:00, 372.16it/s]

epoch: 64, loss: 0.6103760754013771


100%|██████████| 114/114 [00:00<00:00, 367.62it/s]
 32%|███▏      | 37/114 [00:00<00:00, 366.70it/s]

epoch: 65, loss: 0.592734210363311


100%|██████████| 114/114 [00:00<00:00, 364.77it/s]
 33%|███▎      | 38/114 [00:00<00:00, 370.41it/s]

epoch: 66, loss: 0.5846668904034049


100%|██████████| 114/114 [00:00<00:00, 342.27it/s]
 32%|███▏      | 37/114 [00:00<00:00, 362.49it/s]

epoch: 67, loss: 0.5696833448936295


100%|██████████| 114/114 [00:00<00:00, 361.52it/s]
 32%|███▏      | 37/114 [00:00<00:00, 367.44it/s]

epoch: 68, loss: 0.5600081634900904


100%|██████████| 114/114 [00:00<00:00, 366.59it/s]
 33%|███▎      | 38/114 [00:00<00:00, 371.40it/s]

epoch: 69, loss: 0.5489050003960033


100%|██████████| 114/114 [00:00<00:00, 321.45it/s]
 32%|███▏      | 37/114 [00:00<00:00, 362.93it/s]

epoch: 70, loss: 0.5383706786055935


100%|██████████| 114/114 [00:00<00:00, 362.90it/s]
 33%|███▎      | 38/114 [00:00<00:00, 374.52it/s]

epoch: 71, loss: 0.5270210590301694


100%|██████████| 114/114 [00:00<00:00, 366.35it/s]
 33%|███▎      | 38/114 [00:00<00:00, 376.16it/s]

epoch: 72, loss: 0.5175807101366576


100%|██████████| 114/114 [00:00<00:00, 369.32it/s]
 32%|███▏      | 36/114 [00:00<00:00, 356.47it/s]

epoch: 73, loss: 0.5060111530642789


100%|██████████| 114/114 [00:00<00:00, 352.33it/s]
 29%|██▉       | 33/114 [00:00<00:00, 320.15it/s]

epoch: 74, loss: 0.49440483645075367


100%|██████████| 114/114 [00:00<00:00, 343.62it/s]
 32%|███▏      | 36/114 [00:00<00:00, 357.75it/s]

epoch: 75, loss: 0.4865475107572216


100%|██████████| 114/114 [00:00<00:00, 359.24it/s]
 32%|███▏      | 37/114 [00:00<00:00, 369.88it/s]

epoch: 76, loss: 0.4790313203378089


100%|██████████| 114/114 [00:00<00:00, 353.76it/s]
 32%|███▏      | 37/114 [00:00<00:00, 368.23it/s]

epoch: 77, loss: 0.46879508585997376


100%|██████████| 114/114 [00:00<00:00, 366.85it/s]
 32%|███▏      | 37/114 [00:00<00:00, 369.58it/s]

epoch: 78, loss: 0.45765439053349827


100%|██████████| 114/114 [00:00<00:00, 363.29it/s]
 33%|███▎      | 38/114 [00:00<00:00, 376.43it/s]

epoch: 79, loss: 0.45165459913882405


100%|██████████| 114/114 [00:00<00:00, 360.50it/s]
 32%|███▏      | 36/114 [00:00<00:00, 352.59it/s]

epoch: 80, loss: 0.44174769091551114


100%|██████████| 114/114 [00:00<00:00, 337.04it/s]
 33%|███▎      | 38/114 [00:00<00:00, 369.95it/s]

epoch: 81, loss: 0.4327101595151856


100%|██████████| 114/114 [00:00<00:00, 362.26it/s]
 33%|███▎      | 38/114 [00:00<00:00, 372.27it/s]

epoch: 82, loss: 0.4228615710429486


100%|██████████| 114/114 [00:00<00:00, 339.85it/s]
 32%|███▏      | 36/114 [00:00<00:00, 357.71it/s]

epoch: 83, loss: 0.416811615847674


100%|██████████| 114/114 [00:00<00:00, 357.45it/s]
 33%|███▎      | 38/114 [00:00<00:00, 372.33it/s]

epoch: 84, loss: 0.40775731208198696


100%|██████████| 114/114 [00:00<00:00, 366.31it/s]
 32%|███▏      | 37/114 [00:00<00:00, 365.18it/s]

epoch: 85, loss: 0.400695760297478


100%|██████████| 114/114 [00:00<00:00, 361.95it/s]
 26%|██▋       | 30/114 [00:00<00:00, 296.17it/s]

epoch: 86, loss: 0.39606104133698333


100%|██████████| 114/114 [00:00<00:00, 334.50it/s]
 33%|███▎      | 38/114 [00:00<00:00, 370.58it/s]

epoch: 87, loss: 0.3881157607518696


100%|██████████| 114/114 [00:00<00:00, 366.29it/s]
 32%|███▏      | 37/114 [00:00<00:00, 365.96it/s]

epoch: 88, loss: 0.3800561007505312


100%|██████████| 114/114 [00:00<00:00, 365.12it/s]
 33%|███▎      | 38/114 [00:00<00:00, 371.50it/s]

epoch: 89, loss: 0.37390766323530517


100%|██████████| 114/114 [00:00<00:00, 392.41it/s]
 32%|███▏      | 36/114 [00:00<00:00, 353.82it/s]

epoch: 90, loss: 0.3660736814901448


100%|██████████| 114/114 [00:00<00:00, 337.85it/s]
 32%|███▏      | 37/114 [00:00<00:00, 361.80it/s]

epoch: 91, loss: 0.35959374110581505


100%|██████████| 114/114 [00:00<00:00, 337.88it/s]
 32%|███▏      | 36/114 [00:00<00:00, 358.66it/s]

epoch: 92, loss: 0.35294812181351465


100%|██████████| 114/114 [00:00<00:00, 362.43it/s]
 32%|███▏      | 36/114 [00:00<00:00, 354.89it/s]

epoch: 93, loss: 0.3485283073124993


100%|██████████| 114/114 [00:00<00:00, 356.77it/s]
 33%|███▎      | 38/114 [00:00<00:00, 370.34it/s]

epoch: 94, loss: 0.3423374393477326


100%|██████████| 114/114 [00:00<00:00, 365.79it/s]
 33%|███▎      | 38/114 [00:00<00:00, 374.75it/s]

epoch: 95, loss: 0.33655368840344163


100%|██████████| 114/114 [00:00<00:00, 369.94it/s]
 32%|███▏      | 37/114 [00:00<00:00, 367.52it/s]

epoch: 96, loss: 0.32558059385438703


100%|██████████| 114/114 [00:00<00:00, 354.03it/s]
 33%|███▎      | 38/114 [00:00<00:00, 371.29it/s]

epoch: 97, loss: 0.3194265909784772


100%|██████████| 114/114 [00:00<00:00, 368.12it/s]
 33%|███▎      | 38/114 [00:00<00:00, 372.16it/s]

epoch: 98, loss: 0.31357605352965445


100%|██████████| 114/114 [00:00<00:00, 341.63it/s]

epoch: 99, loss: 0.31065826108701955





In [10]:
import io

out_v = io.open('vectors.tsv', 'w', encoding='utf-8')
out_m = io.open('metadata.tsv', 'w', encoding='utf-8')

emb_book = model.emb_book.weight.detach().cpu().numpy()
for idx, w in enumerate(vocab.get_itos()):
    emb = emb_book[idx]
    out_v.write('\t'.join([str(x) for x in emb]) + "\n")
    out_m.write(w + "\n")
out_v.close()
out_m.close()