In [2]:
import pandas as pd
import pickle
import os
from itertools import chain
from six.moves import reduce
import numpy as np
import torch
from torch import optim, nn
import torch.nn.functional as F
from sklearn import metrics

USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda:2" if USE_CUDA else "cpu")
random_state = 42
learning_rate = 0.0002
batch_size = 16
epsilon = 1e-8
max_grad_norm = 40.0
evaluation_interval = 1
hops = 3
epochs = 20
embedding_size = 20

In [3]:
import re
from pyhanlp import *
from pyltp import Segmentor, Postagger
import random
from gensim.models import KeyedVectors


LTP_DATA_DIR = '/home/wujipeng/data/ltp_data_v3.4.0/'  # ltp模型目录的路径
cws_model_path = os.path.join(LTP_DATA_DIR, 'cws.model')  # 分词模型路径，模型名称为`cws.model`
segmentor = Segmentor()  # 初始化实例
segmentor.load(cws_model_path)  # 加载模型


def hanlp_cut(sentence):
    return [term.word for term in HanLP.segment(sentence)]


def ltp_cut(sentence):
    return segmentor.segment(sentence)


def maxS(alist):
    maxScore = 0.0
    maxIndex = -1
    for i in range(len(alist)):
        if alist[i] > maxScore:
            maxScore = alist[i]
            maxIndex = i
    return maxScore, maxIndex


def sampling():
    numOfSentence = 2105
    testSentence = set()
    for i in range(1, numOfSentence + 1):
        if random.random() > 0.9:
            testSentence.add(i)
    rate = float(len(testSentence)) / numOfSentence
    #    print(rate)
    if rate > 0.1 and rate - 0.1 < 0.0003:
        print(len(testSentence))
        return testSentence
    else:
        return sampling()


def division(testSentence, n, mode='train'):
    data_path = '/home/wujipeng/git/ECA/ec/MemNet/' + 'data' if mode == 'train' else 'test_data'
    testSentenceFile = open(os.path.join(data_path, 'testSentenceFile_{}.csv'.format(n)), 'w')
    for index, item in enumerate(testSentence):
        testSentenceFile.write(str(item) + '\n')
    testSentenceFile.close()

    inputFile = open(os.path.join(data_path, 'datacsv_2105.csv'), 'r')
    outputFile1 = open(os.path.join(data_path, 'clause_test_{}.csv'.format(n)), 'w')
    outputFile2 = open(os.path.join(data_path, 'clause_train_{}.csv'.format(n)), 'w')
    test_pos_count, test_neg_count, train_pos_count, train_neg_count = 0, 0, 0, 0
    for _, line in enumerate(inputFile):
        sentenceID = int(line.strip().split(',')[0])
        sentence = line.strip().split(',')[-1]
        keyword = None
        keyPos = -1
        clauseList = re.split('，|。|？|！|；|……', sentence)
        causeOfSent = set()
        for index, item in enumerate(clauseList):
            match = re.search(r'\[f\]([^\[]*)\[/f\]', item)
            if match:
                keyword = match.group(1)
                keyPos = index
            match = re.search(r'\[\d[nv]\][^\[]*\[/\d[nv]\]', item)
            if match:
                causeOfSent.add(index)
            match = re.search(r'\[[-\d]*\*\d[nv]\][^\[]*\[/[-\d]*\*\d[nv]\]', item)
            if match:
                causeOfSent.add(index)
        for index, item in enumerate(clauseList):
            clause = re.sub(r'\[[^\[]*\]', '', item)
            clause = re.sub(r'“', '', clause)
            clause = re.sub(r'”', '', clause)
            clause = re.sub(r'：', '', clause)
            clause = re.sub(r'\(', '', clause)
            clause = re.sub(r'\)', '', clause)
            clause = re.sub(r'、', '', clause)
            clause = re.sub(r'’', '', clause)
            clause = re.sub(r'‘', '', clause)
            clause = re.sub(r'》', '', clause)
            clause = re.sub(r'《', '', clause)
            clause = re.sub(r'~', '', clause)
            words = ltp_cut(clause)
            pos = index - keyPos
            if (sentenceID in testSentence) and (clause.split()):
                if index in causeOfSent:
                    test_pos_count += 1
                    outputFile1.write(str(sentenceID) + ',' + str(index + 1) + ',' + keyword + ',')
                    outputFile1.write(str(pos) + ',yes,' + ' '.join(words) + '\n')
                else:
                    test_neg_count += 1
                    outputFile1.write(str(sentenceID) + ',' + str(index + 1) + ',' + keyword + ',')
                    outputFile1.write(str(pos) + ',no,' + ' '.join(words) + '\n')
            elif (sentenceID not in testSentence) and (clause.split()):
                if index in causeOfSent:
                    train_pos_count += 1
                    outputFile2.write(str(sentenceID) + ',' + str(index + 1) + ',' + keyword + ',')
                    outputFile2.write(str(pos) + ',yes,' + ' '.join(words) + '\n')
                else:
                    train_neg_count += 1
                    outputFile2.write(str(sentenceID) + ',' + str(index + 1) + ',' + keyword + ',')
                    outputFile2.write(str(pos) + ',no,' + ' '.join(words) + '\n')
    print('********************************')
    print('test_pos_count', test_pos_count)
    print('test_neg_count', test_neg_count)
    print('train_pos_count', train_pos_count)
    print('train_neg_count', train_neg_count)
    print('********************************')
    inputFile.close()
    outputFile1.close()
    outputFile2.close()


def statisticPos(mode='train'):
    data_path = '/home/wujipeng/git/ECA/ec/MemNet/' + 'data' if mode == 'train' else 'test_data'
    inputFile = open(os.path.join(data_path, 'datacsv_2105.csv'), 'r')
    posDict = dict()
    for _, line in enumerate(inputFile):
        content = line.strip().split(',')[-1]
        clauseList = re.split('，|。|？|！|；|……', content)
        keyPos = -1
        for index, item in enumerate(clauseList):
            match = re.search(r'\[f\]([^\[]*)\[/f\]', item)
            if match:
                keyPos = index
        for index, item in enumerate(clauseList):
            pos = index - keyPos
            if posDict.get(pos):
                posDict[pos] += 1
            else:
                posDict[pos] = 1
    inputFile.close()
    return posDict


