In [28]:
import inspect
import json
import os.path
import random
import re
from functools import wraps
import hashlib

import numpy as np
import requests
from Bio import Entrez
from math import ceil

from dotenv import load_dotenv
from lxml import etree
from zhipuai import ZhipuAI

load_dotenv()

ai_client = ZhipuAI()
Entrez.email = "jameshu@live.ccom"

pubmed_ids = []
retstart = 0
retmax = 1000


class Cache:
    def __init__(self, root_path):
        self.root_path = root_path

    def bucket(self, name):
        def decorator(func):
            @wraps(func)
            def decorated(*args, **kwargs):
                sig = inspect.signature(func)
                params = sig.parameters

                param_names = [key for key in params.keys()]
                param_dict = {}
                param_dict.update(zip(params.keys(), args))
                param_dict.update(kwargs)

                for k in (params.keys() - param_dict.keys()):
                    param_dict[k] = params[k].default

                key = list(sorted(param_dict.items()))
                key = json.dumps(key)
                m = hashlib.md5()
                m.update(key.encode('utf8'))
                hashkey = m.hexdigest()

                cache_file = os.path.join(self.root_path, name, hashkey[0:2], hashkey)
                os.makedirs(os.path.dirname(cache_file), 0o777, exist_ok=True)

                if os.path.exists(cache_file):
                    with open(cache_file, 'r', encoding='utf8') as fp:
                        try:
                            result = json.load(fp)
                        except:
                            result = None

                        if result is not None:
                            return result

                result = func(*args, **kwargs)
                if result is not None:
                    with open(cache_file, 'w', encoding='utf8') as fp:
                        fp.write(json.dumps(result))

                return result

            return decorated
        return decorator


cache = Cache(".cache")


@cache.bucket("justscience")
def query_justscience(issn):
    if issn is None or len(issn) == 0:
        return None

    response = requests.get(
        f"https://sci.justscience.cn/list?sci=1&q={issn}&research_area=&If_range_min=&If_range_max=&jcr_quartile=0&oa=2&Self_cites_ratio_min=&Self_cites_ratio_max=&mainclass=0&subclass=0&pub_country=&not_pub_country=&sci_type=2&pub_frequency=7&adv=1")
    root = etree.HTML(response.text)
    tr_node = root.find(".//table[@class=\"s-result-table\"]//tbody//tr")
    if tr_node is None:
        response = requests.get(
            f"https://sci.justscience.cn/list?sci=0&q={issn}&research_area=&If_range_min=&If_range_max=&jcr_quartile=0&oa=2&Self_cites_ratio_min=&Self_cites_ratio_max=&mainclass=0&subclass=0&pub_country=&not_pub_country=&sci_type=2&pub_frequency=7&adv=1")
        root = etree.HTML(response.text)
        tr_node = root.find(".//table[@class=\"s-result-table\"]//tbody//tr")

    if tr_node is None:
        lines = []
    else:
        tr_text = etree.tostring(tr_node, method='text', encoding='utf8')
        lines = tr_text.decode().split("\n")
        lines = [line.strip() for line in lines]

    return lines


@cache.bucket("efetch")
def fetch_pubmed(pmid):
    handle = Entrez.efetch(db="pubmed", id=pmid)
    return handle.read().decode('utf8')


@cache.bucket("esearch")
def search_pubmed(term, retmax: int = 2000):
    handle = Entrez.esearch(db="pubmed", retmax=retmax, retstart=0, term=term, sort="Pub Date")
    return Entrez.read(handle)


@cache.bucket("embedding")
def get_embedding(text):
    response = ai_client.embeddings.create(input=text, model='embedding-2')
    return response.data[0].embedding
    # client = ZhipuAI(api_key="")  # 填写您自己的APIKey
    # response = client.chat.completions.create(
    #     model="glm-3-turbo",  # 填写需要调用的模型名称
    #     messages=[
    #         {"role": "user", "content": "作为一名营销专家，请为我的产品创作一个吸引人的slogan"},
    #         {"role": "assistant", "content": "当然，为了创作一个吸引人的slogan，请告诉我一些关于您产品的信息"},
    #         {"role": "user", "content": "智谱AI开放平台"},
    #         {"role": "assistant", "content": "智启未来，谱绘无限一智谱AI，让创新触手可及!"},
    #         {"role": "user", "content": "创造一个更精准、吸引人的slogan"}
    #     ],
    # )
    # print(response.choices[0].message)


def string_to_float(text) -> float:
    try:
        return float(text)
    except:
        return 0.


def parse_pubmed(content: str):
    root = etree.XML(content)

    title_node = root.find(".//ArticleTitle")
    if title_node is not None and title_node.text is not None:
        title = title_node.text.strip()
    else:
        title = ""

    abstract_node = root.find(".//AbstractText")
    if abstract_node is not None and abstract_node.text is not None:
        abstract = abstract_node.text.strip()
    else:
        abstract = ""

    issn_node = root.find(".//ISSN")
    if issn_node is not None and issn_node.text is not None:
        issn = issn_node.text.strip()
    else:
        issn = ""

    return dict(title=title, abstract=abstract, issn=issn)


