In [1]:
from das.distributed_atom_space import DistributedAtomSpace, QueryOutputFormat
from das.database.db_interface import UNORDERED_LINK_TYPES
from das.pattern_matcher.pattern_matcher import PatternMatchingAnswer, OrderedAssignment, UnorderedAssignment, CompositeAssignment, Node, Link, Variable, Not, And, Or, TypedVariable, LinkTemplate
from das.database.db_interface import WILDCARD
from das.expression_hasher import ExpressionHasher
import warnings
import numpy as np
import time
import random
from itertools import combinations
warnings.filterwarnings('ignore')
TARGET_NODES = None
das = DistributedAtomSpace()
db = das.db
das.count_atoms()

Log initialized. Log file: /tmp/das.log


(2584508, 27871440)

In [2]:
def get_gene_node(name):
    verbatim_node = das.get_node("Verbatim", name)
    schema_node = das.get_node("Schema", "Schema:sql_gene_name")
    print(f"verbatim_node = {verbatim_node}")
    print(f"schema_node = {schema_node}")
    v1 = Variable("v1")
    links = das.get_links("Execution", None, [schema_node, WILDCARD, verbatim_node])
    print(f"links = {links}")
    link = das.get_atom(links[0], output_format=QueryOutputFormat.ATOM_INFO)
    print(f"link = {link}")
    gene_node_handle = link["targets"][1]
    print(f"gene_node_handle = {gene_node_handle}")
    gene_node = das.get_atom(gene_node_handle, output_format=QueryOutputFormat.ATOM_INFO)
    print(f"gene_node = {gene_node}")
    return Node("gene", gene_node["name"])

print(das.get_node("gene", "3106709"))

80ef77c79ab33f7a7e5d3070a09ded02


In [3]:
GENE_LIST = [
    "mud"
]

In [4]:
USE_SUBSTRING = False

if USE_SUBSTRING:
    TARGET_TYPE = "Concept"
    TARGET_SUBSTRING = "gl"
else:
    TARGET_NODES = [
        get_gene_node(gene) for gene in GENE_LIST
    ]

NGRAM = 3
SUPPORT = 0
HALO_LENGTH = 2
DEPTH_WEIGTH = [1, 1]
ISURPRISINGNESS_REPORT_THRESHOLD = 0
EPOCHS = 1000
NORMALIZED_ISURPRISINGNESS = False
LINK_RATE = 0.01

verbatim_node = 7ce10bc07c9bee3d5c4b0075b56eee2f
schema_node = 7494788453289a15a95f81a916c9cc21
links = ['61841d9b332a12dc9e094c924052636e']
link = {'handle': '61841d9b332a12dc9e094c924052636e', 'type': 'Execution', 'template': ['Execution', 'Schema', 'gene', 'Verbatim'], 'targets': ['7494788453289a15a95f81a916c9cc21', '728c20ff5bbd6dac888840801842d303', '7ce10bc07c9bee3d5c4b0075b56eee2f']}
gene_node_handle = 728c20ff5bbd6dac888840801842d303
gene_node = {'handle': '728c20ff5bbd6dac888840801842d303', 'type': 'gene', 'name': '8347745'}


In [5]:
assert len(DEPTH_WEIGTH) == HALO_LENGTH
halo_levels = [i for i in range(HALO_LENGTH)]
if TARGET_NODES is None:
    atomspace_nodes = db.get_matched_node_name(TARGET_TYPE, TARGET_SUBSTRING)
    print(atomspace_nodes)
    TARGET_NODES = [Node(TARGET_TYPE, db.get_node_name(h)) for h in atomspace_nodes]
print(f"TARGET_NODES = {TARGET_NODES}")

TARGET_NODES = [<gene: 8347745>]


In [6]:
def print_ordered_assignment(assignment):
    if assignment is not None:
        for key, value in assignment.mapping.items():
            print(f"{key}: {db.get_node_name(value)}")

def print_unordered_assignment(assignment):
    if assignment is not None:
        symbols = []
        for key in assignment.symbols:
            for i in range(assignment.symbols[key]):
                symbols.append(key)
        values = []
        for key in assignment.values:
            for i in range(assignment.values[key]):
                values.append(key)
        mapping_keys = []
        mapping_values = []
        for symbol, value in zip(symbols, values):
            mapping_keys.append(symbol)
            mapping_values.append(db.get_node_name(value))
        print(f"{mapping_keys} = {mapping_values}")
        
def build_pattern_from_template(template):
    targets = []
    count_variables = 1
    for target in template[1:]:
        if target == WILDCARD:
            targets.append(Variable(f"V{count_variables}"))
            count_variables += 1
        else:
            #node_document = das.get_atom(target, output_format=QueryOutputFormat.ATOM_INFO)
            try:
                node_type = das.get_node_type(target)
                node_name = das.get_node_name(target)
                targets.append(Node(node_type, node_name))
            except:
                return None
    return Link(template[0], ordered=(template[0] not in UNORDERED_LINK_TYPES), targets=targets)