def changePos(posDict):
    count = 0
    countDict = dict()
    for key, value in posDict.items():
        countDict[key] = count
        count += 1
    posList = ['AAAA', 'AAAB', 'AAAC', 'AAAD', 'AABA', 'AABB', 'AABC', 'AABD', 'AACA', 'AACB', 'AACC', 'AACD',
               'AADA', 'AADB', 'AADC', 'AADD', 'ABAA', 'ABAB', 'ABAC', 'ABAD', 'ABBA', 'ABBB', 'ABBC', 'ABBD',
               'ABCA', 'ABCB', 'ABCC', 'ABCD', 'ABDA', 'ABDB', 'ABDC', 'ABDD', 'ACAA', 'ACAB', 'ACAC', 'ACAD',
               'ACBA', 'ACBB', 'ACBC', 'ACBD', 'ACCA', 'ACCB', 'ACCC', 'ACCD', 'ACDA', 'ACDB', 'ACDC', 'ACDD',
               'ADAA', 'ADAB', 'ADAC', 'ADAD', 'ADBA', 'ADBB', 'ADBC', 'ADBD', 'ADCA', 'ADCB', 'ADCC', 'ADCD',
               'ADDA', 'ADDB', 'ADDC', 'ADDD', 'BAAA', 'BAAB', 'BAAC', 'BAAD', 'BABA', 'BABB', 'BABC', 'BABD',
               'BACA', 'BACB', 'BACC', 'BACD', 'BADA', 'BADB', 'BADC', 'BADD', 'BBAA', 'BBAB', 'BBAC', 'BBAD',
               'BBBA', 'BBBB', 'BBBC', 'BBBD', 'BBCA', 'BBCB', 'BBCC', 'BBCD', 'BCAA', 'BCAB', 'BCAC', 'BCAD',
               'BDAA', 'BDAB', 'BDAC', 'BDAD', 'BCBA', 'BCBB', 'BCBC']
    strPosDict = dict()
    for key, value in countDict.items():
        strPosDict[key] = posList[value]
    return strPosDict


def construction(strPosDict, n, mode='train'):
    data_path = '/home/wujipeng/git/ECA/ec/MemNet/' + 'data' if mode == 'train' else 'test_data'
    inputFile1 = open(os.path.join(data_path, 'clause_test_{}.csv'.format(n)), 'r')
    inputFile2 = open(os.path.join(data_path, 'clause_train_{}.csv'.format(n)), 'r')
    outputFile1 = open(os.path.join(data_path, 'emotion_cause_clause_level_test_{}.csv'.format(n)), 'w')
    outputFile2 = open(os.path.join(data_path, 'emotion_cause_clause_level_train_{}.csv'.format(n)), 'w')
    for _, line in enumerate(inputFile1):
        keyword = line.strip().split(',')[2]
        position = int(line.strip().split(',')[3])
        posStr = strPosDict[position]
        label = line.strip().split(',')[4]
        clause = line.strip().split(',')[5]
        wordList = clause.strip().split(' ')
        phraseList = []
        if len(wordList) >= 3:
            window = 3
            begin = 0
            for index, item in enumerate(wordList):
                end = begin + window
                if end <= len(wordList):
                    phraseList.append(wordList[begin: end])
                begin += 1
        else:
            phraseList.append(wordList[:])
        lineNum = 1
        for index, item in enumerate(phraseList):
            outputFile1.write(str(lineNum))
            for i in range(len(item)):
                outputFile1.write(' ' + item[i])
            outputFile1.write('\n')
            lineNum += 1
        outputFile1.write(str(lineNum) + ' ' + posStr + ' ' + posStr + ' ' + posStr + '\n')
        lineNum += 1
        outputFile1.write(str(lineNum) + ' ' + keyword + ' ' + keyword + ' ' + keyword + '\t' + label + '\n')
    inputFile1.close()
    outputFile1.close()
    for _, line in enumerate(inputFile2):
        keyword = line.strip().split(',')[2]
        position = int(line.strip().split(',')[3])
        posStr = strPosDict[position]
        label = line.strip().split(',')[4]
        clause = line.strip().split(',')[5]
        wordList = clause.strip().split(' ')
        phraseList = []
        if len(wordList) >= 3:
            window = 3
            begin = 0
            for index, item in enumerate(wordList):
                end = begin + window
                if end <= len(wordList):
                    phraseList.append(wordList[begin: end])
                begin += 1
        else:
            phraseList.append(wordList[:])
        lineNum = 1
        for index, item in enumerate(phraseList):
            outputFile2.write(str(lineNum))
            for i in range(len(item)):
                outputFile2.write(' ' + item[i])
            outputFile2.write('\n')
            lineNum += 1
        outputFile2.write(str(lineNum) + ' ' + posStr + ' ' + posStr + ' ' + posStr + '\n')
        lineNum += 1
        outputFile2.write(str(lineNum) + ' ' + keyword + ' ' + keyword + ' ' + keyword + '\t' + label + '\n')
    inputFile2.close()
    outputFile2.close()


def extractAll(mode='train'):
    data_path = '/home/wujipeng/git/ECA/ec/MemNet/' + 'data' if mode == 'train' else 'test_data'
    inputFile = open(os.path.join(data_path, 'datacsv_2105.csv'), 'r')
    allClause = []
    for index, line in enumerate(inputFile):
        sent = int(line.strip().split(',')[0])  # int
        content = line.strip().split(',')[-1]
        clauseList = re.split('，|。|？|！|；|……', content)
        for index, clause in enumerate(clauseList):
            cause = 'n'
            match = re.search(r'\[\d[nv]\][^\[]*\[/\d[nv]\]', clause)
            if match:
                cause = 'c'
            match = re.search(r'\[[-\d]*\*\d[nv]\][^\[]*\[/[-\d]*\*\d[nv]\]', clause)
            if match:
                cause = 'c'
            clause = re.sub(r'\[[^\[]*\]', '', clause)
            clause = re.sub(r'“', '', clause)
            clause = re.sub(r'”', '', clause)
            clause = re.sub(r'：', '', clause)
            clause = re.sub(r'\(', '', clause)
            clause = re.sub(r'\)', '', clause)
            clause = re.sub(r'、', '', clause)
            clause = re.sub(r'’', '', clause)
            clause = re.sub(r'‘', '', clause)
            clause = re.sub(r'》', '', clause)
            clause = re.sub(r'《', '', clause)
            if clause.split():
                allClause.append([sent, index + 1, cause])
    print('allClause', len(allClause))
    inputFile.close()
    return allClause


def extractRealRight(allClause, n, mode='train'):
    data_path = '/home/wujipeng/git/ECA/ec/MemNet/' + 'data' if mode == 'train' else 'test_data'
    testSentenceFile = open(os.path.join(data_path, 'testSentenceFile_{}.csv'.format(n)), 'r')
    testSentence = set()
    for index, item in enumerate(testSentenceFile):
        item = int(item.strip())
        testSentence.add(item)
    testSentenceFile.close()
    testClause = []
    for index, item in enumerate(allClause):
        sent = allClause[index][0]
        if sent in testSentence:
            testClause.append(item)
    print('testClause(all)', len(testClause))
    realRight = []
    realSet = set()
    for index, item in enumerate(testClause):
        cSent = item[0]
        cClause = item[1]
        cCause = item[2]
        if cCause == 'c':
            realRight.append([cSent, cClause])
            key = str(cSent) + ',' + str(cClause)
            realSet.add(key)
    print('realRight', len(realRight))
    print('realSet', len(realSet))
    return realRight, realSet


