In [2]:
!nvidia-smi

Fri Oct 28 22:55:03 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.73.08    Driver Version: 510.73.08    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  On   | 00000000:00:1E.0 Off |                    0 |
| N/A   27C    P0    23W / 300W |      0MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [1]:
import mlflow
import os

In [2]:
import os
if 'ELASTICSEARCH_HOST' in os.environ:
    host = os.environ.get(['ELASTICSEARCH_HOST'])
    print(host)
else: 
    print("ELASTICSEARCH_HOST host does not configured")

ELASTICSEARCH_HOST host does not configured


In [None]:
host = '3.93.144.226'

In [1]:
from haystack.document_stores import ElasticsearchDocumentStore
from haystack.nodes.retriever import EmbeddingRetriever
from haystack.document_stores import FAISSDocumentStore
from haystack.nodes.question_generator import QuestionGenerator
from haystack.nodes.label_generator import PseudoLabelGenerator
from haystack.pipelines import ExtractiveQAPipeline

  from .autonotebook import tqdm as notebook_tqdm
  if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):


In [2]:
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset

In [None]:
experiment_name = "domain-adaption"  # create a new experiment (do not replace)
s3_bucket = "s3://domain-qa-system/mlruns"  # replace this value
mllfow.create_experiment(experiment_name, s3_bucket)
mlflow.set_experiment(experiment_name)

In [None]:
mlflow.start_run()

In [None]:
index = 'bioasq'
similarity = "cosine"
embedding_dim = 768

mlflow.log_params({
    'es_index': index,
    'es_similarity': similarity,
    'es_embedding_dim': embedding_dim
})

In [6]:
document_store = ElasticsearchDocumentStore(
    host=host,
    username='',
    password='',
    index=index,
    similarity=similarity,
    embedding_dim=embedding_dim
)

In [None]:

mlflow.log_params({
    'es_index': index,
    'es_similarity': similarity,
    'es_embedding_dim': embedding_dim
})

retriever = EmbeddingRetriever(
    document_store=document_store, 
    embedding_model="sentence-transformers/msmarco-distilbert-base-tas-b", 
    model_format="sentence_transformers",
    max_seq_len=max_seq_length,
    progress_bar=False
)

In [3]:
# We load the TAS-B model, a state-of-the-art model trained on MS MARCO
max_seq_length = 200
model_name = "msmarco-distilbert-base-tas-b"

org_model = SentenceTransformer(model_name)
org_model.max_seq_length = max_seq_length

For retriever adaption, we don't need to specify the query and documents
- query: given by user
- documents: first ingested into ElasticSearch, then retrived by adapted retriver
- QA: ExtractiveQAPipeline

In [4]:
# We define a simple query and some documents how diseases are transmitted
# As TAS-B was trained on rather out-dated data (2018 and older), it has now idea about COVID-19
# So in the below example, it fails to recognize the relationship between COVID-19 and Corona

def show_examples(model):
    query = "How is COVID-19 transmitted"
    docs = [
        "Corona is transmitted via the air",
        "Ebola is transmitted via direct contact with blood",
        "HIV is transmitted via sex or sharing needles",
        "Polio is transmitted via contaminated water or food"
    ]

    query_emb = model.encode(query)
    docs_emb = model.encode(docs)
    scores = util.dot_score(query_emb, docs_emb)[0]
    doc_scores = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)

    print("Query:", query)
    for doc, score in doc_scores:
        #print(doc, score)
        print(f"{score:0.02f}\t{doc}")
        
        
print("Original Model")
show_examples(org_model)

Original Model
Query: How is COVID-19 transmitted
94.84	Ebola is transmitted via direct contact with blood
92.87	HIV is transmitted via sex or sharing needles
92.31	Corona is transmitted via the air
91.54	Polio is transmitted via contaminated water or food


In [5]:
dataset = load_dataset("nreimers/trec-covid", split="train")
num_documents = 100
corpus = []
for row in dataset:
    if len(row["title"]) > 20 and len(row["text"]) > 100:
        text = row["title"] + " " + row["text"]

        text_lower = text.lower()

        # The dataset also contains many papers on other diseases. To make the training in this demo
        # more efficient, we focus on papers that talk about COVID.
        if "covid" in text_lower or "corona" in text_lower or "sars-cov-2" in text_lower:
            corpus.append(text)

        if len(corpus) >= num_documents:
            break

