In [1]:
import os
import pickle


## Data Preprocessing
### Default Data path

In [2]:
data_path = 'data'
train_origin_data = 'data.orig'
train_annotated_data = 'annotation.man.af'
entity_type_list = ["d", "s", "c", "i", "a", "b", "t", "p"]
all_data = ["raw_data.orig"]

In [3]:
def resolve_line(line_num, origin_line, line, entity_type_list):
    words = []
    tags = []
    for part in line.split(" "):
        part = part.strip("\n")
        if len(part) >= 2 and part[-2] == '\\' and part[-1] in entity_type_list:
            words.extend(list(part[ :-2]))
            part_tag = ["B-" + part[-1].upper()] + ["I-" + part[-1].upper()] * (len(part) - 3)
            tags.extend(part_tag)
        else:
            words.extend(list(part))
            part_tag = ["O"] * len(part)
            tags.extend(part_tag)
    assert len(words) == len(tags)
    if(list(origin_line.strip()) != words):
        print("At line {} origin and annotated line don't match! ".format(line_num))
    assert words == list(origin_line.strip())
    return words, tags

In [4]:
def read_corpora(data_path, train_origin_data,train_annotated_data,entity_type_list, save_path):
    data_origin = os.path.join(data_path, train_origin_data)
    annotation = os.path.join(data_path, train_annotated_data)
    with open(data_origin) as f1:
        origin_lines = f1.readlines()
    with open(annotation) as f2:
        annotated_lines = f2.readlines()
    assert len(origin_lines) == len(annotated_lines)
    train_data = []
    for i in range(len(origin_lines)):
        words, tags = resolve_line(i+1, origin_lines[i], annotated_lines[i], entity_type_list)
        train_data.append((words, tags))
    with open(save_path, "wb") as f:
        pickle.dump(train_data, f)
    return train_data

Check the preprocessed data existed or not

In [5]:
if os.path.exists(os.path.join(data_path, "train_data.pkl")):
    with open(os.path.join(data_path, "train_data.pkl"), "rb") as f:
        train_data = pickle.load(f)
    print('Data is loaded!')
else:
    train_data = read_corpora(data_path, train_origin_data,train_annotated_data, entity_type_list, os.path.join(data_path, "train_data.pkl"))
    print('Data is preprocessed!')

Data is loaded!


### Build vocab based on the original raw data

In [6]:
def build_vocab(all_data,data_path):
    lines = []
    for path in all_data:
        with open(os.path.join(data_path, path)) as f:
            lines_ = f.readlines()
        lines.extend(lines_)
    word2idx={}
    for line in lines:
        for word in line:
            if word not in word2idx:
                word2idx[word] = len(word2idx)
    return word2idx

In [7]:
# Build vocab 
word_to_ix = build_vocab(all_data,data_path)
print('Vocab is built!')

Vocab is built!


### Test data Load

In [8]:
test_data = "raw_data.orig"

In [9]:
# set testing data
if os.path.exists(os.path.join(data_path, test_data)):
    with open(os.path.join(data_path, test_data), "r") as f:
        test_data = f.readlines()
print('Test data is loaded!')

Test data is loaded!


In [10]:
print(f'test_data:{test_data[0]}')

test_data:左膝活动后疼痛10日,患者10日前活动后出现左膝关节疼痛，伴行走不便，以上下楼酸胀不适为主，平路步行可，无肿胀及关节畸形，无晨僵，无关节绞锁病史，无腰痛、下肢放射痛。近1周来关节症状无好转，无发热，无关节周围红肿。我院膝关节磁共振扫描（左）示：左膝关节半月板损伤。患者为进一步诊疗来我院，门诊以“左膝半月板损伤、左膝骨性关节炎”收入我科。患者自发病来精神、食欲可，大小便正常，体重未见明显降低。膝关节磁共振扫描（左）示（2016-6-14，北京清华长庚医院）：左膝骨关节病，外侧半月板撕裂，关节积液，滑膜炎，腘窝囊肿。,左膝关节对位如常，关节间隙略窄，关节缘、髁间棘及髌骨缘骨质增生。关节面硬化。左侧膝关节退变,实施了关节镜手术,服用了止痛类药物,使用了外用类药物,接受了康复治疗。



In [11]:
print(f'train data:{train_data[0]}')
print(len(train_data))
print(len(train_data[0]))
print(f'Data:{train_data[0][0]}')
print(f'train label:{train_data[0][1]}')