def extractPredRight(n, m, mode='train'):
    data_path = '/home/wujipeng/git/ECA/ec/MemNet/' + 'data' if mode == 'train' else 'test_data'
    inputFile = open(os.path.join(data_path, 'clause_test_{}.csv'.format(n)), 'r')
    sentAndClause = []
    clauseTest = []
    keywords = []
    for index, line in enumerate(inputFile):
        sent = line.strip().split(',')[0]
        clause = line.strip().split(',')[1]
        keyword = line.strip().split(',')[2]
        content = []
        content = line.strip().split(',')[-1].split(' ')
        sentAndClause.append([sent, clause])
        keywords.append(keyword)
        clauseTest.append(content)
    print('sentAndClause', len(sentAndClause))  # the num of testing events
    print('clauseTest', len(clauseTest))
    print('keywords', len(keywords))
    inputFile.close()
    inputFile = open(os.path.join(data_path, 'prediction_{}_{}.csv'.format(n, m)), 'r')
    scoreList = []
    for index, item in enumerate(inputFile):
        score = float(item.strip())
        scoreList.append(score)
    print('score', len(scoreList))
    inputFile.close()
    predict = dict()
    maxScore = 0
    for index, item in enumerate(sentAndClause):
        cSent = int(item[0])
        cClause = int(item[1])
        score = scoreList[index]
        if predict.get(cSent):
            if score > maxScore:
                predict[cSent] = cClause
                maxScore = score
        else:
            predict[cSent] = cClause
            maxScore = score
    print('predict', len(predict))
    predictRight = []
    predSet = set()
    for key, value in predict.items():
        predictRight.append([key, value])
        key = str(key) + ',' + str(value)
        predSet.add(key)
    print('predictRight', len(predictRight))
    print('predSet', len(predSet))
    return predictRight, predSet, sentAndClause, keywords, clauseTest


def statistics(n, m, realRight, realSet, predictRight, predSet, sentAndClause, keywords, clauseTest, mode='train'):
    data_path = '/home/wujipeng/git/ECA/ec/MemNet/' + 'data' if mode == 'train' else 'test_data'
    suc = 0
    sucSet = set()
    for i in range(len(realRight)):
        for j in range(len(predictRight)):
            if realRight[i] == predictRight[j]:
                suc += 1
                key = str(realRight[i][0]) + ',' + str(realRight[i][1])
                sucSet.add(key)
    print('suc', suc)
    precision = float(suc) / len(predictRight)
    recall = float(suc) / len(realRight)
    f1 = 2 * precision * recall / (precision + recall)
    print('precision', precision)
    print('recall', recall)
    print('f1', f1)
    print('****************************************')
    inputFile = open(os.path.join(data_path, 'statistics_{}_{}.csv'.format(n, m)), 'r')
    outputFile = open(os.path.join(data_path, 'statistics_final_{}_{}.csv'.format(n, m)), 'w')
    for index, line in enumerate(inputFile):
        if index == 0:
            line = line.strip() + '\tsentence\tclause\tpredict\treal\tsucc\tkeyword\tcontent\n'
        else:
            line = line.strip() + '\t' + str(sentAndClause[index - 1][0]) + '\t' + str(
                sentAndClause[index - 1][1]) + '\t'
            key = str(sentAndClause[index - 1][0]) + ',' + str(sentAndClause[index - 1][1])
            if key in predSet:
                line += 'c' + '\t'
            else:
                line += 'n' + '\t'
            if key in realSet:
                line += 'c' + '\t'
            else:
                line += 'n' + '\t'
            if key in sucSet:
                line += 'suc' + '\t'
            else:
                line += '###' + '\t'
            line += keywords[index - 1] + '\t' + ' '.join(clauseTest[index - 1]) + '\n'
        outputFile.write(line)
    inputFile.close()
    outputFile.close()
    return precision, recall, f1

import numpy as np
def load_data(n, mode='train'):
    data_path = '/home/wujipeng/git/ECA/ec/MemNet/' + 'data' if mode == 'train' else 'test_data'
    testFile = open(os.path.join(data_path, 'emotion_cause_clause_level_test_{}.csv'.format(n)), 'r')
    trainFile = open(os.path.join(data_path, 'emotion_cause_clause_level_train_{}.csv'.format(n)), 'r')
    testData = []
    testStory = []
    for _, line in enumerate(testFile):
        nid = line.strip().split(' ')[0]
        nLine = line.strip().split(' ')[1:]
        if nid == '1':
            testStory = []
            items = nLine[:]
            testStory.append(items)
        if '\t' in nLine[-1]:
            q = []
            q.append(nLine[0])
            q.append(nLine[1])
            q.append(nLine[-1].split('\t')[0])
            a = [nLine[-1].split('\t')[1]]
            subStory = [x for x in testStory if x]
            testData.append((subStory, q, a))
        else:
            items = nLine[:]
            testStory.append(items)
    testFile.close()
    trainData = []
    trainStory = []
    for _, line in enumerate(trainFile):
        nid = line.strip().split(' ')[0]
        nLine = line.strip().split(' ')[1:]
        if nid == '1':
            trainStory = []
            items = nLine[:]
            trainStory.append(items)
        if '\t' in nLine[-1]:
            q = []
            q.append(nLine[0])
            q.append(nLine[1])
            q.append(nLine[-1].split('\t')[0])
            a = [nLine[-1].split('\t')[1]]
            subStory = [x for x in trainStory if x]
            trainData.append((subStory, q, a))
        else:
            items = nLine[:]
            trainStory.append(items)
    trainFile.close()
    return testData, trainData


def vectorize_data(data, word_idx, memory_size, sentence_size):
    S = []
    Q = []
    A = []
    for story, query, answer in data:
        ss = []
        for index, sentence in enumerate(story):
            ls = max(0, sentence_size - len(sentence))
            ss.append([word_idx[w] for w in sentence] + [0] * ls)
        # take only the most recent sentences that fit in memory
        ss = ss[::-1][:memory_size][::-1]

        # pad to memory_size
        lm = max(0, memory_size - len(ss))
        for _ in range(lm):
            ss.append([0] * sentence_size)

        q = [word_idx[w] for w in query]

        a = np.zeros(2)  # the answer is yes or no
        for an in answer:
            if an == 'yes':
                a[1] = 1
            elif an == 'no':
                a[0] = 1

        S.append(ss)
        Q.append(q)
        A.append(a)
    return np.array(S), np.array(Q), np.array(A)

def create_embedding(vocab, mode='train'):
    print('读取预训练Embbeding')
    word2vec = KeyedVectors.load_word2vec_format('/data/wujipeng/ec/data/embedding/seg_resource.bin', binary=False)
    dim = word2vec.vector_size
    embedding = [np.zeros(dim)]  # pad
    cnt = 0
    for word in vocab:
        if word2vec.vocab.get(word):
            embedding.append(word2vec.get_vector(word))
            cnt += 1
        else:
            embedding.append(np.random.normal(loc=0., scale=0.1, size=dim))
    embedding = np.array(embedding)
    print('Embedding shape', embedding.shape)
    print('Embedding rate: {:.2f}%'.format(cnt / len(vocab) * 100))
    return embedding
# def create_embedding(vocab, mode='train'):
#     from gensim.models import KeyedVectors
#     print('读取预训练Embbeding')
#     word2vec = KeyedVectors.load_word2vec_format('/data/wujipeng/ec/data/embedding/vec_new.txt', binary=False)
#     dim = word2vec.vector_size
#     embedding = [np.zeros(dim)]  # pad
#     cnt = 0
#     for word in vocab:
#         if word2vec.vocab.get(word):
#             embedding.append(word2vec.get_vector(word))
#             cnt += 1
#         else:
#             embedding.append(np.random.normal(loc=0., scale=0.1, size=dim))
#     embedding = np.array(embedding)
#     print('Embedding shape', embedding.shape)
#     print('Embedding rate: {:.2f}%'.format(cnt / len(vocab) * 100))
#     return np.array(embedding)    

