In [1]:
import os
import random
import pickle

def is_pp_useful(pp):
    lines = pp.split('\n')
    targets = []
    conditions = []
    min_freq = 1000000
    for line in lines:
        if line.startswith("⊢"):
            target = line.strip()
            targets.append(target)
        else:
            conditions.append(line.strip())
    if '⊢ False' in targets:
        return False
    #if '⊢ R' in targets:
    #    return False
    if '?m.' in pp:
        return False
    return True

# Given a pkl file which contains the provability data
# Returns (list(state_pp, next_tactic), list(state_pp, steps_to_proven(int or -1)))
def get_states_and_shortest_proofs(pkl_path, verbose=False):
    next_tac_list = []
    provability_list = []
    fin = open(pkl_path, 'rb')
    theorem_states = dict()
    if verbose: print("Loading", pkl_path)
    #
    while True:
        try:
            result = pickle.load(fin)
        except Exception as e:
            print(e)
            break
        file_path, full_name, theorem, state_pair = result
        if state_pair is None:
            continue
        state, proven_state, dist, tactics = state_pair
        state_pp = state.pp
        if full_name not in theorem_states:
            theorem_states[full_name] = dict()
            if verbose: print("Loading theorem", full_name)
        if state_pp not in theorem_states[full_name]:
            theorem_states[full_name][state_pp] = list()
        if tactics is not None and len(tactics) > 0:
            theorem_states[full_name][state_pp].append(tactics)
    #
    if verbose: print("Summarizing data ...")
    for full_name, states in theorem_states.items():
        negative_provability_list = list()
        num_tac_added = 0
        for state_pp, proofs in states.items():
            if not is_pp_useful(state_pp):
                continue
            if len(proofs) == 0:
                negative_provability_list.append((file_path, full_name, state_pp, -1))
            else:
                shortest_tac_len = min([len(str(proof)) for proof in proofs])
                shortest_proofs = [proof for proof in proofs if len(str(proof)) == shortest_tac_len]
                dist = min([len(proof) for proof in shortest_proofs])
                provability_list.append((file_path, full_name, state_pp, dist))
                next_tac_list.append((file_path, full_name, state_pp, random.choice(shortest_proofs)))
                num_tac_added += 1
        negative_provability_list = random.sample(negative_provability_list, 
                                                  min(num_tac_added, len(negative_provability_list)))
        provability_list.extend(negative_provability_list)
    #
    if verbose: print("Done")
    return next_tac_list, provability_list


def get_all_theorems_processed(folder_paths, verbose=True):
    # Enumerate all .pkl files in the folder
    all_proof_list = list()
    all_provability_list = list()
    for folder_path in folder_paths:
        for filename in os.listdir(folder_path):
            if filename.endswith(".pkl"):
                found = False
                file_path = os.path.join(folder_path, filename)
                next_tac_list, provability_list = get_states_and_shortest_proofs(
                    file_path,
                    verbose=True
                )
                all_proof_list.extend(next_tac_list)
                all_provability_list.extend(provability_list)
    return all_proof_list, all_provability_list

previous_output_paths = [
    "/home/mcwave/code/automath/atp/datasets/provability/rag_20240621",
    #"/home/mcwave/code/automath/atp/datasets/provability/rag_20240622",
    #"/home/mcwave/code/automath/atp/datasets/provability/rag_20240623"
]

# all_proof_list, all_provability_list = get_all_theorems_processed(previous_output_paths)

# fout = open('/home/mcwave/code/automath/atp/datasets/provability/rag_20240621_state_proof.pkl', 'wb')
# pickle.dump(all_proof_list, fout)
# fout.close()

# fout = open('/home/mcwave/code/automath/atp/datasets/provability/rag_20240621_state_provability.pkl', 'wb')
# pickle.dump(all_provability_list, fout)
# fout.close()

Loading /home/mcwave/code/automath/atp/datasets/provability/rag_20240621/Mathlib__Combinatorics__Enumerative__Partition.lean.pkl
Loading theorem Nat.Partition.ofComposition_surj
Loading theorem Nat.Partition.indiscrete_parts
Loading theorem Nat.Partition.partition_one_parts
Loading theorem Nat.Partition.count_ofSums_of_ne_zero
Loading theorem Nat.Partition.count_ofSums_zero
Ran out of input
Summarizing data ...
Done
Loading /home/mcwave/code/automath/atp/datasets/provability/rag_20240621/Mathlib__SetTheory__Game__State.lean.pkl
Loading theorem SetTheory.PGame.turnBound_ne_zero_of_left_move
Loading theorem SetTheory.PGame.turnBound_ne_zero_of_right_move
Ran out of input
Summarizing data ...
Done
Loading /home/mcwave/code/automath/atp/datasets/provability/rag_20240621/Mathlib__GroupTheory__CommutingProbability.lean.pkl
Loading theorem commProb_def
Loading theorem commProb_prod
Loading theorem commProb_pi
Loading theorem commProb_function
Loading theorem commProb_eq_zero_of_infinite
Loadi

