In [11]:
import json
from random import choice
import copy
import random

entity_path = "../entity/"
csv_path = "../csv/"
train_path = "../train/"
intent_path = "../intent/"

import csv

def read_csv(file, column):
    reader = csv.reader(open(csv_path + file, 'r', encoding='utf-8'), delimiter="\t")
    result = []
    for i in reader:
        result.append(i[column-1])
        
    return result

def load_json(path, file):
    with open(path + file + ".json", 'r', encoding='utf-8') as jsonfile:
        return json.load(jsonfile)

def load_entity(file):
    return load_json(entity_path, file)

def load_intent(file):
    return load_json(intent_path, file)

samples_dict = {}


def find_entity(entities, obj_name):
    for i in entities:
        if i['entity'] == obj_name:
            return i

def handle_enum_type(entity):
    e = entity['enum']
    name = entity['entity']
    if name not in samples_dict:
        samples_dict[name] = read_csv(e['source'], e['column'])
        
    return len(samples_dict[name]), entity, []
        
def handle_compound_type(js, entity):
    children = []
    total_number = 0
    for i in entity['compound']:
        c, sample, tp = handle_entity(js, i['type'])
        total_number = max(c, total_number)
        children.append((i['name'], sample, tp))
        
    return total_number, entity, children
        
def handle_choice_type(js, entity):
    children = []
    total_number = 0
    for i in entity['choice']:
        c = handle_entity(js, i)
        children.append(c)
        total_number += c[0]
        
    return total_number, entity, children   

def handle_entity(js, c):
    entity = find_entity(js, c)
    if 'enum' in entity:
        return handle_enum_type(entity)
    elif 'compound' in entity:
        return handle_compound_type(js, entity)
    elif 'choice' in entity:
        return handle_choice_type(js, entity)
    
    print("ERROR", entity)
    return 0,0,0



def get_generated_number_of_samples(total, n, allocated):
    number = (n * allocated) // total
    return max(10, number)

def generate_enum_samples(entity, n_samples):
    name = entity['entity']
    samples = [choice(samples_dict[name]) for i in range(n_samples)]
    return name, samples, 'enum'

def split_pattern(pattern):
    pos = pattern.find('@{')
    if pos < 0:
        return [pattern]
    
    end = pattern.find('}')
    
    result = []
    if pos > 0:
        result += [pattern[:pos]]
        
    result += [pattern[pos:end+1]]
    
    if end < len(pattern) - 1:
        result += split_pattern(pattern[end+1:])
    
    return  result

def get_pattern(pattern):
    result = split_pattern(pattern)
    
    tags = {}
    for s, i in zip(result, range(len(result))):
        if s.startswith('@{'):
            tags[s] = i
    
    return tags, result


def get_pos(sample, index):
    start = 0
    for i in range(index):
        start += len(sample[i])
        
    return start, len(sample[index])
        
def generate_compound_samples(entity, children, n_samples):
    samples = []

    for child in children:
        name = child[0]
        this_entity = child[1]
        this_children = child[2]
        if 'enum' in this_entity:
            _, sample, tp = generate_enum_samples(this_entity, n_samples)
            samples.append((name, sample, tp))
        elif 'compound' in this_entity:
            _, sample, tp = flatten_generate_compound_samples(this_entity, this_children, n_samples)
            samples.append((name, sample, tp))
        elif 'choice' in this_entity:
            _, sample, tp = flatten_generate_choice_samples(this_entity, this_children, n_samples)
            samples.append((name, sample, tp))
    
    #print("samples: ", samples, n_samples)
    
    mandatory_dict = {}
    for ent in entity['compound']:
        #print(ent)
        mandatory_dict[ent['name']] = ent['mandatory']
    
    patterns = entity['patterns']
    all_samples = []
    #print(n_samples)
    for p in patterns:
        pattern = get_pattern(p)
        #print(pattern)
        
        for i in range(n_samples):
            s = copy.deepcopy(pattern[1])
            #print(s)
            #record = {}
            tags = copy.deepcopy(pattern[0])
           
            for sample_group in samples:
                sample_group_data = sample_group[1]
                tag = '@{' + sample_group[0] + '}'
                
                if tag in tags: 
                    #print("===: ", i, tags[tag])
                    select_data = sample_group_data[i]
                    if (not mandatory_dict[sample_group[0]]) and choice([False, True]):
                        select_data = ""
                    
                    s[tags[tag]] = select_data
                
            for k in tags:
                tags[k] = get_pos(s, tags[k])
            
            #print(str(len(s)) + " : " + str(n_samples))
            #print("ss: ", s, tags)
            all_samples.append((tags, ''.join(s)))
            
    
    random.shuffle(all_samples)
    all_samples = all_samples[:n_samples]

    #print(all_samples)
    return entity['entity'], all_samples, 'compound'
    