# def create_embedding(vocab, mode='train'):
#     from gensim.models import KeyedVectors
#     print('读取预训练Embbeding')
#     word2vec = KeyedVectors.load_word2vec_format('/data/wujipeng/embedding/Wikipedia/sgns.wiki.word', binary=False)
#     dim = word2vec.vector_size
#     embedding = [np.zeros(dim)]  # pad
#     cnt = 0
#     for word in vocab:
#         if word2vec.vocab.get(word):
#             embedding.append(word2vec.get_vector(word))
#             cnt += 1
#         else:
#             embedding.append(np.random.normal(loc=0., scale=0.1, size=dim))
#     embedding = np.array(embedding)
#     print('Embedding shape', embedding.shape)
#     print('Embedding rate: {:.2f}%'.format(cnt / len(vocab) * 100))
#     return np.array(embedding)

In [4]:
import numpy as np
def load_data(n, mode='train'):
    data_path = '/home/wujipeng/git/ECA/ec/MemNet/' + 'data' if mode == 'train' else 'test_data'
    testFile = open(os.path.join('/home/wujipeng/git/ECA/ec/MemNet/', data_path, 'emotion_cause_clause_level_test_{}.csv'.format(n)), 'r')
    trainFile = open(os.path.join('/home/wujipeng/git/ECA/ec/MemNet/', data_path, 'emotion_cause_clause_level_train_{}.csv'.format(n)), 'r')
    testData = []
    testStory = []
    for _, line in enumerate(testFile):
        nid = line.strip().split(' ')[0]
        nLine = line.strip().split(' ')[1:]
        if nid == '1':
            testStory = []
            items = nLine[:]
            testStory.append(items)
        if '\t' in nLine[-1]:
            q = []
            q.append(nLine[0])
            q.append(nLine[1])
            q.append(nLine[-1].split('\t')[0])
            a = [nLine[-1].split('\t')[1]]
            subStory = [x for x in testStory if x]
            testData.append((subStory, q, a))
        else:
            items = nLine[:]
            testStory.append(items)
    testFile.close()
    trainData = []
    trainStory = []
    for _, line in enumerate(trainFile):
        nid = line.strip().split(' ')[0]
        nLine = line.strip().split(' ')[1:]
        if nid == '1':
            trainStory = []
            items = nLine[:]
            trainStory.append(items)
        if '\t' in nLine[-1]:
            q = []
            q.append(nLine[0])
            q.append(nLine[1])
            q.append(nLine[-1].split('\t')[0])
            a = [nLine[-1].split('\t')[1]]
            subStory = [x for x in trainStory if x]
            trainData.append((subStory, q, a))
        else:
            items = nLine[:]
            trainStory.append(items)
    trainFile.close()
    return testData, trainData

In [5]:
posDict = statisticPos()
strPosDict = changePos(posDict)
allClause = extractAll()
pre_all = 0.0
rec_all = 0.0
f1_all = 0.0
pre_all_ranking = 0.0
rec_all_ranking = 0.0
f1_all_ranking = 0.0
pre_all_list = []
rec_all_list = []
f1_all_list = []
pre_all_list_ranking = []
rec_all_list_ranking = []
f1_all_list_ranking = []
time = 1
time_list = []

allClause 31296


In [6]:
testSentence = sampling()
division(testSentence, time)
construction(strPosDict, time)
realRight, realSet = extractRealRight(allClause, time)

test, train = load_data(time)
data = test + train
vocab = sorted(reduce(lambda x, y: x | y, (set(list(chain.from_iterable(s)) + q) for s, q, _ in data)))
word_idx = dict((c, k + 1) for k, c in enumerate(vocab))

embedding = create_embedding(vocab)
print('embedding', embedding.shape)
memory_size = max(map(len, (s for s, _, _ in data)))
sentence_size = 3
vocab_size = len(word_idx) + 1

211
********************************
test_pos_count 218
test_neg_count 2913
train_pos_count 1949
train_neg_count 26216
********************************
testClause(all) 3131
realRight 218
realSet 218
读取预训练Embbeding
Embedding shape (19908, 20)
Embedding rate: 62.58%
embedding (19908, 20)


In [7]:
def position_encoding(sentence_size, embedding_size):
    encoding = np.ones((embedding_size, sentence_size), dtype=np.float32)
    ls = sentence_size + 1
    le = embedding_size + 1
    for i in range(1, le):
        for j in range(1, ls):
            encoding[i - 1, j - 1] = (i - (le - 1) / 2) * (j - (ls - 1) / 2)
    encoding = 1 + 4 * encoding / embedding_size / sentence_size
    return np.transpose(encoding)

class MemN2N(nn.Module): 
    def __init__(self, embedding, batch_size, vocab_size, sentence_size, memory_size, embedding_size,
                 answer_size=2, hops=3, max_grad_norm=40.0, encoding=position_encoding, name='MemN2N', 
                dropout=0.9):
        super(MemN2N, self).__init__()
        self._batch_size = batch_size
        self._vocab_size = vocab_size
        self._sentence_size = sentence_size
        self._memory_size = memory_size
        self._embedding_size = embedding_size
        self._answer_size = answer_size
        self._hops = hops
        self._max_grad_norm = max_grad_norm
        self._name = name
        self._dropout = dropout

        self._embedding = embedding
        
        self._encoding = torch.tensor(encoding(1, 3 * self._embedding_size)).to(device) # (1, 60)
        
        self.embedding = nn.Embedding(self._vocab_size, self._embedding_size)
        self.embedding.weight.data.copy_(torch.tensor(embedding))
        self.dropout = nn.Dropout(self._dropout)
        self.out = nn.Linear(3 * self._embedding_size, self._answer_size)
        
    

    def forward(self, stories, queries):
        q_emb0 = self.embedding(torch.tensor(queries).to(device)) # output_size: (16, 3, 300)
        q_emb = q_emb0.view(-1, 1, 3 * self._embedding_size)
        u_0 = torch.sum(q_emb * self._encoding, 1)
        u = [u_0] # (1, 16, 60)

        for i in range(self._hops):
            m_emb0 = self.embedding(torch.tensor(stories).to(device)) # (16, 40, 3, 300)
            m_emb = m_emb0.view(-1, self._memory_size, 1, 3 * self._embedding_size)
            m = torch.sum(m_emb * self._encoding, 2) # (16, 40, 900)
            
            u_temp = u[-1].unsqueeze(2).transpose(1, 2) # (16, 900, 1) -> (16, 1, 900)
            dotted = torch.sum(m * u_temp, 2) # (16, 40, 900) ->  # (16, 40)
            probs = F.softmax(dotted, 1) # (16, 40)
            probs_temp = probs.unsqueeze(2).transpose(1, 2) # (16, 40, 1) -> (16, 1, 40)
            
            c_emb0 = self.embedding(torch.tensor(stories).to(device)) # (16, 40, 3, 300)
            c_emb = c_emb0.view(-1, self._memory_size, 1, 3 * self._embedding_size) # (16, 40, 1, 900)
            c_temp = torch.sum(c_emb * self._encoding, 2) # (16, 40, 900)
            c = c_temp.transpose(1, 2) # (16, 900, 40)

            o_k = torch.sum(c * probs_temp, 2) # (16, 900)
            u_k = u[-1] + o_k # (16, 900)
            u.append(u_k)
        return self.dropout(self.out(u_k)) # (16, 2)
    

    def predict_log_proba(self, stories, queries):
        feed_dict = {self._stories: stories, self._queries: queries, self._dropout: 1.0}
        return self._sess.run(self._predict_log_proba_op, feed_dict=feed_dict)
    
    def gradient_noise_and_clip(self, parameters,
                                 noise_stddev=1e-3, max_clip=40.0):
        parameters = list(filter(lambda p: p.grad is not None, parameters))
        norm = nn.utils.clip_grad_norm_(parameters, max_clip)

        for p in parameters:
            noise = torch.randn(p.size()) * noise_stddev
            noise = noise.to(device)
            p.grad.data.add_(noise)
        return norm