Loading theorem Complex.ofNat_log
Loading theorem Complex.log_ofReal_re
Loading theorem Complex.log_ofReal_mul
Loading theorem Complex.log_mul_ofReal
Loading theorem Complex.log_mul_eq_add_log_iff
Loading theorem Complex.log_zero
Loading theorem Complex.log_one
Ran out of input
Summarizing data ...
Done
Loading /home/mcwave/code/automath/atp/datasets/provability/rag_20240621/Mathlib__Topology__Category__Compactum.lean.pkl
Loading theorem Compactum.str_incl
Loading theorem Compactum.str_hom_commute
Loading theorem Compactum.join_distrib
Loading theorem Compactum.isClosed_iff
Loading theorem Compactum.basic_inter
Loading theorem Compactum.cl_cl
Loading theorem Compactum.isClosed_cl
Loading theorem Compactum.str_eq_of_le_nhds
Loading theorem Compactum.le_nhds_of_str_eq
Loading theorem Compactum.lim_eq_str
Loading theorem Compactum.cl_eq_closure
Loading theorem Compactum.continuous_of_hom
Ran out of input
Summarizing data ...
Done
Loading /home/mcwave/code/automath/atp/datasets/provability

Loading theorem Polynomial.splits_of_splits_of_dvd
Loading theorem Polynomial.splits_prod_iff
Loading theorem Polynomial.degree_eq_one_of_irreducible_of_splits
Loading theorem Polynomial.rootOfSplits'_eq_rootOfSplits
Loading theorem Polynomial.degree_eq_card_roots
Loading theorem Polynomial.image_rootSet
Loading theorem Polynomial.adjoin_rootSet_eq_range
Loading theorem Polynomial.eq_prod_roots_of_splits
Loading theorem Polynomial.eq_prod_roots_of_splits_id
Loading theorem Polynomial.eq_prod_roots_of_monic_of_splits_id
Loading theorem Polynomial.eq_X_sub_C_of_splits_of_single_root
Loading theorem Polynomial.mem_lift_of_splits_of_roots_mem_range
Loading theorem Polynomial.splits_of_comp
Loading theorem Polynomial.splits_of_algHom
Loading theorem Polynomial.splits_iff_card_roots
Loading theorem Polynomial.aeval_root_derivative_of_splits
Loading theorem Polynomial.prod_roots_eq_coeff_zero_of_monic_of_split
Loading theorem Polynomial.sum_roots_eq_nextCoeff_of_monic_of_split
Ran out of inpu

Loading theorem Finset.orderIsoOfFin_symm_apply
Loading theorem Finset.orderEmbOfFin_apply
Loading theorem Finset.range_orderEmbOfFin
Loading theorem Finset.orderEmbOfFin_zero
Loading theorem Finset.orderEmbOfFin_last
Loading theorem Finset.orderEmbOfFin_singleton
Loading theorem Finset.orderEmbOfFin_unique
Loading theorem Finset.orderEmbOfFin_eq_orderEmbOfFin_iff
Loading theorem Finset.orderEmbOfCardLe_mem
Ran out of input
Summarizing data ...
Done
Loading /home/mcwave/code/automath/atp/datasets/provability/rag_20240621/Mathlib__NumberTheory__RamificationInertia.lean.pkl
Loading theorem Ideal.ramificationIdx_spec
Loading theorem Ideal.ramificationIdx_lt
Loading theorem Ideal.ramificationIdx_ne_zero
Loading theorem Ideal.le_pow_of_le_ramificationIdx
Loading theorem Ideal.IsDedekindDomain.ramificationIdx_eq_normalizedFactors_count
Loading theorem Ideal.IsDedekindDomain.ramificationIdx_eq_factors_count
Loading theorem Ideal.IsDedekindDomain.ramificationIdx_ne_zero
Loading theorem Ideal.i

