In [481]:
import itertools
import os
import re
from collections import Counter
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
# !pip install git+https://github.com/yumoh/torchcrf.git
from torchcrf import CRF
from tqdm import tqdm
import numpy as np

In [482]:
train_data_path = './train.txt'
# test_data_path = './test.txt'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [483]:
train_texts, train_tags = [], []
with open(train_data_path, 'rt', encoding='utf-8') as f:
    for line in f:
        if line is None or len(line) == 0:
            continue
        splits = line.split('/')
        train_texts.append(splits[0].strip())
        train_tags.append(splits[1].strip())

In [512]:
train_texts[:5], train_tags[:5]

(['江苏南通启东市黄金海滩景区',
  '江苏省镇江市句容市边城镇赵庄村150号',
  '江苏镇江新区金港大道98号',
  '江苏省镇江市句容某某镇莲花新村150号',
  '江苏省苏州市常熟市新海路月亮小区17幢1号'],
 ['aibiciidiiiii',
  'aiibiiciidiifiiiiii',
  'aibicieiiifii',
  'aiibiicidiifiiiiiii',
  'aiibiiciieiifiiiiiiii'])

In [485]:
counter = Counter(''.join(train_texts))
char2id = {c: i for i, (c, _) in enumerate(counter.most_common(None), 1)}  # 0 for pad(mask)
# a:province, b:city, c:district, d:town, e:road, f:home
tag2id = {'o': 0, 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'i': 7}

In [486]:
def text2seq(text, char2id):
    return [char2id.get(c, 0) for c in text]

def tag2seq(tag, tag2id):
    return [tag2id.get(c, 0) for c in tag]

def padding(l, pad_id=0):
    # 输入：[[1, 1, 1], [2, 2], [3]]
    # 返回：[(1, 2, 3), (1, 2, 0), (1, 0, 0)] 返回已经是转置后的 [L, B]
    return list(itertools.zip_longest(*l, fillvalue=pad_id))

def masking(l, pad_id=0):
    # 将targets里非pad部分标记为1，pad部分标记为0
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == pad_id:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

In [487]:
def data_gen(texts, tags, batch_size, char2id, tag2id):
    X_Y, X, Y = [], [], []
    i = 0
    for s, t in zip(texts, tags):
        x = text2seq(s, char2id)
        y = tag2seq(t, tag2id)
        if len(x) > 0 and len(y) > 0 and len(x) == len(y):
            i += 1
            X_Y.append((x, y))
        if len(X_Y) == batch_size or i == len(texts):
            X_Y.sort(key=lambda x: len(x[0]), reverse=True)
            X, Y = zip(*X_Y)
            X = padding(X)
            mask = masking(X)
            Y = padding(Y)
            yield torch.tensor(X, dtype=torch.long).to(device), \
                  torch.tensor(Y, dtype=torch.long).to(device), \
                  torch.tensor(mask, dtype=torch.uint8).to(device)
            X_Y = []

In [488]:
gen = data_gen(train_texts, train_tags, 10, char2id, tag2id)
# for X, Y, mask in gen:
#     print(X.size(), Y.size(), mask.size())

In [489]:
class BLSTM_CRF(nn.Module):
    def __init__(self, vocab_size, tag_size, embedding_size, hidden_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.embedding = nn.Embedding(vocab_size + 1, embedding_size)
        self.lstm = nn.LSTM(embedding_size, hidden_size, bidirectional=True)
        self.fc = nn.Linear(2 * hidden_size, tag_size)
        self.crf = CRF(tag_size)

    def _get_features(self, x, mask=None):
        x = self.embedding(x)  # [L, B, E]
        if mask is None:
            x, _ = self.lstm(x)  # [L, B, 2H]
        else:
            lengths = [torch.sum(mask[:, i] > 0).item() for i in range(mask.size(1))]
            packed = pack_padded_sequence(x, lengths)
            x, _ = self.lstm(packed)
            x, _ = pad_packed_sequence(x)
        x = F.elu(self.fc(x))  # [L, B, T]
        return x

    def get_loss(self, x, y, mask=None):
        x = self._get_features(x, mask)
        loss = self.crf(x, y, mask=mask)
        return -loss 

    def decode(self, x, mask=None):
        x = self._get_features(x)  # [L, B, 2H]
        x = self.crf.decode(x, mask)
        return x

In [490]:
def train(model, optimizer):
    for i in range(150):
#         print('epoch:', i)
        for X, Y, mask in data_gen(train_texts, train_tags, 5, char2id, tag2id):
            model.zero_grad()
            loss = model.get_loss(X, Y, mask)
            print('loss:{:.4f}'.format(loss.item()), end='\r')
            loss.backward()
            optimizer.step()

In [491]:
model = BLSTM_CRF(len(char2id), len(tag2id), 128, 128).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [492]:
train(model, optimizer)

loss:0.033310

In [507]:
def predict(text, model):
    seq = text2seq(text, char2id)
    x = torch.tensor(seq, dtype=torch.long).view(-1, 1).to(device)
    tags = model.decode(x)[0]
    tags = tags.numpy().tolist()
    print(tags)
    pieces = {}
    i = 0
    while i < len(tags):
        s = i
        if tags[i] == 1:
            key = 'provice'
        elif tags[i] == 2:
            key = 'city'
        elif tags[i] == 3:
            key = 'country'
        elif tags[i] == 4:
            key = 'town'
        elif tags[i] == 5:
            key = 'road'
        elif tags[i] == 6:
            key = 'house'
        else:
            i += 1
            continue
        j = i + 1
        while j < len(tags) and tags[j] == 7:
            j += 1
        pieces[key] = text[s: j]
        i = j
    return pieces

In [508]:
predict('常熟市某某路62号附近', model)

[3, 7, 7, 5, 7, 7, 6, 7, 7, 7, 7]


{'country': '常熟市', 'road': '某某路', 'house': '62号附近'}

In [509]:
predict('江苏苏州银行南通分行', model)

[1, 7, 0, 0, 0, 0, 2, 7, 0, 0]


{'provice': '江苏', 'city': '南通'}

In [510]:
predict('我家住在张家港市锦丰镇锦都花苑15幢103-105室', model)

[3, 7, 3, 3, 3, 7, 7, 7, 4, 7, 7, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7]


{'country': '张家港市', 'town': '锦丰镇', 'house': '锦都花苑15幢103-105室'}

In [511]:
predict('位于江苏省“人间天堂”之称的苏州，地址为：姑苏区林泉街道189号', model)

[1, 1, 1, 7, 7, 3, 7, 3, 3, 3, 3, 0, 0, 0, 2, 7, 0, 3, 3, 3, 3, 3, 7, 7, 5, 7, 7, 7, 6, 7, 7, 7]


{'provice': '江苏省',
 'country': '姑苏区',
 'city': '苏州',
 'road': '林泉街道',
 'house': '189号'}