In [8]:
mode = 'train'
data_path = '/home/wujipeng/git/ECA/ec/MemNet/' + 'data' if mode == 'train' else 'test_data'

In [9]:
time = 1

In [10]:
while (True):    
    if (time > 25):
        break
    testSentence = set(map(int, [line.split(',')[0] for line in open('/data/wujipeng/ec/data/ltp_static/static.{}/val_set.txt'.format(time)).readlines()]))
    division(testSentence, time)
    construction(strPosDict, time)
    realRight, realSet = extractRealRight(allClause, time)

    test, train = load_data(time)
    data = test + train
    vocab = sorted(reduce(lambda x, y: x | y, (set(list(chain.from_iterable(s)) + q) for s, q, _ in data)))
    word_idx = dict((c, k + 1) for k, c in enumerate(vocab))

    embedding = create_embedding(vocab, embedding_size)
    print('embedding', embedding.shape)
    memory_size = max(map(len, (s for s, _, _ in data)))
    sentence_size = 3
    vocab_size = len(word_idx) + 1
    print('memory_size', memory_size)
    print('sentence_size', sentence_size)
    print('vocab_size', vocab_size)

    trainS, trainQ, trainA = vectorize_data(train, word_idx, memory_size, sentence_size)
    testS, testQ, testA = vectorize_data(test, word_idx, memory_size, sentence_size)
    print('testS.shape', testS.shape)
    print('testQ.shape', testQ.shape)
    print('testA.shape', testA.shape)
    print('trainS.shape', trainS.shape)
    print('trainQ.shape', trainQ.shape)
    print('trainA.shape', trainA.shape)

    n_train = trainS.shape[0]
    n_test = testS.shape[0]
    print('n_train', n_train)
    print('n_test', n_test)

    train_labels = np.argmax(trainA, axis=1)
    test_labels = np.argmax(testA, axis=1)

    torch.cuda.manual_seed(random_state)
    batch_size = 16

    batches = zip(range(0, n_train - batch_size, batch_size), range(batch_size, n_train, batch_size))
    batches = [(start, end) for start, end in batches]

    model = MemN2N(embedding, batch_size, vocab_size, sentence_size, memory_size, 
                   embedding_size,hops=hops, max_grad_norm=max_grad_norm, dropout=0.1)
    model.to(device)
    criterion = nn.CrossEntropyLoss(reduction='sum')
    pre_list = []
    rec_list = []
    f1_list = []
    pre_ranking_list = []
    rec_ranking_list = []
    f1_ranking_list = []
    for t in range(1, epochs + 1):
        np.random.shuffle(batches)
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        model.train()
        total_cost = 0.0
        for start, end in batches:
            loss = 0
            optimizer.zero_grad()
            s = trainS[start:end]
            q = trainQ[start:end]
            a = trainA[start:end]
            output = model(s, q)
            labels = torch.tensor(np.argmax(a, axis=1)).to(device)
            loss = criterion(output, labels)
            total_cost += loss.item()
            loss.backward()
            # Clip gradients: gradients are modified in place
            grad_norm = model.gradient_noise_and_clip(model.parameters(),
                    noise_stddev=1e-3, max_clip=max_grad_norm)
            # _ = torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            # grads_and_vars = [(add_gradient_noise(g), v) for g, v in grads_and_vars] # Todo

            # Adjust model weights
            optimizer.step()

        if t % evaluation_interval == 0:
            train_preds = []
            model.eval()
            for start in range(0, n_train, batch_size):
                end = start + batch_size
                s = trainS[start:end]
                q = trainQ[start:end]
                out = model(s, q)
                pred = torch.argmax(out, 1)
                train_preds += pred.tolist()
            train_acc = metrics.accuracy_score(train_labels, np.array(train_preds))

            test_out = model(testS, testQ)
            test_preds = torch.argmax(test_out, 1)
            test_results = torch.softmax(test_out, 1)

            outputFile = open(os.path.join(data_path, 'prediction_{}_{}.csv'.format(time, t)), 'w')
            for e in range(len(test_results)):
                outputFile.write(str(test_results[e][1].item()) + '\n')
            outputFile.close()

            test_acc = metrics.accuracy_score(test_labels.tolist(), test_preds.tolist())
            test_pre = metrics.precision_score(test_labels.tolist(), test_preds.tolist())
            test_rec = metrics.recall_score(test_labels.tolist(), test_preds.tolist())
            test_f1 = metrics.f1_score(test_labels.tolist(), test_preds.tolist())
            pre_list.append(test_pre)
            rec_list.append(test_rec)
            f1_list.append(test_f1)

            print('****************************************')
            print('-----------------------')
            print('The time', time)
            print('Epoch', t)
            print('Total Cost:', total_cost)
            print('Training Accuracy:', train_acc)
            print('Testing Accuracy:', test_acc)
            print('Testing Precision:', test_pre)
            print('Testing Recall:', test_rec)
            print('Testing F1:', test_f1)
            print('-----------------------')

            outputFile = open(os.path.join(data_path, 'statistics_{}_{}.csv'.format(time, t)), 'w')
            outputFile.write('SUCC\ttest_labels\ttest_preds\ttest_results[0]\ttest_results[1]\n')
            labelcount = 0
            predcount = 0
            suc = 0
            for j in range(n_test):
                if test_labels[j].item() == 1:
                    labelcount += 1
                if test_preds[j].item() == 1:
                    predcount += 1
                if (test_labels[j].item() == 1) and (test_preds[j].item() == 1):
                    suc += 1
                    text = 'SUC\t'
                else:
                    text = '###\t'
                text += str(test_labels[j].item()) + '\t' + str(test_preds[j].item()) + '\t' + str(
                    test_results[j][0].item()) + '\t' + str(test_results[j][1].item()) + '\n'
                outputFile.write(text)
            outputFile.close()
            print('test_labels', labelcount)
            print('test_pred', predcount)
            print('suc', suc)
            print('****************************************')

            predictRight, predSet, sentAndClause, keywords, clauseTest = extractPredRight(time, t)
            pre_ranking, rec_ranking, f1_ranking = statistics(time, t, realRight, realSet, predictRight,
                                                              predSet, sentAndClause, keywords, clauseTest)
            pre_ranking_list.append(pre_ranking)
            rec_ranking_list.append(rec_ranking)
            f1_ranking_list.append(f1_ranking)
    maxScore, maxIndex = maxS(f1_ranking_list)
    print('#########################################')
    pre_all += pre_list[maxIndex]
    rec_all += rec_list[maxIndex]
    f1_all += f1_list[maxIndex]
    pre_all_ranking += pre_ranking_list[maxIndex]
    rec_all_ranking += rec_ranking_list[maxIndex]
    f1_all_ranking += maxScore
    print('The time', time)
    print('pre', pre_list[maxIndex])
    print('rec', rec_list[maxIndex])
    print('f1', f1_list[maxIndex])
    print('pre_ranking', pre_ranking_list[maxIndex])
    print('rec_ranking', rec_ranking_list[maxIndex])
    print('f1_ranking', maxScore)
    pre_all_list.append(pre_list[maxIndex])
    rec_all_list.append(rec_list[maxIndex])
    f1_all_list.append(f1_list[maxIndex])
    pre_all_list_ranking.append(pre_ranking_list[maxIndex])
    rec_all_list_ranking.append(rec_ranking_list[maxIndex])
    f1_all_list_ranking.append(maxScore)
    time_list.append(time)
    print('#########################################')
    time += 1

