In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
from pathlib import Path
from model_deployment.model_wrapper import FidT5LocalWrapper 

from data_management.sentence_db import SentenceDB
from data_management.dataset_file import DatasetFile, Proof
from data_management.splits import DATA_POINTS_NAME, REPOS_NAME, file_from_split, DataSplit, FileInfo, Split
from data_management.create_lm_dataset import LmDatasetConf

import torch
from torch import log_softmax

from tactic_gen.lm_example import LmExample, LmFormatter, formatter_from_conf

from util.constants import DATA_CONF_NAME

import yaml

  from .autonotebook import tqdm as notebook_tqdm


[2024-04-23 00:48:35,905] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
os.chdir(Path("/home/ubuntu/coq-modeling"))

In [3]:
SENTENCE_DB_LOC = Path("sentences.db")
DATA_LOC = Path("raw-data/coq-dataset")
DATA_SPLIT_LOC = Path("splits/final-split.json")
sentence_db = SentenceDB.load(SENTENCE_DB_LOC)
data_split = DataSplit.load(DATA_SPLIT_LOC)

In [4]:
FILE_NAME = Path("repos/coq-community-bertrand/theories/Binomial.v")
CHECKPOINT_LOC = Path("models/t5-fid-base-basic-final/checkpoint-110500")
THEOREM_NAME = "binomial_def2" 

In [5]:
file_info, split = file_from_split(str(FILE_NAME), data_split)

In [6]:
model_wrapper = FidT5LocalWrapper.from_checkpoint(str(CHECKPOINT_LOC))

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
def get_formatters(checkpoint_loc: Path) -> list[LmFormatter]:
    assert 0 < len(checkpoint_loc.parents)
    model_loc = checkpoint_loc.parents[0]
    lm_data_conf = model_loc / DATA_CONF_NAME
    assert lm_data_conf.exists()
    with lm_data_conf.open("r") as fin:
        yaml_data = yaml.load(fin, Loader=yaml.Loader)
    data_conf = LmDatasetConf.from_yaml(yaml_data)
    formatter_confs = data_conf.lm_formatter_confs
    formatters = [formatter_from_conf(f) for f in formatter_confs]
    return formatters 


In [8]:
dp_obj = file_info.get_dp(DATA_LOC, sentence_db)
proof = dp_obj.get_theorem(THEOREM_NAME)
formatter = get_formatters(CHECKPOINT_LOC)[0]

In [15]:
def get_tokens_and_logits(proof: Proof, dp_obj: DatasetFile, file_info: FileInfo, split: Split, data_loc: Path, model: FidT5LocalWrapper, formatter: LmFormatter) -> tuple[list[str], list[float]]:
    for i, step in enumerate(proof.steps[:1]):
        example = formatter.example_from_step(
            i,
            proof,
            dp_obj=dp_obj,
            file_info=file_info,
            split=split,
            data_loc=data_loc,
            ground_truth_steps=None,  # Not doing this right now
            key_record=None,
            cutoff_idx=None,
        )
        input_batch = model.local_dset.collate([example])
        with torch.no_grad():
            logits = model.model(
                input_batch["input_ids"].cuda(),
                input_batch["attention_mask"].cuda(),
                input_batch["labels"].cuda(),
            ).logits

        vocab_size = logits.shape[-1]
        print(vocab_size)

        # move labels to correct device to enable model parallelism
        labels = input_batch["labels"].to(logits.device)
        # Shift so that tokens < n predict n
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

        logit_mat = shift_logits.view(-1, vocab_size)
        print(logit_mat.shape)
        print(shift_labels.shape)
        log_probs = log_softmax(logit_mat, 1) 
        print(log_probs.shape)
        print(labels)
        #one_d_log_probs = log_probs[shift_labels.view(-1)]
        #return list(shift_labels), list(one_d_log_probs)


In [16]:
get_tokens_and_logits(proof, dp_obj, file_info, split, DATA_LOC, model_wrapper, formatter)

32128
torch.Size([63, 32128])
torch.Size([1, 63])
torch.Size([63, 32128])
tensor([[25029,     5,     1,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100]], device='cuda:0')
