In [1]:
from neo4j import GraphDatabase
import pandas as pd
import numpy as np
class NeoPyExample:

    def __init__(self, url, user, password):
        self.driver = GraphDatabase.driver(url, auth=(user, password))

    def close(self):
        self.driver.close()

    def print_greeting(self, message):
        with self.driver.session() as session:
            greeting = session.write_transaction(self._create_and_return_greeting, message)
            print(greeting)

    def general_query(self, query, parameters=None):
        assert self.driver is not None, "Driver not initialized!"
        session = None
        response = None
        try:
            session = self.driver.session()
            response = list(session.run(query, parameters))
        except Exception as e:
            print("Query failed:", e)
        finally:
            if session is not None:
                session.close()
        return response

    def create_node_query(self, node_type, attributes):
        query = 'CREATE (n:' + node_type + ' $attributes)'
        params = {'attributes': attributes}
        self.general_query(query, parameters=params)

    def create_edge_query(self, relationship, pkn1, pk1, pkn2, pk2):
        query = 'MATCH (a),(b) WHERE a.' + pkn1 + '= $pk1 AND b.' + pkn2 + '= $pk2 CREATE (a)-[r:' + relationship + ']->(b)'
        params = {'pk1': pk1, 'pk2': pk2}
        self.general_query(query, parameters=params)

    def create_write_edge_query(self, scholar_id, paper_id):
        #  (i:institution_entity {cn_name:"{name}"})
        query = 'MATCH (a:scholar_entity {id:$id1}), (b:paper {id:$id2}) CREATE (a) -[r:write]->(b)'
        params = {'id1': scholar_id, 'id2': paper_id}
        self.general_query(query, parameters=params)

    def create_belong_to_industry_edge_query(self, industry_name, industry_id):
        #  (i:institution_entity {cn_name:"{name}"})
        query = 'MATCH (a:company {industry:$id1}), (b:industry {industry_id:$id2}) CREATE (a) -[' \
                'r:belong_to_industry]->(b) '
        params = {'id1': industry_name, 'id2': industry_id}
        self.general_query(query, parameters=params)

    def find_number_of_node(self):
        with self.driver.session() as session:
            result = session.read_transaction(self._find_number_of_node)
            for row in result:
                print("Found node number: {row}".format(row=row))

    def general_read_query(self, query):
        with self.driver.session() as session:
            result = session.read_transaction(self._general_read_query_command, query)
            for row in result:
                print(row[0])
                print('\n')
            return result

    def get_read_query(self, query):
        with self.driver.session() as session:
            result = session.read_transaction(self._general_read_query_command, query)
            res = []
            try:
                for row in result:
                    res.append(row)
            except:
                print("error")
                None
            return res

    def fetch_data_from_neo(self,query):
        with self.driver.session() as session:
            result = session.run(query)
        return result

    def show_type_of_nodes(self):
        query = 'call db.labels()'
        self.general_read_query(query)

    def show_type_of_edges(self):
        query = 'MATCH ()-[relationship]->() RETURN distinct TYPE(relationship) AS type'
        self.general_read_query(query)

    @staticmethod
    def _general_read_query_command(tx, query):
        result = tx.run(query)
        return [row for row in result]

    @staticmethod
    def _create_and_return_greeting(tx, message):
        result = tx.run("CREATE (a:Greeting) "
                        "SET a.message = $message "
                        "RETURN a.message + ', from node ' + id(a)", message=message)
        return result.single()[0]

    @staticmethod
    def _find_number_of_node(tx):
        query = (
            "MATCH (n)"
            "RETURN count(n)"
            "LIMIT 10"
        )
        result = tx.run(query)
        return [row for row in result]


class FinNeo(NeoPyExample):
    def __init__(self, **kwargs):
        '''
        url:数据库接口
        user:user name
        password:user password
        '''
        super().__init__(kwargs['url'], kwargs['user'], kwargs['pwd'])


In [2]:
stock_index = np.load('data/csi300_stock_index.npy', allow_pickle=True)
index_dict = stock_index.tolist()

