In [None]:
import os
import pickle

data_folder = '/home/mcwave/code/automath/atp/datasets/provability/mathlib4_states_w_proof/'
file_names = os.listdir(data_folder)

data = []

count = 0
for file_name in file_names:
    if not file_name.endswith("pkl"):
        continue
    if not 'Algebra' in file_name:
        continue
    count += 1
    if count <= 1:
        continue
    print("Loading", file_name)
    file_path = os.path.join(data_folder, file_name)
    fin = open(file_path, 'rb')
    while True:
        try:
            pair = pickle.load(fin)
            data.append(pair) #(pair[1][0], pair[1][2][0]))
        except:
            break
    break

print(len(data), "examples loaded")

In [2]:
data[1]

(('https://github.com/leanprover-community/mathlib4',
  '27c6744e1c0e25d676be5eb252cd4b6d30c6acc7',
  'Mathlib/Algebra/Ring/Subring/Pointwise.lean',
  'Subring.pointwise_smul_def'),
 ('case h\nM : Type u_1\nR : Type u_2\ninst✝² : Monoid M\ninst✝¹ : Ring R\ninst✝ : MulSemiringAction M R\na : M\nS : Subring R\nnvar0 : R\n⊢ nvar0 ∈ SMul.smul a S ↔ nvar0 ∈ map (MulSemiringAction.toRingHom M R a) S',
  2,
  ['simp_rw [mem_map]', 'exact']))

In [3]:
from utils.lean_math_utils import *
from utils.lean_theorem_utils import *

def count_lines(string):
    # Split the string into lines
    lines = string.splitlines()
    # Count the number of lines
    return len(lines)

def extract_first_case(state_pp):
    state_pp = state_pp.strip()
    if not state_pp.startswith('case'):
        return state_pp
    lines = state_pp.split('\n')
    first_case = []
    for line in lines[1:]:
        if line.strip().startswith('case'):
            break
        if line.strip() != '':
            first_case.append(line)
    return '\n'.join(first_case)


# Params:
#   hyp: tuple(name, type)
#   tactics: list(tactic)
def is_hypothesis_useful(hyp, tactics):
    for tactic in tactics:
        tokens = tokenize_lean_tactic(tactic)
        if hyp[0] in tokens:
            idx = tokens.index(hyp[0])
            if idx > 0:
                if hyp[0].startswith('h'):
                    return True
                if tokens[idx - 1] == 'exact':
                    return True
                if tokens[idx - 1] == 'at':
                    return True
                if tokens[idx - 1] == '[':
                    return True
                if idx < len(tokens) - 1 and tokens[idx + 1] == ']':
                    return True
                for operator in TargetNode.operators:
                    if operator in hyp[1]:
                        return True
    return False

def create_hypothesis_predict_data(raw_state_pp, tactics, theorem_name):
    is_case = raw_state_pp.strip().startswith('case')
    state_pp = extract_first_case(raw_state_pp)
    if is_case and count_lines(state_pp) < count_lines(raw_state_pp) - 2:
        tactics = tactics[0:1]
    #
    premise = Premise()
    premise.theorem_name = theorem_name
    premise.parse_state(state_pp)
    #
    useful_hypotheses, useless_hypotheses = [], OrderedDict()
    for hyp in premise.hypotheses.items():
        useful = is_hypothesis_useful(hyp, tactics)
        if useful:
            #print("YES:", hyp)
            useful_hypotheses.append(hyp)
        else:
            #print("NO :", hyp)
            useless_hypotheses[hyp[0]] = hyp[1]
    premise.hypotheses = useless_hypotheses
    return premise, useful_hypotheses

idx = 120

state_pp = data[idx][1][0]
tactics = data[idx][1][2]
theorem_name = data[idx][0][3]

print("STATE_PP:\n" + state_pp)

print("TACTICS:\n" + "\n".join(tactics))

premise, useful_hypotheses = create_hypothesis_predict_data(state_pp, tactics, theorem_name)

print("STATE_PP:\n" + premise.to_theorem_code())
print("\nHYPOTHESES:\n", useful_hypotheses)


