# SETUP

Install Libraries

In [None]:
!pip install frontend gdown pandas tqdm
!pip install tqdm
# !pip install torchvision
# !pip install torchaudio

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


Download Scripts for processing ADs and NTSB Reports

In [None]:
!pip install PyPDF2

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


Reading PDFs

In [None]:
!mkdir new_reports
# new_reports directory should contain all the pdf files of ADs/reports directly

mkdir: cannot create directory ‘new_reports’: File exists


In [None]:
from PyPDF2 import PdfReader

import os

passages=[]
for x in os.listdir("new_reports"):
    if x.endswith(".pdf"):
        reader = PdfReader("new_reports/"+x)
        for page_no in range(len(reader.pages)):
            passages += [passage.replace("\n"," ")+". "+ 'doc_id: '+x.replace(".pdf","") for passage in reader.pages[page_no].extract_text().split(".\n")]

In [None]:
import os
import pandas as pd
from tqdm import tqdm

In [None]:
# LIMIT_DOCS=1000
doc_id=[]
doc_content=[]
pid=0

for para in tqdm(passages):
  doc_id.append(pid)
  pid+=1
  doc_content.append(para)


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [00:00<00:00, 425659.56it/s]


Download and Index the passages to ColBERT

In [None]:
!pip install sentence_transformers

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


In [None]:
import pandas as pd
from sentence_transformers import SentenceTransformer, util

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
bi_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')


In [None]:
doc_all_embs = bi_model.encode(passages)


In [None]:
from numpy import save
# save to npy file
save('passage_embeddings_multi-qa-mpnet-base-dot-v1.npy', doc_all_embs)


# INFERENCE

Retrieve indexed documents

In [None]:
from numpy import load
doc_all_embs = load('passage_embeddings_multi-qa-mpnet-base-dot-v1.npy')

Load BERT (Fine-tuned on MRC) from Huggingface

In [None]:
from transformers import BertForQuestionAnswering, AutoTokenizer, pipeline
modelname = 'deepset/bert-base-cased-squad2'
qa_model = BertForQuestionAnswering.from_pretrained(modelname)
qa_tokenizer = AutoTokenizer.from_pretrained(modelname)
print("Getting fine-tuned BERT on MRC...")
qa_pipeline = pipeline('question-answering', model=qa_model, tokenizer=qa_tokenizer, device=0)

Getting fine-tuned BERT on MRC...


Load BLOOM 1b7 from Huggingface (used when API doesn't work)

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
torch.set_default_tensor_type(torch.cuda.FloatTensor)
def load_bloom(model_name="bigscience/bloom-3b"):
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return tokenizer,model
bloom_tokenizer,bloom_model=load_bloom("bigscience/bloom-7b1")

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:14<00:00,  7.01s/it]


### Set Question and BLOOM KEY( Key you got from huggingface)

In [None]:
queries=["List the accident numbers of accidents happened in atlanta"]
query=queries[0]
top_k_bi_encoder=15
top_k_cross_encoder=3

#### Global Query

In [None]:
def get_top_passages(queries, passages,bi_model, cross_model, top_k_bi_encoder=20,top_k_cross_encoder=6):

    # bi encoder
    query_relevent_embs = bi_model.encode(queries)
    scores = util.dot_score(query_relevent_embs, doc_all_embs).cpu()
    top_k_indices = torch.topk(scores, k=top_k_bi_encoder, dim=-1).indices
    top_k_docs = [passages[i] for i in top_k_indices[0].tolist()]

    # cross encoder
    model_inputs = [[query, passage] for passage in top_k_docs]
    scores = cross_model.predict(model_inputs)
    top_index = scores.argsort()[-top_k_cross_encoder:][::-1]
    top_k_passages=[]
    for index in top_index:
        top_k_passages.append(model_inputs[index][1])
    return top_k_passages

In [None]:
from sentence_transformers import CrossEncoder

In [None]:
bi_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
cross_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
#Concatenate the query and all passages and predict the scores for the pairs [query, passage]
top_k_passages=get_top_passages(queries, passages,bi_model, cross_model)

NameError: ignored

Query BLOOM

In [None]:
# print(top_k_passages[0])

context_first= 'Probable Cause and Findings,The National Transportation Safety Board determines the probable cause(s) of this accident to be,doc_id:75279,Location:La Veta, Colorado,Date & Time:January 17, 2010, 15:06 Local ,Aircraft:Mooney M20R'
context_second=' '.join(top_k_passages)
prompt_context=context_first+" "+context_second
question_first='Q: Who determined the cause of the accident?\nA: The National Transportation Safety Board'
question_second='Q: '+query
prompt_question=question_first+"\n"+question_second
prompt=prompt_context+"\n"+prompt_question+"\n A:"

# try:
set_seed(424242)
input_ids = bloom_tokenizer(prompt, return_tensors="pt").to(0)
sample = bloom_model.generate(**input_ids, max_new_tokens = 100, top_k=1, temperature=0.7)
bloom_ans=bloom_tokenizer.decode(sample[0], truncate_before_pattern=[r"\n\n^#", "^'''", "\n\n\n"])
try:
    bloom_ans=bloom_ans.split(query)[1].split("Q:")[0][3:]
except:
    bloom_ans=bloom_ans.split(query)[1][3:]
print(query,bloom_ans)