# build s2s_industry for survey

In [6]:
import re
from tqdm import tqdm
fintech = FinNeo(url="bolt://localhost:7687", user='neo4j', pwd='kaisa')
all_num_dict = [re.sub("[^0-9]", "", x) for x in index_dict.keys()]
tuple_set = []

for i in tqdm(all_num_dict):
    query = 'match (n:company) -[r1]- (i:industry) -[r2]- (c:company) where n.code contains \''+i+'\' return c.code as code '
    node = fintech.get_read_query(query)
    node = pd.DataFrame([dict(record) for record in node])
    node.replace('', np.nan, inplace=True)
    node = node.dropna().drop_duplicates()
    node = node.values.tolist()
    node = [x[0] for x in node]
    for n in node:
        if re.sub("[^0-9]", "", n) in all_num_dict:
            tuple_set.append([i,re.sub("[^0-9]", "", n) ])

# 所有csi300的通过industry连接的a股二元组,4561个二元组


static_neighbor = np.zeros([len(index_dict),len(index_dict)])

for group in tuple_set:
    if group[0][0] == '6':
        sym0 = 'SH'+group[0]
    else:
        sym0 = 'SZ'+group[0]
    if group[1][0] == '6':
        sym1 = 'SH'+group[1]
    else:
        sym1 = 'SZ'+group[1]
    index0 = index_dict[sym0]
    index1 = index_dict[sym1]
    static_neighbor[index0][index1] = 1
for i in range(len(index_dict)):
    #自己和自己一定是邻居
    static_neighbor[i][i] = 1

# 4561个二元组，对角线735个，4561*2+735=9857
np.save('data/survey_industry.npy', static_neighbor)


100%|████████████████████████████████████████| 735/735 [00:06<00:00, 110.02it/s]


In [9]:
# 合并hold和2_hold关系
hold1 = np.load('data/csi300_stock2stock_hold.npy')
hold2 = np.load('data/csi300_stock2stock_hold_2_hop.npy')
hold = hold1 + hold2
np.save('data/survey_hold.npy', hold)

In [11]:
# 合并三级SW关系
sw1 = np.load('data/csi300_stock2stock_SWL1.npy')
sw2 = np.load('data/csi300_stock2stock_SWL2.npy')
sw3 = np.load('data/csi300_stock2stock_SWL3.npy')
sw = sw1 + sw2 + sw3
np.save('data/survey_sw.npy',sw)

In [12]:
import re
from tqdm import tqdm
import re
fintech = FinNeo(url="bolt://143.89.126.53:5001", user='neo4j', pwd='csproject')
all_num_dict = [re.sub("[^0-9]", "", x) for x in index_dict.keys()]
query = 'match (n:company)-[r1]-(c:company) return r1'
node = fintech.get_read_query(query)
node = pd.DataFrame([dict(record) for record in node])
node.replace('', np.nan, inplace=True)
node = node.dropna().drop_duplicates()
node = node.values.tolist()
node = [re.findall(r'type=\'(.*?)\'', str(x[0])) for x in node]
node = [x[0] for x in node]
relation_type = list(set(node))

In [13]:
relation_type

['increase_holding',
 'same_industry',
 'dispute',
 'superior',
 'cooperate',
 'invest',
 'be_supplied',
 'be_increased_holding',
 'supply',
 'fall',
 'rise',
 'compete',
 'reduce_holding',
 'be_reduced_holding',
 'be_invested']

## order of relation between companies
'increase_holding',
 'same_industry',
 'dispute',
 'superior',
 'cooperate',
 'invest',
 'be_supplied',
 'be_increased_holding',
 'supply',
 'fall',
 'rise',
 'compete',
 'reduce_holding',
 'be_reduced_holding',
 'be_invested'

