In [30]:
import json
from random import choice
import copy
import random

entity_path = "../entity/"
csv_path = "../csv/"
train_path = "../train/"
intent_path = "../intent/"

import csv

def read_chinese():
    reader = csv.reader(open("../raw-data/chinese.txt", 'r', encoding='utf-8'), delimiter=" ")
    result = []
    for i in reader:
        result += i
        
    return result

noise_lib = read_chinese()

def read_csv(file, column):
    reader = csv.reader(open(csv_path + file, 'r', encoding='utf-8'), delimiter="\t")
    result = []
    for i in reader:
        result.append(i[column-1])
        
    return result

def load_json(path, file):
    with open(path + file + ".json", 'r', encoding='utf-8') as jsonfile:
        return json.load(jsonfile)

def load_entity(file):
    return load_json(entity_path, file)

def load_intent(file):
    return load_json(intent_path, file)

samples_dict = {}


def find_entity(entities, obj_name):
    for i in entities:
        if i['entity'] == obj_name:
            return i

def handle_enum_type(entity):
    e = entity['enum']
    name = entity['entity']
    if name not in samples_dict:
        samples_dict[name] = read_csv(e['source'], e['column'])
        
    return len(samples_dict[name]), entity, []
        
def handle_compound_type(js, entity):
    children = []
    total_number = 0
    for i in entity['compound']:
        c, sample, tp = handle_entity(js, i['type'])
        total_number = max(c, total_number)
        children.append((i['name'], sample, tp))
        
    return total_number, entity, children
        
def handle_choice_type(js, entity):
    children = []
    total_number = 0
    for i in entity['choice']:
        c = handle_entity(js, i)
        children.append(c)
        total_number += c[0]
        
    return total_number, entity, children   

def handle_entity(js, c):
    entity = find_entity(js, c)
    if 'enum' in entity:
        return handle_enum_type(entity)
    elif 'compound' in entity:
        return handle_compound_type(js, entity)
    elif 'choice' in entity:
        return handle_choice_type(js, entity)
    
    print("ERROR", entity)
    return 0,0,0



def get_generated_number_of_samples(total, n, allocated):
    number = (n * allocated) // total
    return max(10, number)

def generate_enum_samples_by_pattern(patterns, samples, n_samples):
    result = []
    for p in patterns:
        pos = p.find("@{this}")
        result += [((pos, len(s)), p.replace('@{this}', s)) for s in samples]
        
    random.shuffle(result)
    
    return result[:n_samples]
    
def flatten_generate_enum_samples(entity, n_samples):
    name, samples, typ = generate_enum_samples(entity, n_samples)
    return name, [s[1] for s in samples], typ
        
    
def generate_enum_samples(entity, n_samples):
    name = entity['entity']
    samples = [choice(samples_dict[name]) for i in range(n_samples)]
    
    if "patterns" in entity:
        samples = generate_enum_samples_by_pattern(entity['patterns'], samples, n_samples)
    else:
        samples = [((0, len(s)), s) for s in samples]
        
    return name, samples, 'enum'

def split_pattern(pattern):
    pos = pattern.find('@{')
    if pos < 0:
        return [pattern]
    
    end = pattern.find('}')
    
    result = []
    if pos > 0:
        result += [pattern[:pos]]
        
    result += [pattern[pos:end+1]]
    
    if end < len(pattern) - 1:
        result += split_pattern(pattern[end+1:])
    
    return  result

def get_pattern(pattern):
    result = split_pattern(pattern)
    
    tags = {}
    for s, i in zip(result, range(len(result))):
        if s.startswith('@{'):
            tags[s] = i
    
    return tags, result


def get_pos(sample, index):
    start = 0
    for i in range(index):
        start += len(sample[i])
        
    return start, len(sample[index])


    
