In [41]:
from sentence_transformers import SentenceTransformer, losses, InputExample, models
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from torch.utils.data import DataLoader
import torch
from torch import nn

import os

import re
import pandas as pd
import numpy as np

In [57]:
class AdapterModule(nn.Module):
    def __init__(self, input_dim, output_dim, dropout_rate=0.2, add_residual=True):
        super(AdapterModule, self).__init__()
        self.dense1 = nn.Linear(in_features=input_dim, out_features=1024, bias=True)
        self.dense2 = nn.Linear(in_features=1024, out_features=512, bias=True)
        self.output = nn.Linear(in_features=512, out_features=output_dim)
        self.activation = nn.Tanh()
        self.dropout = nn.Dropout(dropout_rate)
        self.add_residual = add_residual
        if add_residual:
            self.residual_weight = nn.Parameter(torch.zeros(1))

    def forward(self, input_data):
        x = input_data.get('sentence_embedding')
        original_x = x if self.add_residual else None
        x = self.dropout(self.activation(self.dense1(x)))
        x = self.dropout(self.activation(self.dense2(x)))
        x = self.output(x)
        if self.add_residual:
            x += self.residual_weight * original_x
            
        input_data['sentence_embedding'] = x
        
        return input_data
    
    def save(self, output_path):
        torch.save(self.state_dict(), os.path.join(output_path, 'adapter_module.pt'))
    
word_embedding_model = models.Transformer(
                            model_name_or_path="sentence-transformers/all-MiniLM-L12-v2", 
                            max_seq_length=128, 
                            do_lower_case=False
                            )

# Parametros default del modelo base a utilizar
pooling_model = models.Pooling(**{'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 
                                  'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False,
                                  'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 
                                  'pooling_mode_lasttoken': False, 'include_prompt': True})
normalize = models.Normalize()

# Unica sección que tiene pesos entrenables
for param in word_embedding_model.parameters():
    param.requires_grad = False
    

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Determinar la dimensión de entrada de acuerdo con la última capa del modelo base
adapter = AdapterModule(384, 384).to(device)


base_model = SentenceTransformer(modules=[word_embedding_model, pooling_model, normalize], device=device)

custom_domain_model = SentenceTransformer(modules=[word_embedding_model, pooling_model, normalize, 
                                                   adapter],device=device)

custom_domain_model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Normalize()
  (3): AdapterModule(
    (dense1): Linear(in_features=384, out_features=1024, bias=True)
    (dense2): Linear(in_features=1024, out_features=512, bias=True)
    (output): Linear(in_features=512, out_features=384, bias=True)
    (activation): Tanh()
    (dropout): Dropout(p=0.2, inplace=False)
  )
)

In [3]:
# define evaluator
#from sentence_transformers.evaluation import InformationRetrievalEvaluator
# define over validation dataset

#evaluator = InformationRetrievalEvaluator(queries, corpus, relevant_docs)


In [28]:
data = pd.read_parquet("data/mlt_data_publications.parquet")

In [None]:
titles_map = {row.paperId:row.title for row in data.itertuples()}

related_pubs = [(titles_map[row.paperId], 'relatedWith', titles_map[row.source]) 
                for row in data.itertuples() if row.con_type!='base']

question_answer_pairs = [
    (f'Which paper is cited or referenced in the paper titled "{triple[0]}"?', triple[2])
    for triple in related_pubs
]

In [29]:
authors = []
for row in data.itertuples():
    for author in row.authors:
        names = [author.get('name')]
        
        try:
            aliases = author.get('aliases').tolist()
        except AttributeError:
            aliases = []
        
        names += aliases
        
        names = [' '.join(set(re.sub(r'\s+', ' ', name).split(' ')) )for name in names]

        for name in list(set(names)):
            authors.append({'author': name, 'publication': row.title, 
                            'paperId': row.paperId, 'authorId': author.get('authorId')})

In [30]:
authors = pd.DataFrame(authors)

In [55]:
question_answer_pairs_authors = [
    (f'Who wrote the paper titled "{row.publication}"?', row.author)
    for row in authors.sample(1000, random_state=42).itertuples()
]


In [63]:
train_examples = [
    InputExample(texts=[qa[0], qa[1]])
    for qa in question_answer_pairs_authors
]

In [66]:
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
train_loss = losses.MultipleNegativesRankingLoss(custom_domain_model)

custom_domain_model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=10,
    warmup_steps=100,
    output_path='domain_adaptation_model',
    show_progress_bar=True,
    #evaluator=evaluator, 
    #evaluation_steps=50,
)

Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/32 [00:00<?, ?it/s]

Iteration:   0%|          | 0/32 [00:00<?, ?it/s]

Iteration:   0%|          | 0/32 [00:00<?, ?it/s]

Iteration:   0%|          | 0/32 [00:00<?, ?it/s]

Iteration:   0%|          | 0/32 [00:00<?, ?it/s]

Iteration:   0%|          | 0/32 [00:00<?, ?it/s]

Iteration:   0%|          | 0/32 [00:00<?, ?it/s]

Iteration:   0%|          | 0/32 [00:00<?, ?it/s]

Iteration:   0%|          | 0/32 [00:00<?, ?it/s]

Iteration:   0%|          | 0/32 [00:00<?, ?it/s]

In [77]:
paper = "Computational complexity versus statistical performance on sparse recovery problems"
author = "Nicolas Boumal"

In [79]:
custom_paper = custom_domain_model.encode(paper, normalize_embeddings=True)
custom_author = custom_domain_model.encode(author, normalize_embeddings=True)

np.dot(custom_paper, custom_author)

0.947363

In [80]:
base_paper = base_model.encode(paper, normalize_embeddings=True)
base_author = base_model.encode(author, normalize_embeddings=True)

np.dot(base_paper, base_author)

0.07937869