In [207]:
import jax
import jax.numpy as jnp
import haiku as hk
import optax

from jax import random
import sys
from time import sleep
import json
import pexpect
import re
import os
import argparse
import logging
import timeit
import torch

import seq2seq
from batch_predictor import BatchPredictor
from checkpoint import Checkpoint

from policy_networks import *
import policy_networks

import utp_model

from new_env import *

from jax.config import config
config.update("jax_debug_nans", True) 
import numpy as np
#jax.config.update("jax_enable_x64", False)

In [208]:
import pickle

path_dir = "/home/sean/Documents/PhD/tactic_zero_jax/env/model_params"

def save(params, path):
    with open(path, 'wb') as fp:
        pickle.dump(params, fp)

def load(path):
    with open(path, 'rb') as fp:
        return pickle.load(fp)

In [3]:
#HOLPATH = "/home/sean/Documents/hol/HOL/bin/hol --maxheap=256"
HOLPATH = "/home/sean/Documents/PhD/HOL4/HOL/bin/hol --maxheap=256"

#tactic_zero_path = "/home/sean/Documents/PhD/git/repo/PhD/tacticzero/holgym/"

In [4]:
with open("typed_database.json") as f:
    database = json.load(f)

In [5]:
MORE_TACTICS = True
if not MORE_TACTICS:
    thms_tactic = ["simp", "fs", "metis_tac"]
    thm_tactic = ["irule"]
    term_tactic = ["Induct_on"]
    no_arg_tactic = ["strip_tac"]
else:
    thms_tactic = ["simp", "fs", "metis_tac", "rw"]
    thm_tactic = ["irule", "drule"]
    term_tactic = ["Induct_on"]
    no_arg_tactic = ["strip_tac", "EQ_TAC"]
    
tactic_pool = thms_tactic + thm_tactic + term_tactic + no_arg_tactic

In [6]:
#TODO Move to another file 

def get_polish(raw_goal):
        goal = construct_goal(raw_goal)
        process.sendline(goal.encode("utf-8"))
        process.expect("\r\n>")
        process.sendline("val _ = set_term_printer (HOLPP.add_string o pt);".encode("utf-8"))
        process.expect("\r\n>")
        process.sendline("top_goals();".encode("utf-8"))
        process.expect("val it =")
        process.expect([": goal list", ":\r\n +goal list"])

        polished_raw = process.before.decode("utf-8")
        polished_subgoals = re.sub("“|”","\"", polished_raw)
        polished_subgoals = re.sub("\r\n +"," ", polished_subgoals)

        pd = eval(polished_subgoals)
        
        process.expect("\r\n>")
        process.sendline("drop();".encode("utf-8"))
        process.expect("\r\n>")
        process.sendline("val _ = set_term_printer default_pt;".encode("utf-8"))
        process.expect("\r\n>")

        data = [{"polished":{"assumptions": e[0][0], "goal":e[0][1]},
                 "plain":{"assumptions": e[1][0], "goal":e[1][1]}}
                for e in zip(pd, [([], raw_goal)])]
        return data 
    
def construct_goal(goal):
    s = "g " + "`" + goal + "`;"
    return s

def gather_encoded_content_(history, encoder):
    fringe_sizes = []
    contexts = []
    reverted = []
    for i in history:
        c = i["content"]
        contexts.extend(c)
        fringe_sizes.append(len(c))
    for e in contexts:
        g = revert_with_polish(e)
        reverted.append(g.strip().split())
    out = []
    sizes = []
    for goal in reverted:
        out_, sizes_ = encoder.encode([goal])
        out.append(torch.cat(out_.split(1), dim=2).squeeze(0))
        sizes.append(sizes_)

    representations = out

    return representations, contexts, fringe_sizes

def parse_theory(pg):
    theories = re.findall(r'C\$(\w+)\$ ', pg)
    theories = set(theories)
    for th in EXCLUDED_THEORIES:
        theories.discard(th)
    return list(theories)

def revert_with_polish(context):
    target = context["polished"]
    assumptions = target["assumptions"]
    goal = target["goal"]
    for i in reversed(assumptions): 
        #goal = "@ @ D$min$==> {} {}".format(i, goal)
        goal = "@ @ C$min$ ==> {} {}".format(i, goal)

    return goal 

def split_by_fringe(goal_set, goal_scores, fringe_sizes):
    # group the scores by fringe
    fs = []
    gs = []
    counter = 0
    for i in fringe_sizes:
        end = counter + i
        fs.append(goal_scores[counter:end])
        gs.append(goal_set[counter:end])
        counter = end
    return gs, fs

In [7]:
with open("include_probability.json") as f:
    database = json.load(f)

with open("polished_def_dict.json") as f:
    defs = json.load(f)

fact_pool = list(defs.keys())

encoded_database = torch.load('encoded_include_probability.pt')

TARGET_THEORIES = ["bool", "min", "list"]
GOALS = [(key, value[4]) for key, value in database.items() if value[3] == "thm" and value[0] in TARGET_THEORIES]

print (GOALS[0][1])

LIST_REL (R :α -> β -> bool) (l1 :α list) (l2 :β list) ∧ LIST_REL R (l3 :α list) (l4 :β list) ⇒ LIST_REL R (l1 ++ l3) (l2 ++ l4)


In [8]:
#checkpoint_path = "models/2020_04_22_20_36_50" # 91% accuracy model, only core theories
#checkpoint_path = "models/2020_04_26_20_11_28" # 95% accuracy model, core theories + integer + sorting
#checkpoint_path = "models/2020_09_24_23_38_06" # 98% accuracy model, core theories + integer + sorting | separate theory tokens
#checkpoint_path = "models/2020_11_28_16_45_10" # 96-98% accuracy model, core theories + integer + sorting + real | separate theory tokens
#checkpoint_path = "models/2020_12_04_03_47_22" # 97% accuracy model, core theories + integer + sorting + real + bag | separate theory tokens

#checkpoint_path = "models/2021_02_21_15_46_04" # 98% accuracy model, up to probability theory

checkpoint_path = "models/2021_02_22_16_07_03" # 97-98% accuracy model, up to and include probability theory

checkpoint = Checkpoint.load(checkpoint_path)
seq2seq = checkpoint.model
input_vocab = checkpoint.input_vocab
output_vocab = checkpoint.output_vocab

batch_encoder_ = BatchPredictor(seq2seq, input_vocab, output_vocab)





In [9]:
#function to give the log probability of pi(f | s) so gradient can be computed directly
#also returns sampled index and contexts to determine goal to give tactic network
def sample_fringe(context_params, context_net, rng_key, jax_reps, context_set, fringe_sizes):
    context_scores = context_net(context_params, rng_key, jax_reps)
    contexts_by_fringe, scores_by_fringe = split_by_fringe(context_set, context_scores, fringe_sizes)
    fringe_scores = []
    for s in scores_by_fringe:
        fringe_score = jnp.sum(s)
        fringe_scores.append(fringe_score)
    #TODO some fringes can be empty, but still give value 0 which assigns nonzero probability?
    fringe_scores = jnp.stack(fringe_scores)
    fringe_probs = jax.nn.log_softmax(fringe_scores)

    #samples, gives an index (looks like it does gumbel softmax under the hood to keep differentiability?)
    sampled_idx = jax.random.categorical(rng_key,fringe_probs)

    log_prob = fringe_probs[sampled_idx]
    #log_prob = jnp.log(prob)
    return log_prob, (sampled_idx, contexts_by_fringe)
                                                           
#grad_log_context, (fringe_idx, contexts_by_fringe) = jax.grad(sample_fringe, has_aux=True)(context_params, apply_context, rng_key, jax_reps, context_set, fringe_sizes)

#takes a goal encoding and samples tactic from network, and returns log prob for gradient 
def sample_tactic(tactic_params, tac_net, rng_key, goal_endcoding, action_size=len(tactic_pool)):
    tac_scores = tac_net(tactic_params, rng_key, goal_endcoding, action_size)
    tac_scores = jnp.ravel(tac_scores)
    #tac_scores = tac_scores - max(tac_scores)
    #print (tac_scores)
    #subtract max element for numerical stability 
    tac_probs = jax.nn.log_softmax(tac_scores)
    tac_idx = jax.random.categorical(rng_key, tac_probs)
    log_prob = tac_probs[tac_idx]#jnp.log(tac_probs[tac_idx])
    #print (jnp.exp(log_prob).primal)#, jnp.exp(tac_probs), rng_key)
    return log_prob, tac_idx

#grad_log_tac, tac_idx = jax.grad(sample_tactic, has_aux=True)(tactic_params, apply_tac, rng_key, jnp.expand_dims(target_representation,0), len(tactic_pool))