candidates = []

search_result = search_pubmed("stroke")
for idx, pmid in enumerate(search_result['IdList']):
    content = fetch_pubmed(pmid)
    info = parse_pubmed(content)

    if_info = query_justscience(info['issn'])
    if if_info is None or len(if_info) == 0 or len(info['title']) < 30 or string_to_float(if_info[7]) < 4.0:
        continue

    info['embedding'] = get_embedding(info['title'] + "\n" + info['abstract'])
    candidates.append(info)

for idx, candidate in enumerate(candidates):
    candidate['index'] = idx

from sklearn.cluster import KMeans

X = [info['embedding'] for info in candidates]

# 把所有文章分成10个类
cluster = KMeans(n_clusters=10, random_state=0).fit(X)
for idx, label in enumerate(cluster.labels_):
    candidates[idx]['label'] = label

  super()._check_params_vs_input(X, default_n_init=10)


In [29]:

from imblearn.over_sampling import SMOTE
import numpy as np


X = np.reshape(range(len(candidates)), (-1, 1))
y = [candidate['label'] for candidate in candidates]

model_smote = SMOTE()
x_smote_resampled, y_smote_resampled = model_smote.fit_resample(X, y)

x_smote_resampled, y_smote_resampled

(array([[  0],
        [  1],
        [  2],
        [  3],
        [  4],
        [  5],
        [  6],
        [  7],
        [  8],
        [  9],
        [ 10],
        [ 11],
        [ 12],
        [ 13],
        [ 14],
        [ 15],
        [ 16],
        [ 17],
        [ 18],
        [ 19],
        [ 20],
        [ 21],
        [ 22],
        [ 23],
        [ 24],
        [ 25],
        [ 26],
        [ 27],
        [ 28],
        [ 29],
        [ 30],
        [ 31],
        [ 32],
        [ 33],
        [ 34],
        [ 35],
        [ 36],
        [ 37],
        [ 38],
        [ 39],
        [ 40],
        [ 41],
        [ 42],
        [ 43],
        [ 44],
        [ 45],
        [ 46],
        [ 47],
        [ 48],
        [ 49],
        [ 50],
        [ 51],
        [ 52],
        [ 53],
        [ 54],
        [ 55],
        [ 56],
        [ 57],
        [ 58],
        [ 59],
        [ 60],
        [ 61],
        [ 62],
        [ 63],
        [ 64],
        [ 65],
        [ 

In [30]:
import pandas as pd

label_groups = {}

for x, y in zip(x_smote_resampled, y_smote_resampled):
    if y not in label_groups:
        label_groups[y] = []
    
    label_groups[y].append(candidates[x[0]])

for key, values in label_groups.items():
    random.shuffle(label_groups[key])


In [44]:
# 开始请求 AI

num_loop = len(label_groups[0])
num_loop = num_loop // 3

fout = open("ideas.txt", "w", encoding="utf-8")

for loop in range(num_loop):
    start = loop * 3
    
    candidates = {values[start]['index']: values[start] for key, values in label_groups.items()}
    candidates.update({values[start + 1]['index']: values[start + 1] for key, values in label_groups.items()})
    candidates.update({values[start + 2]['index']: values[start + 2] for key, values in label_groups.items()})
    
    candidates = list(candidates.values())
    random.shuffle(candidates)
    candidates = "\n".join([str(idx + 1) + "、" + candidate['title'].strip() for idx, candidate in enumerate(candidates)])
    
    prompt = f"""
中国国家卒中登记研究-Ⅲ是一个多中心、前瞻性队列研究，拟通过标准诊断流程及目前公认的缺血性卒中病因分型，同时评价缺血性脑血管病相关危险因素，包括目前公认的血压、血脂、血糖等以及肾功能及心功能等预后影响因素，探索缺血性脑血管病病因及发病机制分布；在临床、影像、分子水平确定不同预后及其影响因素，探索包括影像学特征的TIA/卒中风险预测模型建立，认识卒中和TIA患者的预后影响因素，早期评估和识别高危患者。

国内外前沿研究成果：
{candidates}

我想在中国国家卒中登记研究-Ⅲ队列研究基础上写一些论文。请参考国内外目前的相关领域前沿研究成果，扩展我的论文写作思路。每个思路展开说明主要研究背景，研究目的，方法，主要结局，危险因素，相关论文及影响因子。多想想再回答，思路不少于6个。输出格式为：

研究背景：-- 研究背景
研究目的：-- 研究目的
方法：-- 方法
主要结局：-- 主要结局
危险因素：-- 危险因素
相关论文及影响因子：-- 相关论文及影响因子

"""

    response = ai_client.chat.completions.create(
        model="glm-3-turbo",
        messages=[
            {'role': 'user', 'content': prompt}
        ])
    
    fout.write(response.choices[0].message.content + "\n")
    fout.flush()
    print(f"processed {loop}")

processed 0
processed 1
processed 2
processed 3
processed 4
processed 5
processed 6
processed 7
processed 8
processed 9
processed 10
processed 11
processed 12
processed 13
