In [1]:
import numpy as np
import pandas as pd

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Prepare data

In [68]:
# read vocab
with open('/data/wujipeng/ec/data/raw_data/ltp_vocab.txt') as f:
    word_unk = f.readline().strip()
    vocab = ['<pad>', word_unk] + f.readline().strip().split(' ')
i2w = {i: w for i, w in enumerate(vocab)}
w2i = {i: w for w, i in enumerate(vocab)}

In [69]:
# load pos vocab
pos_label = [
    'AAAA', 'AAAB', 'AAAC', 'AAAD', 'AABA', 'AABB', 'AABC', 'AABD', 'AACA',
    'AACB', 'AACC', 'AACD', 'AADA', 'AADB', 'AADC', 'AADD', 'ABAA', 'ABAB',
    'ABAC', 'ABAD', 'ABBA', 'ABBB', 'ABBC', 'ABBD', 'ABCA', 'ABCB', 'ABCC',
    'ABCD', 'ABDA', 'ABDB', 'ABDC', 'ABDD', 'ACAA', 'ACAB', 'ACAC', 'ACAD',
    'ACBA', 'ACBB', 'ACBC', 'ACBD', 'ACCA', 'ACCB', 'ACCC', 'ACCD', 'ACDA',
    'ACDB', 'ACDC', 'ACDD', 'ADAA', 'ADAB', 'ADAC', 'ADAD', 'ADBA', 'ADBB',
    'ADBC', 'ADBD', 'ADCA', 'ADCB', 'ADCC', 'ADCD', 'ADDA', 'ADDB', 'ADDC',
    'ADDD', 'BAAA', 'BAAB', 'BAAC', 'BAAD', 'BABA', 'BABB', 'BABC', 'BABD',
    'BACA', 'BACB', 'BACC', 'BACD', 'BADA', 'BADB', 'BADC', 'BADD', 'BBAA',
    'BBAB', 'BBAC', 'BBAD', 'BBBA', 'BBBB', 'BBBC', 'BBBD', 'BBCA', 'BBCB',
    'BBCC', 'BBCD', 'BCAA', 'BCAB', 'BCAC', 'BCAD', 'BDAA', 'BDAB', 'BDAC',
    'BDAD', 'BCBA', 'BCBB', 'BCBC'
]
p2l = {i+1: w for i, w in enumerate(pos_label)}
l2p = {w: i+1 for i, w in enumerate(pos_label)}

In [70]:
vocab += pos_label

In [41]:
with open('/data/wujipeng/ec/data/raw_data/memnet_ltp_vocab.txt', 'w') as f:
    f.write('<unk>\n')
    f.write(' '.join(vocab[2:]) + '\n')

In [71]:
LTP_DATA_DIR = '/home/wujipeng/data/ltp_data_v3.4.0/'  # ltp模型目录的路径
cws_model_path = os.path.join(LTP_DATA_DIR, 'cws.model')  # 分词模型路径，模型名称为`cws.model`
pos_model_path = os.path.join(LTP_DATA_DIR, 'pos.model')

from pyltp import Segmentor, Postagger

segmentor = Segmentor()  # 初始化实例
segmentor.load(cws_model_path)  # 加载模型
postagger = Postagger()
postagger.load(pos_model_path)

def ltp_seg(sentence):
    words, natures = [], []
    for sen in sentence.split('\x01'):
        seg_words = segmentor.segment(sen)
        seg_natures = postagger.postag(seg_words)
        words += list(seg_words) + ['\x01']
        natures += [nature[0] for nature in list(seg_natures)] + ['\x01']
    
    return words[:-1], natures[:-1]

In [86]:
def load_corpus(raw_corpus):
    #  Hanlp分词
    print('开始分词...')
    corpus = []
    natures = []
    with open('/data/wujipeng/ec/data/raw_data/ltp_corpus.txt', 'w') as fc:
        with open('/data/wujipeng/ec/data/raw_data/ltp_natures.txt', 'w') as fn:
            for line in raw_corpus:
                words, tags = ltp_seg(line)
                words = ' '.join(words)
                tags = ' '.join(tags)
                corpus.append(words)
                natures.append(tags)
                fc.write(words + '\n')
                fn.write(tags + '\n')
    print('完成分词')
    return corpus, natures

In [96]:
from collections import Counter
pos_label = [
    'AAAA', 'AAAB', 'AAAC', 'AAAD', 'AABA', 'AABB', 'AABC', 'AABD', 'AACA',
    'AACB', 'AACC', 'AACD', 'AADA', 'AADB', 'AADC', 'AADD', 'ABAA', 'ABAB',
    'ABAC', 'ABAD', 'ABBA', 'ABBB', 'ABBC', 'ABBD', 'ABCA', 'ABCB', 'ABCC',
    'ABCD', 'ABDA', 'ABDB', 'ABDC', 'ABDD', 'ACAA', 'ACAB', 'ACAC', 'ACAD',
    'ACBA', 'ACBB', 'ACBC', 'ACBD', 'ACCA', 'ACCB', 'ACCC', 'ACCD', 'ACDA',
    'ACDB', 'ACDC', 'ACDD', 'ADAA', 'ADAB', 'ADAC', 'ADAD', 'ADBA', 'ADBB',
    'ADBC', 'ADBD', 'ADCA', 'ADCB', 'ADCC', 'ADCD', 'ADDA', 'ADDB', 'ADDC',
    'ADDD', 'BAAA', 'BAAB', 'BAAC', 'BAAD', 'BABA', 'BABB', 'BABC', 'BABD',
    'BACA', 'BACB', 'BACC', 'BACD', 'BADA', 'BADB', 'BADC', 'BADD', 'BBAA',
    'BBAB', 'BBAC', 'BBAD', 'BBBA', 'BBBB', 'BBBC', 'BBBD', 'BBCA', 'BBCB',
    'BBCC', 'BBCD', 'BCAA', 'BCAB', 'BCAC', 'BCAD', 'BDAA', 'BDAB', 'BDAC',
    'BDAD', 'BCBA', 'BCBB', 'BCBC'
]