List the accident numbers of accidents happened in atlanta : ERA22LA121, ERA22LA175, ERA21FA195, ERA22FA010, ERA22FA010, ERA22FA010, ERA22FA010, ERA22FA010, ERA22FA010, ERA22FA010, ERA22FA010, ERA22FA010, ERA22FA010, ERA22FA010, ERA22FA010, ERA22FA010, ERA22FA


Query BERT

In [None]:
index=1
for rel_passage in top_k_passages:
        try:
            ans_span = qa_pipeline({
                'question': query,
                'context': rel_passage
            })
            print("Answer from passage: ",index)
            print("Answer: ",ans_span['answer'])
            print("Passage: ",rel_passage.split("doc_id")[0])
            print("Doc Name: ",rel_passage.split("doc_id:")[1].split(",")[0])
            index+=1
        except:
            pass

Answer from passage:  1
Answer:  104636
Passage:  Page 2 of 2 ERA22LA121 This is preliminary information, subject to change, and may contain errors. Any errors in this report will be corrected when  the final report has been completed.Meteorological Information and Flight Plan Conditions at Accident Site: VMC Condition of Light: Day Observation Facility, Elevation: CGC,9 ft msl Observation Time: 12:35 Local Distance from Accident Site: 0 Nautical Miles Temperature/Dew Point: 20°C /4°C Lowest Cloud Condition: Clear Wind Speed/Gusts, Direction: 7 knots / 14 knots, 130° Lowest Ceiling: None Visibility: 10 miles Altimeter Setting: 30.23 inches Hg Type of Flight Plan Filed: None Departure Point: Atlanta, GA (ATL) Destination: Crystal River, FL  Wreckage and Impact Information  Crew Injuries: 1 None Aircraft Damage: Substantial Passenger Injuries: N/A Aircraft Fire: None Ground Injuries: N/A Aircraft Explosion: None Total Injuries: 1 None Latitude,  Longitude:28.867611,-82.574111 (est) Admin

#### Doc Specific Query

In [None]:
doc_id="104874" # Enter trhe DOC ID

Query BLOOM

In [None]:
# question = "cause of accident" # You can update your question here
if query[-1]=='.' or query[-1]=='?':
    question=query+" doc_id: "+doc_id
else:
    question=" doc_id: "+doc_id+", "+query

doc_relevant_passages=[passage for passage in top_k_passages if doc_id in passage]
context_first= 'Probable Cause and Findings,The National Transportation Safety Board determines the probable cause(s) of this accident to be:A fuel leak and subsequent fire due to a mechanical defect.,Accident Number:,doc_id:,Location:, Colorado,Date & Time: ,Aircraft:'
context_second=' '.join(doc_relevant_passages)
prompt_context=context_first+" "+context_second
question_first='Q: What is probable causes of fire?\nA: A fuel leak and subsequent fire due to a mechanical defect\nQ: Who determined the cause of the accident?\nA: The National Transportation Safety Board'
question_second='Q: '+question
prompt_question=question_first+"\n"+question_second
prompt=prompt_context+"\n"+prompt_question+"\n A:"


set_seed(424242)
input_ids = bloom_tokenizer(prompt, return_tensors="pt").to(0)
sample = bloom_model.generate(**input_ids, max_new_tokens = 100, top_k=1, temperature=0.7)
bloom_ans=bloom_tokenizer.decode(sample[0], truncate_before_pattern=[r"\n\n^#", "^'''", "\n\n\n"])
try:
    bloom_ans=bloom_ans.split(query)[1].split("Q:")[0][3:]
except:
    bloom_ans=bloom_ans.split(query)[1][3:]
print(query,bloom_ans)

List the accident numbers of accidents happened in atlanta : ERA22LA175, ERA22LA176, ERA22LA177, ERA22LA178, ERA22LA179, ERA22LA180, ERA22LA181, ERA22LA182, ERA22LA183, ERA22LA184, ERA22LA185, ERA22LA186, ERA22LA187, ERA22LA188, ERA22LA189, ERA22LA190, ERA22LA


In [None]:
index=1
for rel_passage in doc_relevant_passages:
    try:
        ans_span = qa_pipeline({
            'question': question,
            'context': rel_passage
        })
        print("Answer from passage: ",index)
        print("Answer: ",ans_span['answer'])
        print("Passage: ",rel_passage.split("doc_id")[0])
        print("Doc Name: ",rel_passage.split("doc_id:")[1].split(",")[0])
        index+=1
    except:
        pass

Answer from passage:  1
Answer:  Morristown, NJ
Passage:  The cockpit voice recorder (CVR) was retained and forwarded to the NTSB Recorders Laboratory in  Washington, DC. The wreckage was recovered for further examination.  Aircraft and Owner/Operator Information  Aircraft Make: LEARJET INC Registration: N877W Model/Series: 45 Aircraft Category: Airplane Amateur Built: Operator: Operating Certificate(s)  Held:None Operator Designator Code: Meteorological Information and Flight Plan Conditions at Accident Site: VMC Condition of Light: Day Observation Facility, Elevation: MMU,187 ft msl Observation Time: 11:25 Local Distance from Accident Site: 0 Nautical Miles Temperature/Dew Point: 7°C /-5°C Lowest Cloud Condition: Few / 25000 ft AGL Wind Speed/Gusts, Direction: 6 knots / 14 knots, 320° Lowest Ceiling: None Visibility: 10 miles Altimeter Setting: 30.11 inches Hg Type of Flight Plan Filed: IFR Departure Point: Atlanta, GA (FTY) Destination: Morristown, NJ . 
Doc Name:   104874
Answer fr