In [1]:
from typing import Any
import os
from pathlib import Path
from tactic_gen.tactic_data import TEST_LM_EXAMPLE, example_collator_conf_from_yaml, example_collator_from_conf, ExampleCollator, get_tokenizer
from tactic_gen.train_decoder import get_model
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer, BitsAndBytesConfig
import torch
from util.constants import TRAINING_CONF_NAME
import yaml



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
if Path(os.curdir).resolve().name == "tactic_gen":
    os.chdir("../..")
elif Path(os.curdir).resolve().name == "coq-modeling": 
    pass
else:
    raise ValueError(f"In an unexpected directory: {os.curdir}")

In [3]:
def get_training_conf(checkpoint_loc: Path) -> Any:
    training_conf_loc = checkpoint_loc.parent / TRAINING_CONF_NAME
    with training_conf_loc.open('r') as f:
        training_conf = yaml.safe_load(f)
    return training_conf

In [4]:
def get_example_collator(checkpoint_loc: Path) -> ExampleCollator:
    training_conf = get_training_conf(checkpoint_loc)
    example_collator_conf = example_collator_conf_from_yaml(training_conf['example_collator'])
    example_collator = example_collator_from_conf(example_collator_conf) 
    return example_collator

In [5]:
CHECKPOINT_LOC = Path("models/deepseek-bm25-proof-tfidf-proj-thm-prem-final/checkpoint-54500")
training_conf = get_training_conf(CHECKPOINT_LOC)
example_collator = get_example_collator(CHECKPOINT_LOC)
tokenizer = get_tokenizer(training_conf["model_name"], add_eos=False) 

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
model = get_model(str(CHECKPOINT_LOC))
None

`low_cpu_mem_usage` was None, now set to True since model is quantized.


In [7]:
N = 4
BEAM = False 

In [8]:
from enum import Enum
from typing import Optional
from tactic_gen.tactic_data import ProofPremiseCollator, NEWLINE_RESPONSE_TEMPLATE 
class TokenMask(Enum):
    STATE = 0
    SCRIPT = 1
    PROOF = 2
    PREMISE = 3


def find_id_start_idx(t: torch.Tensor, s: torch.Tensor) -> Optional[int]:
    for i in range(t.shape[0] - s.shape[0] + 1):
        if torch.all(t[i : i + s.shape[0]] == s):
            return i
    return None


def transform_attention_mask(
    collator: ExampleCollator,
    tokenizer: PreTrainedTokenizer,
    token_mask: Optional[TokenMask],
    input_ids: torch.Tensor,
    attn_mask: torch.Tensor,
) -> torch.Tensor:
    assert isinstance(collator, ProofPremiseCollator)
    match token_mask:
        case None:
            return attn_mask
        case TokenMask.STATE:
            start_ids = tokenizer.encode(collator.STATE_SEP, add_special_tokens=False)
            end_ids = tokenizer.encode(collator.SCRIPT_SEP, add_special_tokens=False)
        case TokenMask.SCRIPT:
            start_ids = tokenizer.encode(collator.SCRIPT_SEP, add_special_tokens=False)
            end_ids = tokenizer.encode(
                NEWLINE_RESPONSE_TEMPLATE, add_special_tokens=False
            )
        case TokenMask.PROOF:
            start_ids = tokenizer.encode(collator.PROOF_SEP, add_special_tokens=False)
            end_ids = tokenizer.encode(collator.STATE_SEP, add_special_tokens=False)
        case TokenMask.PREMISE:
            start_ids = tokenizer.encode(collator.PREMISE_SEP, add_special_tokens=False)
            end_ids = tokenizer.encode(collator.PROOF_SEP, add_special_tokens=False)

    changed_mask = attn_mask.clone()
    for i, id_row in enumerate(input_ids):
        start_idx = find_id_start_idx(id_row, torch.tensor(start_ids))
        end_idx = find_id_start_idx(id_row, torch.tensor(end_ids))
        changed_mask[i, start_idx:end_idx] = 0
    return changed_mask

In [10]:
premise_seq = [tokenizer.encode("\n[PREMISES]\n", add_special_tokens=False)]
premise_seq

[[185, 58, 11787, 9572, 1871, 50, 60, 185]]

In [25]:
collated_input = example_collator.collate_input(tokenizer, TEST_LM_EXAMPLE)
inputs = tokenizer(collated_input, return_tensors='pt')
attention_mask = transform_attention_mask(example_collator, tokenizer, TokenMask.PREMISE, inputs["input_ids"], inputs["attention_mask"])
with torch.no_grad():
    out = model.generate(
        inputs["input_ids"], 
        max_new_tokens=64, 
        do_sample=not BEAM, 
        temperature=None if BEAM else 1.0,
        num_beams=N if BEAM else None,
        return_dict_in_generate=True,
        output_scores=True,
        num_return_sequences=2, 
        attention_mask=attention_mask,
    )