def build_vocab(corpus):
    # build vocab
    vocab_cnt = Counter()
    for line in corpus:
        for word in line.replace('\x01', '').split(' '):
            if word.strip():
                vocab_cnt[word] += 1
#     vocab = [word for word, freq in vocab_cnt.most_common()]
    vocab = [word for word, freq in vocab_cnt.items()]
    vocab += pos_label
    with open('/data/wujipeng/ec/data/raw_data/memnet_ltp_vocab.txt', 'w') as f:
        f.write('<unk>\n')
        f.write(' '.join(vocab))
    print('词典大小: ', len(vocab) + 2)
    print('保存词典')
    return vocab

In [74]:
def save_data(data):
    data['id'] = list(range(1, len(data) + 1))
    data[['id', 'clause', 'nature', 'keyword', 'emotion', 'clause_pos', 'label']].to_csv(
        '/data/wujipeng/ec/data/han/memnet_ltp_processed_data.csv', index=False)
    print('保存处理后数据')

In [97]:
data_root = '/data/wujipeng/ec/data/'
data = pd.read_csv(os.path.join(data_root, 'raw_data', 'process_data_3.csv'), index_col=0)
corpus = [text.replace(' ', '') for text in data['clause'].tolist()]
keyword = data['keyword'].tolist()
poses = [list(map(int, pos.split(' '))) for pos in data['clause_pos'].tolist()]
min_pos = min(min(poses))
poses = [' '.join([str(p - min_pos + 1) for p in pos]) for pos in poses]

# corpus = load_corpus(corpus)
corpus, natures = load_corpus(corpus)
vocab = build_vocab(corpus + keyword)

data['clause'] = corpus
data['nature'] = natures
data['keyword'] = keyword
data['clause_pos'] = poses
save_data(data)

开始分词...
完成分词
词典大小:  19909
保存词典
保存处理后数据


In [136]:
# read vocab
with open('/data/wujipeng/ec/data/raw_data/memnet_ltp_vocab.txt') as f:
    word_unk = f.readline().strip()
    vocab = ['<pad>', word_unk] + f.readline().strip().split(' ')
i2w = {i: w for i, w in enumerate(vocab)}
w2i = {i: w for w, i in enumerate(vocab)}

In [9]:
import os
def load_data(path):
    if not os.path.exists(path):
        raise FileNotFoundError
    data = []
    with open(path, 'r') as f:
        for line in f.readlines():
            data.append(line.strip().split(','))
    return data

In [108]:
# read data
raw_data = load_data('/data/wujipeng/ec/data/ltp_test/train_set.txt')

In [123]:
data = []
for item in raw_data:
    n_id, n_clauses, n_natures, n_keyword, n_emotion, n_pos, n_label = item
    n_id = int(n_id)
    n_clauses = [clause.strip().split(' ') for clause in n_clauses.split('\x01')]
    n_natures = [nature.strip().split(' ') for nature in n_natures.split('\x01')]
    n_keyword = n_keyword.replace(' ', '')
    n_pos = list(map(int, n_pos.strip().split(' ')))
    n_label = list(map(int, n_label.strip().split(' ')))
    for cid, (clause, nature, pos, label) in enumerate(zip(n_clauses, n_natures, n_pos, n_label)):
        data.append(tuple([n_id, cid+1, clause, nature, n_keyword, n_emotion, pos, label]))

In [124]:
data[0]

(1,
 1,
 ['当', '我', '看到', '建议', '被', '采纳'],
 ['p', 'r', 'v', 'n', 'p', 'v'],
 '激动',
 '0',
 63,
 0)

In [125]:
# get item
batched_data = []
for item in data:
    batched_data.append(item)

In [134]:
def word2idx(words, batched=False):
    if not batched:
        indices = [w2i[w] if w in w2i else w2i[word_unk] for w in words]
    else:
        indices = [[w2i[w] if w in w2i else w2i[word_unk] for w in item] for item in words]
    return indices

def idx2word(indices, batched=False):
    if not batched:
        words = [i2w[i] for i in indices]
    else:
        words = [[i2w[i] for i in item] for item in indices]
    return words

def pos2label(poses, batched=False):
    if not batched:
        indices = [p2l[p] for p in poses]
    else:
        indices = [[p2l[p] for p in pos] for pos in poses]
    return indices

In [195]:
memory_size = 40
sequence_size = 3
batch_size = 16
batch_data = list(zip(*batched_data[:batch_size-6]))
ids, cids, clauses, natures, keywords, emotions, poses, labels = batch_data
ids = list(ids)
cids = list(cids)
clauses = word2idx(clauses, batched=True)
keywords = word2idx(keywords, batched=False)
emotions = list(map(int, emotions))
poses = word2idx(pos2label(poses, batched=False))
labels = list(labels)

In [164]:
sentence_length

40