STATE_PP:
case h
M : Type u_1
R : Type u_2
inst✝² : Monoid M
inst✝¹ : Ring R
inst✝ : MulSemiringAction M R
a : M
S : Subring R
nvar0 : R
⊢ nvar0 ∈ map (MulSemiringAction.toRingHom M R a) S ↔ nvar0 ∈ a • S
TACTICS:
symm
simp [eq_comm (a := a)]
cases S
rw [smul_neg]
rw [← eq_f₀']
STATE_PP:
theorem Subring.pointwise_smul_def (M: Type u_1) (R: Type u_2) (inst✝²: Monoid M) (inst✝¹: Ring R) (inst✝: MulSemiringAction M R) (a: M) (S: Subring R) (nvar0: R) : nvar0 ∈ map (MulSemiringAction.toRingHom M R a) S ↔ nvar0 ∈ a • S :=

HYPOTHESES:
 []


In [None]:
import os
import pickle

data_folder = '/home/mcwave/code/automath/atp/datasets/provability/mathlib4_states_w_proof/'
file_names = os.listdir(data_folder)

fout = open('/home/mcwave/code/axiomatization/datasets/mathlib4_all_states_w_proof_hyp_pred.pkl', 'wb')

count = 0
for file_name in file_names:
    if not file_name.endswith("pkl"):
        continue
    count += 1
    data = []
    print("Loading", file_name)
    file_path = os.path.join(data_folder, file_name)
    fin = open(file_path, 'rb')
    while True:
        try:
            pair = pickle.load(fin)
            data.append(pair) #(pair[1][0], pair[1][2][0]))
        except:
            break
    #
    fin.close()
    print(len(data), "examples loaded")
    hyp_data = []
    for pair in data:
        state_pp = pair[1][0]
        tactics = pair[1][2]
        if tactics is None or len(tactics) == 0:
            continue
        full_path = pair[0][2]
        theorem_name = pair[0][3]
        premise, useful_hypotheses = create_hypothesis_predict_data(state_pp, tactics, theorem_name)
        if len(useful_hypotheses) == 0:
            continue
        premise.full_path = full_path
        hyp_data.append((premise, useful_hypotheses))
        pickle.dump((premise, useful_hypotheses), fout)
    #
    fout.flush()
    print(len(hyp_data), "hypotheses data found")

fout.close()

Loading Mathlib__CategoryTheory__Sites__SheafOfTypes.lean.pkl
3083 examples loaded
995 hypotheses data found
Loading Mathlib__Data__Fin__Basic.lean.pkl
368494 examples loaded
58243 hypotheses data found
Loading Mathlib__CategoryTheory__Adjunction__Limits.lean.pkl
44 examples loaded
0 hypotheses data found
Loading Mathlib__Algebra__Algebra__NonUnitalHom.lean.pkl
260187 examples loaded
42581 hypotheses data found
Loading Mathlib__Topology__ContinuousOn.lean.pkl
144094 examples loaded
25240 hypotheses data found
Loading Mathlib__Topology__DiscreteSubset.lean.pkl
1444 examples loaded
169 hypotheses data found
Loading .lake__packages__aesop__Aesop__BuiltinRules.lean.pkl
17291 examples loaded
7694 hypotheses data found
Loading Mathlib__Data__Finsupp__Lex.lean.pkl
10298 examples loaded
545 hypotheses data found
Loading Mathlib__Algebra__Algebra__Subalgebra__Pointwise.lean.pkl
17029 examples loaded
1398 hypotheses data found
Loading Mathlib__Algebra__Homology__ShortComplex__Basic.lean.pkl
3736

4724 hypotheses data found
Loading Mathlib__CategoryTheory__Preadditive__InjectiveResolution.lean.pkl
9965 examples loaded
0 hypotheses data found
Loading Mathlib__Algebra__Order__Hom__Basic.lean.pkl
6183 examples loaded
455 hypotheses data found
Loading Mathlib__Topology__Connected__LocallyConnected.lean.pkl
136 examples loaded
19 hypotheses data found
Loading Mathlib__GroupTheory__Frattini.lean.pkl
58 examples loaded
25 hypotheses data found
Loading Mathlib__Data__NNRat__Lemmas.lean.pkl
2402 examples loaded
50 hypotheses data found
Loading Mathlib__Tactic__PushNeg.lean.pkl
100567 examples loaded
4940 hypotheses data found
Loading Mathlib__Analysis__Calculus__ContDiff__FiniteDimension.lean.pkl
60 examples loaded
15 hypotheses data found
Loading Mathlib__Algebra__Homology__HomologicalComplexBiprod.lean.pkl
76 examples loaded
0 hypotheses data found
Loading Mathlib__Algebra__Lie__NonUnitalNonAssocAlgebra.lean.pkl
172 examples loaded
34 hypotheses data found
Loading Mathlib__Topology__Al

33175 examples loaded
9850 hypotheses data found
Loading Mathlib__Algebra__CharZero__Defs.lean.pkl
23893 examples loaded
655 hypotheses data found
Loading Mathlib__LinearAlgebra__Charpoly__Basic.lean.pkl
19265 examples loaded
1673 hypotheses data found
Loading Mathlib__Data__Nat__GCD__BigOperators.lean.pkl
1092 examples loaded
106 hypotheses data found
Loading Mathlib__Algebra__Order__UpperLower.lean.pkl
44335 examples loaded
2258 hypotheses data found
Loading Mathlib__RingTheory__Adjoin__PowerBasis.lean.pkl
58 examples loaded
26 hypotheses data found
Loading Mathlib__Analysis__Convex__StrictConvexSpace.lean.pkl
264 examples loaded
125 hypotheses data found
Loading Mathlib__Topology__Order__Lattice.lean.pkl
1210 examples loaded
463 hypotheses data found
Loading Mathlib__Data__Rat__Floor.lean.pkl
1918 examples loaded
162 hypotheses data found
Loading Mathlib__CategoryTheory__Preadditive__Basic.lean.pkl
50077 examples loaded
22455 hypotheses data found
Loading Mathlib__GroupTheory__Order

132 hypotheses data found
Loading Mathlib__RingTheory__Derivation__Lie.lean.pkl
8390 examples loaded
296 hypotheses data found
Loading Mathlib__Algebra__Homology__ShortComplex__ModuleCat.lean.pkl
5982 examples loaded
110 hypotheses data found
Loading Mathlib__Algebra__Polynomial__Derivative.lean.pkl
220107 examples loaded
39460 hypotheses data found
Loading Mathlib__Data__Nat__Cast__Order.lean.pkl
52266 examples loaded
3107 hypotheses data found
Loading Mathlib__CategoryTheory__Sites__Pretopology.lean.pkl
1810 examples loaded
232 hypotheses data found
Loading Mathlib__Data__Finsupp__Indicator.lean.pkl
7914 examples loaded
1900 hypotheses data found
Loading Mathlib__CategoryTheory__Bicategory__LocallyDiscrete.lean.pkl
20203 examples loaded
8548 hypotheses data found
Loading Mathlib__CategoryTheory__GradedObject__Unitor.lean.pkl
10163 examples loaded
8076 hypotheses data found
Loading Mathlib__Order__Max.lean.pkl
48846 examples loaded
8502 hypotheses data found
Loading Mathlib__MeasureTh

135541 examples loaded
23294 hypotheses data found
Loading Mathlib__Algebra__Order__Positive__Ring.lean.pkl
17532 examples loaded
7006 hypotheses data found
Loading Mathlib__LinearAlgebra__PerfectPairing.lean.pkl
53215 examples loaded
12383 hypotheses data found
Loading Mathlib__Data__PFunctor__Multivariate__W.lean.pkl
13251 examples loaded
1095 hypotheses data found
Loading Mathlib__GroupTheory__PushoutI.lean.pkl
26924 examples loaded
11279 hypotheses data found
Loading Mathlib__Algebra__Field__Opposite.lean.pkl
30257 examples loaded
15308 hypotheses data found
Loading Mathlib__Analysis__SpecialFunctions__Complex__Analytic.lean.pkl
18 examples loaded
7 hypotheses data found
Loading Mathlib__Data__Sum__Order.lean.pkl
270797 examples loaded
115274 hypotheses data found
Loading Mathlib__Analysis__Normed__Group__Tannery.lean.pkl
16 examples loaded
8 hypotheses data found
Loading Mathlib__Data__Nat__Factorization__Root.lean.pkl
15418 examples loaded
350 hypotheses data found
Loading Mathli

427890 examples loaded
68296 hypotheses data found
Loading Mathlib__Topology__FiberBundle__Trivialization.lean.pkl
66964 examples loaded
17097 hypotheses data found
Loading Mathlib__LinearAlgebra__LinearIndependent.lean.pkl
51411 examples loaded
13540 hypotheses data found
Loading Mathlib__MeasureTheory__Covering__LiminfLimsup.lean.pkl
98 examples loaded
47 hypotheses data found
Loading Mathlib__Combinatorics__SimpleGraph__Matching.lean.pkl
70 examples loaded
25 hypotheses data found
Loading Mathlib__Combinatorics__SimpleGraph__Regularity__Energy.lean.pkl
3674 examples loaded
374 hypotheses data found
Loading Mathlib__Topology__Algebra__MulAction.lean.pkl
24 examples loaded
12 hypotheses data found
Loading .lake__packages__batteries__Batteries__Data__List__Pairwise.lean.pkl
15481 examples loaded
893 hypotheses data found
Loading Mathlib__Topology__MetricSpace__Algebra.lean.pkl
6 examples loaded
0 hypotheses data found
Loading Mathlib__Analysis__Convolution.lean.pkl
23007 examples loade

753 hypotheses data found
Loading Mathlib__CategoryTheory__Extensive.lean.pkl
116 examples loaded
48 hypotheses data found
Loading Mathlib__Analysis__NormedSpace__HomeomorphBall.lean.pkl
2092 examples loaded
627 hypotheses data found
Loading Mathlib__CategoryTheory__Groupoid__Basic.lean.pkl
8 examples loaded
1 hypotheses data found
Loading Mathlib__Algebra__BigOperators__Intervals.lean.pkl
15074 examples loaded
3388 hypotheses data found
Loading Mathlib__Algebra__CharP__Algebra.lean.pkl
1790 examples loaded
544 hypotheses data found
Loading Mathlib__Order__BoundedOrder.lean.pkl
332806 examples loaded
78333 hypotheses data found
Loading Mathlib__Probability__Martingale__Convergence.lean.pkl
114 examples loaded
55 hypotheses data found
Loading Mathlib__Order__WithBot.lean.pkl
868691 examples loaded
101693 hypotheses data found
Loading Mathlib__CategoryTheory__Filtered__Basic.lean.pkl
54 examples loaded
15 hypotheses data found
Loading Mathlib__Analysis__Calculus__LineDeriv__Basic.lean.pk

713 hypotheses data found
Loading Mathlib__Algebra__Tropical__BigOperators.lean.pkl
10687 examples loaded
1252 hypotheses data found
Loading Mathlib__Algebra__Category__GroupCat__Limits.lean.pkl
3453 examples loaded
1659 hypotheses data found
Loading Mathlib__Probability__Process__Adapted.lean.pkl
970 examples loaded
20 hypotheses data found
Loading Mathlib__LinearAlgebra__QuadraticForm__QuadraticModuleCat.lean.pkl
51399 examples loaded
8128 hypotheses data found
Loading Mathlib__Data__Nat__Nth.lean.pkl
578 examples loaded
193 hypotheses data found
Loading Mathlib__MeasureTheory__Function__AEEqFun.lean.pkl
70658 examples loaded
27834 hypotheses data found
Loading Mathlib__Algebra__Homology__HomologicalComplex.lean.pkl
154296 examples loaded
52281 hypotheses data found
Loading Mathlib__Data__Nat__Cast__Field.lean.pkl
11715 examples loaded
4723 hypotheses data found
Loading Mathlib__Logic__Function__Conjugate.lean.pkl
70448 examples loaded
27904 hypotheses data found
Loading Mathlib__Alg

485 hypotheses data found
Loading Mathlib__MeasureTheory__Integral__MeanInequalities.lean.pkl
864 examples loaded
367 hypotheses data found
Loading Mathlib__Data__Real__Hyperreal.lean.pkl
200989 examples loaded
20884 hypotheses data found
Loading Mathlib__SetTheory__Game__PGame.lean.pkl
365313 examples loaded
82927 hypotheses data found
Loading Mathlib__Analysis__Convex__Exposed.lean.pkl
10427 examples loaded
1891 hypotheses data found
Loading Mathlib__Algebra__Module__LinearMap__End.lean.pkl
90860 examples loaded
7138 hypotheses data found
Loading Mathlib__Data__Sign.lean.pkl
168927 examples loaded
32837 hypotheses data found
Loading Mathlib__RingTheory__PowerSeries__Derivative.lean.pkl
5390 examples loaded
25 hypotheses data found
Loading Mathlib__NumberTheory__Padics__RingHoms.lean.pkl
610 examples loaded
191 hypotheses data found
Loading Mathlib__Probability__ProbabilityMassFunction__Constructions.lean.pkl
131810 examples loaded
39645 hypotheses data found
Loading Mathlib__Analysis

577124 examples loaded
33206 hypotheses data found
Loading Mathlib__RingTheory__Nakayama.lean.pkl
86 examples loaded
43 hypotheses data found
Loading Mathlib__Analysis__BoxIntegral__Box__SubboxInduction.lean.pkl
5048 examples loaded
35 hypotheses data found
Loading Mathlib__Data__List__Sections.lean.pkl
12 examples loaded
2 hypotheses data found
Loading Mathlib__Order__Disjoint.lean.pkl
197090 examples loaded
37925 hypotheses data found
Loading Mathlib__CategoryTheory__ChosenFiniteProducts.lean.pkl
3972 examples loaded
2016 hypotheses data found
Loading Mathlib__NumberTheory__Cyclotomic__PID.lean.pkl
24 examples loaded
0 hypotheses data found
Loading Mathlib__Algebra__Group__Opposite.lean.pkl
109132 examples loaded
30462 hypotheses data found
Loading Mathlib__Topology__Algebra__Equicontinuity.lean.pkl
18 examples loaded
9 hypotheses data found
Loading Mathlib__Algebra__MvPolynomial__Expand.lean.pkl
22312 examples loaded
2021 hypotheses data found
Loading Mathlib__Algebra__Module__Submo

788 hypotheses data found
Loading Mathlib__Order__Filter__Curry.lean.pkl
11854 examples loaded
2508 hypotheses data found
Loading Mathlib__CategoryTheory__Limits__Shapes__Multiequalizer.lean.pkl
219098 examples loaded
103029 hypotheses data found
Loading Mathlib__Data__NNRat__BigOperators.lean.pkl
15011 examples loaded
1591 hypotheses data found
Loading Mathlib__Topology__UniformSpace__AbstractCompletion.lean.pkl
9246 examples loaded
1776 hypotheses data found
Loading Mathlib__Order__SupClosed.lean.pkl
179059 examples loaded
44235 hypotheses data found
Loading Mathlib__CategoryTheory__Sites__Types.lean.pkl
17948 examples loaded
14297 hypotheses data found
Loading Mathlib__Analysis__InnerProductSpace__Orientation.lean.pkl
2698 examples loaded
838 hypotheses data found
Loading Mathlib__Algebra__Order__Interval__Set__Instances.lean.pkl
171968 examples loaded
21629 hypotheses data found
Loading Mathlib__Geometry__Manifold__LocalInvariantProperties.lean.pkl
4784 examples loaded
494 hypothes

6694 examples loaded
35 hypotheses data found
Loading Mathlib__RingTheory__Ideal__IsPrincipal.lean.pkl
17037 examples loaded
2323 hypotheses data found
Loading Mathlib__FieldTheory__Finiteness.lean.pkl
2994 examples loaded
30 hypotheses data found
Loading Mathlib__NumberTheory__NumberField__CanonicalEmbedding__ConvexBody.lean.pkl
11539 examples loaded
1884 hypotheses data found
Loading Mathlib__CategoryTheory__Subobject__FactorThru.lean.pkl
12978 examples loaded
2544 hypotheses data found
Loading Mathlib__NumberTheory__ArithmeticFunction.lean.pkl
258822 examples loaded
32732 hypotheses data found
Loading Mathlib__NumberTheory__Cyclotomic__Embeddings.lean.pkl
78 examples loaded
37 hypotheses data found
Loading Mathlib__Algebra__Polynomial__Monomial.lean.pkl
76 examples loaded
14 hypotheses data found
Loading Mathlib__Analysis__Convex__GaugeRescale.lean.pkl
11446 examples loaded
2404 hypotheses data found
Loading Mathlib__Algebra__Homology__ConcreteCategory.lean.pkl
18 examples loaded
7 

12351 hypotheses data found
Loading Mathlib__Algebra__Homology__ExactSequence.lean.pkl
2098 examples loaded
670 hypotheses data found
Loading Mathlib__Data__Matrix__Kronecker.lean.pkl
155374 examples loaded
40335 hypotheses data found
Loading Mathlib__LinearAlgebra__CliffordAlgebra__Prod.lean.pkl
6608 examples loaded
803 hypotheses data found
Loading Mathlib__SetTheory__Cardinal__Cofinality.lean.pkl
31554 examples loaded
4370 hypotheses data found
Loading Mathlib__Analysis__Calculus__Deriv__Basic.lean.pkl
137210 examples loaded
30326 hypotheses data found
Loading Mathlib__Data__Fintype__BigOperators.lean.pkl
30841 examples loaded
7610 hypotheses data found
Loading Mathlib__Probability__Process__HittingTime.lean.pkl
1308 examples loaded
104 hypotheses data found
Loading Mathlib__RingTheory__AdicCompletion__Functoriality.lean.pkl
56245 examples loaded
9593 hypotheses data found
Loading Mathlib__LinearAlgebra__Dimension__RankNullity.lean.pkl
814 examples loaded
250 hypotheses data found
L

334 hypotheses data found
Loading Mathlib__Analysis__NormedSpace__Star__Multiplier.lean.pkl
387991 examples loaded
78284 hypotheses data found
Loading Mathlib__Algebra__Ring__AddAut.lean.pkl
8866 examples loaded
1062 hypotheses data found
Loading Mathlib__Algebra__Polynomial__Splits.lean.pkl
9406 examples loaded
2159 hypotheses data found
Loading Mathlib__Order__Monotone__Union.lean.pkl
24 examples loaded
12 hypotheses data found
Loading Mathlib__CategoryTheory__CatCommSq.lean.pkl
30 examples loaded
3 hypotheses data found
Loading Mathlib__Analysis__NormedSpace__Real.lean.pkl
15033 examples loaded
2717 hypotheses data found
Loading Mathlib__Data__PFunctor__Univariate__M.lean.pkl
32150 examples loaded
7296 hypotheses data found
Loading Mathlib__Order__GaloisConnection.lean.pkl
28410 examples loaded
5429 hypotheses data found
Loading Mathlib__Order__SuccPred__CompleteLinearOrder.lean.pkl
238 examples loaded
117 hypotheses data found
Loading Mathlib__Algebra__Order__Group__Basic.lean.pkl


1149 hypotheses data found
Loading Mathlib__CategoryTheory__FullSubcategory.lean.pkl
41627 examples loaded
18228 hypotheses data found
Loading Mathlib__Topology__Homeomorph.lean.pkl
179731 examples loaded
61006 hypotheses data found
Loading Mathlib__Data__Finsupp__Defs.lean.pkl
531382 examples loaded
123635 hypotheses data found
Loading Mathlib__Order__Ideal.lean.pkl
86326 examples loaded
13599 hypotheses data found
Loading Mathlib__CategoryTheory__ConcreteCategory__Basic.lean.pkl
23844 examples loaded
6582 hypotheses data found
Loading Mathlib__Order__Partition__Finpartition.lean.pkl
36237 examples loaded
4069 hypotheses data found
Loading Mathlib__Topology__Category__Profinite__Nobeling.lean.pkl
50224 examples loaded
13793 hypotheses data found
Loading Mathlib__Data__List__FinRange.lean.pkl
2018 examples loaded
296 hypotheses data found
Loading Mathlib__Analysis__SpecialFunctions__Complex__LogBounds.lean.pkl
17355 examples loaded
511 hypotheses data found
Loading Mathlib__Analysis__S

858 hypotheses data found
Loading Mathlib__Data__Rat__Defs.lean.pkl
187779 examples loaded
53555 hypotheses data found
Loading Mathlib__FieldTheory__Normal.lean.pkl
5291 examples loaded
822 hypotheses data found
Loading Mathlib__Order__Category__PartOrd.lean.pkl
17355 examples loaded
2153 hypotheses data found
Loading Mathlib__Algebra__Order__Ring__WithTop.lean.pkl
11241 examples loaded
4509 hypotheses data found
Loading Mathlib__Data__Set__Pairwise__Lattice.lean.pkl
144 examples loaded
61 hypotheses data found
Loading Mathlib__RingTheory__Filtration.lean.pkl
23365 examples loaded
3905 hypotheses data found
Loading Mathlib__Algebra__Lie__UniversalEnveloping.lean.pkl
8794 examples loaded
3434 hypotheses data found
Loading Mathlib__MeasureTheory__Integral__DominatedConvergence.lean.pkl
133 examples loaded
65 hypotheses data found
Loading Mathlib__MeasureTheory__Group__FundamentalDomain.lean.pkl
17711 examples loaded
4191 hypotheses data found
Loading Mathlib__RingTheory__Polynomial__Cycl

In [4]:
# from datasets import Dataset

# state_pps = []
# target_hyps = []
# fin = open('/home/mcwave/code/axiomatization/datasets/mathlib4_states_w_proof_hyp_pred.pkl', 'rb')

# while True:
#     try:
#         premise, hypotheses = pickle.load(fin)
#         state_pp = premise.to_theorem_code()
#         target_hyp = str([x[1] for x in hypotheses])
#         #data.append((state_pp, target_hyp))
#         state_pps.append(state_pp)
#         target_hyps.append(target_hyp)
#     except:
#         break
    
# fin.close()

# ds = Dataset.from_dict({'state_pp': state_pps, 'target_hyp':target_hyps})
# tmp = ds.train_test_split(test_size=0.01)
# raw_train_dataset = tmp['train']
# raw_test_dataset = tmp['test']

In [5]:
raw_test_dataset

Dataset({
    features: ['state_pp', 'target_hyp'],
    num_rows: 30438
})

In [9]:
import torch
from torch.utils.data import Dataset as TorchDataset, DataLoader, random_split
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AdamW
from tqdm import tqdm
import os

tokenizer = AutoTokenizer.from_pretrained("kaiyuy/leandojo-lean4-tacgen-byt5-small")

def tokenize_function(example):
    instruction = example['state_pp']
    response = example['target_hyp']

    # Tokenize the instruction and response separately
    instruction_tokens = tokenizer(instruction, padding="max_length", truncation=True, max_length=512)
    response_tokens = tokenizer(response, padding="max_length", truncation=True, max_length=512)

    # Combine the instruction and response into a single input
    input_ids = instruction_tokens['input_ids']

    # The labels are the same as the input_ids for causal language modeling
    labels = response_tokens['input_ids']

    return {'input_ids': input_ids, 'labels': labels}

print("Mapping test dataset ...")
tokenized_test = raw_test_dataset.map(tokenize_function, batched=True, num_proc=1, desc="Tokenizing dataset")
print("Mapping train dataset ...")
tokenized_train = raw_train_dataset.map(tokenize_function, batched=True, num_proc=6, desc="Tokenizing dataset")
print("Done")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Mapping test dataset ...


Tokenizing dataset:   0%|          | 0/30438 [00:00<?, ? examples/s]

Mapping train dataset ...


Tokenizing dataset (num_proc=6):   0%|          | 0/3013264 [00:00<?, ? examples/s]

Done


In [22]:
# def evaluate(model, dataloader, device):
#     model.eval()
#     total_loss = 0
#     with torch.no_grad():
#         for batch in dataloader:
#             input_ids = batch["input_ids"].to(device)
#             attention_mask = batch["attention_mask"].to(device)
#             labels = batch["labels"].to(device)

#             outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
#             loss = outputs.loss
#             total_loss += loss.item()

#     return total_loss / len(dataloader)

# Initialize tokenizer and model
model_name = "google/byt5-small"  # You can change this to any other Seq2Seq model
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Set up optimizer
#optimizer = AdamW(model.parameters(), lr=5e-5)

# Training loop
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model.to(device)

In [None]:
from transformers import Trainer, TrainingArguments
from datasets import load_dataset,load_metric
from transformers import AutoTokenizer, DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

training_args = TrainingArguments(
    output_dir="datasets/byt5-small",
    evaluation_strategy="steps", #"epochs"
    learning_rate=2e-5,  # PAY ATTENTION TO LEARNING RATE!
    weight_decay=0.01,
    per_device_train_batch_size=12,
    per_device_eval_batch_size=12,
    num_train_epochs=4,
    fp16=False,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=1000,
    save_total_limit=2,
    load_best_model_at_end=True,
    push_to_hub=False
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    data_collator=data_collator,
)

#cp_path = 'gpt-neo-350m-202310/checkpoint-525000'

trainer.train()



Step,Training Loss,Validation Loss
1000,0.0,
2000,0.0,


In [32]:
# num_epochs = 4
# for epoch in range(num_epochs):
#     model.train()
#     total_loss = 0

#     for batch in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
#         input_ids = batch["input_ids"].to(device)
#         attention_mask = batch["attention_mask"].to(device)
#         labels = batch["labels"].to(device)

#         outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
#         loss = outputs.loss
#         total_loss += loss.item()

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
    
#     # Save the model after each epoch
#     save_path = f"/home/mcwave/code/axiomatization/datasets/hyp_pred_byt5_small/trained_model_epoch_{epoch + 1}"
#     os.makedirs(save_path, exist_ok=True)
#     model.save_pretrained(save_path)
    
#     avg_train_loss = total_loss / len(train_dataloader)
#     avg_test_loss = evaluate(model, test_dataloader, device)

#     print(f"Epoch {epoch + 1}/{num_epochs}")
#     print(f"Average Train Loss: {avg_train_loss:.4f}")
#     print(f"Average Test Loss: {avg_test_loss:.4f}")

Epoch 1/4:  36%|███████████████████████████▌                                                 | 106651/298283 [21:38:44<38:53:35,  1.37it/s]


KeyboardInterrupt: 

In [36]:
avg_test_loss = evaluate(model, test_dataloader, device)

In [35]:
# save_path = f"/home/mcwave/code/axiomatization/datasets/hyp_pred_byt5_small/trained_model_epoch_{epoch + 1}"
# print(save_path)
# os.makedirs(save_path, exist_ok=True)
# model.save_pretrained(save_path)

/home/mcwave/code/axiomatization/datasets/hyp_pred_byt5_small/trained_model_epoch_1


In [6]:
save_path = f"/home/mcwave/code/axiomatization/datasets/hyp_pred_byt5_small/trained_model_epoch_1"
model = AutoModelForSeq2SeqLM.from_pretrained(save_path)

In [21]:
test_case = tokenized_test[7]
input_ids = torch.tensor(test_case['input_ids']).unsqueeze(0).to('cuda')  # Add batch dimension
print("inputs:", tokenizer.decode(test_case['input_ids'], skip_special_tokens=True))
labels = torch.tensor(test_case['labels']).to('cuda')
print("labels:", tokenizer.decode(labels, skip_special_tokens=True))

# Generate output
with torch.no_grad():
    outputs = model.generate(input_ids)

# Decode the generated output and the true labels
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
true_text = tokenizer.decode(labels, skip_special_tokens=True)

# Compare the results
print("Generated:", generated_text)
print("True:", true_text)

inputs: theorem Function.Exact.apply_apply_eq_zero (R: Type u_1) (M: Type u_2) (M': Type u_3) (N: Type u_4) (N': Type u_5) (P: Type u_6) (P': Type u_7) (f: M → N) (g: N → P) (inst✝: Zero P) (x: M) : ¬¬¬¬¬Exact f g :=
labels: ['¬¬¬¬¬g (f x) = 0']




Generated: _ (S__u_7)) (G) (R
True: ['¬¬¬¬¬g (f x) = 0']


In [10]:
labels

tensor([94, 42, 74, 42, 96,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0, 