********************************
test_pos_count 216
test_neg_count 3099
train_pos_count 1951
train_neg_count 26030
********************************
testClause(all) 3315
realRight 216
realSet 216
读取预训练Embbeding
Embedding shape (19908, 20)
Embedding rate: 62.58%
embedding (19908, 20)
memory_size 40
sentence_size 3
vocab_size 19908
testS.shape (3315, 40, 3)
testQ.shape (3315, 3)
testA.shape (3315, 2)
trainS.shape (27981, 40, 3)
trainQ.shape (27981, 3)
trainA.shape (27981, 2)
n_train 27981
n_test 3315
****************************************
-----------------------
The time 1
Epoch 1
Total Cost: 8200.883011221886
Training Accuracy: 0.9307744540938494
Testing Accuracy: 0.9336349924585219
Testing Precision: 0.3
Testing Recall: 0.013888888888888888
Testing F1: 0.02654867256637168
-----------------------
test_labels 216
test_pred 10
suc 3
****************************************
sentAndClause 3315
clauseTest 3315
keywords 3315
score 3315
predict 211
predictRight 211
predSet 211
suc 85
precisio

  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


****************************************
-----------------------
The time 2
Epoch 1
Total Cost: 8300.477170228958
Training Accuracy: 0.9306459066510255
Testing Accuracy: 0.9320695102685624
Testing Precision: 0.0
Testing Recall: 0.0
Testing F1: 0.0
-----------------------
test_labels 215
test_pred 0
suc 0
****************************************
sentAndClause 3165
clauseTest 3165
keywords 3165
score 3165
predict 211
predictRight 211
predSet 211
suc 67
precision 0.3175355450236967
recall 0.3116279069767442
f1 0.3145539906103287
****************************************
****************************************
-----------------------
The time 2
Epoch 2
Total Cost: 6237.408138215542
Training Accuracy: 0.9377555010486651
Testing Accuracy: 0.9358609794628752
Testing Precision: 0.575
Testing Recall: 0.21395348837209302
Testing F1: 0.311864406779661
-----------------------
test_labels 215
test_pred 80
suc 46
****************************************
sentAndClause 3165
clauseTest 3165
keywords 31

In [13]:
np.mean(f1_all_list), np.mean(f1_all_list_ranking)

(0.5681482175573875, 0.6579699780761963)

In [20]:
model.embedding.weight

Parameter containing:
tensor([[-0.0840, -0.1043, -0.0956,  ..., -0.1884,  0.0875, -0.0471],
        [-0.2592,  0.3215, -0.4803,  ..., -0.0737, -0.0183, -0.0924],
        [ 0.1671,  0.0251,  0.0963,  ..., -0.0674, -0.0128, -0.0150],
        ...,
        [-0.3293,  0.0378, -0.5015,  ...,  0.5334,  0.4352, -0.3587],
        [-0.1532,  0.0333, -0.4117,  ...,  0.2888,  0.3627, -0.2814],
        [ 0.0906, -0.1481,  0.0886,  ...,  0.0398,  0.0305, -0.0578]],
       device='cuda:3', requires_grad=True)

In [202]:
def sampling():
    numOfSentence = 2105
    testSentence = set()
    for i in range(1, numOfSentence + 1):
        if random.random() > 0.9:
            testSentence.add(i)
    rate = float(len(testSentence)) / numOfSentence
    #    print(rate)
    if rate > 0.1 and rate - 0.1 < 0.0003:
        print(len(testSentence))
        return testSentence
    else:
        return sampling()
testSentence = sampling()

211


In [217]:
time = 'test'
testSentence = set(map(int, [line.split(',')[0] for line in open('/data/wujipeng/ec/data/test/val_set.txt').readlines()]))

In [250]:
def division(testSentence, n, mode='train'):
    data_path = '/home/wujipeng/git/ECA/ec/MemNet/' + 'data' if mode == 'train' else 'test_data'
    testSentenceFile = open(os.path.join(data_path, 'testSentenceFile_{}.csv'.format(n)), 'w')
    for index, item in enumerate(testSentence):
        testSentenceFile.write(str(item) + '\n')
    testSentenceFile.close()

    inputFile = open(os.path.join(data_path, 'datacsv_2105.csv'), 'r')
    outputFile1 = open(os.path.join(data_path, 'clause_test_{}.csv'.format(n)), 'w')
    outputFile2 = open(os.path.join(data_path, 'clause_train_{}.csv'.format(n)), 'w')
    test_pos_count, test_neg_count, train_pos_count, train_neg_count = 0, 0, 0, 0
    for _, line in enumerate(inputFile):
        sentenceID = int(line.strip().split(',')[0])
        sentence = line.strip().split(',')[-1]
        keyword = None
        keyPos = -1
        clauseList = re.split('，|。|？|！|；|……', sentence)
        causeOfSent = set()
        for index, item in enumerate(clauseList):
            match = re.search(r'\[f\]([^\[]*)\[/f\]', item)
            if match:
                keyword = match.group(1)
                keyPos = index
            match = re.search(r'\[\d[nv]\][^\[]*\[/\d[nv]\]', item)
            if match:
                causeOfSent.add(index)
            match = re.search(r'\[[-\d]*\*\d[nv]\][^\[]*\[/[-\d]*\*\d[nv]\]', item)
            if match:
                causeOfSent.add(index)
        for index, item in enumerate(clauseList):
            clause = re.sub(r'\[[^\[]*\]', '', item)
            clause = re.sub(r'“', '', clause)
            clause = re.sub(r'”', '', clause)
            clause = re.sub(r'：', '', clause)
            clause = re.sub(r'\(', '', clause)
            clause = re.sub(r'\)', '', clause)
            clause = re.sub(r'、', '', clause)
            clause = re.sub(r'’', '', clause)
            clause = re.sub(r'‘', '', clause)
            clause = re.sub(r'》', '', clause)
            clause = re.sub(r'《', '', clause)
            clause = re.sub(r'~', '', clause)
            words = ltp_cut(clause)
            pos = index - keyPos
            if (sentenceID in testSentence) and (clause.split()):
                if index in causeOfSent:
                    test_pos_count += 1
                    outputFile1.write(str(sentenceID) + ',' + str(index + 1) + ',' + keyword + ',')
                    outputFile1.write(str(pos) + ',yes,' + ' '.join(words) + '\n')
                else:
                    test_neg_count += 1
                    outputFile1.write(str(sentenceID) + ',' + str(index + 1) + ',' + keyword + ',')
                    outputFile1.write(str(pos) + ',no,' + ' '.join(words) + '\n')
            elif (sentenceID not in testSentence) and (clause.split()):
                if index in causeOfSent:
                    train_pos_count += 1
                    outputFile2.write(str(sentenceID) + ',' + str(index + 1) + ',' + keyword + ',')
                    outputFile2.write(str(pos) + ',yes,' + ' '.join(words) + '\n')
                else:
                    train_neg_count += 1
                    outputFile2.write(str(sentenceID) + ',' + str(index + 1) + ',' + keyword + ',')
                    outputFile2.write(str(pos) + ',no,' + ' '.join(words) + '\n')
    print('********************************')
    print('test_pos_count', test_pos_count)
    print('test_neg_count', test_neg_count)
    print('train_pos_count', train_pos_count)
    print('train_neg_count', train_neg_count)
    print('********************************')
    inputFile.close()
    outputFile1.close()
    outputFile2.close()