train data:(['左', '膝', '活', '动', '后', '疼', '痛', '1', '0', '日', ',', '患', '者', '1', '0', '日', '前', '活', '动', '后', '出', '现', '左', '膝', '关', '节', '疼', '痛', '，', '伴', '行', '走', '不', '便', '，', '以', '上', '下', '楼', '酸', '胀', '不', '适', '为', '主', '，', '平', '路', '步', '行', '可', '，', '无', '肿', '胀', '及', '关', '节', '畸', '形', '，', '无', '晨', '僵', '，', '无', '关', '节', '绞', '锁', '病', '史', '，', '无', '腰', '痛', '、', '下', '肢', '放', '射', '痛', '。', '近', '1', '周', '来', '关', '节', '症', '状', '无', '好', '转', '，', '无', '发', '热', '，', '无', '关', '节', '周', '围', '红', '肿', '。', '我', '院', '膝', '关', '节', '磁', '共', '振', '扫', '描', '（', '左', '）', '示', '：', '左', '膝', '关', '节', '半', '月', '板', '损', '伤', '。', '患', '者', '为', '进', '一', '步', '诊', '疗', '来', '我', '院', '，', '门', '诊', '以', '“', '左', '膝', '半', '月', '板', '损', '伤', '、', '左', '膝', '骨', '性', '关', '节', '炎', '”', '收', '入', '我', '科', '。', '患', '者', '自', '发', '病', '来', '精', '神', '、', '食', '欲', '可', '，', '大', '小', '便', '正', '常', '，', '体', '重', '未', '见', '明', '显', '降', '低', '。', '膝

In [12]:
import torch
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, data,tag_to_ix,to_ix,mode='Train'):

        self.data = data
        self.tag_to_ix = tag_to_ix
        self.to_ix = to_ix
        self.mode = mode

    def __len__(self):
        
        # return length
        return len(self.data)

    def make_sequence(self,sentence,to_ix):
        idxs = [to_ix[w] for w in sentence]
        return torch.tensor(idxs, dtype=torch.long)
    
    def get_label(self,tags):
        targets = torch.tensor([self.tag_to_ix[t] for t in tags], dtype=torch.long)
        return targets

    def __getitem__(self, i): 
        
        if self.mode == 'Train':
            text = self.data[i][0]
            tags = self.data[i][1]
            seq = self.make_sequence(text,self.to_ix)
            label = self.get_label(tags)
            return seq, label,text
        elif  self.mode == 'Test':
            text = self.data[i]
            split_text = text.strip("\n")
            seq = self.make_sequence(split_text,self.to_ix)
            return split_text,seq
        else:
            print('Wrong mode, please use "Train" or "Test" !')

        

In [13]:
# batch size should be equal to 1, because we didn't set the fixed sequence length in this project
# You will get error if batch size not equal to 1
tag_to_ix = { "O": 0,
              "B-D": 1, "I-D": 2,
              "B-S": 3, "I-S": 4,
              "B-T": 5, "I-T": 6,
              "B-I": 7, "I-I": 8,
              "B-C": 9, "I-C": 10,
              "B-A": 11, "I-A": 12,
              "B-B": 13, "I-B": 14,
              "B-P": 15, "I-P": 16,
              "<START>": 17, "<STOP>": 18}

batch_size = 1

TrainDataset = MyDataset(train_data,tag_to_ix,word_to_ix,'Train')
TestDataset = MyDataset(test_data,tag_to_ix,word_to_ix,'Test')

TrainLoader = DataLoader(TrainDataset, batch_size=batch_size, shuffle=False)
TestLoader = DataLoader(TestDataset, batch_size=batch_size, shuffle=False)

In [14]:
for index, (sentence,seq) in enumerate(TestLoader):
    print(f'index:{index}')
    print(f'sentence:{sentence}')
    print(f'seq:{seq}')
    break

index:0
sentence:('左膝活动后疼痛10日,患者10日前活动后出现左膝关节疼痛，伴行走不便，以上下楼酸胀不适为主，平路步行可，无肿胀及关节畸形，无晨僵，无关节绞锁病史，无腰痛、下肢放射痛。近1周来关节症状无好转，无发热，无关节周围红肿。我院膝关节磁共振扫描（左）示：左膝关节半月板损伤。患者为进一步诊疗来我院，门诊以“左膝半月板损伤、左膝骨性关节炎”收入我科。患者自发病来精神、食欲可，大小便正常，体重未见明显降低。膝关节磁共振扫描（左）示（2016-6-14，北京清华长庚医院）：左膝骨关节病，外侧半月板撕裂，关节积液，滑膜炎，腘窝囊肿。,左膝关节对位如常，关节间隙略窄，关节缘、髁间棘及髌骨缘骨质增生。关节面硬化。左侧膝关节退变,实施了关节镜手术,服用了止痛类药物,使用了外用类药物,接受了康复治疗。',)
seq:tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,   7,
           8,   9,  13,   2,   3,   4,  14,  15,   0,   1,  16,  17,   5,   6,
          18,  19,  20,  21,  22,  23,  18,  24,  25,  26,  27,  28,  29,  22,
          30,  31,  32,  18,  33,  34,  35,  20,  36,  18,  37,  38,  29,  39,
          16,  17,  40,  41,  18,  37,  42,  43,  18,  37,  16,  17,  44,  45,
          46,  47,  18,  37,  48,   6,  49,  26,  50,  51,  52,   6,  53,  54,
           7,  55,  56,  16,  17,  57,  58,  37,  59,  60,  18,  37,  61,  62,
          18,  37,  16,  17,  55,  63,  64,  38,  53,  65,  66,   1,  16,  17,


In [15]:
for index, (sentence, label,seq) in enumerate(TrainLoader):
    print(f'index:{index}')
    print(f'sentence:{sentence.shape}')
    print(f'seq:{seq}')
    print(f'tags:{label}')
    cache = sentence.squeeze(0)
    print(f'cache shape:{cache.shape}')
    break

index:0
sentence:torch.Size([1, 341])
seq:[('左',), ('膝',), ('活',), ('动',), ('后',), ('疼',), ('痛',), ('1',), ('0',), ('日',), (',',), ('患',), ('者',), ('1',), ('0',), ('日',), ('前',), ('活',), ('动',), ('后',), ('出',), ('现',), ('左',), ('膝',), ('关',), ('节',), ('疼',), ('痛',), ('，',), ('伴',), ('行',), ('走',), ('不',), ('便',), ('，',), ('以',), ('上',), ('下',), ('楼',), ('酸',), ('胀',), ('不',), ('适',), ('为',), ('主',), ('，',), ('平',), ('路',), ('步',), ('行',), ('可',), ('，',), ('无',), ('肿',), ('胀',), ('及',), ('关',), ('节',), ('畸',), ('形',), ('，',), ('无',), ('晨',), ('僵',), ('，',), ('无',), ('关',), ('节',), ('绞',), ('锁',), ('病',), ('史',), ('，',), ('无',), ('腰',), ('痛',), ('、',), ('下',), ('肢',), ('放',), ('射',), ('痛',), ('。',), ('近',), ('1',), ('周',), ('来',), ('关',), ('节',), ('症',), ('状',), ('无',), ('好',), ('转',), ('，',), ('无',), ('发',), ('热',), ('，',), ('无',), ('关',), ('节',), ('周',), ('围',), ('红',), ('肿',), ('。',), ('我',), ('院',), ('膝',), ('关',), ('节',), ('磁',), ('共',), ('振',), ('扫',), ('描',), ('（',), ('左',), ('）',

## Model

In [16]:
import torch
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
import pickle
START_TAG = "<START>"
STOP_TAG = "<STOP>"

class BiLSTM_CRF(nn.Module):

    def __init__(self, vocab_size, tag_to_ix, embedding_dim,hidden_dim,cuda_tag=False):
        super(BiLSTM_CRF, self).__init__()
        self.device = "cuda:0" if cuda_tag else "cpu" 
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.tag_to_ix = tag_to_ix
        self.tagset_size = len(tag_to_ix)

        self.word_embeds = nn.Embedding(vocab_size, self.embedding_dim)
        self.lstm = nn.LSTM(self.embedding_dim, self.hidden_dim // 2,
                            num_layers=1, bidirectional=True)

        self.hidden2tag = nn.Linear(self.hidden_dim, self.tagset_size)

        self.transitions = nn.Parameter(
            torch.randn(self.tagset_size, self.tagset_size)).to(self.device)

        self.transitions.data[tag_to_ix[START_TAG], :] = -10000
        self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000


    def init_hidden(self):
        return (torch.randn(2, 1, self.hidden_dim // 2).to(self.device),
                torch.randn(2, 1, self.hidden_dim // 2).to(self.device))

    def _forward_alg(self, feats):
        # Do the forward algorithm to compute the partition function
        init_alphas = torch.full((1, self.tagset_size), -10000., device=self.device)
        # START_TAG has all of the score.
        init_alphas[0][self.tag_to_ix[START_TAG]] = 0.

        # Wrap in a variable so that we will get automatic backprop
        forward_var = init_alphas

        # Iterate through the sentence
        for feat in feats:
            alphas_t = []  # The forward tensors at this timestep
            for next_tag in range(self.tagset_size):
                # broadcast the emission score: it is the same regardless of
                # the previous tag
                emit_score = feat[next_tag].view(
                    1, -1).expand(1, self.tagset_size)
                # the ith entry of trans_score is the score of transitioning to
                # next_tag from i
                trans_score = self.transitions[next_tag].view(1, -1)
                # The ith entry of next_tag_var is the value for the
                # edge (i -> next_tag) before we do log-sum-exp
                next_tag_var = forward_var + trans_score + emit_score

                alphas_t.append(torch.logsumexp(next_tag_var,dim=1).view(1))

            forward_var = torch.cat(alphas_t).view(1, -1)
        terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
        alpha = torch.logsumexp(terminal_var,dim=1)[0]

        return alpha

    def _get_lstm_features(self, sentence):
        self.hidden = self.init_hidden()
        embeds = self.word_embeds(sentence).view(len(sentence), 1, -1)
        lstm_out, self.hidden = self.lstm(embeds, self.hidden)
        lstm_out = lstm_out.view(len(sentence), self.hidden_dim)
        lstm_feats = self.hidden2tag(lstm_out)
        return lstm_feats

    def _score_sentence(self, feats, tags):
        # Gives the score of a provided tag sequence
        score = torch.zeros(1).to(self.device)
        tags = torch.cat([torch.tensor([self.tag_to_ix[START_TAG]], dtype=torch.long, device=self.device), tags])
        for i, feat in enumerate(feats):
            score = score + \
                self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]
        score = score + self.transitions[self.tag_to_ix[STOP_TAG], tags[-1]]
        return score

    def _viterbi_decode(self, feats):
        backpointers = []

        # Initialize the viterbi variables in log space
        init_vvars = torch.full((1, self.tagset_size), -10000.,device=self.device)
        init_vvars[0][self.tag_to_ix[START_TAG]] = 0

        # forward_var at step i holds the viterbi variables for step i-1
        forward_var = init_vvars
        for feat in feats:
            bptrs_t = []  # holds the backpointers for this step
            viterbivars_t = []  # holds the viterbi variables for this step

            for next_tag in range(self.tagset_size):

                next_tag_var = forward_var + self.transitions[next_tag]

                best_tag_id = torch.argmax(next_tag_var).item()

                bptrs_t.append(best_tag_id)
                viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
            # Now add in the emission scores, and assign forward_var to the set
            # of viterbi variables we just computed
            forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
            backpointers.append(bptrs_t)

        # Transition to STOP_TAG
        terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]

        best_tag_id = torch.argmax(terminal_var).item()


        path_score = terminal_var[0][best_tag_id]

        # Follow the back pointers to decode the best path.
        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)
        # Pop off the start tag (we dont want to return that to the caller)
        start = best_path.pop()
        assert start == self.tag_to_ix[START_TAG]  # Sanity check
        best_path.reverse()
        return path_score, best_path

    def neg_log_likelihood(self, sentence, tags):
        feats = self._get_lstm_features(sentence)
        forward_score = self._forward_alg(feats)
        gold_score = self._score_sentence(feats, tags)
        return forward_score - gold_score

    def forward(self, sentence):
        # Get the emission scores from the BiLSTM
        lstm_feats = self._get_lstm_features(sentence)

        # Find the best path, given the features.
        score, tag_seq = self._viterbi_decode(lstm_feats)
        return score, tag_seq

In [17]:
init_model = ""

embedding_dim = 300
hidden_dim = 600
cuda_tag = False

### Model training and saving

### Model initialization

In [18]:
# init model for training
if init_model != str(""):
    model = torch.load(os.path.join(data_path, init_model))
    print('Model loading is done!')
else:
    model = BiLSTM_CRF(len(word_to_ix), tag_to_ix, embedding_dim,hidden_dim,cuda_tag)
    print('New model is created!')
if cuda_tag:
    model = model.cuda() 
    device = "cuda:0"
else:
    device = "cpu" 

print('Model initialization is done!')
print('Use %s to train'%device)

New model is created!
Model initialization is done!
Use cpu to train


### Model training

In [19]:
# train and save model 
print(f'device:{device}')

epoch_num = 1

optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-4)

for epoch in range(epoch_num):
    # for index, (sentence, tags) in enumerate(train_data):
    for index, (sentence, tags,_) in enumerate(TrainLoader):

        model.zero_grad()

        # Inputs ready for the network, that is,turn them into Tensors of word indices.

        # sentence_in = prepare_sequence_torch(sentence, word_to_ix).to(device)
        # targets = torch.tensor([tag_to_ix[t[0]] for t in tags], dtype=torch.long).to(device)
        # print(f'sentence shape:{sentence.shape}')
        sentence_in = sentence.squeeze(0).to(device)
        targets = tags.squeeze(0).to(device)
        loss = model.neg_log_likelihood(sentence_in, targets)

        loss.backward()
        optimizer.step()
        if index % 20 == 0:
            print("iteration {}/{} completed, loss: {:.4f}".format(index, len(train_data), loss.item()))
    print("epoch {}/{} completed, loss: {:.4f}".format(epoch+1, epoch_num, loss.item()))

with open(os.path.join(data_path, "model.pkl"), "wb") as f:
    torch.save(model, f)
print('Model is saved!')

device:cpu
iteration 0/197 completed, loss: 1073.4303
iteration 20/197 completed, loss: 196.7087
iteration 40/197 completed, loss: 347.7083
iteration 60/197 completed, loss: 4.6668
iteration 80/197 completed, loss: 77.8952
iteration 100/197 completed, loss: 333.7498
iteration 120/197 completed, loss: 324.1519
iteration 140/197 completed, loss: 181.4973
iteration 160/197 completed, loss: 137.3928
iteration 180/197 completed, loss: 73.8999
epoch 1/1 completed, loss: 247.7454
Model is saved!


In [20]:
def decode_tags(sent, ixs, ix_to_tag):
    tags = []
    # print(f'ixs:{ixs}')
    # print(f'ix_to_tag:{ix_to_tag}')
    for ix in ixs:
        tags.append(ix_to_tag[ix])

    # correct the annotated tags for the sentence
    for i, tag in enumerate(tags):
        if tag[0] == "I" and not (tags[i - 1] == str("B-" + tag[-1]) or tags[i - 1] == tag):
            tags[i] = "O"
    
    # output the annotated sentence
    i = 0
    output = []
    while i < len(tags):
        if tags[i][0] == "B":
            output.append(" ")
            end = i + 1
            while end < len(tags) and tags[end] == "I" + tags[i][1: ]:
                end += 1
            output.extend(sent[i: end]) 
            output.append("\\" + tags[i][-1].lower() + " ")
            i = end
        else:
            output.extend(sent[i])
            i += 1
    return "".join(output)

In [21]:
# test model
test_output = "output"


with torch.no_grad():
    
    ix_to_tag = {} 
    for key in tag_to_ix:
        ix_to_tag[tag_to_ix[key]] = key
    # print(f'ix_to_tag:{ix_to_tag}')
    annotated_lines = []
    for index, (line,seq) in enumerate(TestLoader):
        with torch.no_grad():
            
            encoded_sent = seq.squeeze(0)
            tag_ixs = model(encoded_sent)[1]
            annotated_line = decode_tags(line[0].strip("\n"), tag_ixs, ix_to_tag) 
        annotated_lines.append(annotated_line.strip() + "\n")
        if index % 20 == 0:
            print("{}/{} items complete testing".format(index, len(test_data)))

with open(os.path.join(data_path, test_output), "w") as f:
    for line in annotated_lines:
        f.write(line)
print('Output file is saved!')

0/446 items complete testing
20/446 items complete testing
40/446 items complete testing
60/446 items complete testing
80/446 items complete testing
100/446 items complete testing
120/446 items complete testing
140/446 items complete testing
160/446 items complete testing
180/446 items complete testing
200/446 items complete testing
220/446 items complete testing
240/446 items complete testing
260/446 items complete testing
280/446 items complete testing
300/446 items complete testing
320/446 items complete testing
340/446 items complete testing
360/446 items complete testing
380/446 items complete testing
400/446 items complete testing
420/446 items complete testing
440/446 items complete testing
Output file is saved!
