In [1]:
# Dec 2023
# Author: SBN
# inspired by https://www.kaggle.com/code/tahmidmir/fine-tuning-gpt-2-on-skin-cancer-articles

In [10]:
from Bio import Entrez
import pandas as pd
import time
import warnings
warnings.filterwarnings('ignore')


In [11]:
# uncomment for insalling:
# !pip install biopython
# !pip install pytz

In [15]:


# NCBI API settings
Entrez.email = "mirtahmid@gmail.com"  # Required by NCBI to use the API



<a id="1"></a>
# <div style="text-align:center; border-radius:3px 50px; padding:7px; color:white; margin:0; font-size:100%; font-family:Pacifico; background-color:#a0b1b0; "><b>⭐ Fetching data⭐</b></div>

In [16]:
num_article = 1000
handle = Entrez.esearch(db="pubmed", term="skin cancer", retmax=num_article)
record = Entrez.read(handle)
ids = record['IdList']
handle.close()


In [17]:
ids[:-3:-1]

['39303343', '39303804']

In [18]:
# Fetch articles with the specified PubMed IDs
handle = Entrez.efetch(db="pubmed", id=",".join(ids), rettype="xml", retmode="text")
records = Entrez.read(handle)
handle.close()

In [19]:
records.keys()

dict_keys(['PubmedBookArticle', 'PubmedArticle'])

<a id="1"></a>
# <div style="text-align:center; border-radius:3px 50px; padding:7px; color:white; margin:0; font-size:100%; font-family:Pacifico; background-color:#a0b1b0; "><b>⭐ Observe the data parts⭐</b></div>

In [20]:
# pick a sample to see the output of fetching:
article_num = 0
records["PubmedArticle"][article_num]

