In [None]:
import multiprocessing as mp
from tqdm import tqdm
import sys

In [11]:
INPUT_FILE =  "/home/ubuntu/Desktop/pubtator3/gene2pubtator3"
OUTPUT_FILE = "results/gene_counts_noFilt.txt" 
CHUNK_SIZE = 1000000 
processes = 8  


In [17]:
def process_chunk(lines):
    """
    Process a chunk of lines:
    Each line: PMIDs in col 1 (index 0), gene in col 3 (index 2)
    Returns a dictionary: gene -> set of PMIDs
    """
    gene_to_pmids = {}
    for line in lines:
        cols = line.strip().split('\t')
        if len(cols) < 3:
            continue
        pmid = cols[0]
        gene = cols[2]
        if gene not in gene_to_pmids:
            gene_to_pmids[gene] = set()
        gene_to_pmids[gene].add(pmid)
    return gene_to_pmids

def merge_dicts(dict_list):
    """
    Merge a list of gene->set_of_pmids dictionaries into one.
    """
    merged = {}
    for d in dict_list:
        for gene, pmids in d.items():
            if gene not in merged:
                merged[gene] = pmids
            else:
                merged[gene].update(pmids)
    return merged

def main():
    print("Counting total lines in the file...", file=sys.stderr)
    with open(INPUT_FILE, 'r') as f:
        total_lines = sum(1 for _ in f)
    print(f"Total lines: {total_lines}", file=sys.stderr)
    
    # Prepare parallel pool
    pool = mp.Pool(processes=processes)
    
    # Generator to read file in chunks
    def chunk_reader(file_path, chunk_size):
        with open(file_path, 'r') as f:
            chunk = []
            for i, line in enumerate(f):
                chunk.append(line)
                if (i+1) % chunk_size == 0:
                    yield chunk
                    chunk = []
            if chunk:  # last partial chunk
                yield chunk

    partial_results = []
    
    with tqdm(total=(total_lines // CHUNK_SIZE) + 1, desc="Processing Chunks") as pbar:
        for partial_dict in pool.imap(process_chunk, chunk_reader(INPUT_FILE, CHUNK_SIZE), chunksize=1):
            partial_results.append(partial_dict)
            pbar.update(1)
    
    pool.close()
    pool.join()

    # Merge all partial dictionaries
    print("Merging partial results...", file=sys.stderr)
    final_dict = merge_dicts(partial_results)

    # Convert sets to counts
    print("Converting sets to counts...", file=sys.stderr)
    gene_to_count = {gene: len(pmids) for gene, pmids in final_dict.items()}

    # Write results to file
    output_file = OUTPUT_FILE
    print(f"Writing final results to {output_file}...", file=sys.stderr)
    with open(output_file, 'w') as out:
        out.write("entrezgene_id\tcount\n")
        for gene, count in gene_to_count.items():
            out.write(f"{gene}\t{count}\n")

    total_genes = len(gene_to_count)
    print(f"Done. Total genes counted: {total_genes}", file=sys.stderr)

In [18]:
if __name__ == "__main__":
    main()

Counting total lines in the file...
Total lines: 70567596
Processing Chunks: 100%|██████████| 71/71 [01:44<00:00,  1.47s/it]
Merging partial results...
Converting sets to counts...
Writing final results to results/gene_counts_noFilt.txt...
Done. Total genes counted: 6733356
