In [1]:
import jax
import jax.numpy as jnp
import haiku as hk
import optax

from jax import random
import sys
from time import sleep
import json
import pexpect
import re
import os
import argparse
import logging
import timeit
import torch

import seq2seq
from batch_predictor import BatchPredictor
from checkpoint import Checkpoint

from policy_networks import *
import policy_networks

import utp_model

from new_env import *

from jax.config import config
config.update("jax_debug_nans", True) 
import numpy as np
#jax.config.update("jax_enable_x64", False)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
#HOLPATH = "/home/sean/Documents/hol/HOL/bin/hol --maxheap=256"
HOLPATH = "/home/sean/Documents/PhD/HOL4/HOL/bin/hol --maxheap=256"

#tactic_zero_path = "/home/sean/Documents/PhD/git/repo/PhD/tacticzero/holgym/"

In [2]:
MORE_TACTICS = True
if not MORE_TACTICS:
    thms_tactic = ["simp", "fs", "metis_tac"]
    thm_tactic = ["irule"]
    term_tactic = ["Induct_on"]
    no_arg_tactic = ["strip_tac"]
else:
    thms_tactic = ["simp", "fs", "metis_tac", "rw"]
    thm_tactic = ["irule", "drule"]
    term_tactic = ["Induct_on"]
    no_arg_tactic = ["strip_tac", "EQ_TAC"]
    
tactic_pool = thms_tactic + thm_tactic + term_tactic + no_arg_tactic

In [3]:
#TODO Move to another file 

def get_polish(raw_goal):
        goal = construct_goal(raw_goal)
        process.sendline(goal.encode("utf-8"))
        process.expect("\r\n>")
        process.sendline("val _ = set_term_printer (HOLPP.add_string o pt);".encode("utf-8"))
        process.expect("\r\n>")
        process.sendline("top_goals();".encode("utf-8"))
        process.expect("val it =")
        process.expect([": goal list", ":\r\n +goal list"])

        polished_raw = process.before.decode("utf-8")
        polished_subgoals = re.sub("“|”","\"", polished_raw)
        polished_subgoals = re.sub("\r\n +"," ", polished_subgoals)

        pd = eval(polished_subgoals)
        
        process.expect("\r\n>")
        process.sendline("drop();".encode("utf-8"))
        process.expect("\r\n>")
        process.sendline("val _ = set_term_printer default_pt;".encode("utf-8"))
        process.expect("\r\n>")

        data = [{"polished":{"assumptions": e[0][0], "goal":e[0][1]},
                 "plain":{"assumptions": e[1][0], "goal":e[1][1]}}
                for e in zip(pd, [([], raw_goal)])]
        return data 
    
def construct_goal(goal):
    s = "g " + "`" + goal + "`;"
    return s

def gather_encoded_content_(history, encoder):
    fringe_sizes = []
    contexts = []
    reverted = []
    for i in history:
        c = i["content"]
        contexts.extend(c)
        fringe_sizes.append(len(c))
    for e in contexts:
        g = revert_with_polish(e)
        reverted.append(g.strip().split())
    out = []
    sizes = []
    for goal in reverted:
        out_, sizes_ = encoder.encode([goal])
        out.append(torch.cat(out_.split(1), dim=2).squeeze(0))
        sizes.append(sizes_)

    representations = out

    return representations, contexts, fringe_sizes

def parse_theory(pg):
    theories = re.findall(r'C\$(\w+)\$ ', pg)
    theories = set(theories)
    for th in EXCLUDED_THEORIES:
        theories.discard(th)
    return list(theories)

def revert_with_polish(context):
    target = context["polished"]
    assumptions = target["assumptions"]
    goal = target["goal"]
    for i in reversed(assumptions): 
        #goal = "@ @ D$min$==> {} {}".format(i, goal)
        goal = "@ @ C$min$ ==> {} {}".format(i, goal)

    return goal 