def _random_selection(v):
    return v[np.random.randint(len(v))]

def random_selection(v, n=1):
    if n == 1:
        return _random_selection(v)
    assert n <= (len(v) / 2)
    a = v.copy()
    selected = []
    for i in range(n):
        s = _random_selection(a)
        a.remove(s)
        selected.append(s)
    return selected

def build_roulette(w):
    answer = []
    s = sum(w)
    acc = 0
    for v in w:
        acc += v / s
        answer.append(acc)
    answer[-1] = 1
    return answer

def roulette_selection(v, w):
    assert len(v) == len(w)
    random = np.random.random()
    for i in range(len(v)):
        if random <= w[i]:
            return v[i]
    
def compute_count(logical_expression):
    query_answer = PatternMatchingAnswer()
    matched = logical_expression.matched(db, query_answer)
    return len(query_answer.assignments) if matched else 0
        
def prob(count):
    return count / universe_size

def compute_isurprisingness(count, terms, term_handles, counts, normalized = False):
    n = len(term_handles)
    if n == 2:
        subset_probs = [prob(counts[0]) * prob(counts[1])]
    elif n == 3:
        subset_probs = [
            prob(counts[0]) * prob(counts[1]) * prob(counts[2]),
            prob(compute_count(And([terms[0], terms[1]]))) * prob(counts[2]), 
            prob(compute_count(And([terms[0], terms[2]]))) * prob(counts[1]),
            prob(compute_count(And([terms[1], terms[2]]))) * prob(counts[0])
        ]
    elif n == 4:
        subset_probs = [
            prob(counts[0]) * prob(counts[1]) * prob(counts[2]) * prob(counts[3]),
            prob(compute_count(And([terms[0], terms[1]]))) * prob(compute_count(And([terms[2], terms[3]]))),
            prob(compute_count(And([terms[0], terms[2]]))) * prob(compute_count(And([terms[1], terms[3]]))),
            prob(compute_count(And([terms[0], terms[3]]))) * prob(compute_count(And([terms[1], terms[2]]))),
            prob(compute_count(And([terms[0], terms[1], terms[2]]))) * prob(counts[3]),
            prob(compute_count(And([terms[0], terms[1], terms[3]]))) * prob(counts[2]),
            prob(compute_count(And([terms[0], terms[2], terms[3]]))) * prob(counts[1]),
            prob(compute_count(And([terms[1], terms[2], terms[3]]))) * prob(counts[0])
        ]
    else:
        raise NotImplementedError()
    p = prob(count)
    isurprisingness = max([p - max(subset_probs), min(subset_probs) - p])
    if normalized:
        return isurprisingness / p
    else:
        return isurprisingness
    
def build_patterns(links):
    chunk_size = 1000
    pattern = {}
    pattern_count = {}
    link_count = 0
    for link in links:
        link_count += 1
        if link_count % chunk_size == 0 or link_count == 1 or link_count == len(links):
            if link_count != 1 and link_count != len(links):
                end = time.perf_counter()
                wall_time = f"{(end - start):.0f} seconds"
                time_per_query = f"{(((end - start) * 1000) / (8 * chunk_size)):.0f} ms/query"
                print(f"link {link_count}/{len(links)} {wall_time} {time_per_query}")
            else:
                print(f"link {link_count}/{len(links)}")
            start = time.perf_counter()
#         link_document = das.get_atom(link, output_format=QueryOutputFormat.ATOM_INFO)
#         targets = link_document['targets']
#         link_type = link_document['type']
        targets = das.get_link_targets(link)
        link_type = das.get_link_type(link)
        arity = len(targets)
        if arity == 2:
            templates = [
                [link_type, WILDCARD, targets[1]],
                [link_type, targets[0], WILDCARD],
                #[link_type, WILDCARD, WILDCARD],
            ]
        elif arity == 3:
            templates = [
                [link_type, WILDCARD, targets[1], targets[2]],
                [link_type, targets[0], WILDCARD, targets[2]],
                [link_type, targets[0], targets[1], WILDCARD],
                [link_type, WILDCARD, WILDCARD, targets[2]],
                [link_type, WILDCARD, targets[1], WILDCARD],
                [link_type, targets[0], WILDCARD, WILDCARD],
                #[link_type, WILDCARD, WILDCARD, WILDCARD],
            ]
        else:
            raise NotImplementedError()
        for template in templates:
            p = build_pattern_from_template(template)
            if p is not None:
                template_handle = ExpressionHasher.composite_hash(template)
                pattern[template_handle] = p
                pattern_count[template_handle] = len(das.get_links(template[0], None, template[1:]))
    return tuple([pattern, pattern_count])
        