#sampled_tac = tactic_pool[tac_idx]

def sample_term(term_params, term_net, rng_key, candidates):
    term_scores = term_net(term_params, rng_key, candidates)
    term_scores = jnp.ravel(term_scores)
    term_probs = jax.nn.log_softmax(term_scores)
    term_idx = jax.random.categorical(rng_key, term_probs)
    log_prob = term_probs[term_idx]#jnp.log(term_probs[term_idx])
    return log_prob, term_idx

#grad_log_term, term_idx = jax.grad(sample_term, has_aux=True)(term_params, apply_term, rng_key, candidates)#, tac_idx, len(tactic_pool), candidates.shape[1])

#function for sampling single argument given previous arguments, 
def sample_arg(arg_params, arg_net, rng_key, input_, candidates, hidden, tactic_size, embedding_dim):
    hidden, arg_scores = arg_net(arg_params, rng_key, input_, candidates, hidden, tactic_size, embedding_dim)
    arg_scores = jnp.ravel(arg_scores)
    arg_probs = jax.nn.log_softmax(arg_scores)
    arg_idx = jax.random.categorical(rng_key, arg_probs)
    log_prob = arg_probs[arg_idx]#jnp.log(arg_probs[arg_idx])
    return log_prob, (arg_idx, hidden)

In [10]:
def episode_loss(context_params, tactic_params, term_params, arg_params, apply_context, apply_tac, apply_term, apply_arg, rng_key, env, encoded_fact_pool, candidate_args):
    log_list = []
    discounted_reward_list = []
    trace = []
    gamma = 0.99
    for i in range(2):
        _, rng_key = jax.random.split(rng_key)
        
        #print ("Proof step {} of 50\n".format(i+1))
        
        try:
            jax_reps, context_set, fringe_sizes = gather_encoded_content_(env.history, batch_encoder_)
            #convert to jax
        except:
            print ("Encoder error")
            if len(log_list) > 0:
                return sum([i[0] * i[1] for i in zip(log_list, discounted_reward_list)])
            else:
                return 0
            
        jax_reps = jnp.stack([jnp.array(jax_reps[i][0].cpu()) for i in range(len(jax_reps))])
        logs, reward, action = run_iter(context_params, tactic_params, term_params, arg_params, apply_context, apply_tac, apply_term, apply_arg, jax_reps, context_set, fringe_sizes, rng_key, env, encoded_fact_pool, candidate_args)

        log_list.append(logs)
        discounted_reward_list.append(reward * (gamma ** i))
        
        trace.append((env.history, action))

                
        #if goal proven
        if reward == 5:
            print ("Goal proved in {} steps".format(i+1))
            return sum([i[0] * i[1] for i in zip(log_list, discounted_reward_list)]), trace
        
        #timeout
        if i == 49:
            discounted_reward_list[-1] = -5.
            
    loss = sum([i[0] * i[1] for i in zip(log_list, discounted_reward_list)])
    
    return loss, trace


In [11]:
def run_iter(context_params, tactic_params, term_params, arg_params, context_net, tactic_net, term_net, arg_net, jax_reps, context_set, fringe_sizes, rng_key, env, encoded_fact_pool, candidate_args):
    
    log_context, (fringe_idx, contexts_by_fringe) = sample_fringe(context_params, context_net, rng_key, jax_reps, context_set, fringe_sizes)
    
    try:
        target_context = contexts_by_fringe[fringe_idx][0]
    except:
        print ("error {} {}".format(contexts_by_fringe), fringe_idx)
    target_goal = target_context["polished"]["goal"]
    target_representation = jax_reps[context_set.index(target_context)]
    
    log_tac, tac_idx = sample_tactic(tactic_params, tactic_net, rng_key, jnp.expand_dims(target_representation,0), len(tactic_pool))
    
    sampled_tac = tactic_pool[tac_idx]
    arg_logs = []

    tactic = sampled_tac
    #for testing
    
    #sampled_tac = "Induct_on"

    #if tactic requires no argument
    if sampled_tac in no_arg_tactic:
        full_tactic = sampled_tac #tactic_pool[tac]


    #Induct_on case; use term policy to find which term to induct on 
    elif sampled_tac in term_tactic:

        goal_tokens = target_goal.split()
        term_tokens = [[t] for t in set(goal_tokens) if t[0] == "V"]
        #add conditional if tokens is empty 

        #now want encodings for terms from AE

        term_reps = []

        for term in term_tokens:
            term_rep, _ = batch_encoder_.encode([term])
            #output is bidirectional so concat vectors
            term_reps.append(torch.cat(term_rep.split(1), dim=2).squeeze(0))
        
        #no terms in expression, only contains literals (e.g. induct_on `0`)
        if len(term_reps) == 0:
            print ("No variables to induct on for goal {}".format(target_goal))
            #return negative loss for now (positive overall as negative of log prob is positive)
            return 1., -1., "Induct no vars"
            
            
        # convert to jax
        term_reps = jnp.stack([jnp.array(term_reps[i][0].cpu()) for i in range(len(term_reps))])

        # now want inputs to term_net to be target_representation (i.e. goal) concatenated with terms
        # models the policies conditional dependence of the term given the goal

        #stack goal representation for each token
        goal_stack = jnp.concatenate([jnp.expand_dims(target_representation,0) for _ in term_tokens])

        #concat with term encodings to give candidate matrix
        candidates = jnp.concatenate([goal_stack, term_reps], 1)

        log_term, term_idx = sample_term(term_params, term_net, rng_key, candidates)

        sampled_term = term_tokens[term_idx]

        tm = sampled_term[0][1:] # remove headers, e.g., "V" / "C" / ...
    
        arg_logs = [log_term]
        
        if tm:
            tactic = "Induct_on `{}`".format(tm)
        else:
            # only to raise an error
            tactic = "Induct_on"
        
    #argument tactic
    else:
        #stack goals to possible arguments to feed into FFN
        goal_stack = jnp.concatenate([jnp.expand_dims(target_representation,0) for _ in encoded_fact_pool])
        candidates = jnp.concatenate([encoded_fact_pool, goal_stack], 1)
        
        #initial state set as goal
        hidden = jnp.expand_dims(target_representation,0)
        init_state = hk.LSTMState(hidden,hidden)
    
        # run through first with tac_idx to initialise state with tactic as c_0
        hidden, _ = arg_net(arg_params, rng_key, tac_idx, candidates, init_state, len(tactic_pool), 256)
        
        ARG_LEN = 5
        arg_inds = []
        arg_logs = []
        input_ = tac_idx
        for _ in range(ARG_LEN):
            log_arg, (arg_idx, hidden) = sample_arg(arg_params, arg_net, rng_key, input_, candidates, hidden, len(tactic_pool), 256)
            arg_logs.append(log_arg)
            arg_inds.append(arg_idx)
            input_ = jnp.expand_dims(encoded_fact_pool[arg_idx], 0)
        
        arg = [candidate_args[i] for i in arg_inds]

        tactic = env.assemble_tactic(sampled_tac, arg)
        
    
    
    action = (int(fringe_idx), 0, tactic)
    #print ("Action {}:\n".format(action))
    
    try:
        reward, done = env.step(action)

    except:
        print("Step exception raised.")
        # print("Fringe: {}".format(env.history))
        print("Handling: {}".format(env.handling))
        print("Using: {}".format(env.using))
        # try again
        # counter = env.counter
        frequency = env.frequency
        env.close()
        print("Aborting current game ...")
        print("Restarting environment ...")
        print(env.goal)
        env = HolEnv(env.goal)
        flag = False
        return 
        
    #print ("Result: Reward {}".format(reward))#, env.history[-1]))

    
    #negative as we want gradient ascent 
    
    logs = (-log_tac - log_context  - sum(arg_logs))

    return logs, reward, action