In [96]:
def pad_memory(sequences, poses, memory_size, sequence_size, pad=0):
    paded_sequences = []
    for sequence, pos in zip(sequences, poses):
        if len(sequence) < sequence_size:
            sequence += [pad] * (sequence_size - len(sequence))
        paded_sequence = [sequence[:sequence_size]]
        for i in range(memory_size-1):
            if i < len(sequence) - sequence_size + 1:
                paded_sequence += [sequence[i:i + 3]]
            elif i == len(sequence) - sequence_size + 1:
                paded_sequence += [[pos] * 3]
            else:
                paded_sequence += [[pad] * 3]
        paded_sequences.append(np.array(paded_sequence))
    return paded_sequences

In [196]:
clauses = pad_memory(clauses, poses, memory_size, sequence_size, pad=0)
keywords = [[keyword] * sequence_size for keyword in keywords]

In [197]:
if len(ids) < batch_size:
    bs = batch_size - len(ids)
    ids += [0] * bs
    cids += [0] * bs
    clauses += [[[0] * sequence_size] * memory_size] * bs
    keywords += [[0] * sequence_size] * bs
    emotions += [0] * bs
    labels += [-100] * bs

# Build embedding

In [15]:
with open('/data/wujipeng/ec/data/raw_data/memnet_ltp_vocab.txt', 'r') as f:
    vocab = ['<pad>', f.readline().strip()]
    vocab = f.readline().strip().split(' ')

In [18]:
len(sorted(vocab)) + 2

19909

In [19]:
import os
import pickle
from gensim.models import KeyedVectors
data_root = '/data/wujipeng/ec/data/'
print('读取预训练Embbeding')
word2vec = KeyedVectors.load_word2vec_format('/data/wujipeng/ec/data/embedding/seg_resource.bin', binary=False)
dim = word2vec.vector_size
embedding = [np.zeros(dim), np.random.normal(loc=0., scale=0.1, size=dim)]  # pad unk
cnt = 0
for word in vocab:
    if word2vec.vocab.get(word):
        embedding.append(word2vec.get_vector(word))
        cnt += 1
    else:
        embedding.append(np.random.normal(loc=0., scale=0.1, size=dim))
embedding = np.array(embedding)
print('Embedding shape', embedding.shape)
print('Embedding rate: {:.2f}%'.format(cnt / len(vocab) * 100))
pickle.dump(embedding,
            open(os.path.join(data_root, 'embedding', '{}_embedding{}d.pkl'.format('memnet', dim)), 'wb'))

读取预训练Embbeding
Embedding shape (19909, 20)
Embedding rate: 62.58%


# Dataset

In [2]:
import sys
sys.path = ['..'] + sys.path
import numpy as np
import torch
from torch.utils.data import Dataset
from utils.data.process import load_data, pad_sequence, pad_memory
import os

In [None]:
class MECDataset(Dataset):
    def __init__(self, data_root, vocab_root, batch_size=16, train=True):
        super(MECDataset, self).__init__()
        self.train = train
        self.data_path = os.path.join(data_root, '{}_set.txt'.format('train' if train else 'val'))
        self.vocab_root = vocab_root
        self.batch_size = batch_size
        self.data = []
        self.read_data()
        self.read_vocab()
        self.read_pos()

    def read_data(self):
        data = load_data(self.data_path)
        for item in data:
            n_id, n_clauses, n_natures, n_keyword, n_emotion, n_pos, n_label = item
            n_id = int(n_id)
            n_clauses = [clause.strip().split(' ') for clause in n_clauses.split('\x01')]
            n_natures = [nature.strip().split(' ') for nature in n_natures.split('\x01')]
            n_keyword = n_keyword.replace(' ', '')
            n_pos = list(map(int, n_pos.strip().split(' ')))
            n_label = list(map(int, n_label.strip().split(' ')))
            for cid, (clause, nature, pos, label) in enumerate(zip(n_clauses, n_natures, n_pos, n_label)):
                self.data.append(tuple([n_id, cid + 1, clause, nature, n_keyword, n_emotion, pos, label]))

    def read_vocab(self):
        if os.path.isdir(self.vocab_root):
            self.vocab_root = os.path.join(self.vocab_root, 'vocab.txt')
        with open(self.vocab_root, 'r') as f:
            self.word_unk = f.readline().strip()
            self.vocab = ['<pad>', self.word_unk] + f.readline().strip().split(' ')
        self.i2w = {i: w for i, w in enumerate(self.vocab)}
        self.w2i = {w: i for i, w in enumerate(self.vocab)}

    def read_pos(self):
        pos_label = [
            'AAAA', 'AAAB', 'AAAC', 'AAAD', 'AABA', 'AABB', 'AABC', 'AABD', 'AACA',
            'AACB', 'AACC', 'AACD', 'AADA', 'AADB', 'AADC', 'AADD', 'ABAA', 'ABAB',
            'ABAC', 'ABAD', 'ABBA', 'ABBB', 'ABBC', 'ABBD', 'ABCA', 'ABCB', 'ABCC',
            'ABCD', 'ABDA', 'ABDB', 'ABDC', 'ABDD', 'ACAA', 'ACAB', 'ACAC', 'ACAD',
            'ACBA', 'ACBB', 'ACBC', 'ACBD', 'ACCA', 'ACCB', 'ACCC', 'ACCD', 'ACDA',
            'ACDB', 'ACDC', 'ACDD', 'ADAA', 'ADAB', 'ADAC', 'ADAD', 'ADBA', 'ADBB',
            'ADBC', 'ADBD', 'ADCA', 'ADCB', 'ADCC', 'ADCD', 'ADDA', 'ADDB', 'ADDC',
            'ADDD', 'BAAA', 'BAAB', 'BAAC', 'BAAD', 'BABA', 'BABB', 'BABC', 'BABD',
            'BACA', 'BACB', 'BACC', 'BACD', 'BADA', 'BADB', 'BADC', 'BADD', 'BBAA',
            'BBAB', 'BBAC', 'BBAD', 'BBBA', 'BBBB', 'BBBC', 'BBBD', 'BBCA', 'BBCB',
            'BBCC', 'BBCD', 'BCAA', 'BCAB', 'BCAC', 'BCAD', 'BDAA', 'BDAB', 'BDAC',
            'BDAD', 'BCBA', 'BCBB', 'BCBC'
        ]
        self.p2l = {i + 1: w for i, w in enumerate(pos_label)}
        self.l2p = {w: i + 1 for i, w in enumerate(pos_label)}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

    def word2idx(self, words, batched=False):
        if not batched:
            indices = [self.w2i[w] if w in self.w2i else self.w2i[self.word_unk] for w in words]
        else:
            indices = [[self.w2i[w] if w in self.w2i else self.w2i[self.word_unk] for w in item] for item in words]
        return indices

    def idx2word(self, indices, batched=False):
        if not batched:
            words = [self.i2w[i] for i in indices]
        else:
            words = [[self.i2w[i] for i in item] for item in indices]
        return words

    def pos2label(self, poses, batched=False):
        if not batched:
            indices = [self.p2l[p] for p in poses]
        else:
            indices = [[self.p2l[p] for p in pos] for pos in poses]
        return indices

    def collate_fn(self, batch_data, pad=True, memory_size=41, sequence_size=3, batch_size=16):
        batch_data = list(zip(*batch_data))
        ids, cids, clauses, natures, keywords, emotions, poses, labels = batch_data
        ids = list(ids)
        cids = list(cids)
        clauses = self.word2idx(clauses, batched=True)
        keywords = self.word2idx(keywords, batched=False)
        emotions = list(map(int, emotions))
        poses = self.word2idx(self.pos2label(poses, batched=False))
        labels = list(labels)
        if pad:
            clauses = pad_memory(clauses, poses, memory_size, sequence_size, pad=0)
            keywords = [[keyword] * sequence_size for keyword in keywords]
            if len(ids) < batch_size:
                bs = batch_size - len(ids)
                ids += [0] * bs
                cids += [0] * bs
                clauses += [[[0] * sequence_size] * memory_size] * bs
                keywords += [[0] * sequence_size] * bs
                emotions += [0] * bs
                labels += [-100] * bs

        return ids, cids, np.array(clauses), natures, np.array(keywords), np.array(emotions), poses, np.array(
            labels)

    @staticmethod
    def batch2input(batch):
        return batch[2], batch[4]

    @staticmethod
    def batch2target(batch):
        return batch[-1]

