In [4]:
from typing import Any
from pathlib import Path
from tactic_gen.tactic_data import TEST_LM_EXAMPLE, example_collator_from_conf, ExampleCollator
from tactic_gen.train_decoder import get_model, get_tokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer, BitsAndBytesConfig
import torch
from util.constants import TRAINING_CONF_NAME
import yaml



In [2]:
def get_training_conf(checkpoint_loc: Path) -> Any:
    training_conf_loc = checkpoint_loc.parent / TRAINING_CONF_NAME
    with training_conf_loc.open('r') as f:
        training_conf = yaml.safe_load(f)
    return training_conf

In [3]:
def get_example_collator(checkpoint_loc: Path) -> ExampleCollator:
    training_conf = get_training_conf(checkpoint_loc)
    example_collator_conf = training_conf['example_collator']
    example_collator = example_collator_from_conf(example_collator_conf) 
    return example_collator

In [12]:
CHECKPOINT_LOC = Path("/home/ubuntu/coq-modeling/models/deepseek-1.3b-basic/checkpoint-17000")
training_conf = get_training_conf(CHECKPOINT_LOC)
example_collator = get_example_collator(CHECKPOINT_LOC)
tokenizer = get_tokenizer(training_conf, add_eos=False) 

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [8]:
model = get_model(str(CHECKPOINT_LOC))
None

`low_cpu_mem_usage` was None, now set to True since model is quantized.


In [35]:
collated_input = example_collator.collate_input(tokenizer, TEST_LM_EXAMPLE)
inputs = tokenizer(collated_input, return_tensors='pt')
with torch.no_grad():
    out = model.generate(inputs["input_ids"], max_new_tokens=64, do_sample=True, num_return_sequences=64, temperature=1)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:32021 for open-end generation.




In [37]:
input_num_tokens = inputs["input_ids"].shape[1]
tokenizer.batch_decode(out[:, input_num_tokens:], skip_special_tokens=True)


[' induction l.',
 '\n  induction l.',
 ' auto.',
 ' induction l.',
 ' simpl.',
 ' simpl.',
 "\n  induction l as [|n l'].",
 '\n  simpl.',
 "\n  induction l as [| n l' Hrecl'].",
 ' induction l as [|x1 l Hrec].',
 " induction l as [| x' l'].",
 ' simpl.',
 ' induction l.',
 ' induction l.',
 '\n  induction l.',
 '\n  induction l; simpl.',
 ' reflexivity.',
 ' induction l.',
 ' induction l.',
 ' induction l.',
 ' destruct l.',
 " induction l as [|l0 l'' IHl].",
 ' \n  remember (rev l).',
 '\n  auto.',
 "\n  induction l as [|x' l].",
 '\n  simpl.',
 ' simpl.',
 "\n  induction l as [|n l' IHl'].",
 '\n  induction l as [|h tl IH].',
 " induction l as [|h t IH] using rev_rect'.",
 '\n  induction l.',
 ' induction l; simpl;\n    reflexivity.',
 '\n  destruct l.',
 "\n  induction l as [\n                 | x l' IHl'].",
 ' induction l.',
 ' induction l.',
 ' induction l.',
 '\n  *',
 ' simpl.',
 ' induction l as [|h t IH].',
 " induction tl as [|x' l' IHl].",
 ' induction l as [|hd tl].',
 '\

In [2]:
from model_deployment.model_wrapper import DecoderLocalWrapper

In [5]:
CHECKPOINT_LOC = Path("/home/ubuntu/coq-modeling/models/deepseek-1.3b-basic/checkpoint-17000")
model_wrapper = DecoderLocalWrapper.from_checkpoint(CHECKPOINT_LOC)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
`low_cpu_mem_usage` was None, now set to True since model is quantized.


In [7]:
result = model_wrapper.get_recs(TEST_LM_EXAMPLE, 64, "")

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:32021 for open-end generation.


In [8]:
result.next_tactic_list

['\n  induction l.',
 ' induction l as [|h t].',
 ' simpl.',
 ' induction l.',
 '\n  functional induction (rev l).',
 '\n  simpl.',
 ' induction l; simpl; congs;auto.',
 '\n  induction l.',
 '\n  induction l.',
 ' induction l.',
 ' induction l.',
 "\n  induction l as [|x' l].",
 '\n  induction l.',
 '\n  induction l.',
 '\n  induction l;\n    simpl (rev l);\n    simpl (rev (x :: l));\n    repeat rewrite push_app;\n    try rewrite rev_push.',
 ' induction l.',
 ' reflexivity.',
 '\n  induction l; cbn; reflexivity.',
 ' symmetry.',
 ' induction l.',
 ' induction l.',
 ' induction l.',
 ' induction l.',
 "\n  induction l as [ | n l'].",
 ' induction l.',
 ' induction l.',
 ' simpl.',
 "\n  induction l as [|h t'].",
 "\n  induction l as [| x' r].",
 ' induction l.',
 '\n  simpl.',
 "\n  induction l as [| x' l' IHl].",
 ' simpl.',
 "\n  induction l as [|r l' IH].",
 ' induction l.',
 "\n  induction l as [| x' l' IH].",
 '\n  induction l;simpl;f_equal;eauto.',
 ' induction l.',
 ' induction 