Loading theorem mem_rootsOfUnity_prime_pow_mul_iff'
Loading theorem mem_primitiveRoots
Loading theorem primitiveRoots_zero
Loading theorem isPrimitiveRoot_of_mem_primitiveRoots
Loading theorem IsPrimitiveRoot.mk_of_lt
Loading theorem IsPrimitiveRoot.of_subsingleton
Loading theorem IsPrimitiveRoot.isUnit
Loading theorem IsPrimitiveRoot.pow_inj
Loading theorem IsPrimitiveRoot.one
Loading theorem IsPrimitiveRoot.one_right_iff
Loading theorem IsPrimitiveRoot.coe_submonoidClass_iff
Loading theorem IsPrimitiveRoot.coe_units_iff
Loading theorem IsPrimitiveRoot.pow_of_coprime
Loading theorem IsPrimitiveRoot.pow_iff_coprime
Loading theorem IsPrimitiveRoot.orderOf
Loading theorem IsPrimitiveRoot.iff
Loading theorem IsPrimitiveRoot.pow_mul_pow_lcm
Loading theorem IsPrimitiveRoot.pow_of_dvd
Loading theorem IsPrimitiveRoot.mem_rootsOfUnity
Loading theorem IsPrimitiveRoot.pow
Loading theorem IsPrimitiveRoot.injOn_pow
Loading theorem IsPrimitiveRoot.map_iff_of_injective
Ran out of input
Summarizing d

Loading theorem Set.einfsep_insert_le
Loading theorem Set.le_einfsep_pair
Loading theorem Set.einfsep_pair_le_right
Loading theorem Set.einfsep_eq_iInf
Loading theorem Set.einfsep_of_fintype
Loading theorem Set.Finite.einfsep
Loading theorem Set.Finset.coe_einfsep
Loading theorem Set.Nontrivial.einfsep_exists_of_finite
Loading theorem Set.einfsep_pair
Loading theorem Set.einfsep_insert
Loading theorem Set.einfsep_triple
Loading theorem Set.le_einfsep_pi_of_le
Loading theorem Set.subsingleton_of_einfsep_eq_top
Loading theorem Set.Nontrivial.einfsep_ne_top
Loading theorem Set.Nontrivial.einfsep_lt_top
Loading theorem Set.einfsep_pos_of_finite
Loading theorem Set.relatively_discrete_of_finite
Loading theorem Set.infsep_zero
Loading theorem Set.infsep_pos
Loading theorem Set.nontrivial_of_infsep_pos
Loading theorem Set.infsep_pair_le_toReal_inf
Loading theorem Set.infsep_pair_eq_toReal
Loading theorem Set.Nontrivial.le_infsep_iff
Loading theorem Set.Nontrivial.infsep_lt_iff
Loading theorem

Ran out of input
Summarizing data ...
Done
Loading /home/mcwave/code/automath/atp/datasets/provability/rag_20240621/Mathlib__RingTheory__TensorProduct__Basic.lean.pkl
Loading theorem LinearMap.baseChange_tmul
Loading theorem LinearMap.baseChange_eq_ltensor
Loading theorem LinearMap.baseChange_add
Loading theorem LinearMap.baseChange_zero
Loading theorem LinearMap.baseChange_smul
Loading theorem LinearMap.baseChange_id
Loading theorem LinearMap.baseChange_comp
Loading theorem LinearMap.baseChange_one
Loading theorem LinearMap.baseChange_mul
Loading theorem LinearMap.baseChange_sub
Loading theorem LinearMap.baseChange_neg
Loading theorem Algebra.TensorProduct.one_def
Loading theorem Algebra.TensorProduct.natCast_def
Loading theorem Algebra.TensorProduct.natCast_def'
Loading theorem Algebra.TensorProduct.mul_apply
Loading theorem Algebra.TensorProduct.tmul_mul_tmul
Loading theorem Algebra.TensorProduct.one_mul
Loading theorem Algebra.TensorProduct.mul_one
Loading theorem Algebra.TensorPro

Loading theorem Matrix.fromBlocks_neg
Loading theorem Matrix.fromBlocks_zero
Loading theorem Matrix.fromBlocks_add
Loading theorem Matrix.fromBlocks_multiply
Loading theorem Matrix.fromBlocks_mulVec
Loading theorem Matrix.vecMul_fromBlocks
Loading theorem Matrix.toBlock_diagonal_self
Loading theorem Matrix.toBlock_diagonal_disjoint
Loading theorem Matrix.fromBlocks_diagonal
Loading theorem Matrix.toBlocks₁₁_diagonal
Loading theorem Matrix.toBlocks₂₂_diagonal
Loading theorem Matrix.toBlocks₁₂_diagonal
Ran out of input
Summarizing data ...
Done
Loading /home/mcwave/code/automath/atp/datasets/provability/rag_20240621/Mathlib__MeasureTheory__Function__EssSup.lean.pkl
Loading theorem essSup_eq_sInf
Loading theorem essInf_eq_sSup
Loading theorem meas_essSup_lt
Loading theorem meas_lt_essInf
Loading theorem essSup_mono_ae
Loading theorem essSup_le_of_ae_le
Loading theorem le_essInf_of_ae_le
Loading theorem OrderIso.essSup_apply
Loading theorem essSup_mono_measure
Loading theorem essInf_antito

