In [17]:
import sqlite3
from peewee import SqliteDatabase
import msgspec
import torch
from torch import nn
from torch.utils.data import DataLoader
import numpy as np
from datasets import Dataset

In [15]:
from database import get_model_class
from data_schema import SentencePair
from load_spacy import load_spacy, translate_sequence
from sentence_embedding_model_v7 import SentenceEmbeddingV7
from train_sentence_embedding import CFG, dataset_map

In [19]:
device = 'cuda'
BATCH_SIZE = 2

In [4]:
nlp = load_spacy()

In [33]:
checkpoint = 'tmp/checkpoints/v9/epoch100_encoder1'

CFG['device'] = device

encoder1 = SentenceEmbeddingV7(config=CFG, nlp=nlp, batch_size=BATCH_SIZE).to(device)
encoder1.load_state_dict(torch.load(checkpoint))
encoder1.eval()

SentenceEmbeddingV7(
  (compress1): Linear(in_features=300, out_features=256, bias=False)
  (compress2): Linear(in_features=256, out_features=128, bias=False)
  (compress3): Linear(in_features=128, out_features=64, bias=False)
  (compress_norm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (encoder1): GRU(64, 16, num_layers=5)
  (decode2): GRU(16, 16, num_layers=4, dropout=0.89)
  (linear_out1): Linear(in_features=16, out_features=128, bias=False)
  (linear_out2): Linear(in_features=16, out_features=128, bias=False)
  (linear_out3): Linear(in_features=16, out_features=128, bias=False)
  (linear_out4): Linear(in_features=16, out_features=128, bias=False)
  (norm): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.89, inplace=False)
  (cos): CosineSimilarity()
)

In [34]:
dataset = 'guardian_headlines'

db = SqliteDatabase(f'sentence_embedding_training_data/{dataset}.txt.db')
ModelClass = get_model_class(db)
db.connect()

True

In [35]:
word_embedding= encoder1.get_word_embedding()

counter = 0
iteration = 2

for record in ModelClass.select():
    counter += 1

    if counter == 1:
        continue
        
    pair = msgspec.json.decode(record.json_data, type=SentencePair)
    sample1 = pair.sample1
    print('sample1: ', sample1)
    text1 = translate_sequence(sample1, nlp)
    sample1_ts = torch.from_numpy(np.array(sample1)).to(device)
    sample1_ts = word_embedding(sample1_ts)
    # sample1_ts = sample1_ts.unsqueeze(dim=1)
    print('sample1_ts.shape: ', sample1_ts.shape)
    print('sample1_ts: ', sample1_ts[0][0:5])
    print('sample1_ts: ', sample1_ts[1][0:5])
    print('sample1_ts: ', sample1_ts[2][0:5])
    input1 = torch.stack((sample1_ts, sample1_ts), dim=1)
    print('input1.shape: ', input1.shape)
    embedded1 = encoder1(input1)

    sample2 = pair.sample2
    print('sample2: ', sample2)
    text2 = translate_sequence(sample2, nlp)
    sample2_ts = torch.from_numpy(np.array(sample2)).to(device)
    sample2_ts = word_embedding(sample2_ts)
    # sample2_ts = sample2_ts.unsqueeze(dim=1)
    print('sample2_ts: ', sample2_ts[0][0:5])
    print('sample2_ts: ', sample2_ts[1][0:5])
    print('sample2_ts: ', sample2_ts[2][0:5])
    input2 = torch.stack((sample2_ts, sample2_ts), dim=1)
    print('input2.shape: ', input2.shape)
    embedded2 = encoder1(input2)

    score = encoder1.cos(embedded1, embedded2)

    print(f'{text1} - {text2} - score: {score[0].item()}')
    print(f'{text1} - {text2} - score: {score[1].item()}')

    if counter >= iteration:
        break

sample1:  [141171, 289, 18, 487, 186, 539, 5, 6076, 105, 931, 514157, 514157, 514157, 514157, 514157, 514157, 514157, 514157, 514157, 514157, 514157, 514157, 514157, 514157, 514157, 514157, 514157, 514157, 514157, 514157, 514157, 514157, 514157, 514157, 514157, 514157, 514157, 514157, 514157, 514157]
sample1_ts.shape:  torch.Size([40, 300])
sample1_ts:  tensor([-0.1264, -2.5252,  2.0018,  0.7056,  2.4124], device='cuda:0')
sample1_ts:  tensor([ 1.6989,  5.3603, -0.2236,  1.3875,  4.8625], device='cuda:0')
sample1_ts:  tensor([ 6.5268,  9.8299, -4.6809, -1.7518,  1.8730], device='cuda:0')
input1.shape:  torch.Size([40, 2, 300])
input.shape:  torch.Size([40, 2, 300])
compressed:  [1.4100443124771118, -1.8930031061172485, -0.7354149222373962, 0.7657529711723328, -0.6652577519416809]
compressed:  [0.7302178144454956, -0.10998714715242386, -0.36964771151542664, 0.4085789918899536, -0.44284600019454956]
compressed:  [-12.233009338378906, 8.465375900268555, 9.533697128295898, -9.5712356567382