In [3]:
from utils.dataset.memnet import MECDataset

In [4]:
from utils.dataloader.memnet import MECDataLoader

In [5]:
train_dataset = MECDataset(data_root='/data/wujipeng/ec/data/ltp_static/static.1/', vocab_root='/data/wujipeng/ec/data/raw_data/memnet_ltp_vocab.txt', batch_size=16, train=True)
eval_dataset = MECDataset(data_root='/data/wujipeng/ec/data/ltp_static/static.1/', vocab_root='/data/wujipeng/ec/data/raw_data/memnet_ltp_vocab.txt', batch_size=16, train=False)

In [8]:
dataset = train_dataset + eval_dataset

In [9]:
sentences = list(zip(*dataset))[2] + list(zip(*dataset))[3]

In [10]:
memory_size = max([len(sentence) for sentence in sentences])
memory_size

40

# Dataloader

In [11]:
import random
class ECDataLoader:
    def __init__(self,
                 dataset,
                 memory_size,
                 sequence_size,
                 batch_size=16,
                 shuffle=True,
                 auto_refresh=True,
                 collate_fn=None):
        self.dataset = dataset
        self.size = len(dataset)
        self.memory_size = memory_size
        self.sequence_size = sequence_size
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.batches = None
        self.num_batches = None
        self.auto_refresh = auto_refresh
        self.collate_fn = collate_fn
        self.inst_count = 0
        self.batch_count = 0
        self._curr_batch = 0
        self._curr_num_insts = None
        self.dataset = list(dataset)
        if self.auto_refresh:
            self.refresh()

    def __iter__(self):
        return self

    def __next__(self):
        return self.next()

    def __len__(self):
        return self.num_batches

    def next(self):
        if self._curr_batch and self._curr_batch + 1 >= self.num_batches:
            if self.auto_refresh:
                self.refresh()
            raise StopIteration
        data = self.get_data()
        return data

    def get_data(self):
        self._curr_batch = (
            self._curr_batch + 1) if self._curr_batch is not None else 0
        self._curr_num_insts = len(self.batches[self._curr_batch])

        self.inst_count += self._curr_num_insts
        self.batch_count += 1
        data = self.batches[self._curr_batch]
        if self.collate_fn:
            data = self.collate_fn(
                data,
                memory_size=self.memory_size,
                sequence_size=self.sequence_size,
                batch_size=self.batch_size)
        return data

    def refresh(self):
        self.batches = []
        batch_start = 0
        for i in range(self.size // self.batch_size):
            self.batches.append(
                self.dataset[batch_start:batch_start + self.batch_size])
            batch_start += self.batch_size
        if batch_start != self.size:
            self.batches.append(self.dataset[batch_start:])
        if self.shuffle:
            random.shuffle(self.batches)
        self.num_batches = len(self.batches)
        self._curr_batch = None
        self._curr_num_insts = None

    def state_dict(self):
        """
        Warning! side effect: np_randomstate will influence other
            potion of the program.
        """
        state = {
            "batch_size": self.batch_size,
            "shuffle": self.shuffle,
            "batches": self.batches,
            "num_batches": self.num_batches,
            "auto_refresh": self.auto_refresh,
            "inst_count": self.inst_count,
            "batch_count": self.batch_count,
            "_curr_batch": self._curr_batch,
            "_curr_num_insts": self._curr_num_insts,
            "np_randomstate": np.random.get_state(),
        }
        return state

    def load_state_dict(self, state):
        """
        Warning! side effect: np_randomstate will influence other
            potion of the program.
        """
        self.batch_size = state["batch_size"]
        self.shuffle = state["shuffle"]
        self.batches = state["batches"]
        self.num_batches = state["num_batches"]
        self.auto_refresh = state["auto_refresh"]
        self.inst_count = state["inst_count"]
        self.batch_count = state["batch_count"]
        self._curr_batch = state["_curr_batch"]
        self._curr_num_insts = state["_curr_num_insts"]
        np.random.set_state(state["np_randomstate"])

In [12]:
train_loader = ECDataLoader(dataset=train_dataset, memory_size=memory_size, sequence_size=3, batch_size=16, shuffle=True, collate_fn=train_dataset.collate_fn)

In [15]:
len(train_loader)

1755

In [16]:
for batch in train_loader:
    break

In [20]:
stories, queries = train_dataset.batch2input(batch)
stories, queries = torch.from_numpy(stories), torch.from_numpy(queries)

In [21]:
stories.size()

torch.Size([16, 40, 3])

In [22]:
queries.size()

torch.Size([16, 3])

# Model

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [14]:
import pickle

In [15]:
device = torch.device('cuda: 3')

In [16]:
train_loader = MECDataLoader(dataset=train_dataset, memory_size=memory_size, sequence_size=3, batch_size=16, shuffle=True, collate_fn=train_dataset.collate_fn)
eval_loader = MECDataLoader(dataset=eval_dataset, memory_size=memory_size, sequence_size=3, batch_size=16, shuffle=False, collate_fn=train_dataset.collate_fn)

## init

In [8]:
batch_size = 16
vocab_size = 19909
sentence_size = 3
memory_size = 40
embedding_dim = 20
num_classes = 2
hops = 3
dropout = 0.1

In [12]:
def position_encoding(sentence_size, embedding_size):
    encoding = np.ones((embedding_size, sentence_size), dtype=np.float32)
    ls = sentence_size + 1
    le = embedding_size + 1
    for i in range(1, le):
        for j in range(1, ls):
            encoding[i - 1, j - 1] = (i - (le - 1) / 2) * (j - (ls - 1) / 2)
    encoding = 1 + 4 * encoding / embedding_size / sentence_size
    return np.transpose(encoding)

In [13]:
embeddings = pickle.load(open('/data/wujipeng/ec/data/embedding/memnet_embedding20d.pkl', 'rb')).astype(np.float32)
embeddings.shape

(19909, 20)

In [228]:
Embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0).from_pretrained(torch.from_numpy(embeddings), freeze=False).to(device)