In [12]:
def train(goals):

    rng_key = jax.random.PRNGKey(11)

    init_context, apply_context = hk.transform(policy_networks._context_forward)
    apply_context = jax.jit(apply_context)

    init_tac, apply_tac = hk.transform(policy_networks._tac_forward)
    apply_tac = partial(jax.jit, static_argnums=3)(apply_tac)

    init_term, apply_term = hk.transform(policy_networks._term_no_tac_forward)
    apply_term = jax.jit(apply_term)

    init_arg, apply_arg = hk.transform(policy_networks._arg_forward)
    apply_arg = partial(jax.jit, static_argnums=(5,6))(apply_arg)

    #initialise these with e.g. random uniform, glorot, He etc. should exist outside function for action selection 
    context_params = init_context(rng_key, jax.random.normal(rng_key, (1,256)))

    tactic_params = init_tac(rng_key, jax.random.normal(rng_key, (1,256)), len(tactic_pool))

    #term_policy for now is only considering variables for induction, hence does not need any arguments 
    term_params = init_term(rng_key, jax.random.normal(rng_key, (1,512)))

    hidden = jax.random.normal(rng_key, (1,256))

    init_state = hk.LSTMState(hidden, hidden)

    arg_params = init_arg(rng_key, jax.random.randint(rng_key, (), 0, len(tactic_pool)), jax.random.normal(rng_key, (1,512)), init_state, len(tactic_pool), 256)

        
    context_lr = 1e-4
    tactic_lr = 1e-4
    arg_lr = 1e-4
    term_lr = 1e-4

    context_optimiser = optax.rmsprop(context_lr)
    tactic_optimiser = optax.rmsprop(tactic_lr)
    arg_optimiser = optax.rmsprop(arg_lr)
    term_optimiser = optax.rmsprop(term_lr)

    opt_state_context = context_optimiser.init(context_params)
    opt_state_tactic = tactic_optimiser.init(tactic_params)
    opt_state_arg = arg_optimiser.init(arg_params)
    opt_state_term = term_optimiser.init(term_params)
    
    proof_dict = {}
    

    for goal in goals:
        g = goal[1]
            
        env = HolEnv(g)

        theories = re.findall(r'C\$(\w+)\$ ', goal[0])
        theories = set(theories)
        theories = list(theories)

        allowed_theories = theories

        goal_theory = g

        #print ("Target goal: {}".format(g))
        
        try:
            allowed_arguments_ids = []
            candidate_args = []
            goal_theory = g#database[polished_goal][0] # plain_database[goal][0]
            for i,t in enumerate(database):
                if database[t][0] in allowed_theories and (database[t][0] != goal_theory or int(database[t][2]) < int(database[polished_goal][2])):
                    allowed_arguments_ids.append(i)
                    candidate_args.append(t)

            env.toggle_simpset("diminish", goal_theory)
            #print("Removed simpset of {}".format(goal_theory))

        except:
            allowed_arguments_ids = []
            candidate_args = []
            for i,t in enumerate(database):
                if database[t][0] in allowed_theories:
                    allowed_arguments_ids.append(i)
                    candidate_args.append(t)
            #print("Theorem not found in database.")

        #print ("Number of candidate facts to use: {}".format(len(candidate_args)))

        encoded_database = torch.load('encoded_include_probability.pt')

        encoded_fact_pool = torch.index_select(encoded_database, 0, torch.tensor(allowed_arguments_ids))
        
        encoded_fact_pool = jnp.array(encoded_fact_pool)
        
        try:
            gradients, trace = jax.grad(episode_loss, argnums=(0,1,2,3), has_aux=True)(context_params, tactic_params, term_params, arg_params, apply_context, apply_tac, apply_term, apply_arg,  rng_key, env, encoded_fact_pool, candidate_args)
        except:
            print ("error")
            continue
        
        proof_dict[goal] = trace
            

        #update parameters
        context_updates, opt_state_context = context_optimiser.update(gradients[0], opt_state_context)
        context_params = optax.apply_updates(context_params, context_updates)

        tactic_updates, opt_state_tactic = tactic_optimiser.update(gradients[1], opt_state_tactic)
        tactic_params = optax.apply_updates(tactic_params, tactic_updates)

        term_updates, opt_state_term = term_optimiser.update(gradients[2], opt_state_term)
        term_params = optax.apply_updates(term_params, term_updates)

        arg_updates, opt_state_arg = arg_optimiser.update(gradients[3], opt_state_arg)
        arg_params = optax.apply_updates(arg_params, arg_updates)
        break
        #save trace 
#         save(proof_dict, path_dir + "/trace")
        
#         #save params after each proof attempt
#         save(context_params, path_dir + "/context_params")
#         save(opt_state_context, path_dir+"/context_state")
#         save(tactic_params, path_dir+"/tactic_params")
#         save(opt_state_tactic, path_dir+"/tactic_state")
#         save(term_params, path_dir+"/term_params")
#         save(opt_state_term, path_dir+"/term_state")
#         save(arg_params, path_dir+"/arg_params")
#         save(opt_state_arg, path_dir+"/arg_state")
            

In [13]:
TARGET_THEORIES = ["list"]
GOALS = [(key, value[4]) for key, value in database.items() if value[3] == "thm" and value[0] in TARGET_THEORIES]

train(GOALS[:2])


2022-05-31 12:41:10.043147: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory


Importing theories...
Loading modules...
Configuration done.
Removing simp lemmas from LIST_REL (R :α -> β -> bool) (l1 :α list) (l2 :β list) ∧ LIST_REL R (l3 :α list) (l4 :β list) ⇒ LIST_REL R (l1 ++ l3) (l2 ++ l4)


In [26]:
def gather_encoded_content(history, encoder):
    # figure out why this is slower than tests
    # figured out: remember to do strip().split()
    fringe_sizes = []
    contexts = []
    reverted = []
    for i in history:
        c = i["content"]
        contexts.extend(c)
        fringe_sizes.append(len(c))
    for e in contexts:
        g = revert_with_polish(e)
        reverted.append(g.strip().split())
    # print(reverted)
    # s1 = timeit.default_timer()
    out, sizes = encoder.encode(reverted)
    # merge two hidden variables
    representations = torch.cat(out.split(1), dim=2).squeeze(0)
    # print(representations.shape)
    # s2 = timeit.default_timer()    
    # print(s2-s1)

    return representations, contexts, fringe_sizes


In [33]:
import utp_model

import time


