In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np

from siamese.SIF import SIF_embedding
from siamese.prepare_dataset import MonthDataset, TripletBatchSampler
from siamese.siamese import TripletSiamese, TripletLoss

In [2]:
from importlib import reload
import siamese.SIF
import siamese.siamese
reload(siamese.SIF)
reload(siamese.siamese)
from siamese.SIF import SIF_embedding
from siamese.siamese import TripletSiamese, TripletLoss

In [3]:
from dataset.BinaryEmbeddings import BinaryEmbeddings

month = "202312"
BE_model = BinaryEmbeddings("dataset/vectors")
with open(f"dataset/corpora/corpus_{month}.txt", "r") as file:
    data = BE_model.parse_docs(file.readlines(), month)
print(BE_model.NE2id[month][BE_model.id2NE[month][5]])
print(data)
data.size()

5
tensor([[False, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [False,  True, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False,  True, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])


torch.Size([11444, 10640])

In [4]:
vectors = pd.read_csv(f'dataset/vectors/vectors_{month}_175.txt', sep=' ', header=None, index_col=0).rename(columns={0: 'word'})

art_entities_orig = [torch.nonzero(r, as_tuple=True)[0].numpy() for r in data]

small_article_indices = [i for i in range(len(art_entities_orig)) if art_entities_orig[i].shape[0] < 5]
# throw out articles with few entities
small_articles = [art_entities_orig[i].shape[0] < 5 for i in range(len(art_entities_orig))]
art_entities = [art_entities_orig[i] for i in range(len(art_entities_orig)) if not small_articles[i]]

ent_weight = [np.ones(ent.shape[0]) for i, ent in enumerate(art_entities)]
entity_embs = [data[i] for i in range(len(data)) if not small_articles[i]]
entity_embs = torch.stack(entity_embs, axis=0)

In [19]:
# extract SIF embeddings
art_emb_sif = SIF_embedding(vectors.values, art_entities, ent_weight, 0)
art_emb_sif = torch.from_numpy(art_emb_sif)
# create dummy data
# art_emb_sif = torch.rand(600, 175)*2
# entity_embs = (torch.rand(600, 3775) > 0.98).type(torch.int32)

In [22]:
## Initialize parameters
lr = 0.0015
threshold = 0.3
margin = 1
epochs = 1000

## Initialize network
model = TripletSiamese(175)

## Initialize optimizer
optim = torch.optim.Adam(model.parameters(),lr=lr)

## Initialize scheduler
#scheduler = torch.optim.lr_scheduler.StepLR(optim, 8)

## Initialize loss
criterion = torch.nn.TripletMarginLoss()
# criterion = torch.nn.BCEWithLogitsLoss()

In [23]:
# create test/val split
n_train = int(0.8*len(art_emb_sif))
n_val = int(0.1*len(art_emb_sif))
n_test = len(art_emb_sif) - n_train - n_val
train, test, val = torch.utils.data.random_split(art_emb_sif, [n_train, n_test, n_val])

In [24]:
# create dataset
train_ds= MonthDataset(train.dataset[train.indices], n_train, entity_embs[train.indices])
val_ds = MonthDataset(val.dataset[val.indices], n_val, entity_embs[val.indices])
test_ds = MonthDataset(test.dataset[test.indices], n_test, entity_embs[test.indices])

torch.Size([90000, 175, 3])
torch.Size([11250, 175, 3])
torch.Size([11253, 175, 3])


In [25]:
# customsamplers
train_sampler = TripletBatchSampler(train_ds.data)
val_sampler = TripletBatchSampler(val_ds.data)
test_sampler = TripletBatchSampler(test_ds.data)

# create dataloader
train_dl = torch.utils.data.DataLoader(train_ds, sampler=train_sampler)
test_dl = torch.utils.data.DataLoader(test_ds, sampler=test_sampler)
valid_dl = torch.utils.data.DataLoader(val_ds, sampler=val_sampler)

In [10]:
print(model.dnn1[6].weight)

Parameter containing:
tensor([[ 0.0882,  0.0307,  0.0155,  ..., -0.0801, -0.0824,  0.0675],
        [-0.0475,  0.0053, -0.0432,  ..., -0.0567,  0.0420,  0.0647],
        [ 0.0743,  0.0867, -0.0599,  ...,  0.0008,  0.0370, -0.0364],
        ...,
        [-0.0777,  0.0454,  0.0103,  ..., -0.0304,  0.0106,  0.0604],
        [-0.0533,  0.0506, -0.0007,  ..., -0.0703,  0.0213, -0.0226],
        [-0.0691, -0.0641, -0.0699,  ..., -0.0018,  0.0479, -0.0150]],
       requires_grad=True)


In [29]:
len(train_dl)/64

1406.25

In [None]:
train_loss = []
valid_loss = []
for epoch in range(500):
    train_epoch_loss = 0
    model.train()    
    if epoch % 50 == 0:
        torch.save(model.state_dict(), f'./model_{month}_epoch_{i}.pt')
    for i, x in enumerate(train_dl):
        optim.zero_grad()
        a, p, n = x[0].T.mT.type(torch.float)
        #print(a.size())
       # print(p)
       # print(n)
        out = model(a, p, n)
        if i == 0:
            print(a == p)
            print(out[0] == out[1])
        # print(a.type(torch.float))
        loss = criterion(*out)
        #print(a.type(torch.float), p.type(torch.float), n.type(torch.float))
        #print(loss)
        #loss =  torch.mean(loss)
        #print(loss)
        train_epoch_loss += torch.mean(loss)
        #loss.requires_grad = True
        loss.mean().backward()
        optim.step()
    
    train_epoch_loss /= len(train_ds)
    train_loss.append(train_epoch_loss)
    
    print("Epoch [{}/{}] ----> Training loss :{} \n".format(epoch+1,epochs,train_epoch_loss))


    valid_epoch_loss = 0
    val_pos_accuracy = 0
    val_neg_accuracy = 0
    num_pos = 0
    num_neg = 0
    model.eval()

    for i, x in enumerate(valid_dl):
        optim.zero_grad()
        a, p, n = x[0].T.mT
        out = model(a.type(torch.float), p.type(torch.float), n.type(torch.float))
        
        loss = criterion(*out)
        loss =  torch.mean(loss)
        
        valid_epoch_loss += loss.item()
       
    valid_epoch_loss /= len(val_ds)
    
    valid_loss.append(valid_epoch_loss)

    print("Validation loss :{}\n".format(valid_epoch_loss))

tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])
tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])
Epoch [1/1000] ----> Training loss :0.006239420734345913 

Validation loss :0.13180310889350044

tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
     