In [1]:
import json
import spacy
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator, Vocab
from typing import List
from torch import nn
import torch
from einops import rearrange
import random
from tqdm import tqdm

In [2]:
with open('docs.json') as f:
    contents = f.read()
    docs = json.loads(contents)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gram = 9
bz = 256
epochs = 100
emb_dim = 64
dims = [256,]

In [4]:
def yield_tokens(docs):
    nlp = spacy.load("en_core_web_sm")
    for doc in docs:
        if "abstract" in doc:
            for abstract in doc["abstract"]:
                if "text" in abstract:
                    doc = nlp(abstract["text"])
                    assert doc.has_annotation("SENT_START")
                    for sent in doc.sents:
                        yield [token.lemma_.lower() for token in sent]

vocab = build_vocab_from_iterator(yield_tokens(docs), specials=["<unk>"], min_freq=3)
vocab.set_default_index(vocab["<unk>"])

vocab.__len__()

4848

In [5]:
for ws in yield_tokens(docs):
    print(ws)
    break

['covid-19', ',', 'cause', 'by', 'sars', '-', 'cov-2', 'infection', ',', 'be', 'mild', 'to', 'moderate', 'in', 'the', 'majority', 'of', 'previously', 'healthy', 'individual', ',', 'but', 'can', 'cause', 'life', '-', 'threaten', 'disease', 'or', 'persistent', 'debilitate', 'symptom', 'in', 'some', 'case', '.']