def run_iteration(goals, mode="training", ARG_LEN=5):
    

    learning_rate = 1e-5

    context_rate = 5e-5
    tac_rate = 5e-5
    arg_rate = 5e-5
    term_rate = 5e-5

    gamma = 0.99 # 0.9

    # for entropy regularization
    trade_off = 1e-2

    context_net = utp_model.ContextPolicy()

    tac_net = utp_model.TacPolicy(len(tactic_pool))

    arg_net = utp_model.ArgPolicy(len(tactic_pool), 256)

    term_net = utp_model.TermPolicy(len(tactic_pool), 256)

    context_net = context_net.to(device)
    tac_net = tac_net.to(device)
    arg_net = arg_net.to(device)
    term_net = term_net.to(device)

    optimizer_context = torch.optim.RMSprop(list(context_net.parameters()), lr=context_rate)

    optimizer_tac = torch.optim.RMSprop(list(tac_net.parameters()), lr=tac_rate)

    optimizer_arg = torch.optim.RMSprop(list(arg_net.parameters()), lr=arg_rate)

    optimizer_term = torch.optim.RMSprop(list(term_net.parameters()), lr=term_rate)


    torch.set_grad_enabled(mode=="training" or mode=="subgoals")
    global iteration_counter
    # state_pool = []
    fringe_pool = []
    tac_pool = []
    arg_pool = []
    reward_pool = []
    reward_print = []
    action_pool = []
    steps = 0
    flag = True
    replay_flag = False
    tac_print = []

    induct_arg = []
    proved = 0
    iteration_rewards = []

    for goal in goals:
        start_t = time.time()
        g = goal[1]
            
        print (g)
        env = HolEnv(g)

        theories = re.findall(r'C\$(\w+)\$ ', goal[0])
        theories = set(theories)
        theories = list(theories)

        allowed_theories = theories

        goal_theory = g

        #print ("Target goal: {}".format(g))
        
        try:
            allowed_arguments_ids = []
            candidate_args = []
            goal_theory = g#database[polished_goal][0] # plain_database[goal][0]
            for i,t in enumerate(database):
                if database[t][0] in allowed_theories and (database[t][0] != goal_theory or int(database[t][2]) < int(database[polished_goal][2])):
                    allowed_arguments_ids.append(i)
                    candidate_args.append(t)

            env.toggle_simpset("diminish", goal_theory)
            #print("Removed simpset of {}".format(goal_theory))

        except:
            allowed_arguments_ids = []
            candidate_args = []
            for i,t in enumerate(database):
                if database[t][0] in allowed_theories:
                    allowed_arguments_ids.append(i)
                    candidate_args.append(t)
            #print("Theorem not found in database.")

        #print ("Number of candidate facts to use: {}".format(len(candidate_args)))

        encoded_database = torch.load('encoded_include_probability.pt')




        encoded_fact_pool = torch.index_select(encoded_database, 0, torch.tensor(allowed_arguments_ids, device=device))


        for i in range(50):

            # gather all the goals in the history
            try:
                representations, context_set, fringe_sizes = gather_encoded_content(env.history, batch_encoder_)
            except Exception as e:
                print (e)
                continue




            representations = representations.to(device)
            context_scores = context_net(representations)
            contexts_by_fringe, scores_by_fringe = split_by_fringe(context_set, context_scores, fringe_sizes)
            fringe_scores = []
            for s in scores_by_fringe:
                # fringe_score = torch.prod(s) # TODO: make it sum
                fringe_score = torch.sum(s) # TODO: make it sum
                fringe_scores.append(fringe_score)
            fringe_scores = torch.stack(fringe_scores)
            fringe_probs = F.softmax(fringe_scores, dim=0)
            fringe_m = Categorical(fringe_probs)
            fringe = fringe_m.sample()
            fringe_pool.append(fringe_m.log_prob(fringe))

            # take the first context in the chosen fringe for now
            try:
                target_context = contexts_by_fringe[fringe][0]
            except:
                print ("error {} {}".format(contexts_by_fringe, fringe))

           # target_context = contexts_by_fringe[fringe][0]
            target_goal = target_context["polished"]["goal"]
            target_representation = representations[context_set.index(target_context)]
            # print(target_representation.shape)
            # exit()

            # size: (1, max_contexts, max_assumptions+1, max_len)
            tac_input = target_representation.unsqueeze(0)
            tac_input = tac_input.to(device)

            # compute scores of tactics
            tac_probs = tac_net(tac_input)
            # print(tac_probs)
            tac_m = Categorical(tac_probs)
            tac = tac_m.sample()
            # log directly the log probability
            tac_pool.append(tac_m.log_prob(tac))
            action_pool.append(tactic_pool[tac])
            tac_print.append(tac_probs.detach())
            # print(len(fact_pool[0].strip().split()))
            # exit()

            tac_tensor = tac.to(device)


            if tactic_pool[tac] in no_arg_tactic:
                tactic = tactic_pool[tac]
                arg_probs = []
                arg_probs.append(torch.tensor(0))
                arg_pool.append(arg_probs)
            elif tactic_pool[tac] == "Induct_on":
                arg_probs = []
                candidates = []
                # input = torch.cat([target_representation, tac_tensor], dim=1)
                tokens = target_goal.split()
                tokens = list(dict.fromkeys(tokens))
                tokens = [[t] for t in tokens if t[0] == "V"]
                if tokens:
                    # concatenate target_representation to token
                    # use seq2seq to compute the representation of a token
                    # also we don't need to split an element in tokens because they are singletons
                    # but we need to make it a list containing a singleton list, i.e., [['Vl']]

                    token_representations, _ = batch_encoder_.encode(tokens)
                    # reshaping
                    encoded_tokens = torch.cat(token_representations.split(1), dim=2).squeeze(0)
                    target_representation_list = [target_representation.unsqueeze(0) for _ in tokens]

                    target_representations = torch.cat(target_representation_list)
                    # size: (len(tokens), 512)
                    candidates = torch.cat([encoded_tokens, target_representations], dim=1)
                    candidates = candidates.to(device)

                    # concat = [torch.cat([torch.tensor([input_vocab.stoi[i] for _ in range(256)], dtype=torch.float), target_representation]) for i in tokens]

                    # candidates = torch.stack(concat)
                    # candidates = candidates.to(device)

                    scores = term_net(candidates, tac_tensor)
                    term_probs = F.softmax(scores, dim=0)
                    try:
                        term_m = Categorical(term_probs.squeeze(1))
                    except:
                        print("probs: {}".format(term_probs))                                          
                        print("candidates: {}".format(candidates.shape))
                        print("scores: {}".format(scores))
                        print("tokens: {}".format(tokens))
                        exit()
                    term = term_m.sample()
                    arg_probs.append(term_m.log_prob(term))
                    induct_arg.append(tokens[term])                
                    tm = tokens[term][0][1:] # remove headers, e.g., "V" / "C" / ...
                    arg_pool.append(arg_probs)
                    if tm:
                        tactic = "Induct_on `{}`".format(tm)
                    else:
                        print("tm is empty")
                        print(tokens)
                        # only to raise an error
                        tactic = "Induct_on"
                else:
                    arg_probs.append(torch.tensor(0))
                    induct_arg.append("No variables")
                    arg_pool.append(arg_probs)
                    tactic = "Induct_on"
            else:
                hidden0 = hidden1 = target_representation.unsqueeze(0).unsqueeze(0)

                hidden0 = hidden0.to(device)
                hidden1 = hidden1.to(device)

                hidden = (hidden0, hidden1)

                # concatenate the candidates with hidden states.

                hc = torch.cat([hidden0.squeeze(), hidden1.squeeze()])
                hiddenl = [hc.unsqueeze(0) for _ in allowed_arguments_ids]

                hiddenl = torch.cat(hiddenl)

                # size: (len(fact_pool), 512)
                candidates = torch.cat([encoded_fact_pool, hiddenl], dim=1)
                candidates = candidates.to(device)

                input = tac_tensor
                # run it once before predicting the first argument
                hidden, _ = arg_net(input, candidates, hidden)

                # the indices of chosen args
                arg_step = []
                arg_step_probs = []
                if tactic_pool[tac] in thm_tactic:
                    arg_len = 1
                else:
                    arg_len = ARG_LEN

                    
                for _ in range(arg_len):
                    hidden, scores = arg_net(input, candidates, hidden)
                    arg_probs = F.softmax(scores, dim=0)
                    arg_m = Categorical(arg_probs.squeeze(1))
                    arg = arg_m.sample()
                    arg_step.append(arg)
                    arg_step_probs.append(arg_m.log_prob(arg))

                    hidden0 = hidden[0].squeeze().repeat(1, 1, 1)
                    hidden1 = hidden[1].squeeze().repeat(1, 1, 1)
                    # encoded chosen argument
                    input = encoded_fact_pool[arg].unsqueeze(0).unsqueeze(0)
                    # print(input.shape)

                    # renew candidates                
                    hc = torch.cat([hidden0.squeeze(), hidden1.squeeze()])
                    hiddenl = [hc.unsqueeze(0) for _ in allowed_arguments_ids]

                    hiddenl = torch.cat(hiddenl)

                    # size: (len(fact_pool), 512)
                    candidates = torch.cat([encoded_fact_pool, hiddenl], dim=1)
                    candidates = candidates.to(device)

                arg_pool.append(arg_step_probs)

                tac = tactic_pool[tac]
                arg = [candidate_args[j] for j in arg_step]

                tactic = env.assemble_tactic(tac, arg)

            action = (fringe.item(), 0, tactic)


            print (action)
            # reward, done = env.step(action)
            try:
                # when step is performed, env.history (probably) changes
                # if goal_index == 0:
                #     raise "boom"
                reward, done = env.step(action)

            except:
                print("Step exception raised.")
                # print("Fringe: {}".format(env.history))
                print("Handling: {}".format(env.handling))
                print("Using: {}".format(env.using))
                # try again
                # counter = env.counter
                frequency = env.frequency
                env.close()
                print("Aborting current game ...")
                print("Restarting environment ...")
                print(env.goal)
                env = HolEnv(env.goal)
                flag = False
                break

            if t == 49:
                reward = -5
            # state_pool.append(state)
            reward_print.append(reward)
            # reward_pool.append(reward+trade_off*entropy)
            reward_pool.append(reward)

                # pg = ng

            steps += 1
            total_reward = float(np.sum(reward_print))

            if done == True:
                print ("Goal Proved in {} steps".format(i+1))
                break
                
            if t == 49:
                print("Failed.")
                print("Rewards: {}".format(reward_print))
                # print("Rewards: {}".format(reward_pool))
                print("Tactics: {}".format(action_pool))
                # print("Mean reward: {}\n".format(np.mean(reward_pool)))
                print("Total: {}".format(total_reward))
                iteration_rewards.append(total_reward)

        # Update policy
        # Discount reward
        print("Updating parameters ... ")
        running_add = 0
        for i in reversed(range(steps)):
            if reward_pool[i] == 0:
                running_add = 0
            else:
                running_add = running_add * gamma + reward_pool[i]
                reward_pool[i] = running_add

        optimizer_context.zero_grad()
        optimizer_tac.zero_grad()
        optimizer_arg.zero_grad()
        optimizer_term.zero_grad()

        for i in range(steps):
            # size : (1,1,4,128)
            total_loss = 0

            # state = state_pool[i]
            reward = reward_pool[i]

            fringe_loss = -fringe_pool[i] * (reward)
            arg_loss = -torch.sum(torch.stack(arg_pool[i])) * (reward)

            tac_loss = -tac_pool[i] * (reward)

            # entropy = fringe_pool[i] + torch.sum(torch.stack(arg_pool[i])) + tac_pool[i]

            # loss = fringe_loss + tac_loss + arg_loss + trade_off*entropy
            loss = fringe_loss + tac_loss + arg_loss
            total_loss += loss
            #loss.backward()

        total_loss.backward()

        # optimizer.step()

        optimizer_context.step()
        optimizer_tac.step()
        optimizer_arg.step()
        optimizer_term.step()

        fringe_pool = []
        tac_pool = []
        arg_pool = []
        action_pool = []
        reward_pool = []
        reward_print = []
        steps = 0
        elapsed = time.time() - start_t

        print (elapsed)

    return

