In [1]:
import time
import difflib
import pickle
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import dgl
import dgl.function as fn
from dgl import DGLGraph

import networkx as nx
from sklearn.metrics import accuracy_score

In [24]:
#来自DGLGraph tutorial
class NodeApplyModule(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(NodeApplyModule, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
        self.activation = activation

    def forward(self, node):
        h = self.linear(node.data['h'])
        if self.activation is not None:
            h = self.activation(h)
        return {'h' : h}

gcn_msg = fn.copy_src(src='h', out='m')
gcn_reduce = fn.sum(msg='m', out='h')

class GCN(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(GCN, self).__init__()
        self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)

    def forward(self, g, feature):
        g.ndata['h'] = feature
        g.update_all(gcn_msg, gcn_reduce)
        g.apply_nodes(func=self.apply_mod)
        return g.ndata.pop('h')
    
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.gcn1 = GCN(768, 192, F.relu)
        self.gcn2 = GCN(192, 96, F.relu)
        self.fc = nn.Linear(96, 2)

    def forward(self, g, features):
        x = self.gcn1(g, features)
        x = self.gcn2(g, x)
        x = self.fc(x)
        return x
    
    def predict(self, g, features):
        pred = F.softmax(self.forward(g, features))
        ans = []
        for t in pred:
            if t[0]>t[1]:
                ans.append(0)
            else:
                ans.append(1)
        return torch.tensor(ans)
net = Net()

In [3]:
#读取DealWithData中golden数据
data_golden = pd.read_csv( './DataSet/book/golden/claims_golden.txt' , sep='\t' )
data_golden['encode'] = torch.load('./DataSet/book/golden/claims_golden_encode.pt')
GoldenLabel = pd.read_table("./DataSet/book/book_golden.txt" , sep='\t' , header=None , names=['isbn','author'])

In [4]:
#随意点划分训练集和测试集，直觉要根据'isbn'划分
data_train = pd.DataFrame(data_golden.drop(data_golden.index,inplace=False))
data_test = pd.DataFrame(data_golden.drop(data_golden.index,inplace=False))
for i in range(0,int(len(GoldenLabel)/2+1)):
    data_trainSlice = data_golden[data_golden['isbn']==GoldenLabel.loc[i]['isbn']]
    data_train = data_train.append(data_trainSlice)
for i in range(int(len(GoldenLabel)/2+1),len(GoldenLabel)):
    data_testSlice = data_golden[data_golden['isbn']==GoldenLabel.loc[i]['isbn']]
    data_test = data_test.append(data_testSlice)
data_train.reset_index(drop=True,inplace=True)
data_test.reset_index(drop=True,inplace=True)

In [None]:
#构建图结构函数，根据训练集和测试集连边
#此处可优化复杂度？n**2/2 -> kn
#此处可根据'book_name'相似度进一步加边
def generate_DGLGraph(df):
    edge_norm = []
    g = DGLGraph()
    g.add_nodes(df.shape[0])
    for i in range(0,len(df)):
        for j in range(i+1,len(df)):
            if (df.loc[i]["source"]==df.loc[j]["source"]):
                g.add_edge(i,j)
                #edge_norm.append(1.0)
                g.add_edge(j,i)
                #edge_norm.append(1.0)
            elif (df.loc[i]["isbn"]==df.loc[j]["isbn"]):
                str1 = df.loc[i]['author']
                str2 = df.loc[j]['author']
                #print(str1,str2,difflib.SequenceMatcher(None,str1,str2).quick_ratio())
                if ( difflib.SequenceMatcher(None,str1,str2).quick_ratio()>0.8 ):
                    g.add_edge(i,j)
                    #edge_norm.append(1.0)
                    g.add_edge(j,i)
                    #edge_norm.append(1.0)
    #edge_norm = torch.Tensor(edge_norm).unsqueeze(1)
    #g.edata.update({ 'norm': edge_norm })
    return g

graph_train = generate_DGLGraph(data_train)
graph_test = generate_DGLGraph(data_test)

In [None]:
#存储图结构
file = open('./DataSet/book/golden/graph_train.pickle', 'wb')
pickle.dump(graph_train, file)
file.close()
file = open('./DataSet/book/golden/graph_test.pickle', 'wb')
pickle.dump(graph_test, file)
file.close()

In [5]:
#加载图结构
with open('./DataSet/book/golden/graph_train.pickle', 'rb') as file:
    graph_train =pickle.load(file)
with open('./DataSet/book/golden/graph_test.pickle', 'rb') as file:
    graph_test =pickle.load(file)

In [None]:
#存储和加载图结构的失败尝试
'''
graph_train_netx = graph_train.to_networkx()
graph_test_netx = graph_test.to_networkx()
nx.write_gexf(graph_train_netx,'./DataSet/book/golden/train_graph.gexf')
nx.write_gexf(graph_test_netx,'./DataSet/book/golden/test_graph.gexf')

graph_train_netx = nx.read_gexf('./DataSet/book/golden/train_graph.gexf')
graph_test_netx = nx.read_gexf('./DataSet/book/golden/test_graph.gexf')

graph_train2 = DGLGraph(graph_train_netx)
#graph_train2.from_networkx(graph_train_netx)
graph_test2 = DGLGraph(graph_test_netx)
#graph_test2.from_networkx(graph_test_netx)
'''

In [6]:
def extract_featureNlabel(df):
    a = torch.zeros(1,768)
    c = torch.zeros(1)
    for _,row in df.iterrows():
        b = row['encode'].reshape([-1,768])
        a = torch.cat((a,b),0)
        if(row['label']):
            d = torch.ones(1)
        else:
            d = torch.zeros(1)
        c = torch.cat((c,d),-1)
    return a[1:,:],c[1:].long()
train_feature,train_label = extract_featureNlabel(data_train)
test_feature,test_label = extract_featureNlabel(data_test)

In [30]:
pred_label = net.predict(graph_test, test_feature)
print('accu',accuracy_score(pred_label,test_label))

accu 0.6070549630844955




In [36]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
dur = []
for epoch in range(100):
    
    t0 = time.time()

    pred_prob = net.forward(graph_train, train_feature)
    loss = criterion(pred_prob,train_label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    

    pred_label = net.predict(graph_train, train_feature)
    train_accu = accuracy_score(pred_label,train_label)
    pred_label = net.predict(graph_test, test_feature)
    print('test_accu',accuracy_score(pred_label,test_label))

    dur.append(time.time() - t0)
    print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | Train_Accu {:4f}".format(
            epoch, loss.item(), np.mean(dur), train_accu))



test_accu 0.6300246103363413
Epoch 00000 | Loss 0.6930 | Time(s) 0.7068 | Train_Accu 0.509356
test_accu 0.6300246103363413
Epoch 00001 | Loss 0.6930 | Time(s) 0.7026 | Train_Accu 0.509356
test_accu 0.6300246103363413
Epoch 00002 | Loss 0.6930 | Time(s) 0.6980 | Train_Accu 0.509356
test_accu 0.6300246103363413
Epoch 00003 | Loss 0.6930 | Time(s) 0.7017 | Train_Accu 0.509356
test_accu 0.6300246103363413
Epoch 00004 | Loss 0.6930 | Time(s) 0.7012 | Train_Accu 0.509356
test_accu 0.6300246103363413
Epoch 00005 | Loss 0.6930 | Time(s) 0.7000 | Train_Accu 0.509356
test_accu 0.6300246103363413
Epoch 00006 | Loss 0.6930 | Time(s) 0.7058 | Train_Accu 0.509356
test_accu 0.6300246103363413
Epoch 00007 | Loss 0.6930 | Time(s) 0.7060 | Train_Accu 0.509356
test_accu 0.6300246103363413
Epoch 00008 | Loss 0.6930 | Time(s) 0.7024 | Train_Accu 0.509356
test_accu 0.6300246103363413
Epoch 00009 | Loss 0.6930 | Time(s) 0.7021 | Train_Accu 0.509356
test_accu 0.6300246103363413
Epoch 00010 | Loss 0.6930 | Tim

test_accu 0.6300246103363413
Epoch 00088 | Loss 0.6930 | Time(s) 0.7196 | Train_Accu 0.509356
test_accu 0.6300246103363413
Epoch 00089 | Loss 0.6930 | Time(s) 0.7192 | Train_Accu 0.509356
test_accu 0.6300246103363413
Epoch 00090 | Loss 0.6930 | Time(s) 0.7200 | Train_Accu 0.509356
test_accu 0.6300246103363413
Epoch 00091 | Loss 0.6930 | Time(s) 0.7214 | Train_Accu 0.509356
test_accu 0.6300246103363413
Epoch 00092 | Loss 0.6930 | Time(s) 0.7212 | Train_Accu 0.509356
test_accu 0.6300246103363413
Epoch 00093 | Loss 0.6930 | Time(s) 0.7207 | Train_Accu 0.509356
test_accu 0.6300246103363413
Epoch 00094 | Loss 0.6930 | Time(s) 0.7203 | Train_Accu 0.509356
test_accu 0.6300246103363413
Epoch 00095 | Loss 0.6930 | Time(s) 0.7199 | Train_Accu 0.509356
test_accu 0.6300246103363413
Epoch 00096 | Loss 0.6930 | Time(s) 0.7198 | Train_Accu 0.509356
test_accu 0.6300246103363413
Epoch 00097 | Loss 0.6930 | Time(s) 0.7198 | Train_Accu 0.509356
test_accu 0.6300246103363413
Epoch 00098 | Loss 0.6930 | Tim

In [None]:
#玩具数据集用以检验图结构生成的准确率。
#测试发现difflib的相似度量是字符级的
'''
dataframe = pd.DataFrame([
        ["a", "111222", "computer Science", "bruce"],
        ["b", "111222", "computer Science", "Bruce Lee"],
        ["c", "111222", "computer Science", "mike ,john"],
        ["a", "111223", "Hassdsdsaad", "kkl"],
        ["d", "111223", "Hassdsdaaad", "kkkl"],
        ["c", "111224", "asdfgh", "zxcr"]
    ],
    columns=["source", "isbn", "name", "author"]
)
g = generate_DGLGraph(dataframe)
'''