In [1]:
from typing import Any
import os
from pathlib import Path
from tactic_gen.tactic_data import TEST_LM_EXAMPLE, example_collator_from_conf, ExampleCollator
from tactic_gen.train_decoder import get_model, get_tokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer, BitsAndBytesConfig
import torch
from util.constants import TRAINING_CONF_NAME
import yaml



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
if Path(os.curdir).resolve().name == "tactic_gen":
    os.chdir("../..")
elif Path(os.curdir).resolve().name == "coq-modeling": 
    pass
else:
    raise ValueError(f"In an unexpected directory: {os.curdir}")

In [3]:
def get_training_conf(checkpoint_loc: Path) -> Any:
    training_conf_loc = checkpoint_loc.parent / TRAINING_CONF_NAME
    with training_conf_loc.open('r') as f:
        training_conf = yaml.safe_load(f)
    return training_conf

In [4]:
def get_example_collator(checkpoint_loc: Path) -> ExampleCollator:
    training_conf = get_training_conf(checkpoint_loc)
    example_collator_conf = training_conf['example_collator']
    example_collator = example_collator_from_conf(example_collator_conf) 
    return example_collator

In [5]:
CHECKPOINT_LOC = Path("models/deepseek-1.3b-basic/checkpoint-48000")
training_conf = get_training_conf(CHECKPOINT_LOC)
example_collator = get_example_collator(CHECKPOINT_LOC)
tokenizer = get_tokenizer(training_conf, add_eos=False) 

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
model = get_model(str(CHECKPOINT_LOC))
None

`low_cpu_mem_usage` was None, now set to True since model is quantized.


In [7]:
N = 4
BEAM = False 

In [7]:
collated_input = example_collator.collate_input(tokenizer, TEST_LM_EXAMPLE)
inputs = tokenizer(collated_input, return_tensors='pt')
with torch.no_grad():
    out = model.generate(
        inputs["input_ids"], 
        max_new_tokens=64, 
        do_sample=not BEAM, 
        temperature=None if BEAM else 1.0,
        num_beams=(),
        return_dict_in_generate=True,
        output_scores=True,
        num_return_sequences=64, 
    )


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:32021 for open-end generation.




In [44]:
gen_out = out.sequences[:, inputs["input_ids"].shape[1]:]

In [55]:
gen_out == -torch.inf 

tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])

In [50]:
gen_out

tensor([[22698,   284,    13,  ..., 32021, 32021, 32021],
        [  185,   207, 22698,  ..., 32021, 32021, 32021],
        [  185,   207, 22698,  ..., 32021, 32021, 32021],
        ...,
        [22698,   284,    13,  ..., 32021, 32021, 32021],
        [22698,   284,    13,  ..., 32021, 32021, 32021],
        [22698,   284,    13,  ..., 32021, 32021, 32021]])

In [66]:
scores = model.compute_transition_scores(gen_out, out.scores, normalize_logits=True)

In [67]:
scores

tensor([[-1.0525, -0.0125, -0.2690,  ...,    -inf,    -inf,    -inf],
        [-0.7967, -0.0052, -0.1003,  ...,    -inf,    -inf,    -inf],
        [-0.7946, -0.0052, -0.0786,  ...,    -inf,    -inf,    -inf],
        ...,
        [-1.0528, -0.0119, -0.1985,  ...,    -inf,    -inf,    -inf],
        [-1.0960, -0.0064, -0.1984,  ...,    -inf,    -inf,    -inf],
        [-1.1107, -0.0103, -0.2682,  ...,    -inf,    -inf,    -inf]])

In [68]:
(scores != -torch.inf).sum(axis=1).tolist()
scores.where(scores != -torch.inf, torch.tensor(0.0)).sum(axis=1).tolist()