In [46]:
TARGET_THEORIES = ["list"]
GOALS = [(key, value[4]) for key, value in database.items() if value[3] == "thm" and value[0] in TARGET_THEORIES]

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#run_iteration(GOALS)

In [23]:
import time
t = time.time()
# do stuff
elapsed = time.time() - t

In [24]:
elapsed

5.245208740234375e-05

CPU times: user 5 µs, sys: 0 ns, total: 5 µs
Wall time: 11.2 µs


2

In [194]:
'''

High level agent class 

'''
class Agent:
    def __init__(self, tactic_pool):
        self.tactic_pool = tactic_pool    
        self.load_encoder()
    
    def load_agent(self):
        pass
    
    def load_encoder(self):
        pass
        
    def run(self, env, max_steps):
        pass
    
    def update_params(self):
        pass
    

    
'''

Vanilla Torch implementation of TacticZero

'''
class TorchVanilla(Agent):
    def __init__(self, tactic_pool):
        super().__init__(tactic_pool)

        learning_rate = 1e-5

        self.context_rate = 5e-5
        self.tac_rate = 5e-5
        self.arg_rate = 5e-5
        self.term_rate = 5e-5

        self.gamma = 0.99 # 0.9

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        self.context_net = utp_model.ContextPolicy().to(self.device)
        self.tac_net = utp_model.TacPolicy(len(tactic_pool)).to(self.device)
        self.arg_net = utp_model.ArgPolicy(len(tactic_pool), 256).to(self.device)
        self.term_net = utp_model.TermPolicy(len(tactic_pool), 256).to(self.device)

        self.optimizer_context = torch.optim.RMSprop(list(self.context_net.parameters()), lr=self.context_rate)
        self.optimizer_tac = torch.optim.RMSprop(list(self.tac_net.parameters()), lr=self.tac_rate)
        self.optimizer_arg = torch.optim.RMSprop(list(self.arg_net.parameters()), lr=self.arg_rate)
        self.optimizer_term = torch.optim.RMSprop(list(self.term_net.parameters()), lr=self.term_rate)

     
    def load_encoder(self):
        checkpoint_path = "models/2021_02_22_16_07_03" # 97-98% accuracy model, up to and include probability theory

        checkpoint = Checkpoint.load(checkpoint_path)
        seq2seq = checkpoint.model
        input_vocab = checkpoint.input_vocab
        output_vocab = checkpoint.output_vocab

        self.encoder = BatchPredictor(seq2seq, input_vocab, output_vocab)
        return 

        
    def run(self, env, encoded_fact_pool, allowed_arguments_ids, candidate_args, max_steps=50):
        
        
        fringe_pool = []
        tac_pool = []
        arg_pool = []
        action_pool = []
        reward_pool = []
        reward_print = []
        tac_print = []
        induct_arg = []
        proved = 0
        iteration_rewards = []
        steps = 0        
        
        start_t = time.time()

        for t in range(max_steps):
            
            # gather all the goals in the history
            try:
                representations, context_set, fringe_sizes = gather_encoded_content(env.history, self.encoder)
            except Exception as e:
                print ("Encoder error {}".format(e))
                return


            representations = torch.stack([i.to(self.device) for i in representations])
            context_scores = self.context_net(representations)
            contexts_by_fringe, scores_by_fringe = split_by_fringe(context_set, context_scores, fringe_sizes)
            fringe_scores = []
            for s in scores_by_fringe:
                fringe_score = torch.sum(s) 
                fringe_scores.append(fringe_score)
            fringe_scores = torch.stack(fringe_scores)
            fringe_probs = F.softmax(fringe_scores, dim=0)
            fringe_m = Categorical(fringe_probs)
            fringe = fringe_m.sample()
            fringe_pool.append(fringe_m.log_prob(fringe))

            # take the first context in the chosen fringe for now
            try:
                target_context = contexts_by_fringe[fringe][0]
            except:
                print ("error {} {}".format(contexts_by_fringe, fringe))

            target_goal = target_context["polished"]["goal"]
            target_representation = representations[context_set.index(target_context)]


            tac_input = target_representation#.unsqueeze(0)
            tac_input = tac_input.to(self.device)

            tac_probs = self.tac_net(tac_input)
            tac_m = Categorical(tac_probs)
            tac = tac_m.sample()
            tac_pool.append(tac_m.log_prob(tac))
            action_pool.append(tactic_pool[tac])
            tac_print.append(tac_probs.detach())


            tac_tensor = tac.to(self.device)

            if tactic_pool[tac] in no_arg_tactic:
                tactic = tactic_pool[tac]
                arg_probs = []
                arg_probs.append(torch.tensor(0))
                arg_pool.append(arg_probs)
                
            elif tactic_pool[tac] == "Induct_on":
                arg_probs = []
                candidates = []

                tokens = target_goal.split()
                tokens = list(dict.fromkeys(tokens))
                tokens = [[t] for t in tokens if t[0] == "V"]
                if tokens:
                    token_representations, _ = self.encoder.encode(tokens)
                
                    encoded_tokens = torch.cat(token_representations.split(1), dim=2).squeeze(0)
                
                    target_representation_list = [target_representation for _ in tokens]

                    target_representations = torch.cat(target_representation_list)

                    candidates = torch.cat([encoded_tokens, target_representations], dim=1)
                    candidates = candidates.to(self.device)


                    scores = self.term_net(candidates, tac_tensor)
                    term_probs = F.softmax(scores, dim=0)
                    try:
                        term_m = Categorical(term_probs.squeeze(1))
                    except:
                        print("probs: {}".format(term_probs))                                          
                        print("candidates: {}".format(candidates.shape))
                        print("scores: {}".format(scores))
                        print("tokens: {}".format(tokens))
                        exit()
                    term = term_m.sample()
                    arg_probs.append(term_m.log_prob(term))
                    induct_arg.append(tokens[term])                
                    tm = tokens[term][0][1:] # remove headers, e.g., "V" / "C" / ...
                    arg_pool.append(arg_probs)
                    if tm:
                        tactic = "Induct_on `{}`".format(tm)
                    else:
                        print("tm is empty")
                        print(tokens)
                        # only to raise an error
                        tactic = "Induct_on"
                else:
                    arg_probs.append(torch.tensor(0))
                    induct_arg.append("No variables")
                    arg_pool.append(arg_probs)
                    tactic = "Induct_on"
            else:
                hidden0 = hidden1 = target_representation#.unsqueeze(0).unsqueeze(0)
                hidden0 = hidden0.to(self.device)
                hidden1 = hidden1.to(self.device)

                hidden = (hidden0, hidden1)
                
                # concatenate the candidates with hidden states.

                hc = torch.cat([hidden0.squeeze(), hidden1.squeeze()])
                hiddenl = [hc.unsqueeze(0) for _ in allowed_arguments_ids]
                
                
                hiddenl = torch.cat(hiddenl)

                # size: (len(fact_pool), 512)
                candidates = torch.cat([encoded_fact_pool, hiddenl], dim=1)
                candidates = candidates.to(self.device)
                            
                input = tac_tensor
                # run it once before predicting the first argument
                hidden, _ = self.arg_net(input, candidates, hidden)

                # the indices of chosen args
                arg_step = []
                arg_step_probs = []
                if tactic_pool[tac] in thm_tactic:
                    arg_len = 1
                else:
                    arg_len = 5#ARG_LEN


                for _ in range(arg_len):
                    hidden, scores = self.arg_net(input, candidates, hidden)
                    arg_probs = F.softmax(scores, dim=0)
                    arg_m = Categorical(arg_probs.squeeze(1))
                    arg = arg_m.sample()
                    arg_step.append(arg)
                    arg_step_probs.append(arg_m.log_prob(arg))

                    hidden0 = hidden[0].squeeze().repeat(1, 1, 1)
                    hidden1 = hidden[1].squeeze().repeat(1, 1, 1)
                    
                    # encoded chosen argument
                    input = encoded_fact_pool[arg].unsqueeze(0)#.unsqueeze(0)

                    # renew candidates                
                    hc = torch.cat([hidden0.squeeze(), hidden1.squeeze()])
                    hiddenl = [hc.unsqueeze(0) for _ in allowed_arguments_ids]

                    hiddenl = torch.cat(hiddenl)
                    #appends both hidden and cell states (when paper only does hidden?)
                    candidates = torch.cat([encoded_fact_pool, hiddenl], dim=1)
                    candidates = candidates.to(self.device)

                arg_pool.append(arg_step_probs)

                tac = tactic_pool[tac]
                arg = [candidate_args[j] for j in arg_step]

                tactic = env.assemble_tactic(tac, arg)

            action = (fringe.item(), 0, tactic)


            #print (action)
            # reward, done = env.step(action)
            try:
                reward, done = env.step(action)

            except:
                print("Step exception raised.")
                return
                # print("Fringe: {}".format(env.history))
                print("Handling: {}".format(env.handling))
                print("Using: {}".format(env.using))
                # try again
                # counter = env.counter
                frequency = env.frequency
                env.close()
                print("Aborting current game ...")
                print("Restarting environment ...")
                print(env.goal)
                env = HolEnv(env.goal)
                flag = False
                break

            if t == 49:
                reward = -5
                
            reward_print.append(reward)
            reward_pool.append(reward)

            steps += 1
            total_reward = float(np.sum(reward_print))

            if done == True:
                print ("Goal Proved in {} steps".format(t+1))
                break

            if t == 49:
                print("Failed.")
                #print("Rewards: {}".format(reward_print))
                # print("Rewards: {}".format(reward_pool))
                #print("Tactics: {}".format(action_pool))
                # print("Mean reward: {}\n".format(np.mean(reward_pool)))
                #print("Total: {}".format(total_reward))
                iteration_rewards.append(total_reward)

        
        self.update_params(reward_pool, fringe_pool, arg_pool, tac_pool, steps)
        
        elapsed = time.time() - start_t

        #print (elapsed)

        return

    def update_params(self, reward_pool, fringe_pool, arg_pool, tac_pool, step_count):
        # Update policy
        # Discount reward
        print("Updating parameters ... ")
        running_add = 0
        for i in reversed(range(step_count)):
            if reward_pool[i] == 0:
                running_add = 0
            else:
                running_add = running_add * self.gamma + reward_pool[i]
                reward_pool[i] = running_add

        self.optimizer_context.zero_grad()
        self.optimizer_tac.zero_grad()
        self.optimizer_arg.zero_grad()
        self.optimizer_term.zero_grad()

        total_loss = 0

        for i in range(step_count):
            reward = reward_pool[i]
            
            fringe_loss = -fringe_pool[i] * (reward)
            arg_loss = -torch.sum(torch.stack(arg_pool[i])) * (reward)
            tac_loss = -tac_pool[i] * (reward)
            
            loss = fringe_loss + tac_loss + arg_loss
            total_loss += loss

        total_loss.backward()

        self.optimizer_context.step()
        self.optimizer_tac.step()
        self.optimizer_arg.step()
        self.optimizer_term.step()
        
        return 