In [14]:
dyset=[] 
for relation in relation_type:
    temp_set = []
    for i in tqdm(all_num_dict):
        query = 'match (n:company)-[r1:'+relation+']-(c:company) where n.code contains \''+i+'\' return c.code as code, r1.time as time '
        node = fintech.get_read_query(query)
        node = pd.DataFrame([dict(record) for record in node])
        node.replace('', np.nan, inplace=True)
        node = node.dropna().drop_duplicates()
        node = node.values.tolist()
        for n in node:
            if re.sub("[^0-9]", "", n[0]) in all_num_dict:
                temp_set.append([i,re.sub("[^0-9]", "", n[0]), n[1]])
        
    dyset.append(temp_set)
        

100%|█████████████████████████████████████████| 735/735 [01:23<00:00,  8.82it/s]
100%|█████████████████████████████████████████| 735/735 [01:22<00:00,  8.93it/s]
100%|█████████████████████████████████████████| 735/735 [01:19<00:00,  9.26it/s]
100%|█████████████████████████████████████████| 735/735 [01:18<00:00,  9.39it/s]
100%|█████████████████████████████████████████| 735/735 [01:20<00:00,  9.14it/s]
100%|█████████████████████████████████████████| 735/735 [01:20<00:00,  9.17it/s]
100%|█████████████████████████████████████████| 735/735 [01:17<00:00,  9.53it/s]
100%|█████████████████████████████████████████| 735/735 [01:18<00:00,  9.33it/s]
100%|█████████████████████████████████████████| 735/735 [01:18<00:00,  9.40it/s]
100%|█████████████████████████████████████████| 735/735 [01:18<00:00,  9.40it/s]
100%|█████████████████████████████████████████| 735/735 [01:18<00:00,  9.32it/s]
100%|█████████████████████████████████████████| 735/735 [01:20<00:00,  9.08it/s]
100%|███████████████████████

In [18]:
target = dyset[0]+dyset[7]
neighbor = np.zeros([len(index_dict),len(index_dict)])

for group in target:
    if group[0][0] == '6':
        sym0 = 'SH'+group[0]
    else:
        sym0 = 'SZ'+group[0]
    if group[1][0] == '6':
        sym1 = 'SH'+group[1]
    else:
        sym1 = 'SZ'+group[1]
    index0 = index_dict[sym0]
    index1 = index_dict[sym1]
    neighbor[index0][index1] = 1

# 4561个二元组，对角线735个，4561*2+735=9857
np.save('data/survey_increase_holding.npy', neighbor)

In [20]:
target = dyset[1]
neighbor = np.zeros([len(index_dict),len(index_dict)])

for group in target:
    if group[0][0] == '6':
        sym0 = 'SH'+group[0]
    else:
        sym0 = 'SZ'+group[0]
    if group[1][0] == '6':
        sym1 = 'SH'+group[1]
    else:
        sym1 = 'SZ'+group[1]
    index0 = index_dict[sym0]
    index1 = index_dict[sym1]
    neighbor[index0][index1] = 1

# 4561个二元组，对角线735个，4561*2+735=9857
np.save('data/survey_same_industry.npy', neighbor)

In [24]:
target = dyset[2]
neighbor = np.zeros([len(index_dict),len(index_dict)])

for group in target:
    if group[0][0] == '6':
        sym0 = 'SH'+group[0]
    else:
        sym0 = 'SZ'+group[0]
    if group[1][0] == '6':
        sym1 = 'SH'+group[1]
    else:
        sym1 = 'SZ'+group[1]
    index0 = index_dict[sym0]
    index1 = index_dict[sym1]
    neighbor[index0][index1] = 1

np.save('data/survey_dispute.npy', neighbor)

In [22]:
target = dyset[3]
neighbor = np.zeros([len(index_dict),len(index_dict)])

for group in target:
    if group[0][0] == '6':
        sym0 = 'SH'+group[0]
    else:
        sym0 = 'SZ'+group[0]
    if group[1][0] == '6':
        sym1 = 'SH'+group[1]
    else:
        sym1 = 'SZ'+group[1]
    index0 = index_dict[sym0]
    index1 = index_dict[sym1]
    neighbor[index0][index1] = 1

np.save('data/survey_superior.npy', neighbor)