In [6]:
datas = []
for ws in yield_tokens(docs):
    for i in range(n_gram//2, len(ws) - n_gram//2):
        ctx = [*ws[i-n_gram//2:i], *ws[i+1:i+1+n_gram//2],]
        tar = ws[i]
        datas.append((ctx, tar))

In [7]:
class CBOW(torch.nn.Module):
    def __init__(self, n_gram, emb_dim, dims, vocab,):
        super(CBOW, self).__init__()
        self.vocab = vocab
        self.n_gram = n_gram
        
        self.emb_book = nn.Embedding(vocab.__len__(), emb_dim)
        
        # dims = [emb_dim*(n_gram-1), *dims]
        # dims = [emb_dim*2, *dims]
        dims = [emb_dim, *dims]
        net = []
        for i in range(len(dims)-1):
            net.append(nn.Linear(dims[i], dims[i+1]))
            net.append(nn.ReLU(),)
            
        net.append(nn.Linear(dims[-1], vocab.__len__()),)
        net.append(nn.LogSoftmax(dim = -1),)
            
        self.net = nn.ModuleList(net)

    def forward(self, ctx_idx):
        inp = self.emb_book(ctx_idx)
        out = inp.mean(1)
        # out = rearrange(inp, "b n d -> b (n d)")
        # f, b = torch.chunk(inp, 2, 1)
        # out = torch.cat([f.mean(1), b.mean(1)], 1)
        for layer in self.net:
            out = layer(out)
        
        return out

In [8]:
model = CBOW(n_gram, emb_dim, dims, vocab).to(device)
loss_fn = nn.NLLLoss()

optim = torch.optim.Adamax([*model.parameters(),], lr=0.01)

In [9]:
for epoch in range(epochs):
    total_loss = 0
    random.shuffle(datas)
    bnum = len(datas)//bz + (1 if len(datas)%bz!=0 else 0)
    for bidx in tqdm(range(bnum)):
        ctx = []
        tar = []
        for _ctx, _tar in datas[bidx*bz:(bidx+1)*bz]:
            ctx.append(vocab(_ctx))
            tar.append(_tar)
        
        ctx = torch.tensor(ctx, dtype=torch.int64, device=device)
        tar = torch.tensor(vocab(tar), dtype=torch.int64, device=device)
        
        out = model(ctx)
        loss = loss_fn(out, tar)
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        total_loss += loss.item()*len(tar)
        
    print("epoch: {}, loss: {}".format(epoch, total_loss/len(datas)))
    torch.save({
        'epoch': epoch,
        'net_state_dict': model.state_dict(),
        'opt_state_dict': optim.state_dict(),
        }, "./cbow.pt")

100%|██████████| 816/816 [00:02<00:00, 343.88it/s]


epoch: 0, loss: 5.827592667229189


100%|██████████| 816/816 [00:02<00:00, 335.55it/s]


epoch: 1, loss: 5.323447640862354


100%|██████████| 816/816 [00:02<00:00, 344.43it/s]


epoch: 2, loss: 5.025668031443371


100%|██████████| 816/816 [00:02<00:00, 308.89it/s]


epoch: 3, loss: 4.755579183264168


100%|██████████| 816/816 [00:02<00:00, 283.40it/s]


epoch: 4, loss: 4.502237911959563


100%|██████████| 816/816 [00:02<00:00, 287.31it/s]


epoch: 5, loss: 4.2671306392252095


100%|██████████| 816/816 [00:02<00:00, 285.97it/s]


epoch: 6, loss: 4.057831577827357


100%|██████████| 816/816 [00:02<00:00, 287.54it/s]


epoch: 7, loss: 3.8769285182002093


100%|██████████| 816/816 [00:02<00:00, 314.13it/s]


epoch: 8, loss: 3.72375105770115


100%|██████████| 816/816 [00:02<00:00, 366.80it/s]


epoch: 9, loss: 3.5928600634408228


100%|██████████| 816/816 [00:02<00:00, 369.02it/s]


epoch: 10, loss: 3.4798559638804285


100%|██████████| 816/816 [00:02<00:00, 379.74it/s]


epoch: 11, loss: 3.377505536298468


100%|██████████| 816/816 [00:02<00:00, 382.56it/s]


epoch: 12, loss: 3.287358525128304


100%|██████████| 816/816 [00:02<00:00, 375.53it/s]


epoch: 13, loss: 3.2046168937852655


100%|██████████| 816/816 [00:02<00:00, 382.54it/s]


epoch: 14, loss: 3.129196349320875


100%|██████████| 816/816 [00:02<00:00, 382.81it/s]


epoch: 15, loss: 3.0598404838863544


100%|██████████| 816/816 [00:02<00:00, 373.32it/s]


epoch: 16, loss: 2.9951357025630427


100%|██████████| 816/816 [00:02<00:00, 372.30it/s]


epoch: 17, loss: 2.9342348236951605


100%|██████████| 816/816 [00:02<00:00, 360.42it/s]


epoch: 18, loss: 2.878367961633753


100%|██████████| 816/816 [00:02<00:00, 370.19it/s]


epoch: 19, loss: 2.825759217585543


100%|██████████| 816/816 [00:02<00:00, 367.79it/s]


epoch: 20, loss: 2.7761906132149403


100%|██████████| 816/816 [00:02<00:00, 371.09it/s]


epoch: 21, loss: 2.728673900305126


100%|██████████| 816/816 [00:02<00:00, 372.99it/s]


epoch: 22, loss: 2.6841060304702014


100%|██████████| 816/816 [00:02<00:00, 368.63it/s]


epoch: 23, loss: 2.6407324772026537


100%|██████████| 816/816 [00:02<00:00, 361.37it/s]


epoch: 24, loss: 2.600771309822527


100%|██████████| 816/816 [00:02<00:00, 375.73it/s]


epoch: 25, loss: 2.5625154906851364


100%|██████████| 816/816 [00:02<00:00, 362.51it/s]


epoch: 26, loss: 2.5254615474795057


100%|██████████| 816/816 [00:02<00:00, 372.87it/s]


epoch: 27, loss: 2.491844566047701


100%|██████████| 816/816 [00:02<00:00, 371.92it/s]


epoch: 28, loss: 2.457064169988563


100%|██████████| 816/816 [00:02<00:00, 366.77it/s]


epoch: 29, loss: 2.424814478593465


100%|██████████| 816/816 [00:02<00:00, 376.20it/s]


epoch: 30, loss: 2.3926763287713837


100%|██████████| 816/816 [00:02<00:00, 365.50it/s]


epoch: 31, loss: 2.3640933502788783


100%|██████████| 816/816 [00:02<00:00, 368.15it/s]


epoch: 32, loss: 2.3342688631575665


100%|██████████| 816/816 [00:02<00:00, 359.91it/s]


epoch: 33, loss: 2.307416696786707


100%|██████████| 816/816 [00:02<00:00, 370.20it/s]


epoch: 34, loss: 2.2801465541541135


100%|██████████| 816/816 [00:02<00:00, 372.43it/s]


epoch: 35, loss: 2.2541805552570264


100%|██████████| 816/816 [00:02<00:00, 371.72it/s]


epoch: 36, loss: 2.229188758873191


100%|██████████| 816/816 [00:02<00:00, 367.06it/s]


epoch: 37, loss: 2.2052004253358657


100%|██████████| 816/816 [00:02<00:00, 371.94it/s]


epoch: 38, loss: 2.181312711966294


100%|██████████| 816/816 [00:02<00:00, 372.64it/s]


epoch: 39, loss: 2.15851300243418


100%|██████████| 816/816 [00:02<00:00, 371.37it/s]


epoch: 40, loss: 2.136362629862342


100%|██████████| 816/816 [00:02<00:00, 374.28it/s]


epoch: 41, loss: 2.1146451205630665


100%|██████████| 816/816 [00:02<00:00, 372.40it/s]


epoch: 42, loss: 2.093320611372885


100%|██████████| 816/816 [00:02<00:00, 374.22it/s]


epoch: 43, loss: 2.0746389452746703


100%|██████████| 816/816 [00:02<00:00, 356.85it/s]


epoch: 44, loss: 2.0542100746926963


100%|██████████| 816/816 [00:02<00:00, 358.47it/s]


epoch: 45, loss: 2.0355851691346265


100%|██████████| 816/816 [00:02<00:00, 369.26it/s]


epoch: 46, loss: 2.0165106718181818


100%|██████████| 816/816 [00:02<00:00, 368.58it/s]


epoch: 47, loss: 1.9979390335510534


100%|██████████| 816/816 [00:02<00:00, 372.94it/s]


epoch: 48, loss: 1.97984710254086


100%|██████████| 816/816 [00:02<00:00, 370.67it/s]


epoch: 49, loss: 1.9639126008815466


100%|██████████| 816/816 [00:02<00:00, 372.10it/s]


epoch: 50, loss: 1.9474529955112057


100%|██████████| 816/816 [00:02<00:00, 369.45it/s]


epoch: 51, loss: 1.9296561538535948


100%|██████████| 816/816 [00:02<00:00, 372.17it/s]


epoch: 52, loss: 1.9145998262693116


100%|██████████| 816/816 [00:02<00:00, 374.57it/s]


epoch: 53, loss: 1.899029837243333


100%|██████████| 816/816 [00:02<00:00, 374.58it/s]


epoch: 54, loss: 1.8842506258600704


100%|██████████| 816/816 [00:02<00:00, 325.86it/s]


epoch: 55, loss: 1.8680669850894576


100%|██████████| 816/816 [00:02<00:00, 282.79it/s]


epoch: 56, loss: 1.8535157769679786


100%|██████████| 816/816 [00:02<00:00, 285.65it/s]


epoch: 57, loss: 1.839645441612835


100%|██████████| 816/816 [00:02<00:00, 317.10it/s]


epoch: 58, loss: 1.8259333451546318


100%|██████████| 816/816 [00:02<00:00, 344.60it/s]


epoch: 59, loss: 1.8115968331052288


100%|██████████| 816/816 [00:02<00:00, 345.19it/s]


epoch: 60, loss: 1.798922135019728


100%|██████████| 816/816 [00:02<00:00, 342.61it/s]


epoch: 61, loss: 1.7853688123130338


100%|██████████| 816/816 [00:02<00:00, 330.33it/s]


epoch: 62, loss: 1.7721332927431483


100%|██████████| 816/816 [00:02<00:00, 366.09it/s]


epoch: 63, loss: 1.7596750598679312


100%|██████████| 816/816 [00:02<00:00, 377.32it/s]


epoch: 64, loss: 1.747546869459963


100%|██████████| 816/816 [00:02<00:00, 367.25it/s]


epoch: 65, loss: 1.7367769696665996


100%|██████████| 816/816 [00:02<00:00, 361.82it/s]


epoch: 66, loss: 1.7237234590309252


100%|██████████| 816/816 [00:02<00:00, 369.62it/s]


epoch: 67, loss: 1.7118070980767974


100%|██████████| 816/816 [00:02<00:00, 369.64it/s]


epoch: 68, loss: 1.7005358406435809


100%|██████████| 816/816 [00:02<00:00, 354.21it/s]


epoch: 69, loss: 1.6897696260245962


100%|██████████| 816/816 [00:02<00:00, 362.71it/s]


epoch: 70, loss: 1.6783096882989934


100%|██████████| 816/816 [00:02<00:00, 367.22it/s]


epoch: 71, loss: 1.668250072276009


100%|██████████| 816/816 [00:02<00:00, 355.89it/s]


epoch: 72, loss: 1.6568632398904468


100%|██████████| 816/816 [00:02<00:00, 369.05it/s]


epoch: 73, loss: 1.6470110833934908


100%|██████████| 816/816 [00:02<00:00, 367.35it/s]


epoch: 74, loss: 1.6365157776698513


100%|██████████| 816/816 [00:02<00:00, 363.20it/s]


epoch: 75, loss: 1.6257413183621023


100%|██████████| 816/816 [00:02<00:00, 375.05it/s]


epoch: 76, loss: 1.6162574779732009


100%|██████████| 816/816 [00:02<00:00, 371.46it/s]


epoch: 77, loss: 1.6068982021430154


100%|██████████| 816/816 [00:02<00:00, 367.75it/s]


epoch: 78, loss: 1.5980551422334293


100%|██████████| 816/816 [00:02<00:00, 375.77it/s]


epoch: 79, loss: 1.5877825870582978


100%|██████████| 816/816 [00:02<00:00, 377.77it/s]


epoch: 80, loss: 1.5784799019867142


100%|██████████| 816/816 [00:02<00:00, 371.84it/s]


epoch: 81, loss: 1.569288124672375


100%|██████████| 816/816 [00:02<00:00, 369.57it/s]


epoch: 82, loss: 1.561446163493457


100%|██████████| 816/816 [00:02<00:00, 373.64it/s]


epoch: 83, loss: 1.5515378521358314


100%|██████████| 816/816 [00:02<00:00, 371.94it/s]


epoch: 84, loss: 1.5439552047151346


100%|██████████| 816/816 [00:02<00:00, 370.73it/s]


epoch: 85, loss: 1.5343008439583525


100%|██████████| 816/816 [00:02<00:00, 372.47it/s]


epoch: 86, loss: 1.5265585534293704


100%|██████████| 816/816 [00:02<00:00, 375.08it/s]


epoch: 87, loss: 1.5190281381761988


100%|██████████| 816/816 [00:02<00:00, 372.29it/s]


epoch: 88, loss: 1.5099740741791783


100%|██████████| 816/816 [00:02<00:00, 366.65it/s]


epoch: 89, loss: 1.5023814674821263


100%|██████████| 816/816 [00:02<00:00, 373.57it/s]


epoch: 90, loss: 1.4941553634182922


100%|██████████| 816/816 [00:02<00:00, 367.83it/s]


epoch: 91, loss: 1.4868992211156487


100%|██████████| 816/816 [00:02<00:00, 375.92it/s]


epoch: 92, loss: 1.479607174147947


100%|██████████| 816/816 [00:02<00:00, 375.78it/s]


epoch: 93, loss: 1.4706883463595917


100%|██████████| 816/816 [00:02<00:00, 374.27it/s]


epoch: 94, loss: 1.4645511812074663


100%|██████████| 816/816 [00:02<00:00, 364.40it/s]


epoch: 95, loss: 1.4565233034605065


100%|██████████| 816/816 [00:02<00:00, 369.12it/s]


epoch: 96, loss: 1.449912134123614


100%|██████████| 816/816 [00:02<00:00, 375.99it/s]


epoch: 97, loss: 1.4420739932920592


100%|██████████| 816/816 [00:02<00:00, 356.59it/s]


epoch: 98, loss: 1.4354226082703854


100%|██████████| 816/816 [00:02<00:00, 366.00it/s]


epoch: 99, loss: 1.4295192252101057


In [10]:
import io

out_v = io.open('vectors.tsv', 'w', encoding='utf-8')
out_m = io.open('metadata.tsv', 'w', encoding='utf-8')

emb_book = model.emb_book.weight.detach().cpu().numpy()
for idx, w in enumerate(vocab.get_itos()):
    emb = emb_book[idx]
    out_v.write('\t'.join([str(x) for x in emb]) + "\n")
    out_m.write(w + "\n")
out_v.close()
out_m.close()