In [36]:
encoding = torch.from_numpy(position_encoding(1, 3 * embedding_dim)).to(device)

In [37]:
dp = nn.Dropout(dropout).to(device)

In [38]:
fc = nn.Linear(3 * embedding_dim, num_classes).to(device)

## forward

In [34]:
stories, queries = stories.to(device), queries.to(device)
stories.size(), queries.size()

(torch.Size([16, 40, 3]), torch.Size([16, 3]))

In [39]:
q_emb0 = Embedding(queries)
q_emb = q_emb0.view(-1, 1, 3*embedding_dim)
u_0 = torch.sum(q_emb * encoding, 1)
u = [u_0]
q_emb0.size(), q_emb.size(), u_0.size()

(torch.Size([16, 3, 20]), torch.Size([16, 1, 60]), torch.Size([16, 60]))

### hops

In [41]:
m_emb0 = Embedding(stories)
m_emb = m_emb0.view(-1, memory_size, 1, 3 * embedding_dim)
m = torch.sum(m_emb * encoding, -2)
m_emb0.size(), m_emb.size(), m.size()

(torch.Size([16, 40, 3, 20]),
 torch.Size([16, 40, 1, 60]),
 torch.Size([16, 40, 60]))

In [42]:
u_temp = u[-1].unsqueeze(-1).transpose(-2, -1)
dotted = torch.sum(m * u_temp, -1)
probs = F.softmax(dotted, -1)
probs_temp = probs.unsqueeze(-1).transpose(-2, -1)
u_temp.size(), dotted.size(), probs.size(), probs_temp.size()

(torch.Size([16, 1, 60]),
 torch.Size([16, 40]),
 torch.Size([16, 40]),
 torch.Size([16, 1, 40]))

In [43]:
c_emb0 = Embedding(stories)
c_emb = c_emb0.view(-1, memory_size, 1, 3 * embedding_dim)
c_temp = torch.sum(c_emb * encoding, -2)
c = c_temp.transpose(-2, -1)
c_emb0.size(), c_emb.size(), c_temp.size(), c.size()

(torch.Size([16, 40, 3, 20]),
 torch.Size([16, 40, 1, 60]),
 torch.Size([16, 40, 60]),
 torch.Size([16, 60, 40]))

In [44]:
o_k = torch.sum(c * probs_temp, -1)
u_k = u[-1] + o_k
u.append(u_k)
o_k.size(), u_k.size()