print("Len Corpus:", len(corpus))



Len Corpus: 100


In [7]:
# document_store = FAISSDocumentStore(faiss_index_factory_str="Flat", similarity="cosine")
# document_store.write_documents([{'content': t} for t in corpus])        

Writing Documents: 10000it [00:00, 20979.69it/s]                                                                                                                                                                            


In [7]:
retriever = EmbeddingRetriever(
    document_store=document_store, 
    embedding_model="sentence-transformers/msmarco-distilbert-base-tas-b", 
    model_format="sentence_transformers",
    max_seq_len=max_seq_length,
    progress_bar=False
)

In [9]:
document_store.write_documents([{'content': t} for t in corpus])
document_store.update_embeddings(retriever)

Updating embeddings: 10000 Docs [00:02, 4769.77 Docs/s]                                                                                                                                                                     


In [10]:
question_producer = QuestionGenerator(
    model_name_or_path="doc2query/msmarco-t5-base-v1",
    max_length=64,
    split_length=128,
    batch_size=32,
    num_queries_per_doc=3,
)

Using sep_token, but it is not set yet.


In [11]:
psg = PseudoLabelGenerator(
    question_producer=question_producer,
    retriever=retriever,
    max_questions_per_document=10,
    batch_size=32,
    top_k=10
)

In [12]:
output, pipe_id = psg.run(documents=document_store.get_all_documents()) 

Generating questions: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 508/508 [00:21<00:00, 23.16it/s]
Mine negatives: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:07<00:00,  3.08it/s]
Score margin: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:03<00:00,  5.92it/s]


Generate both positive and negative docs

In [13]:
output["gpl_labels"][0]

{'question': 'is kawasaki disease polymorphic',
 'pos_doc': "Insertion/Deletion Polymorphism of Angiotensin Converting Enzyme Gene in Kawasaki Disease Polymorphism of angiotensin converting enzyme (ACE) gene is reported to be associated with ischemic heart disease, hypertrophic cardiomyopathy, and idiopathic dilated cardiomyopathy. In this study, we investigated the relationship between Kawasaki disease and insertion/deletion polymorphism of ACE gene. Fifty five Kawasaki disease patients and 43 healthy children were enrolled. ACE genotype was evaluated from each of the subjects' DNA fragments through polymerase chain reaction (PCR). Frequencies of ACE genotypes (DD, ID, II) were 12.7%, 60.0%, 27.3% in Kawasaki group, and 41.9%, 30.2%, 27.9% in control group respectively, indicating low rate of DD and high rate of ID genotype among Kawasaki patients (p<0.01). Comparing allelic (I, D) frequencies, I allele was more prevalent in Kawasaki group than in control group (57.3% vs. 43.0%, p<0.0

In [15]:
len(output["gpl_labels"])

713

In [16]:
retriever.train(output["gpl_labels"])

Epoch:   0%|                                                                                                                                                                                          | 0/1 [00:00<?, ?it/s]
Iteration:   0%|                                                                                                                                                                                     | 0/44 [00:00<?, ?it/s][A
Iteration:   2%|███▉                                                                                                                                                                         | 1/44 [00:00<00:09,  4.60it/s][A
Iteration:   5%|███████▊                                                                                                                                                                     | 2/44 [00:00<00:09,  4.63it/s][A
Iteration:   7%|███████████▊                                                                               

In [17]:
print("Original Model")
show_examples(org_model)

print("\n\nAdapted Model")
show_examples(retriever.embedding_encoder.embedding_model)

Original Model
Query: How is COVID-19 transmitted
94.84	Ebola is transmitted via direct contact with blood
92.87	HIV is transmitted via sex or sharing needles
92.31	Corona is transmitted via the air
91.54	Polio is transmitted via contaminated water or food


Adapted Model
Query: How is COVID-19 transmitted
89.12	Ebola is transmitted via direct contact with blood
85.26	HIV is transmitted via sex or sharing needles
82.79	Corona is transmitted via the air
81.51	Polio is transmitted via contaminated water or food


In [None]:
retriever.save()

In [None]:
mlflow.log_artifacts()

In [None]:
mlflow.end_run()