In [202]:
%clear -f
import pandas as pd
from sqlalchemy import create_engine, update, Table, MetaData
from dotenv import load_dotenv
import os, json, threading, time 
import google.generativeai as genai
from labels_all import get_prompt
from txt_to_df import txt_to_df
from test import get_test_df

load_dotenv(".env")
GEMINI_KEY = os.environ.get("GEMINI_KEY")
genai.configure(api_key=GEMINI_KEY)

# Gemini API # currently need VPN to outside Europe
model = genai.GenerativeModel('gemini-pro')
generation_config = genai.types.GenerationConfig(temperature=0) #default: 1.0

# SQLite for persistent storage
engine = create_engine('sqlite:///publications.db')


[H[2J

In [203]:
%autoreload
publications = txt_to_df("tst.txt")
publications = publications[:20]
publications.head()
publications.to_html("new1.html")

In [204]:
#publications = get_test_df()
# publications["category"] = None
# publications["reasoning"] = None
publications["category1"] = None
publications["reasoning1"] = None
publications["category2"] = None
publications["reasoning2"] = None


# Write whole df to SQLite 
publications.to_sql('publications', con=engine, if_exists='replace', index=False)
publications_table = Table('publications', MetaData(), autoload_with=engine)
publications.shape

(20, 9)

In [199]:
%autoreload
def get_category(i):
 try:
    print(i)
    #get data from gemini api
    doi = publications["DOI"][i]
    title = publications["Title"][i]
    abstract = publications["Abstract"][i]
    prompt = get_prompt(title, abstract)
    res = model.generate_content(prompt, generation_config=generation_config) #generation_config=generation_config
    json_res = json.loads(res.text)
    categories = json_res["categories"]
    category1 = categories[0]["category"]
    reasoning1 = categories[0]["clear_reasoning"]
    if len(categories) > 1:
      category2 = json_res["categories"][1]["category"]
      reasoning2 = json_res["categories"][1]["clear_reasoning"]
    else :
      category2 = None
      reasoning2 = None
    print(i, category1, category2)
    #save in database
    stmt = (
    update(publications_table)
    .where(publications_table.c.DOI == doi)
    .values(category1=category1, reasoning1=reasoning1, category2=category2, reasoning2=reasoning2)
    )
    with engine.connect() as conn:
     result = conn.execute(stmt)
     conn.commit()  
 except Exception as e:
   print(e)
   print(str(i)+"returned with err")

#test function
#get_category(1)


In [200]:
%autoreload
start_idx = 0
end_idx = 20

if __name__ == '__main__':
    threads = []
    for i in range(start_idx, end_idx):
        thread = threading.Thread(target=get_category, args=(i,))
        threads.append(thread)
        thread.start()
        time.sleep(1.01) # current rate limit 60 requests/minute
        if (i%10==0) :
            print("# running threads: " + str(threading.active_count()))

    for thread in threads:
        thread.join()


0
# running threads: 8
1
2
3
4
5
0 D5 A9
6
1 A7 None
4 F None
7
8
3 A7 A3
9
10
5 A7 A1a
2 A5 None
# running threads: 12
11
7 A1b A7
6 A5 C3
12
8 C1 C3
13
9 D4a A4
14
10 A7 A4
15
16
17
Expecting value: line 1 column 1 (char 0)
12returned with err
18
13 A10 C4
11 C3 None
15 A5 None
19
16 C3 A10
14 A10 D10
17 C3 A5
19 A10 C3
18 C3 C2b


In [201]:
query = "SELECT * FROM publications"
publications = pd.read_sql_query(query, engine)
publications
#publications.to_csv('new.csv', index=False, sep=",")
publications.to_html("new1.html")


# # #remove hallucinated labels
labels = ["F", "A1a", "A1b", "A1c", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9", "A10", "C1", "C2a", "C2b", "C2c", "C2d", "C3", "C4", "C5", "D1a", "D1b", "D2", "D3", "D4a", "D4b", "D5", "D6", "D7", "D8", "D9", "D10"]
# # # Replace values not in the list with null
publications['category1'] = publications['category1'].apply(lambda x: x if x in labels else None)
publications['category2'] = publications['category2'].apply(lambda x: x if x in labels else None)
# proposals.to_sql('proposals', con=engine, if_exists='replace', index=False)

# filter for proposals that have not have a label yet
publications = publications[publications['category1'].isnull()]
publications = publications.reset_index()
print(publications.shape)
publications.head()

(1, 10)


Unnamed: 0,index,Index,Title,Abstract,Journal,DOI,category1,reasoning1,category2,reasoning2
0,12,14,Alzheimer's Disease and Aging Association: Ide...,BACKGROUND: Aging is considered a key risk fac...,J Prev Alzheimers Dis,10.14283/jpad.2023.101,,,,