Setting `pad_token_id` to `eos_token_id`:32021 for open-end generation.




In [23]:
(attention_mask.T @ attention_mask).shape

torch.Size([504, 504])

In [20]:
gen_out = out.sequences[:, inputs["input_ids"].shape[1]:]

In [21]:
input_num_tokens = inputs["input_ids"].shape[1]
tokenizer.batch_decode(out.sequences[:, input_num_tokens:], skip_special_tokens=True)


['\n  simpl.', '\n  induction l.']

In [15]:
attention_mask

tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1

In [None]:
from model_deployment.model_wrapper import DecoderLocalWrapper

In [None]:
CHECKPOINT_LOC = Path("models/deepseek-1.3b-basic/checkpoint-48000")
model_wrapper = DecoderLocalWrapper.from_checkpoint(CHECKPOINT_LOC)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
`low_cpu_mem_usage` was None, now set to True since model is quantized.


In [None]:
result = model_wrapper.get_recs(TEST_LM_EXAMPLE, 64, "", False)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:32021 for open-end generation.


Collated:  
[STATE]
x: X
l: list X

rev l ++ [x] = rev (x :: l)
[SCRIPT]
Theorem rev_app : forall x l, rev l ++ [x] = rev (x::l).
Proof.
  intros.
[TACTIC]





In [None]:
result.next_tactic_list

['\n  induction l.',
 ' induction l.',
 ' induction l.',
 ' induction l as [|h t IH].',
 ' induction l.',
 "\n  induction l as [ | y0 l' IHl'].",
 '\n  induction l.',
 '\n  induction l.',
 '\n  induction l.',
 ' induction l.',
 ' induction l as [|h t].',
 '\n  induction l.',
 '\n  induction l.',
 '\n  induction l.',
 ' induction l.',
 '\n  induction l as [|h t].',
 "\n  induction l as [|h t I'].",
 '\n  induction l.',
 ' induction l.',
 '\n  induction l as [|m l IHl].',
 ' induction l.',
 ' induction l.',
 '\n  induction l.',
 '\n  rewrite rev_app_split.',
 '\n  induction l.',
 '\n  induction l.',
 ' induction l.',
 '\n  induction l.',
 '\n  induction l; simpl.',
 ' induction l.',
 ' induction l.',
 ' induction l.',
 ' unfold rev.',
 ' induction l.',
 '\n  revert x.',
 '\n  induction l; auto.',
 ' induction l.',
 '\n  induction l.',
 ' induction l.',
 '\n  induction l.',
 '\n  induction l.',
 ' \n  rewrite <- rev_involutive.',
 ' symmetry.',
 ' simpl.',
 ' induction l; simpl; auto.',
 

In [None]:
result.score_list

[-1.091173529624939,
 -1.0829819440841675,
 -0.9765117168426514,
 -5.801501274108887,
 -1.3765190839767456,
 -13.580381393432617,
 -1.2010478973388672,
 -0.7581014633178711,
 -0.9752275943756104,
 -1.3491300344467163,
 -6.818093299865723,
 -0.9610227346420288,
 -1.306260347366333,
 -0.969602108001709,
 -7.681689739227295,
 -6.0971221923828125,
 -9.352446556091309,
 -0.9166039228439331,
 -0.9867836236953735,
 -10.54260540008545,
 -1.3574259281158447,
 -1.7184215784072876,
 -0.9940738677978516,
 -11.772241592407227,
 -1.07068932056427,
 -1.2064350843429565,
 -1.4584046602249146,
 -1.0888279676437378,
 -3.8191006183624268,
 -1.2318209409713745,
 -1.3284974098205566,
 -1.3770112991333008,
 -4.765250205993652,
 -1.0577272176742554,
 -8.313982009887695,
 -6.794739246368408,
 -1.26614511013031,
 -0.9316457509994507,
 -1.3570245504379272,
 -0.9584465026855469,
 -1.2043745517730713,
 -10.348194122314453,
 -7.571715354919434,
 -10.32215690612793,
 -5.461085796356201,
 -0.9433013796806335,
 -3.61

In [None]:
result.num_tokens_list

[6,
 4,
 4,
 11,
 4,
 16,
 6,
 6,
 6,
 4,
 9,
 6,
 6,
 6,
 5,
 11,
 12,
 6,
 4,
 14,
 4,
 4,
 6,
 10,
 6,
 6,
 4,
 6,
 8,
 4,
 4,
 4,
 5,
 4,
 7,
 8,
 4,
 6,
 4,
 6,
 6,
 14,
 3,
 4,
 8,
 6,
 3,
 12,
 6,
 10,
 6,
 4,
 9,
 6,
 6,
 4,
 6,
 6,
 3,
 12,
 4,
 3,
 6,
 4]