[-1.3340169191360474,
 -1.1784601211547852,
 -1.1543546915054321,
 -0.9947277307510376,
 -6.49794340133667,
 -0.9908392429351807,
 -8.996271133422852,
 -1.3592931032180786,
 -6.107923984527588,
 -4.414965629577637,
 -6.603355884552002,
 -1.0965176820755005,
 -5.465126991271973,
 -1.0972161293029785,
 -0.9893094301223755,
 -9.866199493408203,
 -1.344957947731018,
 -1.120327353477478,
 -3.602762460708618,
 -1.3293654918670654,
 -3.1278340816497803,
 -9.600914001464844,
 -0.9180083870887756,
 -0.9945132732391357,
 -1.6933228969573975,
 -6.059633255004883,
 -8.38379192352295,
 -1.3060755729675293,
 -7.074195384979248,
 -1.3598614931106567,
 -1.0195387601852417,
 -1.1817296743392944,
 -11.261676788330078,
 -0.9186042547225952,
 -1.170425295829773,
 -1.260474443435669,
 -0.9372342228889465,
 -1.0401631593704224,
 -1.1500436067581177,
 -1.3038089275360107,
 -0.9233178496360779,
 -1.057296633720398,
 -1.0779473781585693,
 -1.2862298488616943,
 -0.8998291492462158,
 -6.451598167419434,
 -9.2837

In [60]:
input_num_tokens = inputs["input_ids"].shape[1]
tokenizer.batch_decode(out.sequences[:, input_num_tokens:], skip_special_tokens=True)


[' induction l.',
 '\n  induction l.',
 '\n  induction l.',
 ' induction l.',
 '\n  generalize dependent x.',
 '\n  induction l.',
 "\n  induction l as [|l' H].",
 ' induction l.',
 ' induction l as [|h t IH].',
 '\n  simpl.',
 '\n  induction l as [| h t IH].',
 '\n  induction l.',
 "\n  induction l as [|y l'].",
 ' induction l.',
 '\n  induction l.',
 "\n  induction l as [|x0 l' IHl'].",
 ' induction l.',
 '\n  induction l.',
 ' simpl.',
 ' induction l.',
 ' simpl.',
 '\n  rewrite (rev_involutive l).',
 '\n  induction l.',
 '\n  induction l.',
 ' induction l.',
 '\n  induction l as [|h t].',
 "\n  induction l as [ | y l'].",
 ' induction l.',
 '\n  rewrite app_comm.',
 ' induction l.',
 ' induction l.',
 '\n  induction l.',
 '\n  simpl.',
 '\n  induction l.',
 '\n  induction l.',
 ' induction l.',
 '\n  induction l.',
 ' induction l.',
 '\n  induction l.',
 ' induction l.',
 '\n  induction l.',
 '\n  induction l.',
 ' induction l.',
 ' induction l.',
 '\n  induction l.',
 '\n  inducti

In [5]:
from model_deployment.model_wrapper import DecoderLocalWrapper

In [9]:
CHECKPOINT_LOC = Path("models/deepseek-1.3b-basic/checkpoint-48000")
model_wrapper = DecoderLocalWrapper.from_checkpoint(CHECKPOINT_LOC)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
`low_cpu_mem_usage` was None, now set to True since model is quantized.


In [11]:
result = model_wrapper.get_recs(TEST_LM_EXAMPLE, 64, "", False)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:32021 for open-end generation.


Collated:  
[STATE]
x: X
l: list X

rev l ++ [x] = rev (x :: l)
[SCRIPT]
Theorem rev_app : forall x l, rev l ++ [x] = rev (x::l).
Proof.
  intros.
[TACTIC]





In [12]:
result.next_tactic_list

['\n  induction l.',
 ' induction l.',
 ' induction l.',
 ' induction l as [|h t IH].',
 ' induction l.',
 "\n  induction l as [ | y0 l' IHl'].",
 '\n  induction l.',
 '\n  induction l.',
 '\n  induction l.',
 ' induction l.',
 ' induction l as [|h t].',
 '\n  induction l.',
 '\n  induction l.',
 '\n  induction l.',
 ' induction l.',
 '\n  induction l as [|h t].',
 "\n  induction l as [|h t I'].",
 '\n  induction l.',
 ' induction l.',
 '\n  induction l as [|m l IHl].',
 ' induction l.',
 ' induction l.',
 '\n  induction l.',
 '\n  rewrite rev_app_split.',
 '\n  induction l.',
 '\n  induction l.',
 ' induction l.',
 '\n  induction l.',
 '\n  induction l; simpl.',
 ' induction l.',
 ' induction l.',
 ' induction l.',
 ' unfold rev.',
 ' induction l.',
 '\n  revert x.',
 '\n  induction l; auto.',
 ' induction l.',
 '\n  induction l.',
 ' induction l.',
 '\n  induction l.',
 '\n  induction l.',
 ' \n  rewrite <- rev_involutive.',
 ' symmetry.',
 ' simpl.',
 ' induction l; simpl; auto.',
 

In [13]:
result.score_list

[-1.091173529624939,
 -1.0829819440841675,
 -0.9765117168426514,
 -5.801501274108887,
 -1.3765190839767456,
 -13.580381393432617,
 -1.2010478973388672,
 -0.7581014633178711,
 -0.9752275943756104,
 -1.3491300344467163,
 -6.818093299865723,
 -0.9610227346420288,
 -1.306260347366333,
 -0.969602108001709,
 -7.681689739227295,
 -6.0971221923828125,
 -9.352446556091309,
 -0.9166039228439331,
 -0.9867836236953735,
 -10.54260540008545,
 -1.3574259281158447,
 -1.7184215784072876,
 -0.9940738677978516,
 -11.772241592407227,
 -1.07068932056427,
 -1.2064350843429565,
 -1.4584046602249146,
 -1.0888279676437378,
 -3.8191006183624268,
 -1.2318209409713745,
 -1.3284974098205566,
 -1.3770112991333008,
 -4.765250205993652,
 -1.0577272176742554,
 -8.313982009887695,
 -6.794739246368408,
 -1.26614511013031,
 -0.9316457509994507,
 -1.3570245504379272,
 -0.9584465026855469,
 -1.2043745517730713,
 -10.348194122314453,
 -7.571715354919434,
 -10.32215690612793,
 -5.461085796356201,
 -0.9433013796806335,
 -3.61

In [14]:
result.num_tokens_list

[6,
 4,
 4,
 11,
 4,
 16,
 6,
 6,
 6,
 4,
 9,
 6,
 6,
 6,
 5,
 11,
 12,
 6,
 4,
 14,
 4,
 4,
 6,
 10,
 6,
 6,
 4,
 6,
 8,
 4,
 4,
 4,
 5,
 4,
 7,
 8,
 4,
 6,
 4,
 6,
 6,
 14,
 3,
 4,
 8,
 6,
 3,
 12,
 6,
 10,
 6,
 4,
 9,
 6,
 6,
 4,
 6,
 6,
 3,
 12,
 4,
 3,
 6,
 4]