def build_composite_pattern(terms):
    assert len(terms) > 1
    for i in range(len(terms)):
        if i == 0:
            first_term = terms[i]
        else:
            second_term = terms[i]
            composite_pattern = And([first_term, second_term])
            first_term = composite_pattern
    return composite_pattern
    
def print_query(pattern):
    print(pattern)
    query_answer = PatternMatchingAnswer()
    pattern.matched(db, query_answer)
    for assignment in query_answer.assignments:
        if type(assignment) is OrderedAssignment:
            print_ordered_assignment(assignment)
        elif type(assignment) is UnorderedAssignment:
            print_unordered_assignment(assignment)
        elif type(assignment) is CompositeAssignment:
            print_ordered_assignment(assignment.ordered_mapping)
            for unordered_assignment in assignment.unordered_mappings:
                print_unordered_assignment(unordered_assignment)
        print("")
        
halo_level_roulette = build_roulette(DEPTH_WEIGTH)

In [7]:
node_handle_list = set([ExpressionHasher.terminal_hash(n.atom_type, n.name) for n in TARGET_NODES])
#print(f"node_handle_list = {node_handle_list}")
links = [set() for i in range(HALO_LENGTH)]
for level in range(HALO_LENGTH):
    new_level_node_handles = set()
    node_handle_count = 0
    for node_handle in node_handle_list:
        node_handle_count += 1
        #print(f"===========================================")
        template_list = [
            [node_handle, WILDCARD], 
            [WILDCARD, node_handle], 
            [node_handle, WILDCARD, WILDCARD], 
            [WILDCARD, node_handle, WILDCARD], 
            [WILDCARD, WILDCARD, node_handle]
        ]
        start = time.perf_counter()
        num_queries = 0
        for template in template_list:
            #print(f"template = {template}")
            link_list = set(das.get_links(None, None, template))
            num_queries += len(link_list) + 1
            #print(f"len(link_list) = {len(link_list)}")
            for link in link_list:
                #link_document = das.get_atom(link, output_format=QueryOutputFormat.ATOM_INFO)
                for h in das.get_link_targets(link):
                    new_level_node_handles.add(h)
            links[level].update(link_list)
        end = time.perf_counter()
        wall_time = f"{(end - start):.0f} seconds"
        time_per_query = f"{(((end - start) * 1000) / num_queries):.3f} ms/query"        
        print(f"Halo level {level+1}/{HALO_LENGTH} node_handle {node_handle_count}/{len(node_handle_list)} {wall_time} {time_per_query}")
    node_handle_list.update(new_level_node_handles)
for level in range(HALO_LENGTH):
    if level == 0:
        all_links = set([link for link in links[level]])
    else:
        links[level] = links[level].difference(all_links)
        all_links.update(links[level])
universe_size = len(all_links)
print(f"===========================================")
print(f"Done - universe_size = {universe_size}")
print(f"===========================================")

Halo level 1/2 node_handle 1/1 0 seconds 2.289 ms/query
Halo level 2/2 node_handle 1/16 25 seconds 0.100 ms/query
Halo level 2/2 node_handle 2/16 25 seconds 0.102 ms/query
Halo level 2/2 node_handle 3/16 25 seconds 0.101 ms/query
Halo level 2/2 node_handle 4/16 24 seconds 0.099 ms/query
Halo level 2/2 node_handle 5/16 0 seconds 0.099 ms/query
Halo level 2/2 node_handle 6/16 0 seconds 0.097 ms/query
Halo level 2/2 node_handle 7/16 0 seconds 0.101 ms/query
Halo level 2/2 node_handle 8/16 26 seconds 0.106 ms/query
Halo level 2/2 node_handle 9/16 0 seconds 0.131 ms/query
Halo level 2/2 node_handle 10/16 0 seconds 0.104 ms/query
Halo level 2/2 node_handle 11/16 71 seconds 0.103 ms/query
Halo level 2/2 node_handle 12/16 8 seconds 0.114 ms/query
Halo level 2/2 node_handle 13/16 28 seconds 0.112 ms/query
Halo level 2/2 node_handle 14/16 3 seconds 0.104 ms/query
Halo level 2/2 node_handle 15/16 21 seconds 0.104 ms/query
Halo level 2/2 node_handle 16/16 27 seconds 0.110 ms/query
Done - universe_

In [8]:
#print(node_handle_list)
#print(links)

In [9]:
total = 0
for level in range(HALO_LENGTH):
    total += len(links[level])
    print(len(links[level]))
print("----------")
print(total)
#links

8
1941490
----------
1941498