{'MedlineCitation': DictElement({'OtherID': [], 'CitationSubset': ['IM'], 'KeywordList': [ListElement([StringElement('accelerated failure time model', attributes={'MajorTopicYN': 'N'}), StringElement('cured model', attributes={'MajorTopicYN': 'N'}), StringElement('partly interval censoring', attributes={'MajorTopicYN': 'N'}), StringElement('penalized likelihood', attributes={'MajorTopicYN': 'N'}), StringElement('semiparametric estimation', attributes={'MajorTopicYN': 'N'})], attributes={'Owner': 'NOTNLM'})], 'OtherAbstract': [], 'InvestigatorList': [], 'GeneralNote': [], 'SpaceFlightMission': [], 'PMID': StringElement('39508209', attributes={'Version': '1'}), 'DateCompleted': {'Year': '2024', 'Month': '11', 'Day': '07'}, 'DateRevised': {'Year': '2024', 'Month': '11', 'Day': '07'}, 'Article': DictElement({'ELocationID': [StringElement('10.1002/bimj.202300203', attributes={'EIdType': 'doi', 'ValidYN': 'Y'})], 'ArticleDate': [], 'Language': ['eng'], 'Journal': {'ISSN': StringElement('1521

In [21]:
records["PubmedArticle"][article_num].keys()

dict_keys(['MedlineCitation', 'PubmedData'])

In [22]:
records["PubmedArticle"][article_num]['MedlineCitation'].keys()

dict_keys(['OtherID', 'CitationSubset', 'KeywordList', 'OtherAbstract', 'InvestigatorList', 'GeneralNote', 'SpaceFlightMission', 'PMID', 'DateCompleted', 'DateRevised', 'Article', 'MedlineJournalInfo', 'MeshHeadingList'])

In [23]:
records["PubmedArticle"][article_num]['PubmedData'].keys()

dict_keys(['ReferenceList', 'History', 'PublicationStatus', 'ArticleIdList'])

In [24]:
records["PubmedArticle"][article_num]['MedlineCitation']["Article"].keys()

dict_keys(['ELocationID', 'ArticleDate', 'Language', 'Journal', 'ArticleTitle', 'Pagination', 'Abstract', 'AuthorList', 'PublicationTypeList'])

In [25]:
records["PubmedArticle"][article_num]['MedlineCitation']["Article"]["Abstract"].keys()

dict_keys(['AbstractText', 'CopyrightInformation'])

In [26]:
records["PubmedArticle"][article_num]['MedlineCitation']["Article"]["ArticleTitle"]

'Mixture Cure Semiparametric Accelerated Failure Time Models With Partly Interval-Censored Data.'

In [27]:
records["PubmedArticle"][article_num]['MedlineCitation']["Article"]["Abstract"]["AbstractText"]

['In practical survival analysis, the situation of no event for a patient can arise even after a long period of waiting time, which means a portion of the population may never experience the event of interest. Under this circumstance, one remedy is to adopt a mixture cure Cox model to analyze the survival data. However, if there clearly exhibits an acceleration (or deceleration) factor among their survival times, then an accelerated failure time (AFT) model will be preferred, leading to a mixture cure AFT model. In this paper, we consider a penalized likelihood method to estimate the mixture cure semiparametric AFT models, where the unknown baseline hazard is approximated using Gaussian basis functions. We allow partly interval-censored survival data which can include event times and left-, right-, and interval-censoring times. The penalty function helps to achieve a smooth estimate of the baseline hazard function. We will also provide asymptotic properties to the estimates so that inf

<a id="1"></a>
# <div style="text-align:center; border-radius:3px 50px; padding:7px; color:white; margin:0; font-size:100%; font-family:Pacifico; background-color:#a0b1b0; "><b>⭐ Saving data⭐</b></div>

In [28]:
abstract_bank = []
failed_fetech = 0
for i in range(num_article):
    try:
        abstract_text = records["PubmedArticle"][i]['MedlineCitation']["Article"]["Abstract"]["AbstractText"]
        if len(abstract_text)>1: # get rid of those abstracts which has more than 1 part. They are not of standrd order.
            continue
    except:
        # print('one article is not fetched!')
        failed_fetech+=1
    abstract_bank.append(abstract_text[0])
    # print(abstract_text[0])
print(f"{failed_fetech/num_article *100}% of the data is corrput and failed!")    

7.7% of the data is corrput and failed!


In [29]:
with open("pubmed_skin_cancer_articles.txt", "w") as f:
    for abstract in abstract_bank:
        f.write(abstract + "\n\n")
print("Fetched articles saved to 'pubmed_skin_cancer_articles.txt'")

Fetched articles saved to 'pubmed_skin_cancer_articles.txt'


In [30]:
len(records["PubmedArticle"])

997

In [31]:
# !pip install transformers datasets
# !pip install torch torchvision torchaudio
# !pip install accelerate
# !pip install --upgrade accelerate
# !pip install --upgrade transformers


<a id="1"></a>
# <div style="text-align:center; border-radius:3px 50px; padding:7px; color:white; margin:0; font-size:100%; font-family:Pacifico; background-color:#a0b1b0; "><b>⭐ Fine-tune the GPT-2 ⭐</b></div>

In [32]:
import accelerate

In [64]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments

# Load GPT-2 tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")

# Tokenize the dataset
def load_dataset(file_path, block_size=128):
    """
    Load the dataset and tokenize it for GPT-2
    """
    dataset = TextDataset(
        tokenizer=tokenizer,
        file_path=file_path,
        block_size=block_size,
    )
    return dataset

# Data collator for dynamic padding
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # GPT-2 uses causal language modeling, not masked language modeling (MLM)
)

# Load the custom dataset (skin cancer-related PubMed abstracts)
dataset = load_dataset("pubmed_skin_cancer_articles.txt")

# Fine-tuning setup using Trainer API
training_args = TrainingArguments(
    output_dir="./fine_tuned_gpt2",
    overwrite_output_dir=True,
    num_train_epochs=300,  # Adjust epochs depending on dataset size
    per_device_train_batch_size=2,
    save_steps=10_000,
    save_total_limit=2,
    prediction_loss_only=True,
)

# Trainer API for fine-tuning
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset,
)

# Start fine-tuning
trainer.train()

# Save the fine-tuned model
model.save_pretrained("./fine_tuned_gpt2")
tokenizer.save_pretrained("./fine_tuned_gpt2")

print("Fine-tuning complete. Model saved to './fine_tuned_gpt2'")

Step,Training Loss
500,0.2747
1000,0.0149


Fine-tuning complete. Model saved to './fine_tuned_gpt2'


<a id="1"></a>
# <div style="text-align:center; border-radius:3px 50px; padding:7px; color:white; margin:0; font-size:100%; font-family:Pacifico; background-color:#a0b1b0; "><b>⭐ Inference and Test⭐</b></div>



In [77]:


import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load the fine-tuned model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("./fine_tuned_gpt2")
model = GPT2LMHeadModel.from_pretrained("./fine_tuned_gpt2")

# Function to generate a response from the fine-tuned model
def generate_finetuned_response(prompt, max_length=200, temperature=0.7):
    inputs = tokenizer(prompt, return_tensors='pt', max_length=512, truncation=True)

    with torch.no_grad():
        outputs = model.generate(
            inputs['input_ids'],
            max_length=max_length,
            temperature=temperature,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            pad_token_id=tokenizer.eos_token_id
        )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Example: Generate a response to a medical query
prompt = "What's the skin cancer?"
response = generate_finetuned_response(prompt)
print("Generated Response:", response)



Generated Response: What's the skin cancer?

Pheochromocytoma, also called macrophage hyperplasia, is a rare cutaneous neuroendocrine cancer that has the potential to metastasize. However, brain metastasis is infrequent in this type of cancer. To illustrate the importance of molecular diagnostic approaches in detecting this rare condition, a real case study involving melanoma recurrence is conducted and reported. Our model is implemented in our R package aftQnp which is available from https://github.com/Isabellee4555/aftQnP.
 Merkel cell carcinoma (MMC) is one of the important causes of death in patients with MEN2A. MCC is aggressive and has a high median survival. Although no significant difference was observed between treated and untreated patients, an increased likelihood of metastatic melanomas was noted in treated MEN1A patients. The results highlight the need for further investigation and understanding of MEN4A symptoms and indicate