division(testSentence, 'test')

********************************
test_pos_count 220
test_neg_count 3007
train_pos_count 1947
train_neg_count 26122
********************************


In [251]:
def statisticPos(mode='train'):
    data_path = '/home/wujipeng/git/ECA/ec/MemNet/' + 'data' if mode == 'train' else 'test_data'
    inputFile = open(os.path.join(data_path, 'datacsv_2105.csv'), 'r')
    posDict = dict()
    for _, line in enumerate(inputFile):
        content = line.strip().split(',')[-1]
        clauseList = re.split('，|。|？|！|；|……', content)
        keyPos = -1
        for index, item in enumerate(clauseList):
            match = re.search(r'\[f\]([^\[]*)\[/f\]', item)
            if match:
                keyPos = index
        for index, item in enumerate(clauseList):
            pos = index - keyPos
            if posDict.get(pos):
                posDict[pos] += 1
            else:
                posDict[pos] = 1
    inputFile.close()
    return posDict
posDict = statisticPos()

In [252]:
def changePos(posDict):
    count = 0
    countDict = dict()
    for key, value in posDict.items():
        countDict[key] = count
        count += 1
    posList = ['AAAA', 'AAAB', 'AAAC', 'AAAD', 'AABA', 'AABB', 'AABC', 'AABD', 'AACA', 'AACB', 'AACC', 'AACD',
               'AADA', 'AADB', 'AADC', 'AADD', 'ABAA', 'ABAB', 'ABAC', 'ABAD', 'ABBA', 'ABBB', 'ABBC', 'ABBD',
               'ABCA', 'ABCB', 'ABCC', 'ABCD', 'ABDA', 'ABDB', 'ABDC', 'ABDD', 'ACAA', 'ACAB', 'ACAC', 'ACAD',
               'ACBA', 'ACBB', 'ACBC', 'ACBD', 'ACCA', 'ACCB', 'ACCC', 'ACCD', 'ACDA', 'ACDB', 'ACDC', 'ACDD',
               'ADAA', 'ADAB', 'ADAC', 'ADAD', 'ADBA', 'ADBB', 'ADBC', 'ADBD', 'ADCA', 'ADCB', 'ADCC', 'ADCD',
               'ADDA', 'ADDB', 'ADDC', 'ADDD', 'BAAA', 'BAAB', 'BAAC', 'BAAD', 'BABA', 'BABB', 'BABC', 'BABD',
               'BACA', 'BACB', 'BACC', 'BACD', 'BADA', 'BADB', 'BADC', 'BADD', 'BBAA', 'BBAB', 'BBAC', 'BBAD',
               'BBBA', 'BBBB', 'BBBC', 'BBBD', 'BBCA', 'BBCB', 'BBCC', 'BBCD', 'BCAA', 'BCAB', 'BCAC', 'BCAD',
               'BDAA', 'BDAB', 'BDAC', 'BDAD', 'BCBA', 'BCBB', 'BCBC']
    strPosDict = dict()
    for key, value in countDict.items():
        strPosDict[key] = posList[value]
    return strPosDict
strPosDict = changePos(posDict)

In [253]:
def extractAll(mode='train'):
    data_path = '/home/wujipeng/git/ECA/ec/MemNet/' + 'data' if mode == 'train' else 'test_data'
    inputFile = open(os.path.join(data_path, 'datacsv_2105.csv'), 'r')
    allClause = []
    for index, line in enumerate(inputFile):
        sent = int(line.strip().split(',')[0])  # int
        content = line.strip().split(',')[-1]
        clauseList = re.split('，|。|？|！|；|……', content)
        for index, clause in enumerate(clauseList):
            cause = 'n'
            match = re.search(r'\[\d[nv]\][^\[]*\[/\d[nv]\]', clause)
            if match:
                cause = 'c'
            match = re.search(r'\[[-\d]*\*\d[nv]\][^\[]*\[/[-\d]*\*\d[nv]\]', clause)
            if match:
                cause = 'c'
            clause = re.sub(r'\[[^\[]*\]', '', clause)
            clause = re.sub(r'“', '', clause)
            clause = re.sub(r'”', '', clause)
            clause = re.sub(r'：', '', clause)
            clause = re.sub(r'\(', '', clause)
            clause = re.sub(r'\)', '', clause)
            clause = re.sub(r'、', '', clause)
            clause = re.sub(r'’', '', clause)
            clause = re.sub(r'‘', '', clause)
            clause = re.sub(r'》', '', clause)
            clause = re.sub(r'《', '', clause)
            if clause.split():
                allClause.append([sent, index + 1, cause])
    print('allClause', len(allClause))
    inputFile.close()
    return allClause
allClause = extractAll()

allClause 31296


