Import libraries

In [1]:
from sentence_transformers import (SentenceTransformer, models, 
                                   SentenceTransformerTrainingArguments,
                                   SentenceTransformerTrainer)
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.util import dot_score

from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from datasets import load_from_disk
import torch

import pandas as pd
import numpy as np

from utils import create_data_for_evaluator
from custom_adapter_module.AdapterModule import AdapterModule

  from tqdm.autonotebook import tqdm, trange


A custom SentenceTransformer model is configured by defining and configuring several components, including a word embedding model, a grouping model, a normalization layer, and an adapter module. The word embedding model is initialized with a pre-trained transformer model from Sentence Transformers, with specific settings for a maximum sequence length and case-sensitivity. The pooling model is configured to use token averaging for pooling, with other pooling modes disabled. The normalization layer is defined to standardize the embeddings.

In [2]:
# Carga del modelo de embeddings de palabras
word_embedding_model = models.Transformer(
    model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",  # Modelo base de Sentence Transformers
    max_seq_length=512,  # Longitud máxima de la secuencia
    do_lower_case=False,  # No convertir a minúsculas
)

# Definición de los parámetros del modelo de pooling
pooling_model = models.Pooling(
    word_embedding_dimension=384,  # Dimensión de los embeddings de palabras
    pooling_mode_cls_token=False,  # No usar el token CLS para el pooling
    pooling_mode_mean_tokens=True,  # Usar el promedio de los tokens para el pooling
    pooling_mode_max_tokens=False,  # No usar el máximo de los tokens para el pooling
    pooling_mode_mean_sqrt_len_tokens=False,  # No usar el promedio de la raíz cuadrada de la longitud para el pooling
    pooling_mode_weightedmean_tokens=False,  # No usar el promedio ponderado de los tokens para el pooling
    pooling_mode_lasttoken=False,  # No usar el último token para el pooling
    include_prompt=True  # Incluir el prompt en el pooling
)

# Definición del modelo de normalización
normalize = models.Normalize()

# Congelar los pesos del modelo de embeddings de palabras para que no se entrenen
for param in word_embedding_model.parameters():
    param.requires_grad = False

# Configuración del dispositivo para usar GPU si está disponible, de lo contrario usar CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Definir el módulo adaptador con las dimensiones de entrada y salida
adapter = AdapterModule(384, 384).to(device)

# Definir el modelo base de Sentence Transformer con las capas de embedding, pooling y normalización
base_model = SentenceTransformer(modules=[word_embedding_model, pooling_model, normalize], 
                                 device=device,
                                 model_kwargs={"torch_dtype": "float16"}
                                 )

# Definir el modelo personalizado de Sentence Transformer que incluye el adaptador
custom_domain_model = SentenceTransformer(
    modules=[word_embedding_model, pooling_model, adapter, normalize], device=device,
)

custom_domain_model  # Mostrar la arquitectura del modelo personalizado

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): AdapterModule(
    (dense1): Linear(in_features=384, out_features=768, bias=True)
    (dense2): Linear(in_features=768, out_features=512, bias=True)
    (output): Linear(in_features=512, out_features=384, bias=True)
    (activation): ReLU()
    (dropout): Dropout(p=0.3, inplace=False)
  )
  (3): Normalize()
)

To prevent the weights of the word embedding model from being updated during training, all its parameters are frozen. Device configuration ensures that the model uses a GPU if available; otherwise it will default to CPU usage. An AdapterModule instance with defined input and output dimensions is created and moved to the specified device.

Two SentenceTransformer models are instantiated: the base model, which includes the word embedding, pooling, and normalization layers; and the custom domain model, which additionally incorporates the adapter module. This configuration allows flexible adaptation of embeddings tailored to specific tasks or domains. The final architecture of the custom domain model is shown for your review.

Load the training and evaluation datasets for question answering tasks from the respective pickled files stored in the 'data' directory.

In [3]:
qa_train = load_from_disk('./data/train_dataset')
qa_eval = load_from_disk('./data/eval_dataset')
qa_test = load_from_disk('./data/test_dataset')

Create training examples using the question-answer pairs from the dataset `qa`, where each example consists of a question (`qa[0]`) and its corresponding answer (`qa[1]`).

In [4]:
print("Training lenght: ", len(qa_train))
print("Validation lenght: ", len(qa_eval))
print("Test lenght: ", len(qa_test))

Training lenght:  29547
Validation lenght:  3677
Test lenght:  3666