def generate_compound_samples(entity, children, n_samples):
    samples = []

    for child in children:
        name = child[0]
        this_entity = child[1]
        this_children = child[2]
        if 'enum' in this_entity:
            _, sample, tp = flatten_generate_enum_samples(this_entity, n_samples)
            samples.append((name, sample, tp))
        elif 'compound' in this_entity:
            _, sample, tp = flatten_generate_compound_samples(this_entity, this_children, n_samples)
            samples.append((name, sample, tp))
        elif 'choice' in this_entity:
            _, sample, tp = flatten_generate_choice_samples(this_entity, this_children, n_samples)
            samples.append((name, sample, tp))
    
    #print("samples: ", samples, n_samples)
    
    #mandatory_dict = {}
    #for ent in entity['compound']:
        #print(ent)
    #    mandatory_dict[ent['name']] = ent['mandatory']
    
    patterns = entity['patterns']
    all_samples = []
    #print(n_samples)
    for p in patterns:
        pattern = get_pattern(p)
        #print(pattern)
        
        for i in range(n_samples):
            s = copy.deepcopy(pattern[1])
            #print(s)
            #record = {}
            tags = copy.deepcopy(pattern[0])
           
            for sample_group in samples:
                sample_group_data = sample_group[1]
                tag = '@{' + sample_group[0] + '}'
                
                if tag in tags: 
                    #print("===: ", i, tags[tag])
                    select_data = sample_group_data[i]
                    #if (not mandatory_dict[sample_group[0]]) and choice([False, True]):
                    #    select_data = ""
                    
                    s[tags[tag]] = select_data
                
            for k in tags:
                tags[k] = get_pos(s, tags[k])
            
            #print(str(len(s)) + " : " + str(n_samples))
            #print("ss: ", s, tags)
            all_samples.append((tags, ''.join(s)))
            
    
    random.shuffle(all_samples)
    all_samples = all_samples[:n_samples]

    #print(all_samples)
    return entity['entity'], all_samples, 'compound'
    
def flatten_generate_compound_samples(entity, children, n_samples):
    e, samples, _ = generate_compound_samples(entity, children, n_samples)
    all_samples = [i[1] for i in samples]
    random.shuffle(all_samples)
    return entity['entity'], all_samples, 'enum'
    
def flatten_samples(entity, samples):
    results = []
    for s in samples:
        #print(s)
        typ = s[2]
        if typ == 'compound':
            results += [i[1] for i in s[1]]
        elif typ == 'enum':
            results += s[1]
        else:
            print("ERROR")
    
    random.shuffle(results)
    
    return entity, results, 'enum'
    
def flatten_generate_choice_samples(entity, children, n_samples):
    e, samples, _ = generate_choice_samples(entity, children, n_samples)
    return flatten_samples(e, samples)
            