In [1]:
import os
import pickle

fin = open('/home/mcwave/code/automath/atp/datasets/provability/rag_20240621_20240623_state_proof.pkl', 'rb')
all_proof_list = pickle.load(fin)
fin.close()

print(len(all_proof_list))

465717


In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import T5ForConditionalGeneration, T5Tokenizer, T5Config
from datasets import Dataset

MAX_LENGTH = 512  # You can adjust this value as needed

# Prepare the dataset
def prepare_dataset(qa_pairs):
    return Dataset.from_dict({
        "question": [pair[2] for pair in qa_pairs],
        "answer": [pair[3][0] for pair in qa_pairs]
    })

dataset = prepare_dataset(all_proof_list)

# Split the dataset
split_dataset = dataset.train_test_split(test_size=0.05, seed=42)

train_dataset = split_dataset['train']
test_dataset = split_dataset['test']

# Load the model and tokenizer
print("Loading model")
# Load the model and tokenizer
model_name = "google/flan-t5-base"
#tokenizer = T5Tokenizer.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained("kaiyuy/leandojo-lean4-tacgen-byt5-small")

# Tokenize the datasets
def preprocess_function(examples):
    inputs = tokenizer(examples["question"], padding="max_length", truncation=True, max_length=512)
    outputs = tokenizer(examples["answer"], padding="max_length", truncation=True, max_length=512)
    
    inputs["labels"] = outputs["input_ids"]
    return inputs

tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True)
tokenized_test_dataset = test_dataset.map(preprocess_function, batched=True)

tokenized_train_dataset.save_to_disk('/home/mcwave/code/automath/atp/datasets/provability/rag_20240621_20240623_state_proof_train.dataset')
tokenized_test_dataset.save_to_disk('/home/mcwave/code/automath/atp/datasets/provability/rag_20240621_20240623_state_proof_test.dataset')

Loading model


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Map:   0%|          | 0/442431 [00:00<?, ? examples/s]

Map:   0%|          | 0/23286 [00:00<?, ? examples/s]

Saving the dataset (0/7 shards):   0%|          | 0/442431 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/23286 [00:00<?, ? examples/s]

In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import T5ForConditionalGeneration, T5Tokenizer, T5Config
from datasets import Dataset, load_from_disk

MAX_LENGTH = 512  # You can adjust this value as needed

# Load the model and tokenizer
print("Loading model")
# Load the model and tokenizer
model_name = "google/flan-t5-base"
config = T5Config.from_pretrained(model_name, max_length=MAX_LENGTH)
model = T5ForConditionalGeneration.from_pretrained(model_name, config=config)

tokenized_train_dataset = load_from_disk('/home/mcwave/code/automath/atp/datasets/provability/rag_20240621_20240623_state_proof_train.dataset')
tokenized_test_dataset = load_from_disk('/home/mcwave/code/automath/atp/datasets/provability/rag_20240621_20240623_state_proof_test.dataset')

Loading model




In [None]:
from transformers import Trainer, TrainingArguments

# Define training arguments
training_args = TrainingArguments(
    output_dir="./datasets/rag_20240621_20240623_state_proof",
    num_train_epochs=2,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="./logs",
    evaluation_strategy="steps",
    eval_steps=5000,
    save_strategy="steps",
    save_steps=5000,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_test_dataset,
)

# Train the model
trainer.train()



[2024-06-24 10:07:29,634] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/home/mcwave/anaconda3/envs/atp/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status




Step,Training Loss,Validation Loss
5000,0.0293,0.028081
10000,0.0276,0.024412


Non-default generation parameters: {'max_length': 512}
Non-default generation parameters: {'max_length': 512}


In [None]:
# Evaluate the model
eval_results = trainer.evaluate()
print(f"Final evaluation results: {eval_results}")

# Save the model
model.save_pretrained("./flan-t5-qa-model")
tokenizer.save_pretrained("./flan-t5-qa-model")

# Function to generate answer for a new question
def generate_answer(question):
    input_text = f"question: {question}"
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids
    
    outputs = model.generate(input_ids, max_length=512, num_return_sequences=1)
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return answer

# Example usage
new_question = "What is the capital of France?"
generated_answer = generate_answer(new_question)
print(f"Question: {new_question}")
print(f"Generated Answer: {generated_answer}")