In [23]:
target = dyset[4]
neighbor = np.zeros([len(index_dict),len(index_dict)])

for group in target:
    if group[0][0] == '6':
        sym0 = 'SH'+group[0]
    else:
        sym0 = 'SZ'+group[0]
    if group[1][0] == '6':
        sym1 = 'SH'+group[1]
    else:
        sym1 = 'SZ'+group[1]
    index0 = index_dict[sym0]
    index1 = index_dict[sym1]
    neighbor[index0][index1] = 1

np.save('data/survey_cooprate.npy', neighbor)

In [25]:
target = dyset[5]+dyset[14]
neighbor = np.zeros([len(index_dict),len(index_dict)])

for group in target:
    if group[0][0] == '6':
        sym0 = 'SH'+group[0]
    else:
        sym0 = 'SZ'+group[0]
    if group[1][0] == '6':
        sym1 = 'SH'+group[1]
    else:
        sym1 = 'SZ'+group[1]
    index0 = index_dict[sym0]
    index1 = index_dict[sym1]
    neighbor[index0][index1] = 1

np.save('data/survey_invest.npy', neighbor)

In [26]:
target = dyset[6]+dyset[8]
neighbor = np.zeros([len(index_dict),len(index_dict)])

for group in target:
    if group[0][0] == '6':
        sym0 = 'SH'+group[0]
    else:
        sym0 = 'SZ'+group[0]
    if group[1][0] == '6':
        sym1 = 'SH'+group[1]
    else:
        sym1 = 'SZ'+group[1]
    index0 = index_dict[sym0]
    index1 = index_dict[sym1]
    neighbor[index0][index1] = 1

np.save('data/survey_supply.npy', neighbor)

In [27]:
target = dyset[9]
neighbor = np.zeros([len(index_dict),len(index_dict)])

for group in target:
    if group[0][0] == '6':
        sym0 = 'SH'+group[0]
    else:
        sym0 = 'SZ'+group[0]
    if group[1][0] == '6':
        sym1 = 'SH'+group[1]
    else:
        sym1 = 'SZ'+group[1]
    index0 = index_dict[sym0]
    index1 = index_dict[sym1]
    neighbor[index0][index1] = 1

np.save('data/survey_fall.npy', neighbor)

In [28]:
target = dyset[10]
neighbor = np.zeros([len(index_dict),len(index_dict)])

for group in target:
    if group[0][0] == '6':
        sym0 = 'SH'+group[0]
    else:
        sym0 = 'SZ'+group[0]
    if group[1][0] == '6':
        sym1 = 'SH'+group[1]
    else:
        sym1 = 'SZ'+group[1]
    index0 = index_dict[sym0]
    index1 = index_dict[sym1]
    neighbor[index0][index1] = 1

np.save('data/survey_rise.npy', neighbor)

In [29]:
target = dyset[11]
neighbor = np.zeros([len(index_dict),len(index_dict)])

for group in target:
    if group[0][0] == '6':
        sym0 = 'SH'+group[0]
    else:
        sym0 = 'SZ'+group[0]
    if group[1][0] == '6':
        sym1 = 'SH'+group[1]
    else:
        sym1 = 'SZ'+group[1]
    index0 = index_dict[sym0]
    index1 = index_dict[sym1]
    neighbor[index0][index1] = 1

np.save('data/survey_compete.npy', neighbor)

In [30]:
target = dyset[12]+dyset[13]
neighbor = np.zeros([len(index_dict),len(index_dict)])

for group in target:
    if group[0][0] == '6':
        sym0 = 'SH'+group[0]
    else:
        sym0 = 'SZ'+group[0]
    if group[1][0] == '6':
        sym1 = 'SH'+group[1]
    else:
        sym1 = 'SZ'+group[1]
    index0 = index_dict[sym0]
    index1 = index_dict[sym1]
    neighbor[index0][index1] = 1

np.save('data/survey_reduce_holding.npy', neighbor)

In [37]:
np.expand_dims(neighbor, axis=2).shape

(735, 735, 1)