In [201]:
class Experiment:
    def __init__(self, agent, goals, db_dir, encoded_db_dir, num_iterations):
        self.agent = agent
        self.goals = goals
        self.num_iterations = num_iterations
        self.load_db(db_dir)
        self.load_encoded_db(encoded_db_dir)
        
    def train(self):
        env = HolEnv("T")
        for iteration in range(self.num_iterations):
            for i, goal in enumerate(self.goals):
                try:
                    env.reset(goal[1])
                except Exception as e:
                    print ("Restarting environment..")
                    env = HolEnv("T")
                    continue
                 
                encoded_fact_pool, allowed_arguments_ids, candidate_args = self.gen_fact_pool(env, goal)
                self.agent.run(env, encoded_fact_pool, allowed_arguments_ids, candidate_args, max_steps=50)

    def load_encoded_db(self, encoded_db_dir):
        self.encoded_database = torch.load(encoded_db_dir)

    def load_db(self, db_dir):
        with open(db_dir) as f:
            self.database = json.load(f)

    def gen_fact_pool(self, env, goal):

        allowed_theories = list(set(re.findall(r'C\$(\w+)\$ ', goal[0])))
        
        goal_theory = goal[1]
    
        polished_goal = env.fringe["content"][0]["polished"]["goal"]
        
        try:
            allowed_arguments_ids = []
            candidate_args = []
            for i,t in enumerate(self.database):
                if self.database[t][0] in allowed_theories and (self.database[t][0] != goal_theory or int(self.database[t][2]) < int(self.database[polished_goal][2])):
                    allowed_arguments_ids.append(i)
                    candidate_args.append(t)

            env.toggle_simpset("diminish", goal_theory)
            #print("Removed simpset of {}".format(goal_theory))

        except:
            allowed_arsguments_ids = []
            candidate_args = []
            for i,t in enumerate(self.database):
                if self.database[t][0] in allowed_theories:
                    allowed_arguments_ids.append(i)
                    candidate_args.append(t)
            print("Theorem not found in database.")

        #print ("Number of candidate facts to use: {}".format(len(candidate_args)))
        try:
            encoded_fact_pool = torch.index_select(self.encoded_database, 0, torch.tensor(allowed_arguments_ids))
        except Exception as e:
            print ("Index select error {}".format(e))
        return encoded_fact_pool, allowed_arguments_ids, candidate_args

In [202]:
test = TorchVanilla(tactic_pool)

In [203]:
# with open("dataset.json") as fp:
#     dataset = json.load(fp)
    

In [204]:
# env = HolEnv("T")
# val_goals = []
# for goal in dataset:
#     try:
#         p_goal = env.get_polish(goal)
#         ret.append((p_goal[0]["polished"]['goal'], goal))
#     except:
#         print (goal)

In [205]:
exp = Experiment(test, ret,"include_probability.json", 'encoded_include_probability.pt', 1)

In [206]:
exp.train()

Importing theories...
Loading modules...
Configuration done.
Initialization done. Main goal is:
∀(s1 :α -> bool) (s2 :α -> bool). s1 ⊂ s2 ⇔ s1 ⊆ s2 ∧ ¬(s2 ⊆ s1).
Removing simp lemmas from ∀(s1 :α -> bool) (s2 :α -> bool). s1 ⊂ s2 ⇔ s1 ⊆ s2 ∧ ¬(s2 ⊆ s1)
same action
same action
same action
same action
Failed.
Updating parameters ... 
Initialization done. Main goal is:
∀(p :num -> bool). (∃(n :num). p n) ⇔ p ($LEAST p) ∧ ∀(n :num). n < $LEAST p ⇒ ¬p n.
Removing simp lemmas from ∀(p :num -> bool). (∃(n :num). p n) ⇔ p ($LEAST p) ∧ ∀(n :num). n < $LEAST p ⇒ ¬p n
same action
same action
same action
same action
Failed.
Updating parameters ... 
Initialization done. Main goal is:
∀(f :β -> γ) (g :α -> β) (l :α list). MAP f (MAP g l) = MAP (f ∘ g) l.
Removing simp lemmas from ∀(f :β -> γ) (g :α -> β) (l :α list). MAP f (MAP g l) = MAP (f ∘ g) l
same action
Goal Proved in 18 steps
Updating parameters ... 
Initialization done. Main goal is:
(x :α) ∈ RDOM ((R :α -> β -> bool) \\ (k :α)) ⇔ x ∈ RDOM 

Goal Proved in 6 steps
Updating parameters ... 
Initialization done. Main goal is:
∀(l :α list) (m :num). m = LENGTH l ⇒ TAKE m l = l.
Removing simp lemmas from ∀(l :α list) (m :num). m = LENGTH l ⇒ TAKE m l = l
Goal Proved in 5 steps
Updating parameters ... 
Initialization done. Main goal is:
∀(s :α -> bool). s ⊆ s.
Removing simp lemmas from ∀(s :α -> bool). s ⊆ s
Goal Proved in 2 steps
Updating parameters ... 
Initialization done. Main goal is:
(f :δ -> γ) ∘ UNCURRY (g :α -> β -> δ) = UNCURRY (($o f :(β -> δ) -> β -> γ) ∘ g).
Removing simp lemmas from (f :δ -> γ) ∘ UNCURRY (g :α -> β -> δ) = UNCURRY (($o f :(β -> δ) -> β -> γ) ∘ g)
same action
same action
Goal Proved in 15 steps
Updating parameters ... 
Initialization done. Main goal is:
∀(t :bool). F ⇒ t ⇔ T.
Removing simp lemmas from ∀(t :bool). F ⇒ t ⇔ T
Goal Proved in 2 steps
Updating parameters ... 
Initialization done. Main goal is:
∀(s :num -> bool). s ≠ (∅ :num -> bool) ∧ FINITE s ⇒ MIN_SET s ≤ MAX_SET s.
Removing simp lemmas