def split_by_fringe(goal_set, goal_scores, fringe_sizes):
    # group the scores by fringe
    fs = []
    gs = []
    counter = 0
    for i in fringe_sizes:
        end = counter + i
        fs.append(goal_scores[counter:end])
        gs.append(goal_set[counter:end])
        counter = end
    return gs, fs

In [37]:
'''

High level agent class 

'''
class Agent:
    def __init__(self, tactic_pool):
        self.tactic_pool = tactic_pool    
        self.load_encoder()
    
    def load_agent(self):
        pass
    
    def load_encoder(self):
        pass
        
    def run(self, env, max_steps):
        pass
    
    def update_params(self):
        pass
    

    
'''

Vanilla Torch implementation of TacticZero

'''
class TorchVanilla(Agent):
    def __init__(self, tactic_pool):
        super().__init__(tactic_pool)

        self.ARG_LEN = 5
        learning_rate = 1e-5

        self.context_rate = 5e-5
        self.tac_rate = 5e-5
        self.arg_rate = 5e-5
        self.term_rate = 5e-5
        
        self.gamma = 0.99 # 0.9

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        self.context_net = utp_model.ContextPolicy().to(self.device)
        self.tac_net = utp_model.TacPolicy(len(tactic_pool)).to(self.device)
        self.arg_net = utp_model.ArgPolicy(len(tactic_pool), 256).to(self.device)
        self.term_net = utp_model.TermPolicy(len(tactic_pool), 256).to(self.device)

        self.optimizer_context = torch.optim.RMSprop(list(self.context_net.parameters()), lr=self.context_rate)
        self.optimizer_tac = torch.optim.RMSprop(list(self.tac_net.parameters()), lr=self.tac_rate)
        self.optimizer_arg = torch.optim.RMSprop(list(self.arg_net.parameters()), lr=self.arg_rate)
        self.optimizer_term = torch.optim.RMSprop(list(self.term_net.parameters()), lr=self.term_rate)

     
    def load_encoder(self):
        checkpoint_path = "models/2021_02_22_16_07_03" # 97-98% accuracy model, up to and include probability theory

        checkpoint = Checkpoint.load(checkpoint_path)
        seq2seq = checkpoint.model
        input_vocab = checkpoint.input_vocab
        output_vocab = checkpoint.output_vocab

        self.encoder = BatchPredictor(seq2seq, input_vocab, output_vocab)
        return 

        
    def run(self, env, encoded_fact_pool, allowed_arguments_ids, candidate_args, max_steps=50):
        
        
        fringe_pool = []
        tac_pool = []
        arg_pool = []
        action_pool = []
        reward_pool = []
        reward_print = []
        tac_print = []
        induct_arg = []
        proved = 0
        iteration_rewards = []
        steps = 0        
        
        trace = []
        
        start_t = time.time()

        for t in range(max_steps):
            
            # gather all the goals in the history
            try:
                representations, context_set, fringe_sizes = gather_encoded_content(env.history, self.encoder)
            except Exception as e:
                print ("Encoder error {}".format(e))
                return ("Encoder error", str(e))


            representations = torch.stack([i.to(self.device) for i in representations])
            context_scores = self.context_net(representations)
            contexts_by_fringe, scores_by_fringe = split_by_fringe(context_set, context_scores, fringe_sizes)
            fringe_scores = []
            for s in scores_by_fringe:
                fringe_score = torch.sum(s) 
                fringe_scores.append(fringe_score)
            fringe_scores = torch.stack(fringe_scores)
            fringe_probs = F.softmax(fringe_scores, dim=0)
            fringe_m = Categorical(fringe_probs)
            fringe = fringe_m.sample()
            fringe_pool.append(fringe_m.log_prob(fringe))

            # take the first context in the chosen fringe for now
            try:
                target_context = contexts_by_fringe[fringe][0]
            except:
                print ("error {} {}".format(contexts_by_fringe, fringe))

            target_goal = target_context["polished"]["goal"]
            target_representation = representations[context_set.index(target_context)]


            tac_input = target_representation#.unsqueeze(0)
            tac_input = tac_input.to(self.device)

            tac_probs = self.tac_net(tac_input)
            tac_m = Categorical(tac_probs)
            tac = tac_m.sample()
            tac_pool.append(tac_m.log_prob(tac))
            action_pool.append(tactic_pool[tac])
            tac_print.append(tac_probs.detach())


            tac_tensor = tac.to(self.device)

            if tactic_pool[tac] in no_arg_tactic:
                tactic = tactic_pool[tac]
                arg_probs = []
                arg_probs.append(torch.tensor(0))
                arg_pool.append(arg_probs)
                
            elif tactic_pool[tac] == "Induct_on":
                arg_probs = []
                candidates = []

                tokens = target_goal.split()
                tokens = list(dict.fromkeys(tokens))
                tokens = [[t] for t in tokens if t[0] == "V"]
                if tokens:
                    token_representations, _ = self.encoder.encode(tokens)
                
                    encoded_tokens = torch.cat(token_representations.split(1), dim=2).squeeze(0)
                
                    target_representation_list = [target_representation for _ in tokens]

                    target_representations = torch.cat(target_representation_list)

                    candidates = torch.cat([encoded_tokens, target_representations], dim=1)
                    candidates = candidates.to(self.device)


                    scores = self.term_net(candidates, tac_tensor)
                    term_probs = F.softmax(scores, dim=0)
                    try:
                        term_m = Categorical(term_probs.squeeze(1))
                    except:
                        print("probs: {}".format(term_probs))                                          
                        print("candidates: {}".format(candidates.shape))
                        print("scores: {}".format(scores))
                        print("tokens: {}".format(tokens))
                        exit()
                    term = term_m.sample()
                    arg_probs.append(term_m.log_prob(term))
                    induct_arg.append(tokens[term])                
                    tm = tokens[term][0][1:] # remove headers, e.g., "V" / "C" / ...
                    arg_pool.append(arg_probs)
                    if tm:
                        tactic = "Induct_on `{}`".format(tm)
                    else:
                        print("tm is empty")
                        print(tokens)
                        # only to raise an error
                        tactic = "Induct_on"
                else:
                    arg_probs.append(torch.tensor(0))
                    induct_arg.append("No variables")
                    arg_pool.append(arg_probs)
                    tactic = "Induct_on"
            else:
                hidden0 = hidden1 = target_representation
                hidden0 = hidden0.to(self.device)
                hidden1 = hidden1.to(self.device)

                hidden = (hidden0, hidden1)
                
                # concatenate the candidates with hidden states.

                hc = torch.cat([hidden0.squeeze(), hidden1.squeeze()])
                hiddenl = [hc.unsqueeze(0) for _ in allowed_arguments_ids]
                
                hiddenl = torch.cat(hiddenl)

                candidates = torch.cat([encoded_fact_pool, hiddenl], dim=1)
                candidates = candidates.to(self.device)
                            
                input = tac_tensor
                
                # run it once before predicting the first argument
                hidden, _ = self.arg_net(input, candidates, hidden)

                # the indices of chosen args
                arg_step = []
                arg_step_probs = []
                
                if tactic_pool[tac] in thm_tactic:
                    arg_len = 1
                else:
                    arg_len = self.ARG_LEN#ARG_LEN


                for _ in range(arg_len):
                    hidden, scores = self.arg_net(input, candidates, hidden)
                    arg_probs = F.softmax(scores, dim=0)
                    arg_m = Categorical(arg_probs.squeeze(1))
                    arg = arg_m.sample()
                    arg_step.append(arg)
                    arg_step_probs.append(arg_m.log_prob(arg))

                    hidden0 = hidden[0].squeeze().repeat(1, 1, 1)
                    hidden1 = hidden[1].squeeze().repeat(1, 1, 1)
                    
                    # encoded chosen argument
                    input = encoded_fact_pool[arg].unsqueeze(0)#.unsqueeze(0)

                    # renew candidates                
                    hc = torch.cat([hidden0.squeeze(), hidden1.squeeze()])
                    hiddenl = [hc.unsqueeze(0) for _ in allowed_arguments_ids]

                    hiddenl = torch.cat(hiddenl)
                    #appends both hidden and cell states (when paper only does hidden?)
                    candidates = torch.cat([encoded_fact_pool, hiddenl], dim=1)
                    candidates = candidates.to(self.device)

                arg_pool.append(arg_step_probs)

                tac = tactic_pool[tac]
                arg = [candidate_args[j] for j in arg_step]

                tactic = env.assemble_tactic(tac, arg)

            action = (fringe.item(), 0, tactic)
            
            trace.append(action)

            #print (action)
            # reward, done = env.step(action)
            try:
                reward, done = env.step(action)

            except:
                print("Step exception raised.")
                return ("Step error", action)
                # print("Fringe: {}".format(env.history))
                print("Handling: {}".format(env.handling))
                print("Using: {}".format(env.using))
                # try again
                # counter = env.counter
                frequency = env.frequency
                env.close()
                print("Aborting current game ...")
                print("Restarting environment ...")
                print(env.goal)
                env = HolEnv(env.goal)
                flag = False
                break

            if t == 49:
                reward = -5
                
            #could add environment state, but would grow rapidly
            trace.append((reward, action))
            
            reward_print.append(reward)
            reward_pool.append(reward)

            steps += 1
            total_reward = float(np.sum(reward_print))

            if done == True:
                print ("Goal Proved in {} steps".format(t+1))
                break

            if t == 49:
                print("Failed.")
                #print("Rewards: {}".format(reward_print))
                # print("Rewards: {}".format(reward_pool))
                #print("Tactics: {}".format(action_pool))
                # print("Mean reward: {}\n".format(np.mean(reward_pool)))
                #print("Total: {}".format(total_reward))
                iteration_rewards.append(total_reward)

        
        self.update_params(reward_pool, fringe_pool, arg_pool, tac_pool, steps)
        
        elapsed = time.time() - start_t

        #print (elapsed)

        return trace, steps, done

    def update_params(self, reward_pool, fringe_pool, arg_pool, tac_pool, step_count):
        # Update policy
        # Discount reward
        print("Updating parameters ... ")
        running_add = 0
        for i in reversed(range(step_count)):
            if reward_pool[i] == 0:
                running_add = 0
            else:
                running_add = running_add * self.gamma + reward_pool[i]
                reward_pool[i] = running_add

        self.optimizer_context.zero_grad()
        self.optimizer_tac.zero_grad()
        self.optimizer_arg.zero_grad()
        self.optimizer_term.zero_grad()

        total_loss = 0

        for i in range(step_count):
            reward = reward_pool[i]
            
            fringe_loss = -fringe_pool[i] * (reward)
            arg_loss = -torch.sum(torch.stack(arg_pool[i])) * (reward)
            tac_loss = -tac_pool[i] * (reward)
            
            loss = fringe_loss + tac_loss + arg_loss
            total_loss += loss

        total_loss.backward()

        self.optimizer_context.step()
        self.optimizer_tac.step()
        self.optimizer_arg.step()
        self.optimizer_term.step()
        
        return 