Prepares and configures the training and evaluation process for a custom SentenceTransformer model. Initially, a training data set is created by generating a list of `InputExample` instances, where each instance consists of a pair of texts (question and answer). This data set is then loaded into a "DataLoader", which shuffles the data at each epoch and sets the batch size to 256.

The training loss is defined using "MultipleNegativesSymmetricRankingLoss", which is suitable for information retrieval tasks involving positive text pairs. An evaluator is configured using "InformationRetrievalEvaluator", which evaluates the performance of the model on a set of queries and corpora, with the main scoring function specified as "dot_score".

In [5]:
eval_dataset_evaluator = create_data_for_evaluator(qa_eval)
test_dataset_evaluator = create_data_for_evaluator(qa_test)

In [6]:
dev_evaluator = InformationRetrievalEvaluator(
        queries=eval_dataset_evaluator['queries'],
        corpus=eval_dataset_evaluator['corpus'],
        relevant_docs=eval_dataset_evaluator['relevant_docs'],
        name='qa_eval', 
        map_at_k=[10],
        accuracy_at_k = [10],
        precision_recall_at_k = [10],
        score_functions={'dot_score':dot_score}
    )

test_evaluator = InformationRetrievalEvaluator(
        queries=test_dataset_evaluator['queries'],
        corpus=test_dataset_evaluator['corpus'],
        relevant_docs=test_dataset_evaluator['relevant_docs'],
        name='qa_test', 
        map_at_k=[10],
        accuracy_at_k = [10],
        precision_recall_at_k = [10],
        score_functions={'dot_score':dot_score}
    )

In [7]:
## Base model evaluation

results = dev_evaluator(base_model)

results

{'qa_eval_dot_score_accuracy@10': 0.7819225251076041,
 'qa_eval_dot_score_precision@10': 0.08371592539454806,
 'qa_eval_dot_score_recall@10': 0.705266618842659,
 'qa_eval_dot_score_ndcg@10': 0.5763616901461027,
 'qa_eval_dot_score_mrr@10': 0.5725218339368267,
 'qa_eval_dot_score_map@10': 0.5162513355613247}

In [None]:
loss = MultipleNegativesRankingLoss(custom_domain_model,
                                    similarity_fct=dot_score)

args = SentenceTransformerTrainingArguments(
    output_dir="./results/domain_adaptation_model",  
    num_train_epochs=5,  # Entrenar por al menos 50 épocas
    per_device_train_batch_size=128,  # Ajustar según la memoria disponible
    gradient_accumulation_steps=2,  # 32 * 4 = 128
    per_device_eval_batch_size=128,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    bf16=True,
    gradient_checkpointing=True,  # Reducir uso de memoria
    optim="adamw_torch_fused",  # Optimizer más eficiente
    lr_scheduler_type="cosine",  # Planificador de tasa de aprendizaje
    batch_sampler=BatchSamplers.NO_DUPLICATES,
    eval_strategy="epoch",  # Evaluar al final de cada época
    save_strategy="epoch",  # Guardar al final de cada época
    save_total_limit=1,  # Mantener los últimos 3 checkpoints
    logging_steps=100,  # Ajustar según la frecuencia deseada
    metric_for_best_model="qa_eval_dot_score_map@10",
    greater_is_better=True,  # Si un mayor MAP es mejor
    load_best_model_at_end=True,
)


# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=custom_domain_model,
    args=args,
    train_dataset=qa_train.select_columns(["anchor", "positive", "negative"]),
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()
trainer.save_model()

  0%|          | 0/575 [00:00<?, ?it/s]



The number of training epochs is set to 500 and the warm-up steps are calculated as 10% of the total training steps, determined by the length of the DataLoader and the number of epochs. This setup ensures that the model is properly prepared and evaluated during training.

## Evaluating the base model & the custom model

In [9]:
custom_domain_model = SentenceTransformer('./results/domain_adaptation_model',
                                          device=device,
                                          model_kwargs={"torch_dtype": "float16"}
                                          )

  config = torch.load(os.path.join(input_path, 'config.pt'))
  model.load_state_dict(torch.load(os.path.join(input_path, 'adapter_module.pt')))


Evaluate the Mean Average Precision (MAP) at k=10 for both the base and custom domain models using the evaluator, and print the results for comparison.

In [10]:
eva_base_model = test_evaluator(base_model, output_path='results/base_model/')
print("Base model: ", eva_base_model)

eva_custom_model = test_evaluator(custom_domain_model, output_path='results/custom_model/')
print("Custom model: ", eva_custom_model)

