# PubMed Article Analysis and Text Generation Model Training

In [None]:
!pip uninstall -y pyarrow datasets
!pip install pyarrow==12.0.0
!pip install datasets==2.10.1
!pip install pandas scikit-learn transformers torch accelerate biopython nltk rouge_score bert_score

Found existing installation: pyarrow 12.0.0
Uninstalling pyarrow-12.0.0:
  Successfully uninstalled pyarrow-12.0.0
Found existing installation: datasets 2.10.1
Uninstalling datasets-2.10.1:
  Successfully uninstalled datasets-2.10.1
Collecting pyarrow==12.0.0
  Using cached pyarrow-12.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB)
Using cached pyarrow-12.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (38.9 MB)
Installing collected packages: pyarrow
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 12.0.0 which is incompatible.[0m[31m
[0mSuccessfully installed pyarrow-12.0.0
Collecting datasets==2.10.1
  Using cached datasets-2.10.1-py3-none-any.whl.metadata (20 kB)
Using cached datasets-2.10.1-py3-none-any.whl (469 kB)
Installing col

In [None]:
from Bio import Entrez
import pandas as pd
import re
import nltk
import json
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np
import torch
from torch.utils.data import Dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from datasets import load_metric

In [None]:
Entrez.email = "m.zandieh7878@gmail.com"

def fetch_pubmed_articles(query, max_results=100):
    handle = Entrez.esearch(db="pubmed", term=query, retmax=max_results)
    record = Entrez.read(handle)
    handle.close()
    id_list = record["IdList"]

    articles = []
    for article_id in id_list:
        handle = Entrez.efetch(db="pubmed", id=article_id, rettype="abstract", retmode="text")
        article = handle.read()
        articles.append(article)
        handle.close()

    return articles

query = "disease"
articles = fetch_pubmed_articles(query)



In [None]:
def process_text(text):
    lines = text.split('\n')
    processed_lines = lines[2:]
    inside_author_info = False
    inside_text = False
    final_lines = []
    for line in processed_lines:
        if "author information:" in line:
            final_lines.pop()
            for item in final_lines[::-1]:
                if item.strip() != "":
                    final_lines.pop()
                else:
                    break
            inside_author_info = True
        if inside_author_info and line.strip() == "":
            inside_author_info = False
            inside_text = True
            continue
        if inside_text and line.strip() == "":
            break
        if not inside_author_info and not line.startswith(("doi", "pmid", "copyright")):
            final_lines.append(line)
    return '\n'.join(final_lines)

articles2 = [process_text(article.lower()).strip() for article in articles]


In [None]:

!wget https://github.com/gersteinlab/MedAgents/raw/main/datasets/MedQA/test.jsonl -O test.jsonl


--2024-07-27 08:10:46--  https://github.com/gersteinlab/MedAgents/raw/main/datasets/MedQA/test.jsonl
Resolving github.com (github.com)... 140.82.114.3
Connecting to github.com (github.com)|140.82.114.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/gersteinlab/MedAgents/main/datasets/MedQA/test.jsonl [following]
--2024-07-27 08:10:46--  https://raw.githubusercontent.com/gersteinlab/MedAgents/main/datasets/MedQA/test.jsonl
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1325785 (1.3M) [text/plain]
Saving to: ‘test.jsonl’


2024-07-27 08:10:46 (106 MB/s) - ‘test.jsonl’ saved [1325785/1325785]



In [None]:
questions = []
answers = []

with open('test.jsonl', 'r') as file:
    for line in file:
        data = json.loads(line.strip())
        questions.append(data['question'])
        answers.append(data['answer'])


In [None]:

nltk.download('stopwords')
stop_words = set(nltk.corpus.stopwords.words('english'))

def preprocess_text(text, max_words=526):
    text = text.lower()
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'[^\w\s]', '', text)
    tokens = text.split()
    tokens = [word for word in tokens if word not in stop_words]
    if len(tokens) > max_words:
        tokens = tokens[:max_words]
    return ' '.join(tokens)

articles2 = [preprocess_text(article) for article in articles2]


[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [None]:

vectorizer = TfidfVectorizer(stop_words='english')
X = vectorizer.fit_transform(articles2)

def get_related_docs(question, vectorizer, X, top_k=5):
    question_vec = vectorizer.transform([question])
    similarities = cosine_similarity(question_vec, X).flatten()
    related_docs_indices = similarities.argsort()[-top_k:][::-1]
    return related_docs_indices

data = []
for question, answer in zip(questions, answers):
    related_docs_indices = get_related_docs(question, vectorizer, X)
    related_docs = " ".join([articles2[i] for i in related_docs_indices])
    prompt = question + " " + related_docs
    processed_prompt = preprocess_text(prompt)
    data.append({'prompt': processed_prompt, 'answer': answer})

df = pd.DataFrame(data)

In [None]:
print(df.head(5))

                                              prompt  \
0  junior orthopaedic surgery resident completing...   
1  67yearold man transitional cell carcinoma blad...   
2  two weeks undergoing emergency cardiac catheri...   
3  39yearold woman brought emergency department f...   
4  35yearold man comes physician itchy watery eye...   

                                              answer  
0  Tell the attending that he cannot fail to disc...  
1                               Cross-linking of DNA  
2                           Cholesterol embolization  
3  Lactose-fermenting, gram-negative rods forming...  
4                                Ketotifen eye drops  


In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
import torch

tokenizer = T5Tokenizer.from_pretrained('google/t5-efficient-tiny')
model = T5ForConditionalGeneration.from_pretrained('google/t5-efficient-tiny')


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
class PromptAnswerDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, tokenizer, max_length):
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        input_text = row['prompt']
        target_text = row['answer']
        inputs = self.tokenizer(input_text, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')
        targets = self.tokenizer(target_text, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')
        input_ids = inputs['input_ids'].squeeze()
        attention_mask = inputs['attention_mask'].squeeze()
        labels = targets['input_ids'].squeeze().tolist()
        # Ensure labels are within vocab size
        labels = [l if l < self.tokenizer.vocab_size else self.tokenizer.pad_token_id for l in labels]
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': torch.tensor(labels, dtype=torch.long)
        }

In [None]:
# Split dataset
from sklearn.model_selection import train_test_split
train_df, val_df = train_test_split(df, test_size=0.05, random_state=42)
train_dataset = PromptAnswerDataset(train_df, tokenizer, max_length=526)
val_dataset = PromptAnswerDataset(val_df, tokenizer, max_length=526)


In [None]:
from datasets import load_metric
bleu = load_metric('bleu')
rouge = load_metric('rouge')
bertscore = load_metric('bertscore')

  bleu = load_metric('bleu')


In [None]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    # Check if predictions is a tuple and extract tensor
    if isinstance(predictions, tuple):
        predictions = predictions[0]

    # Ensure predictions are in tensor format
    if isinstance(predictions, np.ndarray):
        predictions = torch.tensor(predictions)

    # Get predicted ids
    predicted_ids = torch.argmax(predictions, dim=-1)

    # Decode predictions and labels
    try:
        decoded_preds = tokenizer.batch_decode(predicted_ids, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    except IndexError as e:
        print(f"Error during decoding: {e}")
        print(f"Predictions shape: {predictions.shape}")
        print(f"Labels shape: {labels.shape}")
        print(f"Predictions (first 5): {predictions[:5]}")
        print(f"Labels (first 5): {labels[:5]}")
        raise e

    # Compute metrics
    accuracy = accuracy_score(decoded_labels, decoded_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(decoded_labels, decoded_preds, average='weighted')

    # Compute BLEU
    bleu_result = bleu.compute(predictions=[p.split() for p in decoded_preds], references=[[l.split()] for l in decoded_labels])

    # Compute ROUGE
    rouge_result = rouge.compute(predictions=decoded_preds, references=decoded_labels)

    # Compute BERTscore
    bertscore_result = bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'bleu': bleu_result['bleu'], # Changed key from 'score' to 'bleu'
        'rougeL': rouge_result['rougeL'].mid.fmeasure,
        'bertscore': np.mean(bertscore_result['f1'])
    }

In [None]:
from transformers import Trainer, TrainingArguments

# Define the training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=5,
    per_device_train_batch_size=3,
    per_device_eval_batch_size=3,
    gradient_accumulation_steps=8,
    fp16=True,
    warmup_steps=100,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    save_steps=100,
    evaluation_strategy="steps",
    eval_steps=10,
    save_total_limit=2,
)

# Create the Trainer instance
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)



In [16]:
trainer.train()

Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1,Bleu,Rougel,Bertscore
10,20.6902,18.235081,0.0,0.0,0.0,0.0,0.0,0.0,0.754146
20,20.2787,17.271236,0.0,0.0,0.0,0.0,0.0,0.0,0.776841
30,19.1778,15.790132,0.0,0.0,0.0,0.0,0.0,0.0,0.777923
40,17.7213,14.054646,0.0,0.0,0.0,0.0,0.0,0.0,0.789973
50,16.2048,12.176167,0.0,0.0,0.0,0.0,0.0,0.0,0.791233
60,14.3236,10.122239,0.0,0.0,0.0,0.0,0.0,0.0,0.790585
70,12.2917,7.754481,0.0,0.0,0.0,0.0,0.0,0.0,0.775295
80,9.9781,5.013665,0.0,0.0,0.0,0.0,0.0,0.005208,0.784704
90,7.6649,2.05741,0.0,0.0,0.0,0.0,0.0,0.005357,0.783128
100,5.3707,0.33065,0.0,0.0,0.0,0.0,0.0,0.00477,0.781078


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(re

Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1,Bleu,Rougel,Bertscore
10,20.6902,18.235081,0.0,0.0,0.0,0.0,0.0,0.0,0.754146
20,20.2787,17.271236,0.0,0.0,0.0,0.0,0.0,0.0,0.776841
30,19.1778,15.790132,0.0,0.0,0.0,0.0,0.0,0.0,0.777923
40,17.7213,14.054646,0.0,0.0,0.0,0.0,0.0,0.0,0.789973
50,16.2048,12.176167,0.0,0.0,0.0,0.0,0.0,0.0,0.791233
60,14.3236,10.122239,0.0,0.0,0.0,0.0,0.0,0.0,0.790585
70,12.2917,7.754481,0.0,0.0,0.0,0.0,0.0,0.0,0.775295
80,9.9781,5.013665,0.0,0.0,0.0,0.0,0.0,0.005208,0.784704
90,7.6649,2.05741,0.0,0.0,0.0,0.0,0.0,0.005357,0.783128
100,5.3707,0.33065,0.0,0.0,0.0,0.0,0.0,0.00477,0.781078


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

TrainOutput(global_step=250, training_loss=6.063409430503845, metrics={'train_runtime': 12878.7849, 'train_samples_per_second': 0.469, 'train_steps_per_second': 0.019, 'total_flos': 135395278848000.0, 'train_loss': 6.063409430503845, 'epoch': 4.962779156327543})

In [17]:
trainer.evaluate(eval_dataset=val_dataset)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.16548490524291992,
 'eval_accuracy': 0.0,
 'eval_precision': 0.0,
 'eval_recall': 0.0,
 'eval_f1': 0.0,
 'eval_bleu': 0.0,
 'eval_rougeL': 0.016666666666666666,
 'eval_bertscore': 0.5625052284449339,
 'eval_runtime': 93.3447,
 'eval_samples_per_second': 0.686,
 'eval_steps_per_second': 0.236,
 'epoch': 4.962779156327543}