In [182]:
import json
from random import choice
import copy
import random

path = "../entity/"
csv_path = "../csv/"
train_path = "../train/"

import csv

def read_csv(file, column):
    reader = csv.reader(open(csv_path + file, 'r', encoding='utf-8'), delimiter="\t")
    result = []
    for i in reader:
        result.append(i[column-1])
        
    return result

def load_json(file):
    with open(path + file + ".json", 'r', encoding='utf-8') as jsonfile:
        return json.load(jsonfile)

samples_dict = {}


def find_entity(entities, obj_name):
    for i in entities:
        if i['entity'] == obj_name:
            return i

def handle_enum_type(entity):
    e = entity['enum']
    name = entity['entity']
    if name not in samples_dict:
        samples_dict[name] = read_csv(e['source'], e['column'])
        
    return len(samples_dict[name]), entity
        
def handle_compound_type(js, entity):
    children = []
    total_number = 0
    for i in entity['compound']:
        c, rest = handle_entity(js, i['type'])
        total_number = max(c, total_number)
        children.append((i['name'], rest))
        
    return total_number, entity, children
        
def handle_choice_type(js, entity):
    children = []
    total_number = 0
    for i in entity['choice']:
        c = handle_entity(js, i)
        children.append(c)
        total_number += c[0]
        
    return total_number, entity, children   

def handle_entity(js, c):
    entity = find_entity(js, c)
    if 'enum' in entity:
        return handle_enum_type(entity)
    elif 'compound' in entity:
        return handle_compound_type(js, entity)
    elif 'choice' in entity:
        return handle_choice_type(js, entity)
    else:
        print("ERROR", entity)
        return 0,0



def get_generated_number_of_samples(total, n, allocated):
    number = (n * allocated) // total
    return max(10, number)

def generate_enum_samples(entity, n_samples):
    name = entity['entity']
    samples = [choice(samples_dict[name]) for i in range(n_samples)]
    return name, samples, 'enum'

def split_pattern(pattern):
    pos = pattern.find('@{')
    if pos < 0:
        return [pattern]
    
    end = pattern.find('}')
    
    result = []
    if pos > 0:
        result += [pattern[:pos]]
        
    result += [pattern[pos:end+1]]
    
    if end < len(pattern) - 1:
        result += split_pattern(pattern[end+1:])
    
    return  result

def get_pattern(pattern):
    result = split_pattern(pattern)
    
    tags = {}
    for s, i in zip(result, range(len(result))):
        if s.startswith('@{'):
            tags[s] = i
    
    return tags, result


def get_pos(sample, index):
    start = 0
    for i in range(index):
        start += len(sample[i])
        
    return start, len(sample[index])
        
def generate_compound_samples(entity, children, n_samples):
    samples = []

    for c in children:
        name = c[0]
        if 'enum' in c[1]:
            _, sample, tp = generate_enum_samples(c[1], n_samples)
            samples.append((name, sample, tp))
        elif 'compound' in c[1]:
            _, sample, tp = generate_compound_samples(c[1], c[2], n_samples)
            samples.append((name, sample, tp))
        elif 'choice' in c[1]:
            _, sample, tp = generate_choice_samples(c[2], n_samples, n_samples)
            samples.append((name, sample, tp))
            
    patterns = entity['patterns']
    all_samples = []
    #print(n_samples)
    for p in patterns:
        pattern = get_pattern(p)
        
        for i in range(n_samples):
            s = copy.deepcopy(pattern[1])
            record = {}
            tags = copy.deepcopy(pattern[0])
            for j in samples:
                tag = '@{' + j[0] + '}'
                s[tags[tag]] = j[1][i]
            
            for k in tags:
                tags[k] = get_pos(s, tags[k])
                  
            all_samples.append((tags, "".join(s)))
    
    random.shuffle(all_samples)
    all_samples = all_samples[:n_samples]

    return entity['entity'], all_samples, 'compound'
    
        
def generate_choice_samples(children, n_samples, allocated):
    m = allocated
    n = n_samples
    all_samples = []
    total = 0
    for i in sorted(children, key=lambda x: x[0]):
        s = allocated // len(children) #get_generated_number_of_samples(n, i[0], m)
        n -= s
        m -= s
    
        entity = i[1]
        if 'enum' in entity:
            r = generate_enum_samples(entity, s)
        elif 'compound' in entity:
            r = generate_compound_samples(entity, i[2], s)
        elif 'choice' in entity:
            r = generate_choice_samples(entity, s, allocated)
        else:
            r = None
        
        all_samples.append(r)

    return all_samples
    

def generate_sample(tree, generated_samples):
    n_samples = tree[0]
    entry = tree[1]
    if 'choice' in entry:
        return generate_choice_samples(tree[2], n_samples, generated_samples)
    
    return []

def make_enum_label(samples, label):
    result = []
    for sample in samples:
        s = []
        for i in sample:
            s.append([i, label])
        result.append(s)
        
    return result

def make_compound_label(samples, label):
    result = []
    for sample in samples:
        s = []
        for i in sample[1]:
            s.append([i, label])
        result.append(s)
        
    return result

def generated_labled_sample(samples):
    result = []
    for i in range(len(samples)):
        label = "C" + str(i)
        ss = samples[i]
        if ss[2] == 'enum':
            result += make_enum_label(ss[1], label)
        elif ss[2] == 'compound':
            result += make_compound_label(ss[1], label)
            
    return result


def make_enum_test_data(samples):
    result = []
    for sample in samples:
        s = []
        for i in sample:
            s.append([i])
        s.append(["。"])
        result.append(s)
        
    return result
    
def make_compound_test_label(samples):
    result = []
    for sample in samples:
        s = []
        for i in sample[1]:
            s.append([i])
        s.append(["。"])
        result.append(s)
        
    return result
    
def generate_test_data(samples):
    result = []
    for i in range(len(samples)):
        ss = samples[i]
        if ss[2] == 'enum':
            result += make_enum_test_data(ss[1])
        elif ss[2] == 'compound':
            result += make_compound_test_label(ss[1])
            
    return result

#result
#get_pattern("从@{from}到@{end}sh")

In [185]:
#total_samples = result[0]
key = "time"    
js = load_json(key)

time_anylize = handle_entity(js, key)

labelled = generated_labled_sample(generate_sample(time_anylize, 100))
random.shuffle(labelled)

time_data = []
for i in labelled:
    time_data += i
    time_data.append(['。' , 'O'])


with open(train_path + 'time/' + 'train.data', 'w', encoding='utf-8') as csvfile:
    my_writer = csv.writer(csvfile, delimiter='\t')
    for i in time_data:
        my_writer.writerow(i)

In [186]:
test_sample = generate_test_data(generate_sample(time_anylize, 10))
random.shuffle(test_sample)

time_test_data = []
for i in test_sample:
    time_test_data += i

with open(train_path + 'time/' + 'test.data', 'w', encoding='utf-8') as csvfile:
    my_writer = csv.writer(csvfile, delimiter='\t')
    for i in time_test_data:
        my_writer.writerow(i)