Step exception raised.
Timeout exceeded.
<pexpect.pty_spawn.spawn object at 0x7f4f0d959fa0>
command: /home/sean/Documents/PhD/HOL4/HOL/bin/hol
args: ['/home/sean/Documents/PhD/HOL4/HOL/bin/hol', '--maxheap=256']
buffer (last 100 chars): b'\xe2\x87\x92 x1 = x2) \xe2\x87\x92 \xe2\x88\x80(x :\xce\xb1). PMATCH_ROW_COND p g i x \xe2\x87\x92 (@(y :\xce\xb1). PMATCH_ROW_COND p g i y) = x`;\r\n'
before (last 100 chars): b'\xe2\x87\x92 x1 = x2) \xe2\x87\x92 \xe2\x88\x80(x :\xce\xb1). PMATCH_ROW_COND p g i x \xe2\x87\x92 (@(y :\xce\xb1). PMATCH_ROW_COND p g i y) = x`;\r\n'
after: <class 'pexpect.exceptions.TIMEOUT'>
match: None
match_index: None
exitstatus: None
flag_eof: False
pid: 50630
child_fd: 69
closed: False
timeout: 3
delimiter: <class 'pexpect.exceptions.EOF'>
logfile: None
logfile_read: None
logfile_send: None
maxread: 2000
ignorecase: False
searchwindowsize: None
delaybeforesend: None
delayafterclose: 0.1
delayafterterminate: 0.1
searcher: searcher_re:
    0: re.compile(b'\r\n>')
Impo

