## Preparations

### Install Dependencies

First we need to install dependencies such as towhee, towhee.models and gradio.

In [3]:
!python -m pip install -q towhee towhee.models  gradio "pydantic>=1.4.0,<2" datasets ipywidgets 

### Prepare the Data

In [4]:
from datasets import load_dataset
from IPython.display import clear_output
dataset = load_dataset("ruslanmv/ai-medical-chatbot")
clear_output()

In [5]:
train_data = dataset["train"]
for i in range(1):
    print(train_data[i])

{'Description': 'Q. What does abutment of the nerve root mean?', 'Patient': 'Hi doctor,I am just wondering what is abutting and abutment of the nerve root means in a back issue. Please explain. What treatment is required for\xa0annular bulging and tear?', 'Doctor': 'Hi. I have gone through your query with diligence and would like you to know that I am here to help you. For further information consult a neurologist online -->'}


For this demo let us choose the first 1000 dialogues

In [6]:
import pandas as pd
df = pd.DataFrame(train_data[:1000])

In [7]:
#df = df[["Patient", "Doctor"]].rename(columns={"Patient": "question", "Doctor": "answer"})
df = df[["Description", "Doctor"]].rename(columns={"Description": "question", "Doctor": "answer"})

In [8]:
# Add the 'ID' column as the first column
df.insert(0, 'id', df.index)
# Reset the index and drop the previous index column
df = df.reset_index(drop=True)

In [9]:
import re
# Clean the 'question' and 'answer' columns
df['question'] = df['question'].apply(lambda x: re.sub(r'\s+', ' ', x.strip()))
df['answer'] = df['answer'].apply(lambda x: re.sub(r'\s+', ' ', x.strip()))
df['question'] = df['question'].str.replace('^Q.', '', regex=True)
# Assuming your DataFrame is named df
max_length = 500  # Due to our embedding model does not allow long strings
df['question'] = df['question'].str.slice(0, max_length)
#df['answer'] = df['answer'].str.slice(0, max_length)

In [10]:
df.to_csv("question_answer.csv", sep='\t', encoding='utf-8', index=False)

**question_answer.csv**: a file containing question and the answer.

Let's take a quick look:

In [11]:
import pandas as pd
# Load the Pandas DataFrame
df = pd.read_csv('question_answer.csv', sep='\t', encoding='utf-8')
df.head()

Unnamed: 0,id,question,answer
0,0,What does abutment of the nerve root mean?,Hi. I have gone through your query with dilige...
1,1,What should I do to reduce my weight gained d...,Hi. You have really done well with the hypothy...
2,2,I have started to get lots of acne on my face...,Hi there Acne has multifactorial etiology. Onl...
3,3,Why do I have uncomfortable feeling between t...,Hello. The popping and discomfort what you fel...
4,4,My symptoms after intercourse threatns me eve...,Hello. The HIV test uses a finger prick blood ...


To use the dataset to get answers, let's first define the dictionary:

- `id_answer`: a dictionary of id and corresponding answer

In [12]:
id_answer = df.set_index('id')['answer'].to_dict()

### Create Milvus Collection

