In [12]:
import time
from typing import Dict, Set
from Bio import Entrez, Medline
from logisticregression import LogisticRegression

In [2]:
with open( 'positive_examples.txt' ) as file :
    posids = set( file.read().split( ',' ) )

In [3]:
with open( 'negative_examples.txt' ) as file :
    negids = set( file.read().split( ',' ) )

In [4]:
# define data structure
Class, Document = str, str
data : Dict[ Class, Set[ Document ] ] = { 'positive' : set(), 'negative' : set() }

In [5]:
def get_abstract( pmid : Set[ str ] ) -> Dict[ Class, Set[ Document ] ] :
    '''
    Get Abstract from PubMed
    '''
    return Medline.read(
        Entrez.efetch(
            db = 'pubmed',
            id = pmid,
            email = 'chiodini.zachary@epa.gov',
            retmode = 'text',
            rettype = 'medline'
            )
        ).get( 'AB' )

In [6]:
# getting data
for pmid in posids :
    data[ 'positive' ].add( get_abstract( pmid ) )
    time.sleep( 1/3 ) # avoid PubMed ban
for pmid in negids :
    data[ 'negative' ].add( get_abstract( pmid ) )
    time.sleep( 1/3 ) # avoid PubMed ban

In [42]:
model = LogisticRegression()

In [43]:
model.kFoldValidate( data, rate = 1, convergence = 0.01, target = 'positive', k = 10 ) # 10-fold cross-validation

Examples Trained: 900
Examples Tested : 100
Total Examples  : 1000


In [44]:
import pandas as pd
pd.DataFrame( model.predictions ).fillna( '' )

Unnamed: 0,positive,negative,model
positive,270.0,0.0,
negative,0.0,630.0,
truth rate,1.0,1.0,
false rate,0.0,0.0,
precision,1.0,1.0,
accuracy,,,1.0(0.0)


In [69]:
model.trainAndTest( 
    data, target = 'positive', rate = 1, 
    batches = 1, convergence = 0.00001, iters = 1000 
    ) # Monte-Carlo cross-validation

Examples Trained: 49691
Examples Tested : 50309
Total Examples  : 100000


In [70]:
pd.DataFrame( model.predictions ).fillna( '' )

Unnamed: 0,positive,negative,model
positive,12165.0,846.0,
negative,3133.0,34165.0,
truth rate,0.795202,0.975836,
false rate,0.0241638,0.204798,
precision,0.934978,0.916001,
accuracy,,,0.94(0.06)