In [10]:
pattern = [None for i in range(HALO_LENGTH)]
pattern_count = [None for i in range(HALO_LENGTH)]
pattern_handles = [None for i in range(HALO_LENGTH)]
all_patterns = {}
all_patterns_count = {}
for level in range(HALO_LENGTH):
    print(f"###########################################")
    print(f"Building patterns for level {level}")
    striped_links = [link for link in links[level] if level == 0 or random.random() < LINK_RATE]
    pattern[level], pattern_count[level] = build_patterns(striped_links)
    pattern_handles[level] = [key for key in pattern[level].keys()]
    for key, value in pattern[level].items():
        all_patterns[key] = value
    for key, value in pattern_count[level].items():
        all_patterns_count[key] = value
print(f"===========================================")
print(f"Done - len(all_patterns) = {len(all_patterns)}")
print(f"===========================================")

###########################################
Building patterns for level 0
link 1/8
link 8/8
###########################################
Building patterns for level 1
link 1/19172
link 1000/19172 830 seconds 104 ms/query
link 2000/19172 779 seconds 97 ms/query
link 3000/19172 788 seconds 98 ms/query
link 4000/19172 716 seconds 89 ms/query
link 5000/19172 785 seconds 98 ms/query
link 6000/19172 678 seconds 85 ms/query
link 7000/19172 719 seconds 90 ms/query
link 8000/19172 651 seconds 81 ms/query
link 9000/19172 713 seconds 89 ms/query
link 10000/19172 675 seconds 84 ms/query
link 11000/19172 654 seconds 82 ms/query
link 12000/19172 597 seconds 75 ms/query
link 13000/19172 637 seconds 80 ms/query
link 14000/19172 631 seconds 79 ms/query
link 15000/19172 635 seconds 79 ms/query
link 16000/19172 615 seconds 77 ms/query
link 17000/19172 596 seconds 74 ms/query
link 18000/19172 604 seconds 76 ms/query
link 19000/19172 681 seconds 85 ms/query
link 19172/19172
Done - len(all_patterns) = 71826


In [11]:
total = 0
for level in range(HALO_LENGTH):
    total += len(pattern_handles[level])
    print(len(pattern_handles[level]))
print("----------")
print(total)
#pattern_handles

40
71803
----------
71843


In [None]:
higher_isurprisingness = 0
best_pattern = None
for i in range(EPOCHS):
    if True or i % 1000 == 0 or i == EPOCHS - 1:
        print(f"Epoch {i + 1}/{EPOCHS}")
    selected_handle = random_selection(pattern_handles[0])
    term_handles = [tuple([selected_handle, 0])]
    terms = [pattern[0][selected_handle]]
    counts = [pattern_count[0][selected_handle]]
    for i in range(NGRAM - 1):
        while True:
            selected_level = roulette_selection(halo_levels, halo_level_roulette)
            selected_handle = random_selection(pattern_handles[selected_level])
            if tuple([selected_handle, selected_level]) not in term_handles:
                break
        term_handles.append(tuple([selected_handle, selected_level]))
        terms.append(pattern[selected_level][selected_handle])
        counts.append(pattern_count[selected_level][selected_handle])
    composite_pattern = build_composite_pattern(terms)
    count = compute_count(composite_pattern)
    if count > 0:
        print(f"Count: {count}")
    if count >= SUPPORT:
        isurprisingness = compute_isurprisingness(count, terms, term_handles, counts, normalized=NORMALIZED_ISURPRISINGNESS) 
        if isurprisingness > higher_isurprisingness:
            print(f"{count} {isurprisingness}: {terms} {term_handles} {counts}")
            higher_isurprisingness = isurprisingness
            best_pattern = composite_pattern
print_query(best_pattern)

Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000
Epoch 67/1000
Epoch 68/1000
Epoch 69/1000
Epoch 70/1000
Epoch 71/1000
Epoch 72/1000
E

In [None]:
higher_isurprisingness = 0
best_pattern = None
all_patterns_handles = all_patterns.keys()

count_bh = 0
for basic_handle in pattern_handles[0]:
    count_bh += 1
    print(f"Cycle {count_bh}/{len(pattern_handles[0])}")
    for combination_handles in combinations(all_patterns, NGRAM - 1):
        if basic_handle in combination_handles:
            continue
        term_handles = [basic_handle, *combination_handles]
        terms = [all_patterns[handle] for handle in term_handles]
        counts = [all_patterns_count[handle] for handle in term_handles]
        composite_pattern = build_composite_pattern(terms)
        count = compute_count(composite_pattern)
        if count >= SUPPORT:
            isurprisingness = compute_isurprisingness(count, terms, term_handles, counts, normalized=NORMALIZED_ISURPRISINGNESS) 
            if isurprisingness > higher_isurprisingness:
                print(f"{count} {isurprisingness}: {terms} {counts}")
                higher_isurprisingness = isurprisingness
                best_pattern = composite_pattern    
print_query(best_pattern)