Before getting started, please make sure that you have started a [Milvus service](https://milvus.io/docs/install_standalone-docker.md). This notebook uses [milvus 2.2.10](https://milvus.io/docs/v2.2.x/install_standalone-docker.md) and [pymilvus 2.2.11](https://milvus.io/docs/release_notes.md#2210).

In [15]:
#!python -m pip install -q pymilvus==2.2.11 python-dotenv


Next to define the function `create_milvus_collection` to create collection in Milvus that uses the [L2 distance metric](https://milvus.io/docs/metric.md#Euclidean-distance-L2) and an [IVF_FLAT index](https://milvus.io/docs/index.md#IVF_FLAT).

### Setup Remote Server
Here we should define the variable `REMOTE_SERVER` just created [here](https://github.com/ruslanmv/Watsonx-Assistant-with-Milvus-as-Vector-Database/blob/master/README.md)

In [16]:
LOCAL_SERVER='127.0.0.1'
from dotenv import load_dotenv
import os
load_dotenv()
host_milvus = os.environ.get("REMOTE_SERVER", "localhost")

In [17]:
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
connections.connect(host=host_milvus, port='19530')

In [18]:
collection_name="qa_medical"
def create_milvus_collection(collection_name, dim):
    if utility.has_collection(collection_name):
        utility.drop_collection(collection_name)

    fields = [
    FieldSchema(name='id', dtype=DataType.VARCHAR, descrition='ids', max_length=500, is_primary=True, auto_id=False),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim)
    ]
    schema = CollectionSchema(fields=fields, description='reverse image search')
    collection = Collection(name=collection_name, schema=schema)

    # create IVF_FLAT index for collection.
    index_params = {
        'metric_type':'L2',
        'index_type':"IVF_FLAT",
        'params':{"nlist":2048}
    }
    collection.create_index(field_name="embedding", index_params=index_params)
    return collection
dim_collection=768
collection = create_milvus_collection(collection_name,dim_collection )

### Load question embedding into Milvus

We first generate embedding from question text with [dpr](https://towhee.io/text-embedding/dpr) operator and insert the embedding into Milvus. Towhee provides a [method-chaining style API](https://towhee.readthedocs.io/en/main/index.html) so that users can assemble a data processing pipeline with operators.

In [19]:
from IPython.display import clear_output

In [20]:
%%time
from towhee import pipe, ops
import numpy as np
from towhee.datacollection import DataCollection
from IPython.display import clear_output
insert_pipe = (
    pipe.input('id', 'question', 'answer')
        .map('question', 'vec', ops.text_embedding.dpr(model_name='facebook/dpr-ctx_encoder-single-nq-base'))
        .map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))
        .map(('id', 'vec'), 'insert_status', ops.ann_insert.milvus_client(host=host_milvus, port='19530', collection_name=collection_name))
        .output()
)
import csv
with open('question_answer.csv', encoding='utf-8') as f:
    reader = csv.reader(f, delimiter='\t')
    next(reader)  # skip header
    for row in reader:
        insert_pipe(*row)
clear_output()

CPU times: user 14.3 s, sys: 1.83 s, total: 16.2 s
Wall time: 3min 20s


In [26]:
print('Total number of inserted data is {}.'.format(collection.num_entities))

Total number of inserted data is 0.


#### Explanation of Data Processing Pipeline

Here is detailed explanation for each line of the code:

`pipe.input('id', 'question', 'answer')`: Get three inputs, namely question's id, quesion's text and question's answer;

`map('question', 'vec', ops.text_embedding.dpr(model_name='facebook/dpr-ctx_encoder-single-nq-base'))`: Use the `acebook/dpr-ctx_encoder-single-nq-base` model to generate the question embedding vector with the [dpr operator](https://towhee.io/text-embedding/dpr) in towhee hub;

`map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))`: normalize the embedding vector;

`map(('id', 'vec'), 'insert_status', ops.ann_insert.milvus_client(host='127.0.0.1', port='19530', collection_name='question_answer'))`: insert question embedding vector into Milvus;

### Ask Question with Milvus and Towhee

Now that embedding for question dataset have been inserted into Milvus, we can ask question with Milvus and Towhee. Again, we use Towhee to load the input question, compute a embedding, and use it as a query in Milvus. Because Milvus only outputs IDs and distance values, we provide the `id_answers` dictionary to get the answers based on IDs and display.

In [27]:
%%time
collection.load()
ans_pipe = (
    pipe.input('question')
        .map('question', 'vec', ops.text_embedding.dpr(model_name="facebook/dpr-ctx_encoder-single-nq-base"))
        .map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))
        .map('vec', 'res', ops.ann_search.milvus_client(host=host_milvus, port='19530', collection_name=collection_name, limit=1))
        .map('res', 'answer', lambda x: [id_answer[int(i[0])] for i in x])
        .output('question', 'answer')
)
ans = ans_pipe('What does abutment of the nerve root mean?')
ans = DataCollection(ans)
clear_output()

CPU times: user 15.6 ms, sys: 188 ms, total: 203 ms
Wall time: 32.9 s


In [28]:
ans.show()

question,answer
What does abutment of the nerve root mean?,Hi. I have gone through your query with diligence and would like you to know that I am here to help you. For further information...


Then we can get the answer about 'What does abutment of the nerve root mean??'.

In [29]:
ans[0]['answer']

['Hi. I have gone through your query with diligence and would like you to know that I am here to help you. For further information consult a neurologist online -->']

## Release a Showcase

We've done an excellent job on the core functionality of our question answering engine. Now it's time to build a showcase with interface. [Gradio](https://gradio.app/) is a great tool for building demos. With Gradio, we simply need to wrap the data processing pipeline via a `chat` function:

In [30]:
import towhee
def chat(message, history):
    history = history or []
    ans_pipe = (
        pipe.input('question')
            .map('question', 'vec', ops.text_embedding.dpr(model_name="facebook/dpr-ctx_encoder-single-nq-base"))
            .map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))
            .map('vec', 'res', ops.ann_search.milvus_client(host=host_milvus, port='19530', collection_name=collection_name, limit=1))
            .map('res', 'answer', lambda x: [id_answer[int(i[0])] for i in x])
            .output('question', 'answer')
    )
    response = ans_pipe(message).get()[1][0]
    history.append((message, response))
    return history, history

In [31]:
chat('What does abutment of the nerve root mean?',[])

([('What does abutment of the nerve root mean?',
   'Hi. I have gone through your query with diligence and would like you to know that I am here to help you. For further information consult a neurologist online -->')],
 [('What does abutment of the nerve root mean?',
   'Hi. I have gone through your query with diligence and would like you to know that I am here to help you. For further information consult a neurologist online -->')])

In [None]:
import gradio
collection.load()
chatbot = gradio.Chatbot()
interface = gradio.Interface(
    chat,
    ["text", "state"],
    [chatbot, "state"],
    allow_flagging="never",
)
interface.launch(inline=True, share=False)
clear_output()