In [1]:
from sentence_transformers import SentenceTransformer, losses, InputExample, models
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from torch.utils.data import DataLoader
import torch
from torch import nn

import os

import re
import pandas as pd
import numpy as np

In [2]:
class AdapterModule(nn.Module):
    def __init__(self, input_dim, output_dim, dropout_rate=0.3, add_residual=True):
        super(AdapterModule, self).__init__()
        self.dense1 = nn.Linear(in_features=input_dim, out_features=1024, bias=True)
        self.dense2 = nn.Linear(in_features=1024, out_features=512, bias=True)
        self.output = nn.Linear(in_features=512, out_features=output_dim)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)
        self.add_residual = add_residual
        if add_residual:
            self.residual_weight = nn.Parameter(nn.init.uniform_(torch.empty(1), 0, 0.1))  # Small initialization

    def forward(self, input_data):
        x = input_data.get('sentence_embedding')
        original_x = x if self.add_residual else None
        x = self.dropout(self.activation(self.dense1(x)))
        x = self.dropout(self.activation(self.dense2(x)))
        x = self.output(x)
        if self.add_residual:
            x += self.residual_weight * original_x
            
        input_data['sentence_embedding'] = x
        
        return input_data
    
    def save(self, output_path):
        torch.save(self.state_dict(), os.path.join(output_path, 'adapter_module.pt'))
    
word_embedding_model = models.Transformer(
                            model_name_or_path="sentence-transformers/all-MiniLM-L12-v2", 
                            max_seq_length=128, 
                            do_lower_case=False
                            )

# Parametros default del modelo base a utilizar
pooling_model = models.Pooling(**{'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 
                                  'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False,
                                  'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 
                                  'pooling_mode_lasttoken': False, 'include_prompt': True})
normalize = models.Normalize()

# Unica sección que tiene pesos entrenables
for param in word_embedding_model.parameters():
    param.requires_grad = False
    

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Determinar la dimensión de entrada de acuerdo con la última capa del modelo base
adapter = AdapterModule(384, 384).to(device)


base_model = SentenceTransformer(modules=[word_embedding_model, pooling_model, normalize], device=device)

custom_domain_model = SentenceTransformer(modules=[word_embedding_model, pooling_model, 
                                                   adapter, normalize
                                                   ],device=device)

custom_domain_model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): AdapterModule(
    (dense1): Linear(in_features=384, out_features=1024, bias=True)
    (dense2): Linear(in_features=1024, out_features=512, bias=True)
    (output): Linear(in_features=512, out_features=384, bias=True)
    (activation): ReLU()
    (dropout): Dropout(p=0.3, inplace=False)
  )
  (3): Normalize()
)

In [3]:
# define evaluator
#from sentence_transformers.evaluation import InformationRetrievalEvaluator
# define over validation dataset

#evaluator = InformationRetrievalEvaluator(queries, corpus, relevant_docs)


In [4]:
data = pd.read_parquet("data/mlt_data_publications.parquet")

In [5]:
titles_map = {row.paperId:row.title for row in data.itertuples()}

related_pubs = [(titles_map[row.paperId], 'relatedWith', titles_map[row.source]) 
                for row in data.itertuples() if row.con_type!='base']

question_answer_pairs = [
    (f'Which paper is cited or referenced in the paper titled "{triple[0]}"?', triple[2])
    for triple in related_pubs
]

In [6]:
authors = []
for row in data.itertuples():
    for author in row.authors:
        names = [author.get('name')]
        
        try:
            aliases = author.get('aliases').tolist()
        except AttributeError:
            aliases = []
        
        names += aliases
        
        names = [' '.join(set(re.sub(r'\s+', ' ', name).split(' ')) )for name in names]

        for name in list(set(names)):
            authors.append({'author': name, 'publication': row.title, 
                            'paperId': row.paperId, 'authorId': author.get('authorId')})

In [7]:
authors = pd.DataFrame(authors)

In [8]:
question_answer_pairs_authors = [
    (f'Who wrote the paper titled "{row.publication}"?', row.author)
    for row in authors.iloc[:1000].itertuples()
]

In [9]:
train_examples = [
    InputExample(texts=[qa[0], qa[1]])
    for qa in question_answer_pairs_authors
]

In [10]:
loader = DataLoader(train_examples, shuffle=True, batch_size=64)
train_loss = losses.MultipleNegativesSymmetricRankingLoss(custom_domain_model,
                                                          )
epochs = 100
warmup_steps = int(len(loader) * epochs * 0.1)

custom_domain_model.fit(
    train_objectives=[(loader, train_loss)],
    epochs=epochs,
    warmup_steps=warmup_steps,
    output_path='domain_adaptation_model',
    show_progress_bar=True,
    save_best_model=True,
    #use_amp=True,
    #evaluator=evaluator, 
    #evaluation_steps=50,
)

Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

In [15]:
paper = "RADAR: an in-building RF-based user location and tracking system"
author = "Paramvir Bahl"
author2 = "V. N. Padmanabhan"

concept1 = "shark"
concept2 = "ocean"
concept3 = "strawberry"

In [16]:
custom_paper = custom_domain_model.encode(paper)
custom_author = custom_domain_model.encode(author)
custom_author2 = custom_domain_model.encode(author2)

custom_concept1 = custom_domain_model.encode(concept1)
custom_concept2 = custom_domain_model.encode(concept2)
custom_concept3 = custom_domain_model.encode(concept3)

# Imprimir los resultados y explicaciones
print(f"Producto punto entre el titulo del paper y el autor: {np.dot(custom_paper, custom_author)}")
print(f"Producto punto entre dos co-autores del mismo paper: {np.dot(custom_author, custom_author2)}")
print(f"Producto punto entre dos conceptos (shark y ocean): {np.dot(custom_concept1, custom_concept2)}")
print(f"Producto punto entre dos conceptos (shark y strawberry): {np.dot(custom_concept1, custom_concept3)}")
print(f"Producto punto entre el documento y un concepto (ocean): {np.dot(custom_paper, custom_concept2)}")

Producto punto entre el titulo del paper y el autor: 0.37796667218208313
Producto punto entre dos co-autores del mismo paper: 0.38921666145324707
Producto punto entre dos conceptos (shark y ocean): 0.45532286167144775
Producto punto entre dos conceptos (shark y strawberry): 0.20252814888954163
Producto punto entre el documento y un concepto (ocean): 0.16327065229415894


In [17]:
base_paper = base_model.encode(paper)
base_author = base_model.encode(author)
base_author2 = base_model.encode(author2)

base_concept1 = base_model.encode(concept1)
base_concept2 = base_model.encode(concept2)
base_concept3 = base_model.encode(concept3)  

# Imprimir los resultados y explicaciones
print(f"Producto punto entre el titulo del paper y el autor: {np.dot(base_paper, base_author)}")
print(f"Producto punto entre dos co-autores del mismo paper: {np.dot(base_author, base_author2)}")
print(f"Producto punto entre dos conceptos (shark y ocean): {np.dot(base_concept1, base_concept2)}")
print(f"Producto punto entre dos conceptos (shark y strawberry): {np.dot(base_concept1, base_concept3)}")
print(f"Producto punto entre el documento y un concepto (ocean): {np.dot(base_paper, base_concept2)}")

Producto punto entre el titulo del paper y el autor: 0.03417949751019478
Producto punto entre dos co-autores del mismo paper: 0.23409880697727203
Producto punto entre dos conceptos (shark y ocean): 0.5232283473014832
Producto punto entre dos conceptos (shark y strawberry): 0.23324593901634216
Producto punto entre el documento y un concepto (ocean): 0.03188034147024155
