# Extract SQ_TP, SQ_FP and SQ_R papers from oa_comm database

In [53]:
from Bio import Entrez
import csv
import os
import xml.etree.ElementTree as ET
import copy

## Function that searches for papers using our SQs and returns a list of PubMed IDs

In [54]:
def search_pubmed_for_ids(query, max_results=13):
    Entrez.email = "zeynep.korkmaz@tum.de"  # Set email address

    handle = Entrez.esearch(db="pubmed", term=query, retmax=max_results)
    record = Entrez.read(handle)
    handle.close()

    return record["IdList"]

## Function that reads the keywords and SQs from a directory with CSV files and returns a dictionary

In [55]:

def read_keywords_from_directory(directory):
    keywords_dict = {}

    for filename in os.listdir(directory):
        if filename.endswith(".csv"):
            csv_file = os.path.join(directory, filename)
            
            # for troubleshooting (dictionary only contains 99 files but should contain ~140)
            #print("Reading keywords from file: {}".format(csv_file)) # all files are read

            with open(csv_file, 'r') as file:
                reader = csv.reader(file)

                current_pub_title = None
                current_keywords = []
                current_sq_tp = []
                current_sq_fp = []
                current_sq_r = []

                for row in reader:
                    row = [item.strip(', ') for item in row]
                    if row and not row[0].isdigit():
                        if row[0] == "Pub Title":
                            if current_pub_title:
                                keywords_dict[current_pub_title] = {
                                    "Pub Title": current_pub_title,
                                    "Keywords": current_keywords,
                                    "SQ_TP": current_sq_tp,
                                    "SQ_FP": current_sq_fp,
                                    "SQ_R": current_sq_r
                                }
                            current_pub_title = row[1]
                            current_keywords = []
                            current_sq_tp = []
                            current_sq_fp = []
                            current_sq_r = []
                        elif row[0] == "Keywords":
                            current_keywords.extend(item for item in row[1:] if item)
                        elif row[0] == "SQ_TP":
                            current_sq_tp.extend(item for item in row[1:] if item)
                        elif row[0] == "SQ_FP":
                            current_sq_fp.extend(item for item in row[1:] if item)
                        elif row[0] == "SQ_R":
                            current_sq_r.extend(item for item in row[1:] if item)

                if current_pub_title:
                    keywords_dict[current_pub_title] = {
                        "Pub Title": current_pub_title,
                        "Keywords": current_keywords,
                        "SQ_TP": current_sq_tp,
                        "SQ_FP": current_sq_fp,
                        "SQ_R": current_sq_r
                    }

    return keywords_dict



## Example usage of read_keywords_from_csv

In [None]:
# path to csv file 
input_dir = "Keyword_CSVs" 
#input_dir = "less_keywords" 

# create dictionary from csv
keywords_dict = read_keywords_from_directory(input_dir)

#print("\n############# \n")

# Why is this only 99? Should be ~140
#print(len(keywords_dict))

#print("\n############# \n")

# print dictionary
#for pub_title, data in keywords_dict.items():
          #  print(f"Pub Title: {data['Pub Title']}")
          #  print(f"Keywords: {', '.join(data['Keywords'])}")
          #  print(f"SQ_TP: {', '.join(data['SQ_TP'])}")
          #  print(f"SQ_FP: {', '.join(data['SQ_FP'])}")
          #  print(f"SQ_R: {', '.join(data['SQ_R'])}")
           # print("\n" + "=" * 80 + "\n")  # Separator between entries

## Function that takes keyword_dict/input_dict and returns dict with list of PubMed IDs based on SQs

In [57]:
def dict_to_pubmed_id(input_dict):
    # Initialize a new dictionary to store the results
    result_dict = {}

    # Iterate over each publication entry in the input dictionary
    for pub_title, pub_data in input_dict.items():
        # Create a copy of the publication data
        pub_result = pub_data.copy()

        # Initialize empty lists for PubMed IDs for SQ_TP, SQ_FP, and SQ_R
        pub_result['PubMed_IDs_TP'] = []
        pub_result['PubMed_IDs_FP'] = []
        pub_result['PubMed_IDs_R'] = []

        # Extract elements from SQ_TP, SQ_FP, and SQ_R lists and search PubMed for IDs
        for sq_tp_element in pub_data['SQ_TP']:
            pub_result['PubMed_IDs_TP'].extend(search_pubmed_for_ids(sq_tp_element))

        for sq_fp_element in pub_data['SQ_FP']:
            pub_result['PubMed_IDs_FP'].extend(search_pubmed_for_ids(sq_fp_element))

        for sq_r_element in pub_data['SQ_R']:
            pub_result['PubMed_IDs_R'].extend(search_pubmed_for_ids(sq_r_element))

        # Add the modified publication data to the result dictionary
        result_dict[pub_title] = pub_result

    return result_dict



In [None]:
result_dict = dict_to_pubmed_id(keywords_dict)
result_dict

In [None]:
#print(result_dict.items())
for pub_title, data in result_dict.items():
            print(data['PubMed_IDs_TP'])
            #print(f"Pub Title: {data['Pub Title']}")
            #print(f"Keywords: {', '.join(data['Keywords'])}")
            #print(f"SQ_TP: {', '.join(data['SQ_TP'])}")
            #print(f"SQ_FP: {', '.join(data['SQ_FP'])}")
            #print(f"SQ_R: {', '.join(data['SQ_R'])}")
            #print(f"PubMed_IDs_TP: {', '.join(data['PubMed_IDs_TP'])}")
            #print(f"PubMed_IDs_FP: {', '.join(data['PubMed_IDs_FP'])}")
            #print(f"PubMed_IDs_R: {', '.join(data['PubMed_IDs_R'])}")
            #print("\n" + "=" * 80 + "\n")  # Separator between entries



## Function that takes dict with list of PubMeds IDs for the SQs as input, searches specified directory for the corresping XML papers and combines all to one large XML file (output)

In [70]:
import xml.etree.ElementTree as ET

def extract_xml_files(input_dict, input_dir, output_file):
    
    SQ_IDs = {
        'SQ_TP_IDs': [id for pub_title, data in input_dict.items() for id in data['PubMed_IDs_TP']],
        'SQ_FP_IDs': [id for pub_title, data in input_dict.items() for id in data['PubMed_IDs_FP']],
        'SQ_R_IDs': [id for pub_title, data in input_dict.items() for id in data['PubMed_IDs_R']]
    }

    with open(output_file, 'wb') as f:
        f.write(b'<root>\n')

        for SQ, desired_IDs in SQ_IDs.items():
            SQ_root = ET.Element(f'{SQ}')

            for root_dir, dirs, files in os.walk(input_dir):
                for xml_file in files:
                    if xml_file.endswith('.xml'):
                        xml_file_path = os.path.join(root_dir, xml_file)
                        
                        try:
                            tree = ET.parse(xml_file_path)
                        except ET.ParseError:
                            print(f"Skipping file due to ParseError: {xml_file_path}")
                            continue

                        root = tree.getroot()
                        root_copy = copy.deepcopy(root)

                        for element in root_copy.iter('article-id'):
                            if element.attrib.get('pub-id-type') == 'pmid' and element.text in desired_IDs:
                                SQ_root.append(root_copy)

            f.write(ET.tostring(SQ_root, encoding='utf-8'))

        f.write(b'</root>')

In [75]:
input_dir = '/Users/tillohlendorf/Downloads/Extraced_XML'
output_file = f'{os.path.join(os.getcwd(), "output.xml")}'


extract_xml_files(result_dict, input_dir, output_file )