(torch.Size([16, 60]), torch.Size([16, 60]))

### loop

In [71]:
for i in range(hops):
    m_emb0 = Embedding(stories)
    m_emb = m_emb0.view(batch_size, -1, memory_size, 1, 3 * embedding_dim)
    m = torch.sum(m_emb * encoding, -2)
    
    u_temp = u[-1].unsqueeze(-1).transpose(-2, -1)
    dotted = torch.sum(m * u_temp, -1)
    probs = F.softmax(dotted, -1)
    probs_temp = probs.unsqueeze(-1).transpose(-2, -1)
    
    c_emb0 = Embedding(stories)
    c_emb = c_emb0.view(batch_size, -1, memory_size, 1, 3 * embedding_dim)
    c_temp = torch.sum(c_emb * encoding, -2)
    c = c_temp.transpose(-2, -1)
    o_k = torch.sum(c * probs_temp, -1)
    u_k = u[-1] + o_k
    u.append(u_k)

In [75]:
outputs = fc(dp(u_k))
outputs.size()

torch.Size([16, 19, 2])

# Memnet

In [17]:
def position_encoding(sentence_size, embedding_size):
    encoding = np.ones((embedding_size, sentence_size), dtype=np.float32)
    ls = sentence_size + 1
    le = embedding_size + 1
    for i in range(1, le):
        for j in range(1, ls):
            encoding[i - 1, j - 1] = (i - (le - 1) / 2) * (j - (ls - 1) / 2)
    encoding = 1 + 4 * encoding / embedding_size / sentence_size
    return np.transpose(encoding)


class MemN2N(nn.Module):
    def __init__(self,
                 batch_size,
                 vocab_size,
                 embedding_dim,
                 sentence_size,
                 memory_size,
                 hops,
                 num_classes,
                 dropout=0.5,
                 fix_embed=True,
                 name='MemN2N'):
        super(MemN2N, self).__init__()
        self.batch_size = batch_size
        self.embedding_dim = embedding_dim
        self.sentence_size = sentence_size
        self.memory_size = memory_size
        self.hops = hops
        self.num_classes = num_classes
        self.fix_embed = fix_embed
        self.name = name

        self.encoding = torch.from_numpy(position_encoding(1, sentence_size * embedding_dim))
        self.Embedding = nn.Embedding(vocab_size, embedding_dim)

        self.fc = nn.Linear(3 * embedding_dim, num_classes)
        self.dropout = nn.Dropout(dropout)
        # self.fc = nn.Sequential(
        #     nn.Linear(2 * self.sentence_rnn_size, linear_hidden_dim),
        #     nn.ReLU(inplace=True),
        #     nn.Dropout(dropout),
        #     nn.Linear(linear_hidden_dim, num_classes)
        # )

    def init_weights(self, embeddings):
        if embeddings is not None:
            self.Embedding.weight.data.copy_(embeddings)
#             self.Embedding = self.Embedding.from_pretrained(
#                 embeddings, freeze=self.fix_embed)

    def set_device(self, device):
        self.encoding = self.encoding.to(device)

    def forward(self, stories, queries):
        q_emb0 = self.Embedding(queries)
        q_emb = q_emb0.view(-1, 1, 3 * self.embedding_dim)
        u_0 = torch.sum(q_emb * self.encoding, 1)
        u = [u_0]

        for i in range(self.hops):
            m_emb0 = self.Embedding(stories)
            m_emb = m_emb0.view(-1, self.memory_size, 1, 3 * self.embedding_dim)
            m = torch.sum(m_emb * self.encoding, -2)

            u_temp = u[-1].unsqueeze(-1).transpose(-2, -1)
            dotted = torch.sum(m * u_temp, -1)
            probs = F.softmax(dotted, -1)
            probs_temp = probs.unsqueeze(-1).transpose(-2, -1)

            c_emb0 = self.Embedding(stories)
            c_emb = c_emb0.view(-1, self.memory_size, 1, 3 * self.embedding_dim)
            c_temp = torch.sum(c_emb * self.encoding, -2)
            c = c_temp.transpose(-2, -1)

            o_k = torch.sum(c * probs_temp, -1)
            u_k = u[-1] + o_k
            u.append(u_k)

        outputs = self.dropout(self.fc(u_k))
        return outputs
    
    def gradient_noise_and_clip(self, parameters, device,
                                 noise_stddev=1e-3, max_clip=40.0):
        parameters = list(filter(lambda p: p.grad is not None, parameters))
        norm = nn.utils.clip_grad_norm_(parameters, max_clip)

        for p in parameters:
            noise = torch.randn(p.size()) * noise_stddev
            noise = noise.to(device)
            p.grad.data.add_(noise)
        return norm

# Train

In [28]:
import sys
sys.path = ['..'] + sys.path
import numpy as np
import torch
from utils.dataset.memnet import MECDataset
from utils.dataloader.memnet import MECDataLoader
from utils.data.process import load_data, pad_sequence, pad_memory
import os

In [29]:
batch_size = 16
vocab_size = 19909
sentence_size = 3
memory_size = 40
embedding_dim = 20
num_classes = 2
hops = 3
dropout = 0.1

In [30]:
device = torch.device('cuda: 3')

In [31]:
model = MemN2N(batch_size, vocab_size, embedding_dim, sentence_size, memory_size, hops, num_classes, dropout=0.1)

In [32]:
criterion = nn.CrossEntropyLoss(reduction='sum').to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)

In [33]:
model = model.to(device)
model.set_device(device)