def flatten_generate_compound_samples(entity, children, n_samples):
    e, samples, _ = generate_compound_samples(entity, children, n_samples)
    return entity['entity'], [i[1] for i in samples], 'enum'
    
def flatten_samples(entity, samples):
    results = []
    for s in samples:
        #print(s)
        typ = s[2]
        if typ == 'compound':
            results += [i[1] for i in s[1]]
        elif typ == 'enum':
            results += s[1]
        else:
            print("ERROR")
            
    return entity, results, 'enum'
    
def flatten_generate_choice_samples(entity, children, n_samples):
    e, samples, _ = generate_choice_samples(entity, children, n_samples)
    return flatten_samples(e, samples)
            

def generate_choice_samples(entity, children, n_samples):
    all_samples = []
    for i in sorted(children, key=lambda x: x[0]):
        s = (n_samples // len(children)) + 1 #get_generated_number_of_samples(n, i[0], m)
        
        ent = i[1]
        if 'enum' in ent:
            r = generate_enum_samples(ent, s)
        elif 'compound' in ent:
            r = flatten_generate_compound_samples(ent, i[2], s)
        elif 'choice' in ent:
            r = flatten_generate_choice_samples(ent, i[2], s)
        else:
            r = None
        
        all_samples.append(r)

    return entity['entity'], all_samples, 'choice'
    

def generate_sample(tree, n_samples):
    entry = tree[1]
    children = tree[2]
    if 'choice' in entry:
        return generate_choice_samples(entry, children, n_samples)
    elif 'compound' in entry:
        return generate_compound_samples(entry, children, n_samples)
    elif 'enum' in entry:
        return generate_enum_samples(entry, n_samples)
    
    print("ERROR")
    
    return []

def make_enum_label(samples, label):
    result = []
    for sample in samples:
        s = []
        for i in sample:
            s.append([i, label])
        result.append(s)
        
    return result

def make_compound_label(samples, label):
    result = []
    for sample in samples:
        s = []
        for i in sample[1]:
            s.append([i, label])
        result.append(s)
        
    return result

def generated_labeled_choice_sample(samples):
    result = []
    for i in range(len(samples)):
        label = "C" + str(i)
        ss = samples[i]
        if ss[2] == 'enum':
            result += make_enum_label(ss[1], label)
        elif ss[2] == 'compound':
            result += make_compound_label(ss[1], label)
            
    return result


def make_enum_test_data(samples):
    result = []
    for sample in samples:
        s = []
        for i in sample:
            s.append([i])
        s.append(["。"])
        result.append(s)
        
    return result
    
def make_compound_test_label(samples):
    result = []
    for sample in samples:
        s = []
        for i in sample[1]:
            s.append([i])
        s.append(["。"])
        result.append(s)
        
    return result
    
def generate_test_data(samples):
    result = []
    for i in range(len(samples)):
        ss = samples[i]
        if ss[2] == 'enum':
            result += make_enum_test_data(ss[1])
        elif ss[2] == 'compound':
            result += make_compound_test_label(ss[1])
            
    return result

range_tags = ["A", "B","C","D","E","F","G","H"]

def generate_compound_labels(entity):
    labels = {}
    
    print(entity)
    elems = entity['compound']
    
    for n, i in zip(elems, range(len(elems))):
        labels["@{" + n["name"] + "}"] = range_tags[i]
        
    return labels

def label_compound_samples(entity, samples):
    results = []
    
    labels = generate_compound_labels(entity)
    
    #print(samples)
    for i in samples:
        label = i[0]
        data = [[c, 'O'] for c in i[1]]

        for tag in label:
            begin, size = label[tag]
            #print("---", begin, size)
            for k in range(begin, begin+size):
                data[k][1] = labels[tag]

        data += [['。', 'O']]
        
        results += data
        
    
    #print(data)
    return results

def label_choice_samples(samples):
    results = []
    
    print(samples)
    for s, i in zip(samples, range(len(samples))):
        for d in s[1]:
            results += [[c, range_tags[i]] for c in d]
            results += [['。' , 'O']]
    
    return results

def do_write_csv_data(entity, samples, ty):
    with open(train_path + entity + "/" + ty + '.data', 'w', encoding='utf-8') as csvfile:
        my_writer = csv.writer(csvfile, delimiter='\t',lineterminator='\n')
        for i in samples:
            my_writer.writerow(i)
            
def do_write_training_data(entity, samples):
    print(samples)
    do_write_csv_data(entity, samples, "train")

def do_write_test_data(entity, samples):
    do_write_csv_data(entity, samples, "test")

        
def write_training_data(js, samples):
    entity = samples[0]
    typ = samples[2]
    data = samples[1]
    random.shuffle(data)

    if typ == 'compound':
        do_write_training_data(entity, label_compound_samples(find_entity(js, entity), data))
    elif typ == 'choice':
        do_write_training_data(entity, label_choice_samples(data))

def make_samples(js, entity_name, num):
    return generate_sample(handle_entity(js, entity_name), num)
    
def make_training_file(js, entity_name, num):
    write_training_data(js, make_samples(js, entity_name, num))

def make_compound_test_samples(samples):
    results = []
    #print(samples)
    for i in samples:
        results += i[1]
        results.append(['。'])

    return results

def make_choice_test_samples(samples):
    results = []
    
    for s in samples:
        for d in s[1]:
            results += [[c] for c in d]
            results += [['。']]
    
    return results


def write_test_data(samples):
    entity = samples[0]
    typ = samples[2]
    data = samples[1]
    random.shuffle(data)
    
    if typ == 'compound':
        do_write_test_data(entity, make_compound_test_samples(data))
    elif typ == 'choice':
        do_write_test_data(entity, make_choice_test_samples(data))

def make_test_file(js, entity_name, num):
    write_test_data(make_samples(js, entity_name, num))

def make_entity_data(js, entity_name, train_n=100, test_n=10):
    make_training_file(js, entity_name, train_n)
    make_test_file(js, entity_name, test_n)
    
    
js = load_entity('datetime') + load_entity('date') + load_entity('time') 
js += load_entity('city') + load_entity('ticket')
js += load_intent('book_ticket')
#result
#get_pattern("从@{from}到@{end}sh")

In [6]:
make_entity_data(js, "datetime",1000)

{'entity': 'datetime', 'compound': [{'name': 'date', 'type': 'date', 'mandatory': True}, {'name': 'time', 'type': 'time', 'mandatory': False}], 'patterns': ['@{date}@{time}']}
[['下', 'A'], ['星', 'A'], ['期', 'A'], ['三', 'A'], ['。', 'O'], ['今', 'A'], ['天', 'A'], ['傍', 'B'], ['晚', 'B'], ['。', 'O'], ['二', 'A'], ['号', 'A'], ['。', 'O'], ['今', 'A'], ['天', 'A'], ['凌', 'B'], ['晨', 'B'], ['。', 'O'], ['四', 'A'], ['日', 'A'], ['中', 'B'], ['午', 'B'], ['一', 'B'], ['点', 'B'], ['四', 'B'], ['十', 'B'], ['二', 'B'], ['分', 'B'], ['左', 'B'], ['右', 'B'], ['。', 'O'], ['明', 'A'], ['天', 'A'], ['。', 'O'], ['今', 'A'], ['天', 'A'], ['傍', 'B'], ['晚', 'B'], ['。', 'O'], ['二', 'A'], ['零', 'A'], ['一', 'A'], ['七', 'A'], ['年', 'A'], ['一', 'A'], ['月', 'A'], ['三', 'A'], ['号', 'A'], ['。', 'O'], ['一', 'A'], ['九', 'A'], ['三', 'A'], ['六', 'A'], ['年', 'A'], ['下', 'A'], ['个', 'A'], ['月', 'A'], ['十', 'A'], ['五', 'A'], ['号', 'A'], ['。', 'O'], ['上', 'A'], ['上', 'A'], ['礼', 'A'], ['拜', 'A'], ['二', 'A'], ['。', 'O'], ['四', 'A'], ['日', '

In [12]:
make_entity_data(js, "time", 1000)

TypeError: flatten_generate_choice_samples() missing 1 required positional argument: 'n_samples'

In [8]:
make_entity_data(js, "range-time", 100)

{'entity': 'range-time', 'compound': [{'name': 'from', 'type': 'single-time', 'mandatory': True}, {'name': 'to', 'type': 'single-time', 'mandatory': True}], 'patterns': ['@{from}到@{to}', '从@{from}到@{to}', '@{from}至@{to}之间']}
[['上', 'A'], ['午', 'A'], ['四', 'A'], ['点', 'A'], ['零', 'A'], ['五', 'A'], ['分', 'A'], ['左', 'A'], ['右', 'A'], ['到', 'O'], ['下', 'B'], ['午', 'B'], ['十', 'B'], ['五', 'B'], ['点', 'B'], ['三', 'B'], ['十', 'B'], ['七', 'B'], ['分', 'B'], ['。', 'O'], ['下', 'A'], ['午', 'A'], ['十', 'A'], ['五', 'A'], ['点', 'A'], ['二', 'A'], ['十', 'A'], ['七', 'A'], ['分', 'A'], ['至', 'O'], ['凌', 'B'], ['晨', 'B'], ['四', 'B'], ['点', 'B'], ['左', 'B'], ['右', 'B'], ['之', 'O'], ['间', 'O'], ['。', 'O'], ['从', 'O'], ['傍', 'A'], ['晚', 'A'], ['二', 'A'], ['十', 'A'], ['三', 'A'], ['点', 'A'], ['零', 'A'], ['五', 'A'], ['分', 'A'], ['左', 'A'], ['右', 'A'], ['到', 'O'], ['凌', 'B'], ['晨', 'B'], ['十', 'B'], ['一', 'B'], ['点', 'B'], ['四', 'B'], ['十', 'B'], ['四', 'B'], ['分', 'B'], ['左', 'B'], ['右', 'B'], ['。', 'O'], ['中', 

In [9]:
make_entity_data(js, "general-city", 100)

{'entity': 'general-city', 'compound': [{'name': 'province', 'type': 'province', 'mandatory': False}, {'name': 'city', 'type': 'city', 'mandatory': True}], 'patterns': ['@{province}@{city}', '@{province}@{city}市']}
[['广', 'A'], ['东', 'A'], ['梅', 'B'], ['河', 'B'], ['口', 'B'], ['。', 'O'], ['广', 'A'], ['西', 'A'], ['思', 'B'], ['茅', 'B'], ['。', 'O'], ['广', 'A'], ['西', 'A'], ['湘', 'B'], ['潭', 'B'], ['市', 'O'], ['。', 'O'], ['浙', 'A'], ['江', 'A'], ['本', 'B'], ['溪', 'B'], ['。', 'O'], ['河', 'B'], ['津', 'B'], ['市', 'O'], ['。', 'O'], ['盐', 'B'], ['城', 'B'], ['。', 'O'], ['桂', 'B'], ['林', 'B'], ['。', 'O'], ['宁', 'A'], ['夏', 'A'], ['武', 'B'], ['安', 'B'], ['市', 'O'], ['。', 'O'], ['新', 'A'], ['疆', 'A'], ['思', 'B'], ['茅', 'B'], ['市', 'O'], ['。', 'O'], ['重', 'B'], ['庆', 'B'], ['。', 'O'], ['临', 'B'], ['河', 'B'], ['市', 'O'], ['。', 'O'], ['山', 'A'], ['东', 'A'], ['建', 'B'], ['阳', 'B'], ['市', 'O'], ['。', 'O'], ['内', 'A'], ['蒙', 'A'], ['儋', 'B'], ['州', 'B'], ['市', 'O'], ['。', 'O'], ['广', 'A'], ['东', 'A'], ['锦'

In [45]:
make_entity_data(js, "book_ticket", 100)

{'entity': 'book_ticket', 'compound': [{'name': 'datetime', 'type': 'datetime', 'mandatory': True}, {'name': 'from', 'type': 'general-city', 'mandatory': False}, {'name': 'to', 'type': 'general-city', 'mandatory': True}, {'name': 'ticket', 'type': 'ticket', 'mandatory': True}], 'patterns': ['我想订@{datetime}@{from}到@{to}的@{ticket}', '帮我看看@{datetime}从@{from}到@{to}的@{ticket}', '查一下@{datetime}从@{from}到@{to}的@{ticket}', '查下@{datetime}到@{to}的@{ticket}', '问下@{datetime}去@{to}的@{ticket}', '查下@{datetime}从@{from}去@{to}的@{ticket}', '帮我看看到@{to}的@{ticket}的情况', '帮我看看到@{to}的@{ticket}还有吗', '帮我看看有没有@{datetime}去@{to}的@{ticket}', '@{datetime}@{from}到@{to}的@{ticket}', '@{datetime}去@{to}的@{ticket}']}


In [10]:
make_samples(js, 'single-time', 10)

('single-time',
 ['早晨十三点四十二分左右',
  '下午二十一点零二分左右',
  '凌晨十一点三十七分左右',
  '早晨一点五十四分左右',
  '上午十五点三十八分',
  '中午三点四十三分左右',
  '晚上三点五十七分左右',
  '上午十一点五十二分',
  '傍晚十八点三十四分左右',
  '十八点十三分'],
 'enum')