def generate_choice_samples(entity, children, n_samples):
    all_samples = []
    for i in sorted(children, key=lambda x: x[0]):
        s = (n_samples // len(children)) + 1 #get_generated_number_of_samples(n, i[0], m)
        
        ent = i[1]
        if 'enum' in ent:
            r = flatten_generate_enum_samples(ent, s)
        elif 'compound' in ent:
            r = flatten_generate_compound_samples(ent, i[2], s)
        elif 'choice' in ent:
            r = flatten_generate_choice_samples(ent, i[2], s)
        else:
            r = None
        
        all_samples.append(r)

    return entity['entity'], all_samples, 'choice'
    

def generate_sample(tree, n_samples):
    entry = tree[1]
    children = tree[2]
    if 'choice' in entry:
        return generate_choice_samples(entry, children, n_samples)
    elif 'compound' in entry:
        return generate_compound_samples(entry, children, n_samples)
    elif 'enum' in entry:
        return generate_enum_samples(entry, n_samples)
    
    print("ERROR")
    
    return []

def make_enum_label(samples, label):
    result = []
    for sample in samples:
        s = []
        for i in sample:
            s.append([i, label])
        result.append(s)
        
    return result

def make_compound_label(samples, label):
    result = []
    for sample in samples:
        s = []
        for i in sample[1]:
            s.append([i, label])
        result.append(s)
        
    return result

def generated_labeled_choice_sample(samples):
    result = []
    for i in range(len(samples)):
        label = "C" + str(i)
        ss = samples[i]
        if ss[2] == 'enum':
            result += make_enum_label(ss[1], label)
        elif ss[2] == 'compound':
            result += make_compound_label(ss[1], label)
            
    return result


def make_enum_test_data(samples):
    result = []
    for sample in samples:
        s = []
        for i in sample:
            s.append([i])
        s.append(["。"])
        result.append(s)
        
    return result
    
def make_compound_test_label(samples):
    result = []
    for sample in samples:
        s = []
        for i in sample[1]:
            s.append([i])
        s.append(["。"])
        result.append(s)
        
    return result
    
def generate_test_data(samples):
    result = []
    for i in range(len(samples)):
        ss = samples[i]
        if ss[2] == 'enum':
            result += make_enum_test_data(ss[1])
        elif ss[2] == 'compound':
            result += make_compound_test_label(ss[1])
            
    return result

range_tags = ["A", "B","C","D","E","F","G","H"]

def generate_compound_labels(entity):
    labels = {}
    
    print(entity)
    elems = entity['compound']
    
    for n, i in zip(elems, range(len(elems))):
        labels["@{" + n["name"] + "}"] = range_tags[i]
        
    return labels

def make_noise():
    num = choice(range(1,5,1))
    return [[choice(noise_lib), 'O'] for i in range(num)]

def make_noised_sample(data):
    return make_noise() + data + make_noise() + [['。', 'O']]
    
def label_compound_samples(entity, samples):
    results = []
    
    labels = generate_compound_labels(entity)
    
    #print(samples)
    for i in samples:
        label = i[0]
        data = [[c, 'O'] for c in i[1]]

        for tag in label:
            begin, size = label[tag]
            #print("---", begin, size)
            for k in range(begin, begin+size):
                data[k][1] = labels[tag]
                
        results += make_noised_sample(data)
        
    
    #print(data)
    return results

def label_choice_samples(samples):
    results = []
    
    print(samples)
    for s, i in zip(samples, range(len(samples))):
        for d in s[1]:
            data = [[c, range_tags[i]] for c in d]
            results += make_noised_sample(data)
    
    return results

def do_write_csv_data(entity, samples, ty):
    with open(train_path + entity + "/" + ty + '.data', 'w', encoding='utf-8') as csvfile:
        my_writer = csv.writer(csvfile, delimiter='\t',lineterminator='\n')
        for i in samples:
            my_writer.writerow(i)
            
def do_write_training_data(entity, samples):
    do_write_csv_data(entity, samples, "train")

def do_write_test_data(entity, samples):
    do_write_csv_data(entity, samples, "test")

        
def write_training_data(js, samples):
    entity = samples[0]
    typ = samples[2]
    data = samples[1]
    random.shuffle(data)

    if typ == 'compound':
        do_write_training_data(entity, label_compound_samples(find_entity(js, entity), data))
    elif typ == 'choice':
        do_write_training_data(entity, label_choice_samples(data))

def make_samples(js, entity_name, num):
    return generate_sample(handle_entity(js, entity_name), num)
    
def make_training_file(js, entity_name, num):
    write_training_data(js, make_samples(js, entity_name, num))

def make_compound_test_samples(samples):
    results = []
    #print(samples)
    for i in samples:
        results += i[1]
        results.append(['。'])

    return results

def make_choice_test_samples(samples):
    results = []
    
    for s in samples:
        for d in s[1]:
            results += [[c] for c in d]
            results += [['。']]
    
    return results


def write_test_data(samples):
    entity = samples[0]
    typ = samples[2]
    data = samples[1]
    random.shuffle(data)
    
    if typ == 'compound':
        do_write_test_data(entity, make_compound_test_samples(data))
    elif typ == 'choice':
        do_write_test_data(entity, make_choice_test_samples(data))
    elif typ == 'enum':
        do_write_test_data(entity, make_enum_test_samples(data))

def make_test_file(js, entity_name, num):
    write_test_data(make_samples(js, entity_name, num))

def make_entity_data(js, entity_name, train_n=100, test_n=10):
    make_training_file(js, entity_name, train_n)
    make_test_file(js, entity_name, test_n)
    
    
js = load_entity('datetime') + load_entity('date') + load_entity('time') 
js += load_entity('city') + load_entity('ticket')
js += load_intent('book_ticket')
#result
#get_pattern("从@{from}到@{end}sh")

In [21]:
make_samples(js, 'huge-muni-city', 10)

('huge-muni-city',
 [((0, 2), '重庆市'),
  ((0, 2), '重庆市'),
  ((0, 2), '重庆市'),
  ((0, 2), '重庆'),
  ((0, 2), '重庆'),
  ((0, 2), '重庆'),
  ((0, 2), '重庆'),
  ((0, 2), '重庆市'),
  ((0, 2), '重庆市'),
  ((0, 2), '重庆')],
 'enum')

In [27]:
make_entity_data(js, "datetime",10000)

{'entity': 'datetime', 'compound': [{'name': 'date', 'type': 'date', 'mandatory': True}, {'name': 'time', 'type': 'time', 'mandatory': False}], 'patterns': ['@{date}@{time}', '@{date}', '@{time}']}


In [8]:
make_entity_data(js, "time", 100)

[('range-time', ['十一点整左右到中午三点五十六分左右', '从十六点半左右到十点二十分', '十三点二十一分到凌晨十八点零七分', '从上午一点二十四分左右到三点四十四分左右', '中午十九点三十九分左右到上午二十三点五十七分左右', '凌晨二十点四十九分左右到二点三刻左右', '十五点整到二十三点半', '九点一刻左右至十六点三刻左右之间', '凌晨二十点四十九分左右至二点三刻左右之间', '六点二十二分左右到晚上二十点四十分左右', '零点三十一分到十七点零八分', '十二点半至早晨二十一点零四分之间', '从十六点三刻左右到下午十点零九分', '从九点零九分左右到二十二点半', '从六点一刻到三点四十七分左右', '从三点十三分左右到下午二十点四十六分左右', '凌晨二十三点四十四分至八点三刻左右之间', '从晚上一点十四分到一点整', '二十点半至八点半左右之间', '四点半到十五点半', '十三点三刻到十点整左右', '中午十九点三十九分左右至上午二十三点五十七分左右之间', '从二十三点半左右到二十点三十九分', '早晨十九点三十分左右到七点十三分', '从十五点整到二十三点半', '六点一刻到三点四十七分左右', '三点十三分左右至下午二十点四十六分左右之间', '六点一刻至三点四十七分左右之间', '二十点半到八点半左右', '二十点二十七分至十点三刻之间', '二十三点十五分左右到早晨十二点二十三分左右', '上午一点二十四分左右到三点四十四分左右', '下午十五点三十五分至上午八点三十五分左右之间', '从十九点四十八分左右到五点四十七分'], 'enum'), ('single-time', ['二点整左右', '十五点半左右', '晚上十九点三十三分左右', '早晨九点三十五分', '二十四点五十七分', '十九点四十二分左右', '五点二十五分左右', '晚上九点半', '十七点三刻左右', '七点半', '十八点三刻', '凌晨一点四十五分', '四点', '凌晨十九点零一分左右', '七点左右', '下午十点五十分', '三点一刻', '十五点零四分左右', '十七点十七分', '十九点三十分', '十九点三十二分左右', '四点一刻', '十八点五十五分', '下午二点二十八分', '晚上十七点零二分', '上午二十

In [9]:
make_entity_data(js, "range-time", 100)

{'entity': 'range-time', 'compound': [{'name': 'from', 'type': 'single-time', 'mandatory': True}, {'name': 'to', 'type': 'single-time', 'mandatory': True}], 'patterns': ['@{from}到@{to}', '从@{from}到@{to}', '@{from}至@{to}之间']}


In [16]:
make_entity_data(js, "province-city", 20000)

{'entity': 'province-city', 'compound': [{'name': 'province', 'type': 'province', 'mandatory': False}, {'name': 'city', 'type': 'city', 'mandatory': True}], 'patterns': ['@{province}@{city}']}


In [15]:
make_entity_data(js, "general-city", 20000)

[('muni-city', ['北京市', '天津', '重庆市', '北京市', '天津市', '重庆', '天津', '上海', '北京市', '重庆市', '北京', '上海', '重庆市', '上海市', '北京', '重庆市', '重庆市', '重庆市', '重庆市', '重庆', '重庆市', '北京', '重庆', '重庆市', '上海市', '北京', '上海市', '重庆市', '北京市', '上海市', '上海市', '重庆', '重庆市', '重庆', '天津', '天津', '上海市', '上海市', '重庆市', '重庆', '重庆市', '北京', '重庆', '重庆市', '重庆市', '重庆', '重庆', '重庆市', '重庆市', '上海市', '重庆市', '天津市', '天津', '重庆市', '上海', '上海市', '重庆市', '重庆市', '天津市', '北京市', '重庆市', '重庆市', '重庆', '重庆', '上海', '重庆', '北京市', '重庆', '重庆市', '重庆市', '重庆市', '上海市', '天津市', '上海市', '重庆', '重庆市', '北京市', '北京', '重庆市', '北京市', '上海', '重庆', '天津市', '重庆市', '上海市', '重庆', '北京', '北京市', '重庆市', '重庆', '天津市', '上海市', '重庆', '重庆市', '天津市', '重庆', '重庆市', '北京市', '重庆', '重庆市', '重庆', '重庆市', '天津', '重庆', '重庆市', '重庆市', '北京市', '北京', '重庆市', '天津', '上海', '天津', '重庆市', '重庆', '重庆', '重庆市', '天津市', '天津', '上海市', '天津', '重庆市', '重庆市', '重庆', '天津市', '北京', '重庆市', '重庆', '上海市', '重庆', '重庆市', '重庆市', '上海市', '重庆', '天津市', '重庆市', '天津市', '北京', '重庆市', '天津市', '上海市', '重庆', '重庆', '重庆市', '北京市', '重庆市', '北京', '上海', '上海市', '北京', 

In [32]:
make_entity_data(js, "book_ticket", 10000)

{'entity': 'book_ticket', 'compound': [{'name': 'datetime', 'type': 'datetime', 'mandatory': True}, {'name': 'from', 'type': 'general-city', 'mandatory': False}, {'name': 'to', 'type': 'general-city', 'mandatory': True}, {'name': 'ticket', 'type': 'ticket', 'mandatory': True}], 'patterns': ['我想订@{datetime}@{from}到@{to}的@{ticket}', '帮我看看@{datetime}从@{from}到@{to}的@{ticket}', '查一下@{datetime}从@{from}到@{to}的@{ticket}', '查下@{datetime}到@{to}的@{ticket}', '问下@{datetime}去@{to}的@{ticket}', '看下@{datetime}去@{to}的@{ticket}', '看下有@{datetime}去@{to}的@{ticket}吗', '查下@{datetime}从@{from}去@{to}的@{ticket}', '帮我看看到@{to}的@{ticket}的情况', '帮我看看到@{to}的@{ticket}还有吗', '帮我看看有没有@{datetime}去@{to}的@{ticket}', '@{datetime}@{from}到@{to}的@{ticket}', '@{datetime}去@{to}的@{ticket}']}


In [29]:
make_samples(js, 'single-time', 10)

('single-time',
 [('usual-time', ['八点一刻左右', '零点整左右', '九点一刻左右', '三点整'], 'enum'),
  ('exact-time', ['二十二点四十五分', '二十三点三十七分左右', '三点零一分', '六点十五分左右'], 'enum'),
  ('qualified-time',
   ['上午二十点二十分左右', '傍晚三点零八分', '下午十五点五十五分', '傍晚八点四十分'],
   'enum')],
 'choice')