### Transformer aggregation inference example. Scientific papers dataset and the finetuned mpnet model combination.

In [None]:

from loading_paper_utils import load_data_from_pkl, embed_col, embed_text_list_col
from sentence_transformers import SentenceTransformer

import pandas as pd
pd.options.mode.chained_assignment = None  # default='warn'
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
import networkx as nx
import torch
from transformer_agg import transformer_agg_net, PaperVecPairsDataset
import warnings
warnings.filterwarnings("ignore", message=".*To copy construct from a tensor.*")

def run_trans_agg_embed(df_with_embedded_text_list_col, embedded_text_list_col_name, agg_model, max_len, agg_model_name, vec_shape=768):
    df_with_embedded_text_list_col[embedded_text_list_col_name] = df_with_embedded_text_list_col[embedded_text_list_col_name].apply(lambda x: torch.tensor(x))
    # print(df_with_embedded_text_list_col[embedded_text_list_col_name].iloc[0].shape[0])
    df_with_embedded_text_list_col['padding_indicator'] = df_with_embedded_text_list_col[embedded_text_list_col_name].apply(lambda x: torch.cat((torch.zeros(x.shape[0]), torch.ones(max_len - x.shape[0]))))
    df_with_embedded_text_list_col[embedded_text_list_col_name] = df_with_embedded_text_list_col[embedded_text_list_col_name].apply(lambda x: torch.cat((x, torch.zeros((max_len - x.shape[0], vec_shape))), dim=0))
    df_with_embedded_text_list_col[f'{agg_model_name}_agg_outs_tuple'] = df_with_embedded_text_list_col.apply(lambda x: agg_model((x[embedded_text_list_col_name].unsqueeze(0), x['padding_indicator'].unsqueeze(0))), axis=1)
    df_with_embedded_text_list_col[f'{agg_model_name}_agg_rep'] = df_with_embedded_text_list_col[f'{agg_model_name}_agg_outs_tuple'].apply(lambda x: x[0].squeeze().cpu().detach().numpy())

sci_papers_mpnet_agg_model = transformer_agg_net(d_model=768, nhead=12, n_transformer_layers=2, dropout=0.1)
sci_papers_mpnet_agg_model.load_state_dict(torch.load('TRAINED/MODEL/SAVE/PATH', weights_only=True))
sci_papers_mpnet_agg_model.eval()

transformer_agg_net(
  (encoder_layer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
    )
    (linear1): Linear(in_features=768, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=2048, out_features=768, bias=True)
    (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (linear1): Linear(in_features=768, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (

In [None]:
sci_papers_train_set = load_data_from_pkl('/PATH/TO/TRAIN/PKL/OF/DATA/PREPPED/FOR/TRANSFORMER/AGGREGATION')
sci_papers_test_set = load_data_from_pkl('//PATH/TO/TEST/PKL/OF/DATA/PREPPED/FOR/TRANSFORMER/AGGREGATION')


In [3]:
sci_papers_max_len = sci_papers_train_set['pars_abstract_ft_e5_embedding'].apply(lambda x: len(x)).max()

In [4]:
print('SCI PAPERS TRAIN SET')
print(sci_papers_train_set.columns)
print('SCI PAPERS TEST SET')
print(sci_papers_test_set.columns)

SCI PAPERS TRAIN SET
Index(['id', 'title', 'abstract', 'category', 'html_path', 'pars', 'sections',
       'num_pars', 'pars_abstract_ft_e5_embedding',
       'pars_abstract_ft_all_mpnet_embedding'],
      dtype='object')
SCI PAPERS TEST SET
Index(['id', 'title', 'abstract', 'category', 'html_path', 'pars', 'sections',
       'num_pars', 'pars_abstract_ft_e5_embedding',
       'pars_abstract_ft_all_mpnet_embedding'],
      dtype='object')


In [5]:
for df, name in [(sci_papers_train_set, 'sci_train'), (sci_papers_test_set, 'sci_test')]:
    run_trans_agg_embed(df, 'pars_abstract_ft_all_mpnet_embedding', sci_papers_mpnet_agg_model, sci_papers_max_len, 'mpnet_trained_aggregator')


In [6]:
from sklearn.metrics import silhouette_score
print("TRAIN")
df = sci_papers_train_set
name = 'sci_papers'
for c in [x for x in df.columns if 'aggregator' in x and 'rep' in x]:
    cat_conv = {cat:i for i, cat in enumerate(df['category'].unique())}
    # print(df[c].iloc[0][0].squeeze().shape)
    print(f'{name} columns {c} sil: {silhouette_score(df[c].tolist(), df["category"].apply(lambda x: cat_conv[x]).tolist(), metric="cosine")}')

print("TEST")
df = sci_papers_test_set
name = 'sci_papers'
for c in [x for x in df.columns if 'aggregator' in x and 'rep' in x]:
    cat_conv = {cat:i for i, cat in enumerate(df['category'].unique())}
    # print(df[c].iloc[0][0].squeeze().shape)
    print(f'{name} columns {c} sil: {silhouette_score(df[c].tolist(), df["category"].apply(lambda x: cat_conv[x]).tolist(), metric="cosine")}')


TRAIN
sci_papers columns mpnet_trained_aggregator_agg_rep sil: 0.9944777488708496
TEST
sci_papers columns mpnet_trained_aggregator_agg_rep sil: 0.5468906164169312


In [7]:
from data_refining_pipeline import train_df_to_centroids_v2, alg_for_par_selection, alg_for_par_selection_with_initial_cents

sci_papers_cents_mpnet, sci_paper_cent_d_mpnet = train_df_to_centroids_v2(sci_papers_train_set, 'mpnet_trained_aggregator_agg_rep')
sci_papers_train_set['mpnet_predicted_cent'] = sci_papers_train_set['mpnet_trained_aggregator_agg_rep'].apply(lambda x:
                                                                                                        np.argmax(cosine_similarity([x], sci_papers_cents_mpnet)))

sci_papers_test_set['mpnet_predicted_cent'] = sci_papers_test_set['mpnet_trained_aggregator_agg_rep'].apply(lambda x:
                                                                                                        np.argmax(cosine_similarity([x], sci_papers_cents_mpnet)))



Centroids shape: (16, 768)


In [8]:
from matching_metrics import get_macro_f1

df = sci_papers_test_set
name = 'sci_papers'
pred_col = 'mpnet_predicted_cent'        
cat_d = {cat:i for i, cat in enumerate(df['category'].unique())}
df['cat_id'] = df['category'].apply(lambda x: cat_d[x])
pairs = pd.merge(df[[pred_col, 'category', 'id']], df[[pred_col, 'category', 'id']], how='cross', suffixes=('_1', '_2'))
pairs['label'] = pairs.apply(lambda x: 1 if x['category_1'] == x['category_2'] else 0, axis=1)
pairs['pred'] = pairs.apply(lambda x: 1 if x[f'{pred_col}_1'] == x[f'{pred_col}_2'] else 0, axis=1)
print(f"F1 result for {name} df, {pred_col} col, {get_macro_f1(pairs)[0]}")

F1 result for sci_papers df, mpnet_predicted_cent col, 0.5806770708609359