In [34]:
embeddings = pickle.load(open('/data/wujipeng/ec/data/embedding/memnet_embedding20d.pkl', 'rb')).astype(np.float32)
model.init_weights(torch.from_numpy(embeddings).to(device))

In [35]:
train_dataset = MECDataset(data_root='/data/wujipeng/ec/data/ltp_static/static.2/', vocab_root='/data/wujipeng/ec/data/raw_data/memnet_ltp_vocab.txt', batch_size=16, train=True)
eval_dataset = MECDataset(data_root='/data/wujipeng/ec/data/ltp_static/static.2/', vocab_root='/data/wujipeng/ec/data/raw_data/memnet_ltp_vocab.txt', batch_size=16, train=False)

In [36]:
train_loader = MECDataLoader(dataset=train_dataset, memory_size=memory_size, sequence_size=3, batch_size=16, shuffle=True, collate_fn=train_dataset.collate_fn)
eval_loader = MECDataLoader(dataset=eval_dataset, memory_size=memory_size, sequence_size=3, batch_size=16, shuffle=False, collate_fn=train_dataset.collate_fn)

In [37]:
model.train()
for epoch in range(25):
    losses = 0.
    for i, batch in enumerate(train_loader):
        optimizer.zero_grad()
        inputs = eval_dataset.batch2input(batch)

        clauses, keywords = inputs
        clauses = torch.from_numpy(clauses).to(device)
        keywords = torch.from_numpy(keywords).to(device)

        labels = eval_dataset.batch2target(batch)
        labels = torch.from_numpy(labels).to(device)

        outputs = model(clauses,keywords)
        probs = outputs.view(-1, outputs.size(-1))
        loss = criterion(probs, labels)
        losses += loss.item()
        loss.backward()

        grad_norm = model.gradient_noise_and_clip(model.parameters(), device, noise_stddev=1e-3, max_clip=40.0)

        optimizer.step()

        predicts = probs.max(dim=-1)[1]
        total, acc, pre, rec, f1, auc = metrics(predicts.tolist(), labels.tolist(), probs[:, 1].tolist())
        
        if i % 100 == 0:
            print('E: [{:3}/{}][{:3}/{}] L: {:.4f} A: {:.4f} '
                                     'P: {:.4f} R: {:.4f} F: {:.4f} N: {:.4f}'.format(epoch, 10, i, len(train_loader),
                                                                                      losses/(i+1), auc, pre, rec, f1,
                                                                                      grad_norm))
    model.eval()
    losses = 0.
    all_probs = []
    all_preds = []
    all_targets = []
    for batch in eval_loader:
        inputs = eval_dataset.batch2input(batch)

        clauses, keywords = inputs
        clauses = torch.from_numpy(clauses).to(device)
        keywords = torch.from_numpy(keywords).to(device)

        labels = eval_dataset.batch2target(batch)
        labels = torch.from_numpy(labels).to(device)


        outputs = model(clauses,keywords)
        probs = outputs.view(-1, outputs.size(-1))
        predicts = probs.max(dim=-1)[1]

        loss = criterion(probs.view(-1, probs.size(-1)), labels.view(-1))

        losses += loss

        probs = F.softmax(probs, dim=-1)

        all_probs += probs.tolist()
        all_preds += predicts.tolist()
        all_targets += labels.tolist()

        pos_probs = np.array(all_probs)[:, 1]

    total, acc, pre, rec, f1, auc = metrics(all_preds, all_targets, pos_probs)
    print('L: {:.4f} A: {:.4f} P: {:.4f} R: {:.4f} F: {:.4f}'.format(losses, auc, pre, rec, f1))

E: [  0/10][  0/1759] L: 9.1852 A: 0.8929 P: 0.0000 R: 0.0000 F: 0.0000 N: 23.1952
E: [  0/10][100/1759] L: 6.8052 A: 0.0000 P: 0.0000 R: 0.0000 F: 0.0000 N: 19.5181
E: [  0/10][200/1759] L: 6.0974 A: 0.6000 P: 0.0000 R: 0.0000 F: 0.0000 N: 5.3694
E: [  0/10][300/1759] L: 5.7711 A: 0.9667 P: 0.0000 R: 0.0000 F: 0.0000 N: 6.5366
E: [  0/10][400/1759] L: 5.4885 A: 0.0000 P: 0.0000 R: 0.0000 F: 0.0000 N: 12.4790
E: [  0/10][500/1759] L: 5.3106 A: 0.9667 P: 0.0000 R: 0.0000 F: 0.0000 N: 7.4174
E: [  0/10][600/1759] L: 5.1997 A: 0.8667 P: 0.0000 R: 0.0000 F: 0.0000 N: 2.8252
E: [  0/10][700/1759] L: 5.0862 A: 1.0000 P: 0.0000 R: 0.0000 F: 0.0000 N: 2.9057
E: [  0/10][800/1759] L: 4.9731 A: 0.7333 P: 0.0000 R: 0.0000 F: 0.0000 N: 2.9998
E: [  0/10][900/1759] L: 4.8507 A: 0.6667 P: 0.0000 R: 0.0000 F: 0.0000 N: 1.8835
E: [  0/10][1000/1759] L: 4.7985 A: 0.0000 P: 0.0000 R: 0.0000 F: 0.0000 N: 6.1600
E: [  0/10][1100/1759] L: 4.7486 A: 0.9333 P: 0.0000 R: 0.0000 F: 0.0000 N: 5.6694
E: [  0/10]

In [25]:
model.Embedding.weight