Base model:  {'qa_test_dot_score_accuracy@10': 0.7692031586503948, 'qa_test_dot_score_precision@10': 0.08155061019382627, 'qa_test_dot_score_recall@10': 0.6938023450586264, 'qa_test_dot_score_ndcg@10': 0.5686011104188272, 'qa_test_dot_score_mrr@10': 0.5632566175548943, 'qa_test_dot_score_map@10': 0.5110383265534019}
Custom model:  {'qa_test_dot_score_accuracy@10': 0.7114142139267767, 'qa_test_dot_score_precision@10': 0.0754845656855707, 'qa_test_dot_score_recall@10': 0.6434912658530749, 'qa_test_dot_score_ndcg@10': 0.5122504366271634, 'qa_test_dot_score_mrr@10': 0.5012685308629311, 'qa_test_dot_score_map@10': 0.45437743327370034}


Load evaluation results from CSV files for both the base and custom domain models, add a column to indicate the model type, and concatenate the results into a single DataFrame for comparison.

In [14]:
base_model_eval = pd.read_csv('results/base_model/Information-Retrieval_evaluation_qa_eval_results.csv')
base_model_eval['tipo'] = 'base_model'
custom_model_eval = pd.read_csv('results/custom_model/Information-Retrieval_evaluation_qa_eval_results.csv')
custom_model_eval['tipo'] = 'custom_model'

pd.concat([base_model_eval, custom_model_eval]).to_csv('results/eval_comparation.csv', index=False)

pd.concat([base_model_eval, custom_model_eval])


Unnamed: 0,epoch,steps,cosine-Accuracy@1,cosine-Accuracy@3,cosine-Accuracy@5,cosine-Accuracy@10,cosine-Precision@1,cosine-Recall@1,cosine-Precision@3,cosine-Recall@3,...,dot-Precision@3,dot-Recall@3,dot-Precision@5,dot-Recall@5,dot-Precision@10,dot-Recall@10,dot-MRR@10,dot-NDCG@10,dot-MAP@100,tipo
0,-1,-1,0.33988,0.476541,0.52946,0.596017,0.33988,0.009441,0.158847,0.013237,...,0.158847,0.013237,0.105892,0.014707,0.059602,0.016556,0.421933,0.102059,0.01193,base_model
1,-1,-1,0.33988,0.476541,0.52946,0.596017,0.33988,0.009441,0.158847,0.013237,...,0.158847,0.013237,0.105892,0.014707,0.059602,0.016556,0.421933,0.102059,0.01193,base_model
2,-1,-1,0.33988,0.476541,0.52946,0.596017,0.33988,0.009441,0.158847,0.013237,...,0.158847,0.013237,0.105892,0.014707,0.059602,0.016556,0.421933,0.102059,0.01193,base_model
3,-1,-1,0.33988,0.476541,0.52946,0.596017,0.33988,0.009441,0.158847,0.013237,...,0.158847,0.013237,0.105892,0.014707,0.059602,0.016556,0.421933,0.102059,0.01193,base_model
4,-1,-1,0.340698,0.476541,0.52946,0.596017,0.340698,0.009464,0.158847,0.013237,...,0.158847,0.013237,0.105892,0.014707,0.059602,0.016556,0.422303,0.102119,0.04223,base_model
5,-1,-1,0.596017,0.059602,0.016556,0.422303,0.102119,0.04223,0.596017,0.059602,...,,,,,,,,,,base_model
6,-1,-1,0.607474,0.060747,0.016874,0.428965,0.103823,0.042897,0.607474,0.060747,...,,,,,,,,,,base_model
7,-1,-1,0.597109,0.059711,0.016586,0.428617,0.10323,0.042862,0.597109,0.059711,...,,,,,,,,,,base_model
8,-1,-1,0.596017,0.059602,0.016556,0.422303,0.102119,0.04223,0.596017,0.059602,...,,,,,,,,,,base_model
9,-1,-1,0.691673,0.069167,0.019213,0.503166,0.120688,0.050317,0.691673,0.069167,...,,,,,,,,,,base_model


### Comparing QA

In [15]:
# Asumiendo que los embeddings están normalizados
question1 = "How does the FSRA define and evaluate 'principal risks and uncertainties' for a Petroleum Reporting Entity, particularly for the remaining six months of the financial year?"
answer1 =  "A Reporting Entity must: (a) prepare such report: (i) for the first six months of each financial year or period, and if there is a change to the accounting reference date, prepare such report in respect of the period up to the old accounting reference date; and (ii) in accordance with the applicable IFRS standards or other standards acceptable to the Regulator; (b) ensure the financial statements have either been audited or reviewed by auditors, and the audit or review by the auditor is included within the report; and (c) ensure that the report includes: (i) except in the case of a Mining Exploration Reporting Entity or a Petroleum Exploration Reporting Entity, an indication of important events that have occurred during the first six months of the financial year, and their impact on the financial statements; (ii) except in the case of a Mining Exploration Reporting Entity or a Petroleum Exploration Reporting Entity, a description of the principal risks and uncertainties for the remaining six months of the financial year; and (iii) a condensed set of financial statements, an interim management report and associated responsibility statements."