In [38]:
class Experiment:
    def __init__(self, agent, goals, db_dir, encoded_db_dir, num_iterations):
        self.agent = agent
        self.goals = goals
        self.num_iterations = num_iterations
        self.load_db(db_dir)
        self.load_encoded_db(encoded_db_dir)
        
    def train(self):
        env = HolEnv("T")
        
        env_errors = []
        agent_errors = []
        full_trace = []
        
        for iteration in range(self.num_iterations):
            iter_trace = {}
            for i, goal in enumerate(self.goals):
                print ("Goal #{}".format(str(i+1)))

                try:
                    env.reset(goal[1])
                except Exception as e:
                    print ("Restarting environment..")
                    env = HolEnv("T")
                    env_errors.append((goal, e, i))
                    continue
                 
                try:
                    encoded_fact_pool, allowed_arguments_ids, candidate_args = self.gen_fact_pool(env, goal)
                except:
                    env_errors.append((goal, "Error generating fact pool", i))
                    continue
                    
                result = self.agent.run(env, 
                                        encoded_fact_pool, 
                                        allowed_arguments_ids, candidate_args, max_steps=50)
                
                #agent run returns (error_msg, details) if error, else 3-tuple
                if len(result) == 2:
                    agent_errors.append((result, i))
                else:
                    trace, steps, done = result
                    iter_trace[goal[0]] = (trace, steps, done)
            
            full_trace.append(iter_trace)
            
            return full_trace, env_errors, agent_errors
            
                
                
    def load_encoded_db(self, encoded_db_dir):
        self.encoded_database = torch.load(encoded_db_dir)

    def load_db(self, db_dir):
        with open(db_dir) as f:
            self.database = json.load(f)

    def gen_fact_pool(self, env, goal):

        allowed_theories = list(set(re.findall(r'C\$(\w+)\$ ', goal[0])))
        
        goal_theory = goal[1]
    
        polished_goal = env.fringe["content"][0]["polished"]["goal"]
        
        try:
            allowed_arguments_ids = []
            candidate_args = []
            for i,t in enumerate(self.database):
                if self.database[t][0] in allowed_theories and (self.database[t][0] != goal_theory or int(self.database[t][2]) < int(self.database[polished_goal][2])):
                    allowed_arguments_ids.append(i)
                    candidate_args.append(t)

            env.toggle_simpset("diminish", goal_theory)
            #print("Removed simpset of {}".format(goal_theory))

        except:
            allowed_arsguments_ids = []
            candidate_args = []
            for i,t in enumerate(self.database):
                if self.database[t][0] in allowed_theories:
                    allowed_arguments_ids.append(i)
                    candidate_args.append(t)
            raise Exception("Theorem not found in database.")

        #print ("Number of candidate facts to use: {}".format(len(candidate_args)))
        try:
            encoded_fact_pool = torch.index_select(self.encoded_database, 0, torch.tensor(allowed_arguments_ids))
        except Exception as e:
            raise Exception("Index select error {}".format(e))
        return encoded_fact_pool, allowed_arguments_ids, candidate_args

In [39]:
test = TorchVanilla(tactic_pool)

In [29]:
with open("dataset.json") as fp:
    dataset = json.load(fp)
    

In [23]:
env = HolEnv("T")
paper_goals = []
for goal in dataset:
    try:
        p_goal = env.get_polish(goal)
        paper_goals.append((p_goal[0]["polished"]['goal'], goal))
    except:
        print (goal)

Importing theories...
Loading modules...
Configuration done.
DATATYPE ((sum :(α -> α + β) -> (β -> α + β) -> γ) (INL :α -> α + β) (INR :β -> α + β))
∀(A :α -> bool) (B :β -> bool) (C :α -> bool) (D :β -> bool). A × B ∩ (C × D) = A ∩ C × (B ∩ D)
DATATYPE ((list :α list -> (α -> α list -> α list) -> bool) ([] :α list) (CONS :α -> α list -> α list))


In [40]:
exp = Experiment(test, paper_goals,"include_probability.json", 'encoded_include_probability.pt', 1)

In [49]:
full_trace, env_errors, agent_errors = exp.train()

In [17]:
import time

In [None]:
#TODO
#Add logging, saving of logs, experiment metadata (date, agent info, database used etc.)
#Add validation logic
#Add replays