Parameter containing:
tensor([[ 0.0201, -0.0440, -0.0023,  ...,  0.0272,  0.0265, -0.0060],
        [ 0.0904, -0.0236, -0.1043,  ..., -0.1321, -0.0570,  0.2497],
        [-0.0183,  0.1107, -0.0237,  ..., -0.1831, -0.0702,  0.0026],
        ...,
        [ 0.0278,  0.0905,  0.0961,  ...,  0.0057, -0.1061, -0.1791],
        [-0.0235, -0.0097,  0.0345,  ...,  0.0516, -0.0535, -0.0009],
        [ 0.1544,  0.0738, -0.0548,  ..., -0.0089,  0.0208,  0.0780]],
       device='cuda:3', requires_grad=True)

In [29]:
model.encoding

tensor([[0.0333, 0.0667, 0.1000, 0.1333, 0.1667, 0.2000, 0.2333, 0.2667, 0.3000,
         0.3333, 0.3667, 0.4000, 0.4333, 0.4667, 0.5000, 0.5333, 0.5667, 0.6000,
         0.6333, 0.6667, 0.7000, 0.7333, 0.7667, 0.8000, 0.8333, 0.8667, 0.9000,
         0.9333, 0.9667, 1.0000, 1.0333, 1.0667, 1.1000, 1.1333, 1.1667, 1.2000,
         1.2333, 1.2667, 1.3000, 1.3333, 1.3667, 1.4000, 1.4333, 1.4667, 1.5000,
         1.5333, 1.5667, 1.6000, 1.6333, 1.6667, 1.7000, 1.7333, 1.7667, 1.8000,
         1.8333, 1.8667, 1.9000, 1.9333, 1.9667, 2.0000]], device='cuda:3')

# Metrics

In [24]:
from sklearn.metrics import roc_auc_score

def metrics(pred_labels, true_labels, probs, ignore_index=-100):
    """
        Args:
            pred_labels: (bat, n(s)), 0-3
            true_labels: (bat, n(s)), 0-3
        """
    if type(pred_labels[0]) != int:
        pred_labels = list(itertools.chain.from_iterable(pred_labels))
        true_labels = list(itertools.chain.from_iterable(true_labels))
    tp, tn, fp, fn = 0, 0, 0, 0
    all_pred, all_true, all_probs = [], [], []
    for i in range(len(pred_labels)):
        if true_labels[i] == ignore_index:
            continue
        if pred_labels[i] == true_labels[i]:
            if true_labels[i] == 0:
                tn += 1
            else:
                tp += 1
        else:
            if true_labels[i] == 0:
                fp += 1
            else:
                fn += 1
        all_pred.append(pred_labels[i])
        all_true.append(true_labels[i])
        all_probs.append(probs[i])
    acc = (tp + tn) / (tp + tn + fp + fn)
    pre = tp / (tp + fp) if (tp + fp) > 0 else 0
    rec = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * pre * rec / (pre + rec) if (pre + rec) > 0 else 0
    auc = roc_auc_score(all_true, all_probs) if sum(all_true) > 0 else 0.
    return tp + tn + fp + fn, acc, pre, rec, f1, auc

In [239]:
model.eval()
losses = 0.
all_probs = []
all_preds = []
all_targets = []

In [242]:
S, Q, A = [], [], []
for batch in eval_loader:
    inputs = eval_dataset.batch2input(batch)

    clauses, keywords = inputs
    S += clauses.tolist()
    Q += keywords.tolist()
    clauses = torch.from_numpy(clauses).to(device)
    keywords = torch.from_numpy(keywords).to(device)

    labels = eval_dataset.batch2target(batch)
    A += labels.tolist()
    labels = torch.from_numpy(labels).to(device)


    outputs = model(clauses,keywords)
    probs = outputs.view(-1, outputs.size(-1))
    predicts = probs.max(dim=-1)[1]
    
    loss = criterion(probs.view(-1, probs.size(-1)), labels.view(-1))
    
    losses += loss

    probs = F.softmax(probs, dim=-1)

    all_probs += probs.tolist()
    all_preds += predicts.tolist()
    all_targets += labels.tolist()


    pos_probs = np.array(all_probs)[:, 1]

total, acc, pre, rec, f1, auc = metrics(all_preds, all_targets, pos_probs)

In [243]:
outputs

tensor([[ 0.0879, -0.0325],
        [-0.0089, -0.0088],
        [-0.0081, -0.0391],
        [-0.0091, -0.0224],
        [-0.0479, -0.0298],
        [-0.0242, -0.0085],
        [-0.0275, -0.0233],
        [-0.0140, -0.0217],
        [-0.0184,  0.0200],
        [-0.0279, -0.0604],
        [ 0.0760,  0.0065],
        [-0.0200, -0.0136],
        [-0.0200, -0.0136],
        [-0.0200, -0.0136],
        [-0.0200, -0.0136],
        [-0.0200, -0.0136]], device='cuda:3', grad_fn=<AddmmBackward>)

In [244]:
probs

tensor([[0.5301, 0.4699],
        [0.5000, 0.5000],
        [0.5078, 0.4922],
        [0.5033, 0.4967],
        [0.4955, 0.5045],
        [0.4961, 0.5039],
        [0.4990, 0.5010],
        [0.5019, 0.4981],
        [0.4904, 0.5096],
        [0.5081, 0.4919],
        [0.5174, 0.4826],
        [0.4984, 0.5016],
        [0.4984, 0.5016],
        [0.4984, 0.5016],
        [0.4984, 0.5016],
        [0.4984, 0.5016]], device='cuda:3', grad_fn=<SoftmaxBackward>)

In [1]:
train_dataset

NameError: name 'train_dataset' is not defined