In [254]:
def construction(strPosDict, n, mode='train'):
    data_path = '/home/wujipeng/git/ECA/ec/MemNet/' + 'data' if mode == 'train' else 'test_data'
    inputFile1 = open(os.path.join(data_path, 'clause_test_{}.csv'.format(n)), 'r')
    inputFile2 = open(os.path.join(data_path, 'clause_train_{}.csv'.format(n)), 'r')
    outputFile1 = open(os.path.join(data_path, 'emotion_cause_clause_level_test_{}.csv'.format(n)), 'w')
    outputFile2 = open(os.path.join(data_path, 'emotion_cause_clause_level_train_{}.csv'.format(n)), 'w')
    for _, line in enumerate(inputFile1):
        keyword = line.strip().split(',')[2]
        position = int(line.strip().split(',')[3])
        posStr = strPosDict[position]
        label = line.strip().split(',')[4]
        clause = line.strip().split(',')[5]
        wordList = clause.strip().split(' ')
        phraseList = []
        if len(wordList) >= 3:
            window = 3
            begin = 0
            for index, item in enumerate(wordList):
                end = begin + window
                if end <= len(wordList):
                    phraseList.append(wordList[begin: end])
                begin += 1
        else:
            phraseList.append(wordList[:])
        lineNum = 1
        for index, item in enumerate(phraseList):
            outputFile1.write(str(lineNum))
            for i in range(len(item)):
                outputFile1.write(' ' + item[i])
            outputFile1.write('\n')
            lineNum += 1
        outputFile1.write(str(lineNum) + ' ' + posStr + ' ' + posStr + ' ' + posStr + '\n')
        lineNum += 1
        outputFile1.write(str(lineNum) + ' ' + keyword + ' ' + keyword + ' ' + keyword + '\t' + label + '\n')
    inputFile1.close()
    outputFile1.close()
    for _, line in enumerate(inputFile2):
        keyword = line.strip().split(',')[2]
        position = int(line.strip().split(',')[3])
        posStr = strPosDict[position]
        label = line.strip().split(',')[4]
        clause = line.strip().split(',')[5]
        wordList = clause.strip().split(' ')
        phraseList = []
        if len(wordList) >= 3:
            window = 3
            begin = 0
            for index, item in enumerate(wordList):
                end = begin + window
                if end <= len(wordList):
                    phraseList.append(wordList[begin: end])
                begin += 1
        else:
            phraseList.append(wordList[:])
        lineNum = 1
        for index, item in enumerate(phraseList):
            outputFile2.write(str(lineNum))
            for i in range(len(item)):
                outputFile2.write(' ' + item[i])
            outputFile2.write('\n')
            lineNum += 1
        outputFile2.write(str(lineNum) + ' ' + posStr + ' ' + posStr + ' ' + posStr + '\n')
        lineNum += 1
        outputFile2.write(str(lineNum) + ' ' + keyword + ' ' + keyword + ' ' + keyword + '\t' + label + '\n')
    inputFile2.close()
    outputFile2.close()
construction(strPosDict, time)

In [258]:
def extractRealRight(allClause, n, mode='train'):
    data_path = '/home/wujipeng/git/ECA/ec/MemNet/' + 'data' if mode == 'train' else 'test_data'
    testSentenceFile = open(os.path.join(data_path, 'testSentenceFile_{}.csv'.format(n)), 'r')
    testSentence = set()
    for index, item in enumerate(testSentenceFile):
        item = int(item.strip())
        testSentence.add(item)
    testSentenceFile.close()
    testClause = []
    for index, item in enumerate(allClause):
        sent = allClause[index][0]
        if sent in testSentence:
            testClause.append(item)
    print('testClause(all)', len(testClause))
    realRight = []
    realSet = set()
    for index, item in enumerate(testClause):
        cSent = item[0]
        cClause = item[1]
        cCause = item[2]
        if cCause == 'c':
            realRight.append([cSent, cClause])
            key = str(cSent) + ',' + str(cClause)
            realSet.add(key)
    print('realRight', len(realRight))
    print('realSet', len(realSet))
    return realRight, realSet
realRight, realSet = extractRealRight(allClause, time)

testClause(all) 3227
realRight 220
realSet 220


In [259]:
import numpy as np
def load_data(n, mode='train'):
    data_path = '/home/wujipeng/git/ECA/ec/MemNet/' + 'data' if mode == 'train' else 'test_data'
    testFile = open(os.path.join('/home/wujipeng/git/ECA/ec/MemNet/', data_path, 'emotion_cause_clause_level_test_{}.csv'.format(n)), 'r')
    trainFile = open(os.path.join('/home/wujipeng/git/ECA/ec/MemNet/', data_path, 'emotion_cause_clause_level_train_{}.csv'.format(n)), 'r')
    testData = []
    testStory = []
    for _, line in enumerate(testFile):
        nid = line.strip().split(' ')[0]
        nLine = line.strip().split(' ')[1:]
        if nid == '1':
            testStory = []
            items = nLine[:]
            testStory.append(items)
        if '\t' in nLine[-1]:
            q = []
            q.append(nLine[0])
            q.append(nLine[1])
            q.append(nLine[-1].split('\t')[0])
            a = [nLine[-1].split('\t')[1]]
            subStory = [x for x in testStory if x]
            testData.append((subStory, q, a))
        else:
            items = nLine[:]
            testStory.append(items)
    testFile.close()
    trainData = []
    trainStory = []
    for _, line in enumerate(trainFile):
        nid = line.strip().split(' ')[0]
        nLine = line.strip().split(' ')[1:]
        if nid == '1':
            trainStory = []
            items = nLine[:]
            trainStory.append(items)
        if '\t' in nLine[-1]:
            q = []
            q.append(nLine[0])
            q.append(nLine[1])
            q.append(nLine[-1].split('\t')[0])
            a = [nLine[-1].split('\t')[1]]
            subStory = [x for x in trainStory if x]
            trainData.append((subStory, q, a))
        else:
            items = nLine[:]
            trainStory.append(items)
    trainFile.close()
    return testData, trainData
test, train = load_data(time)
data = test + train
vocab = sorted(reduce(lambda x, y: x | y, (set(list(chain.from_iterable(s)) + q) for s, q, _ in data)))
word_idx = dict((c, k + 1) for k, c in enumerate(vocab))

In [236]:
def create_embedding(vocab, mode='train'):
    from gensim.models import KeyedVectors
    print('读取预训练Embbeding')
    word2vec = KeyedVectors.load_word2vec_format('/data/wujipeng/embedding/Wikipedia/sgns.wiki.word', binary=False)
    dim = word2vec.vector_size
    embedding = [np.zeros(dim)]  # pad
    cnt = 0
    for word in vocab:
        if word2vec.vocab.get(word):
            embedding.append(word2vec.get_vector(word))
            cnt += 1
        else:
            embedding.append(np.random.normal(loc=0., scale=0.1, size=dim))
    embedding = np.array(embedding)
    print('Embedding shape', embedding.shape)
    print('Embedding rate: {:.2f}%'.format(cnt / len(vocab) * 100))
    return np.array(embedding)
embedding = create_embedding(vocab)

读取预训练Embbeding
Embedding shape (23150, 300)
Embedding rate: 84.62%
embedding (23150, 300)


In [241]:
len(vocab)

23149

In [263]:
def create_embedding(vocab, mode='train'):
    from gensim.models import KeyedVectors
    print('读取预训练Embbeding')
    word2vec = KeyedVectors.load_word2vec_format('/data/wujipeng/ec/data/embedding/vec_new.txt', binary=False)
    dim = word2vec.vector_size
    embedding = [np.zeros(dim)]  # pad
    cnt = 0
    for word in vocab:
        if word2vec.vocab.get(word):
            embedding.append(word2vec.get_vector(word))
            cnt += 1
        else:
            embedding.append(np.random.normal(loc=0., scale=0.1, size=dim))
    embedding = np.array(embedding)
    print('Embedding shape', embedding.shape)
    print('Embedding rate: {:.2f}%'.format(cnt / len(vocab) * 100))
    return np.array(embedding)
embedding = create_embedding(vocab)
print('embedding', embedding.shape)

读取预训练Embbeding
Embedding shape (19908, 30)
Embedding rate: 95.80%
embedding (19908, 30)