Loading modules...
Configuration done.
Initialization done. Main goal is:
RC ($PSUBSET :(α -> bool) -> (α -> bool) -> bool) = ($SUBSET :(α -> bool) -> (α -> bool) -> bool).
Removing simp lemmas from RC ($PSUBSET :(α -> bool) -> (α -> bool) -> bool) = ($SUBSET :(α -> bool) -> (α -> bool) -> bool)
Goal Proved in 2 steps
Updating parameters ... 
Initialization done. Main goal is:
∀(x :bool) (x' :bool) (y :bool) (y' :bool). (¬y ⇒ x ⇒ x') ∧ (¬x' ⇒ y ⇒ y') ⇒ x ∨ y ⇒ x' ∨ y'.
Removing simp lemmas from ∀(x :bool) (x' :bool) (y :bool) (y' :bool). (¬y ⇒ x ⇒ x') ∧ (¬x' ⇒ y ⇒ y') ⇒ x ∨ y ⇒ x' ∨ y'
Goal Proved in 6 steps
Updating parameters ... 
Initialization done. Main goal is:
∀(e :α) (s :α -> bool). POW (e INSERT s) = IMAGE ($INSERT e) (POW s) ∪ POW s.
Removing simp lemmas from ∀(e :α) (s :α -> bool). POW (e INSERT s) = IMAGE ($INSERT e) (POW s) ∪ POW s
same action
same action
same action
same action
Failed.
Updating parameters ... 
Initialization done. Main goal is:
∀(s :α -> bool). s DIFF s =

Failed.
Updating parameters ... 
Initialization done. Main goal is:
∀(P :α -> β) (v :γ). P (PMATCH v ([] :(γ -> α option) list)) = P (ARB :α).
Removing simp lemmas from ∀(P :α -> β) (v :γ). P (PMATCH v ([] :(γ -> α option) list)) = P (ARB :α)
Encoder error 'C$patternMatches$'
Initialization done. Main goal is:
∀(s :α -> bool) (t :α -> bool). INFINITE s ∧ FINITE t ⇒ ∃(x :α). x ∈ s ∧ x ∉ t.
Removing simp lemmas from ∀(s :α -> bool) (t :α -> bool). INFINITE s ∧ FINITE t ⇒ ∃(x :α). x ∈ s ∧ x ∉ t
same action
Failed.
Updating parameters ... 
Initialization done. Main goal is:
∀(v :α) (rows :(α -> β option) list) (n :num). n < LENGTH rows ∧ IS_SOME (EL n rows v) ⇒ PMATCH v rows = PMATCH v (TAKE (SUC n) rows).
Removing simp lemmas from ∀(v :α) (rows :(α -> β option) list) (n :num). n < LENGTH rows ∧ IS_SOME (EL n rows v) ⇒ PMATCH v rows = PMATCH v (TAKE (SUC n) rows)
Encoder error 'Vrows'
Initialization done. Main goal is:
∀(f :α -> β) (ls :α list). ALL_DISTINCT (MAP f ls) ⇒ ALL_DISTINCT ls.
R

Failed.
Updating parameters ... 
Initialization done. Main goal is:
{x | (P :α -> bool) x} = P.
Removing simp lemmas from {x | (P :α -> bool) x} = P
Goal Proved in 2 steps
Updating parameters ... 
Initialization done. Main goal is:
∀(x :α) (s :α -> bool) (t :α -> bool). x INSERT s ⊆ t ⇔ x ∈ t ∧ s ⊆ t.
Removing simp lemmas from ∀(x :α) (s :α -> bool) (t :α -> bool). x INSERT s ⊆ t ⇔ x ∈ t ∧ s ⊆ t
Goal Proved in 3 steps
Updating parameters ... 
Initialization done. Main goal is:
∀(t :bool). t ⇒ T.
Removing simp lemmas from ∀(t :bool). t ⇒ T
Goal Proved in 3 steps
Updating parameters ... 
Initialization done. Main goal is:
(∀(x :α) (y :β). (R1 :α -> β -> bool) x y ⇒ (R2 :β -> α -> bool) y x) ⇒ ∀(x :α list) (y :β list). LIST_REL R1 x y ⇒ LIST_REL R2 y x.
Removing simp lemmas from (∀(x :α) (y :β). (R1 :α -> β -> bool) x y ⇒ (R2 :β -> α -> bool) y x) ⇒ ∀(x :α list) (y :β list). LIST_REL R1 x y ⇒ LIST_REL R2 y x
same action
same action
Failed.
Updating parameters ... 
Initialization done. Mai

Goal Proved in 4 steps
Updating parameters ... 
Initialization done. Main goal is:
(x :α -> bool) ∩ COMPL x = (∅ :α -> bool) ∧ COMPL x ∩ x = (∅ :α -> bool).
Removing simp lemmas from (x :α -> bool) ∩ COMPL x = (∅ :α -> bool) ∧ COMPL x ∩ x = (∅ :α -> bool)
Goal Proved in 3 steps
Updating parameters ... 
Initialization done. Main goal is:
TAKE (n :num) (l :α list) = ([] :α list) ⇔ n = (0 :num) ∨ l = ([] :α list).
Removing simp lemmas from TAKE (n :num) (l :α list) = ([] :α list) ⇔ n = (0 :num) ∨ l = ([] :α list)
Goal Proved in 2 steps
Updating parameters ... 
Initialization done. Main goal is:
APPLY_REDUNDANT_ROWS_INFO ([] :bool list) ([] :α list) = ([] :α list) ∧ (∀(is :bool list) (x :β) (xs :β list). APPLY_REDUNDANT_ROWS_INFO (T::is) (x::xs) = APPLY_REDUNDANT_ROWS_INFO is xs) ∧ ∀(is :bool list) (x :γ) (xs :γ list). APPLY_REDUNDANT_ROWS_INFO (F::is) (x::xs) = x::APPLY_REDUNDANT_ROWS_INFO is xs.
Removing simp lemmas from APPLY_REDUNDANT_ROWS_INFO ([] :bool list) ([] :α list) = ([] :α lis

Failed.
Updating parameters ... 
Initialization done. Main goal is:
(x :α list) ≠ ([] :α list) ⇔ (0 :num) < LENGTH x.
Removing simp lemmas from (x :α list) ≠ ([] :α list) ⇔ (0 :num) < LENGTH x
same action
Failed.
Updating parameters ... 
Initialization done. Main goal is:
∀(y :α) (x :α) (l :α list). MEM y (SNOC x l) ⇔ y = x ∨ MEM y l.
Removing simp lemmas from ∀(y :α) (x :α) (l :α list). MEM y (SNOC x l) ⇔ y = x ∨ MEM y l
Goal Proved in 1 steps
Updating parameters ... 
Initialization done. Main goal is:
SUM_SET (∅ :num -> bool) = (0 :num).
Removing simp lemmas from SUM_SET (∅ :num -> bool) = (0 :num)
Goal Proved in 1 steps
Updating parameters ... 
Initialization done. Main goal is:
MAP (λ(x :α). x) (l :α list) = l ∧ MAP (I :α -> α) l = l.
Removing simp lemmas from MAP (λ(x :α). x) (l :α list) = l ∧ MAP (I :α -> α) l = l
Goal Proved in 4 steps
Updating parameters ... 
Initialization done. Main goal is:
∀(s :α -> bool) (t :α -> bool). s ⊆ t ∧ countable t ⇒ countable s.
Removing simp lemm

same action
same action
Failed.
Updating parameters ... 
Initialization done. Main goal is:
∀(R :α -> α -> bool). transitive R ⇒ R⁺ = R.
Removing simp lemmas from ∀(R :α -> α -> bool). transitive R ⇒ R⁺ = R
same action
same action
same action
Failed.
Updating parameters ... 
Initialization done. Main goal is:
(OPTION_GUARD (b :bool) = SOME () ⇔ b) ∧ (OPTION_GUARD b = (NONE :unit option) ⇔ ¬b).
Removing simp lemmas from (OPTION_GUARD (b :bool) = SOME () ⇔ b) ∧ (OPTION_GUARD b = (NONE :unit option) ⇔ ¬b)
Encoder error 'OPTION_GUARD'
Initialization done. Main goal is:
∀(f :α -> β) (s :β -> bool). PREIMAGE f s = s ∘ f.
Removing simp lemmas from ∀(f :α -> β) (s :β -> bool). PREIMAGE f s = s ∘ f
same action
same action
Goal Proved in 43 steps
Updating parameters ... 
Initialization done. Main goal is:
∀(l1 :α list) (l2 :α list) (P :α -> bool) (P' :α -> bool). l1 = l2 ∧ (∀(x :α). MEM x l2 ⇒ (P x ⇔ P' x)) ⇒ (EXISTS P l1 ⇔ EXISTS P' l2).
Removing simp lemmas from ∀(l1 :α list) (l2 :α list) (P :

Initialization done. Main goal is:
∀(P :α -> bool) (Q :β -> bool) (x :α). (x INSERT P) × Q = {x} × Q ∪ P × Q.
Removing simp lemmas from ∀(P :α -> bool) (Q :β -> bool) (x :α). (x INSERT P) × Q = {x} × Q ∪ P × Q
same action
Failed.
Updating parameters ... 
Initialization done. Main goal is:
(n :num) = LENGTH (l1 :α list) ⇒ ZIP (l1,COUNT_LIST n) = GENLIST (λ(n :num). (EL n l1,n)) (LENGTH l1).
Removing simp lemmas from (n :num) = LENGTH (l1 :α list) ⇒ ZIP (l1,COUNT_LIST n) = GENLIST (λ(n :num). (EL n l1,n)) (LENGTH l1)
Failed.
Updating parameters ... 
Initialization done. Main goal is:
(∀(x :α). MEM x (ls :α list) ⇒ (R :α -> α -> bool) x x) ⇒ LIST_REL R ls ls.
Removing simp lemmas from (∀(x :α). MEM x (ls :α list) ⇒ (R :α -> α -> bool) x x) ⇒ LIST_REL R ls ls
same action
Goal Proved in 45 steps
Updating parameters ... 
Initialization done. Main goal is:
LAST ((h :α)::(t :α list)) = if t = ([] :α list) then h else LAST t.
Removing simp lemmas from LAST ((h :α)::(t :α list)) = if t = ([] :α 

Goal Proved in 29 steps
Updating parameters ... 
Initialization done. Main goal is:
∀(x :α -> bool) (y :β -> bool). FUNSET x y = DFUNSET x (K y :α -> β -> bool).
Removing simp lemmas from ∀(x :α -> bool) (y :β -> bool). FUNSET x y = DFUNSET x (K y :α -> β -> bool)
Goal Proved in 7 steps
Updating parameters ... 
Initialization done. Main goal is:
∀(A :bool). ¬A ∧ A ⇔ F.
Removing simp lemmas from ∀(A :bool). ¬A ∧ A ⇔ F
Goal Proved in 1 steps
Updating parameters ... 
Initialization done. Main goal is:
∀(x :α) (y :β). (INL x :α + β) ≠ (INR y :α + β).
Removing simp lemmas from ∀(x :α) (y :β). (INL x :α + β) ≠ (INR y :α + β)
Goal Proved in 2 steps
Updating parameters ... 
Initialization done. Main goal is:
∀(ls :α list) (n :num). DROP n ls = ([] :α list) ⇔ LENGTH ls ≤ n.
Removing simp lemmas from ∀(ls :α list) (n :num). DROP n ls = ([] :α list) ⇔ LENGTH ls ≤ n
Goal Proved in 4 steps
Updating parameters ... 
Initialization done. Main goal is:
DIV2 (BIT1 (x :num)) = x.
Removing simp lemmas fro

same action
same action
Failed.
Updating parameters ... 
Initialization done. Main goal is:
(∀(a :α list) (c :α list). (d :α list) ≠ a ++ [(b :α)] ++ c) ⇔ ¬MEM b d.
Removing simp lemmas from (∀(a :α list) (c :α list). (d :α list) ≠ a ++ [(b :α)] ++ c) ⇔ ¬MEM b d
same action
Failed.
Updating parameters ... 
Initialization done. Main goal is:
∀(a :α list) (b :α list) (c :α list). a ++ b ≼ a ++ c ⇔ b ≼ c.
Removing simp lemmas from ∀(a :α list) (b :α list) (c :α list). a ++ b ≼ a ++ c ⇔ b ≼ c
Goal Proved in 1 steps
Updating parameters ... 
Initialization done. Main goal is:
countable 𝕌(:α # β) ⇔ countable 𝕌(:α) ∧ countable 𝕌(:β).
Removing simp lemmas from countable 𝕌(:α # β) ⇔ countable 𝕌(:α) ∧ countable 𝕌(:β)
same action
same action
same action
Failed.
Updating parameters ... 
Initialization done. Main goal is:
∀(l1 :α list) (l2 :α list) (e :α). FRONT (l1 ++ e::l2) = l1 ++ FRONT (e::l2).
Removing simp lemmas from ∀(l1 :α list) (l2 :α list) (e :α). FRONT (l1 ++ e::l2) = l1 ++ FRONT (e::l2)

Goal Proved in 1 steps
Updating parameters ... 
Initialization done. Main goal is:
∀(s :α -> bool). ¬(s ⊂ s).
Removing simp lemmas from ∀(s :α -> bool). ¬(s ⊂ s)
Goal Proved in 2 steps
Updating parameters ... 
Initialization done. Main goal is:
∀(l1 :α list) (n :num). LENGTH l1 ≤ n ⇒ ∀(l2 :α list). EL n (l1 ++ l2) = EL (n − LENGTH l1) l2.
Removing simp lemmas from ∀(l1 :α list) (n :num). LENGTH l1 ≤ n ⇒ ∀(l2 :α list). EL n (l1 ++ l2) = EL (n − LENGTH l1) l2
Goal Proved in 40 steps
Updating parameters ... 
Initialization done. Main goal is:
LINV_OPT (f :α -> β) (s :α -> bool) (y :β) = SOME (x :α) ⇒ x ∈ s ∧ f x = y.
Removing simp lemmas from LINV_OPT (f :α -> β) (s :α -> bool) (y :β) = SOME (x :α) ⇒ x ∈ s ∧ f x = y
same action
same action
same action
same action
same action
same action
same action
Failed.
Updating parameters ... 
Initialization done. Main goal is:
∀(P :α -> bool) (l :α list). EVERY P (REVERSE l) ⇔ EVERY P l.
Removing simp lemmas from ∀(P :α -> bool) (l :α list). EVERY P 

RuntimeError: index_select(): Expected dtype int32 or int64 for index

In [3]:
ret[0]

NameError: name 'ret' is not defined

In [136]:
e = HolEnv("T")

Importing theories...
Loading modules...
Configuration done.


In [137]:
e.goal

'T'

In [None]:
#TODO
#fix error handling for step exception 
#Add logging, saving of logs, experiment metadata (date, agent info, database used etc.)
#Add validation logic
#Run on paper dataset
#Add replays