question2 = 'Under Rules 7.3.2 and 7.3.3, what are the two specific conditions related to the maturity of a financial instrument that would trigger a disclosure requirement?'
answer2 =  'Events that trigger a disclosure. For the purposes of Rules 7.3.2 and 7.3.3, a Person is taken to hold Financial Instruments in or relating to a Reporting Entity, if the Person holds a Financial Instrument that on its maturity will confer on him: (1) an unconditional right to acquire the Financial Instrument; or (2) the discretion as to his right to acquire the Financial Instrument.',


emb_q1 = custom_domain_model.encode(question1)  # el embedding está normalizado
emb_q2 = custom_domain_model.encode(question2)  # el embedding está normalizado
ans_1 = custom_domain_model.encode(answer1)
ans_2 = custom_domain_model.encode(answer2)


print("q1", ans_1 @ emb_q1,"(answer1) --", ans_2 @ emb_q1, "(answer2)")
print("q2", ans_1 @ emb_q2, "(answer1) --", ans_2 @ emb_q2, "(answer2)")


print("------ Base Model ------")

emb_q1 = base_model.encode(question1)  # el embedding está normalizado
emb_q2 = base_model.encode(question2)  # el embedding está normalizado
ans_1 = base_model.encode(answer1)
ans_2 = base_model.encode(answer2)


print("q1", ans_1 @ emb_q1,"(answer1) --", ans_2 @ emb_q1, "(answer2)")
print("q2", ans_1 @ emb_q2, "(answer1) --", ans_2 @ emb_q2, "(answer2)")


q1 0.8037261 (answer1) -- [0.761284] (answer2)
q2 0.7965807 (answer1) -- [0.8650476] (answer2)
------ Base Model ------
q1 0.6166147 (answer1) -- [0.4293761] (answer2)
q2 0.55304617 (answer1) -- [0.7028685] (answer2)


### The custom model mantain original capabilities

Encodes sample text inputs, including the title of an article, author names, and various concepts, using both the custom domain model and the base model. Also, the dot product between the coded vectors is calculated to measure the similarity between different pairs of concepts and between the paper and a concept. Print the similarity scores for each comparison to see the differences. 

In [16]:
paper = "Composable Lightweight Processors"

concept1 = "shark"
concept2 = "ocean"
concept3 = "strawberry"

In [17]:
custom_paper = custom_domain_model.encode(paper)

custom_concept1 = custom_domain_model.encode(concept1)
custom_concept2 = custom_domain_model.encode(concept2)
custom_concept3 = custom_domain_model.encode(concept3)

# Imprimir los resultados y explicaciones
print(f"Producto punto entre dos conceptos (shark y ocean): {np.dot(custom_concept1, custom_concept2)}")
print(f"Producto punto entre dos conceptos (shark y strawberry): {np.dot(custom_concept1, custom_concept3)}")
print(f"Producto punto entre el documento y un concepto (ocean): {np.dot(custom_paper, custom_concept2)}")

Producto punto entre dos conceptos (shark y ocean): 0.7008720636367798
Producto punto entre dos conceptos (shark y strawberry): 0.6190972328186035
Producto punto entre el documento y un concepto (ocean): 0.4101462662220001


In [18]:
base_paper = base_model.encode(paper)

base_concept1 = base_model.encode(concept1)
base_concept2 = base_model.encode(concept2)
base_concept3 = base_model.encode(concept3)  

# Imprimir los resultados y explicaciones
print(f"Producto punto entre dos conceptos (shark y ocean): {np.dot(base_concept1, base_concept2)}")
print(f"Producto punto entre dos conceptos (shark y strawberry): {np.dot(base_concept1, base_concept3)}")
print(f"Producto punto entre el documento y un concepto (ocean): {np.dot(base_paper, base_concept2)}")

Producto punto entre dos conceptos (shark y ocean): 0.5527569055557251
Producto punto entre dos conceptos (shark y strawberry): 0.27426061034202576
Producto punto entre el documento y un concepto (ocean): -0.05138666182756424
