In [1]:
%load_ext autoreload
%autoreload 2

# Create ColBERTRetriever

In [2]:
from reprover.retrieval.colbert.model import ColBERTPremiseRetriever
from colbert.modeling.checkpoint import Checkpoint
from colbert.infra.config import ColBERTConfig

corpus_path = "../data/leandojo_benchmark_4/corpus.jsonl"
collection_path = "../data/leandojo_benchmark_4/random/colbert_collection.tsv"
checkpoint_path = "../checkpoints/colbertv2.0/"
num_negatives = 1
num_retrieved = 2

config = ColBERTConfig.load_from_checkpoint(checkpoint_path)
checkpoint = Checkpoint(checkpoint_path, colbert_config=config)

model = ColBERTPremiseRetriever(
    index_name="colbert_v1",
    experiment_name="lean",
    checkpoint_path_or_name=checkpoint_path,
    collection=collection_path,
    # config: Optional[Union[str, ColBERTConfig]] = None,
    index_root="../experiments/",
    num_retrieved=num_retrieved,
)
model.load_corpus(corpus_path)

[2024-04-06 21:25:22,443] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[Apr 06, 21:25:23] #> Loading collection...
[Apr 06, 21:25:24] #> Loading codec...
[Apr 06, 21:25:24] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[Apr 06, 21:25:24] Loading packbits_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[Apr 06, 21:25:24] #> Loading IVF...
[Apr 06, 21:25:24] #> Loading doclens...


100%|██████████| 7/7 [00:00<00:00, 2561.52it/s]

[Apr 06, 21:25:24] #> Loading codes and residuals...



100%|██████████| 7/7 [00:00<00:00, 140.86it/s]


## Test corpus reindexing

In [3]:
# model.reindex_corpus(64)

## Test retrieve() function

In [4]:
from reprover.retrieval.datamodule import ColBERTRetrievalDataModule

data_path = "../data/leandojo_benchmark_4/random/"

data_module = ColBERTRetrievalDataModule(
    data_path,
    corpus_path,
    num_negatives=num_negatives,
    num_in_file_negatives=1,
    colbert_config=model.config,
    batch_size=1,
    eval_batch_size=1,
    max_seq_len=128,
    verbose=3,
    num_workers=0,
)
data_module.setup()

example = data_module.ds_val.data[0]
ctx = example["context"]


inputs = dict(
    state=[ctx.state],
    file_name=[ctx.path],
    theorem_full_name=[ctx.theorem_full_name],
    theorem_pos=[ctx.theorem_pos],
)

[32m2024-04-06 21:25:59.387[0m | [1mINFO    [0m | [36mreprover.retrieval.datamodule[0m:[36m_load_data[0m:[36m56[0m - [1mLoading data from ../data/leandojo_benchmark_4/random/train.json[0m
100%|██████████| 98514/98514 [00:14<00:00, 6810.28it/s]
[32m2024-04-06 21:26:23.956[0m | [1mINFO    [0m | [36mreprover.retrieval.datamodule[0m:[36m_load_data[0m:[36m100[0m - [1mLoaded 297243 examples.[0m
[32m2024-04-06 21:26:23.972[0m | [1mINFO    [0m | [36mreprover.retrieval.datamodule[0m:[36m_load_data[0m:[36m56[0m - [1mLoading data from ../data/leandojo_benchmark_4/random/val.json[0m
100%|██████████| 2000/2000 [00:00<00:00, 7052.53it/s]
[32m2024-04-06 21:26:24.303[0m | [1mINFO    [0m | [36mreprover.retrieval.datamodule[0m:[36m_load_data[0m:[36m100[0m - [1mLoaded 4641 examples.[0m
[32m2024-04-06 21:26:24.304[0m | [1mINFO    [0m | [36mreprover.retrieval.datamodule[0m:[36m_load_data[0m:[36m56[0m - [1mLoading data from ../data/leandojo_benchmar

In [5]:
ranking, premises, scores = model.retrieve(**inputs, k=2, reindex_batch_size=32, do_reindex=False)


#> QueryTokenizer.tensorize(batch_text[0], batch_background[0], bsize) ==
#> Input: . I : Type w₀
J : Type w₁
C : I → Type u₁
inst✝² : (i : I) → Category.{v₁, u₁} (C i)
D : I → Type u₁
inst✝¹ : (i : I) → Category.{v₁, u₁} (D i)
A : Type u₁
inst✝ : Category.{u₁, u₁} A
f : (i : I) → A ⥤ C i
i : I
⊢ pi' f ⋙ Pi.eval C i = f i, 		 True, 		 None
#> Output IDs: torch.Size([32]), tensor([  101, 30522,  1045,  1024,  2828,  1059, 17110,  1046,  1024,  2828,
         1059, 11871,  1039,  1024,  1045,  1585,  2828,  1057, 11871,   100,
         1024,  1006,  1045,  1024,  1045,  1007,  1585,  4696,  1012,  1063,
         1058,   102])
#> Output Mask: torch.Size([32]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1])



1it [00:00, 136.89it/s]


In [6]:
premises, scores

([[Premise(path='Mathlib/CategoryTheory/Comma.lean', full_name='CategoryTheory.Comma.post', code='@[simps]\ndef post (L : A ⥤ T) (R : B ⥤ T) (F : T ⥤ C) : Comma L R ⥤ Comma (L ⋙ F) (R ⋙ F) where\n  obj X :=\n    { left := X.left\n      right := X.right\n      hom := F.map X.hom }\n  map f :=\n    { left := f.left\n      right := f.right\n      w := by simp only [Functor.comp_map, ← F.map_comp, f.w] }'),
   Premise(path='Mathlib/CategoryTheory/Sites/SheafOfTypes.lean', full_name='CategoryTheory.Equalizer.Presieve.firstMap', code='def firstMap : FirstObj P R ⟶ SecondObj P R :=\n  Pi.lift fun fg =>\n    haveI := Presieve.hasPullbacks.has_pullbacks fg.1.2.2 fg.2.2.2\n    Pi.π _ _ ≫ P.map pullback.fst.op')]],
 [(24.421875, 24.015625)])

## Test retrieve_from_preprocessed() function

In [7]:
batch = next(iter(data_module.val_dataloader()))
model.retrieve_from_preprocessed(batch)

100%|██████████| 1/1 [00:00<00:00, 165.59it/s]


([[Premise(path='Mathlib/CategoryTheory/Comma.lean', full_name='CategoryTheory.Comma.post', code='@[simps]\ndef post (L : A ⥤ T) (R : B ⥤ T) (F : T ⥤ C) : Comma L R ⥤ Comma (L ⋙ F) (R ⋙ F) where\n  obj X :=\n    { left := X.left\n      right := X.right\n      hom := F.map X.hom }\n  map f :=\n    { left := f.left\n      right := f.right\n      w := by simp only [Functor.comp_map, ← F.map_comp, f.w] }'),
   Premise(path='Mathlib/CategoryTheory/Sites/SheafOfTypes.lean', full_name='CategoryTheory.Equalizer.Presieve.firstMap', code='def firstMap : FirstObj P R ⟶ SecondObj P R :=\n  Pi.lift fun fg =>\n    haveI := Presieve.hasPullbacks.has_pullbacks fg.1.2.2 fg.2.2.2\n    Pi.π _ _ ≫ P.map pullback.fst.op')]],
 [24.421875, 24.015625])

## Test lighnint-like interface

In [11]:
checkpoint_path

'../checkpoints/colbertv2.0/'

In [18]:
from reprover.retrieval.colbert.model import ColBERTPremiseRetrieverLightning

model_pl = ColBERTPremiseRetrieverLightning(
    index_name="colbert_v1",
    experiment_name="lean",
    checkpoint_path_or_name=checkpoint_path,
    collection=collection_path,
    index_root="../experiments/",
    num_retrieved=num_retrieved,
)
model_pl.load_corpus(corpus_path)

[Apr 06, 21:45:06] #> Loading collection...
[Apr 06, 21:45:07] #> Loading codec...
[Apr 06, 21:45:07] #> Loading IVF...
[Apr 06, 21:45:07] #> Loading doclens...


100%|██████████| 7/7 [00:00<00:00, 2126.00it/s]

[Apr 06, 21:45:07] #> Loading codes and residuals...



100%|██████████| 7/7 [00:00<00:00, 150.16it/s]


In [21]:
batch

{'context': [Context(path='Mathlib/CategoryTheory/Pi/Basic.lean', theorem_full_name="CategoryTheory.Functor.pi'_eval", theorem_pos=(233, 1), state="I : Type w₀\nJ : Type w₁\nC : I → Type u₁\ninst✝² : (i : I) → Category.{v₁, u₁} (C i)\nD : I → Type u₁\ninst✝¹ : (i : I) → Category.{v₁, u₁} (D i)\nA : Type u₁\ninst✝ : Category.{u₁, u₁} A\nf : (i : I) → A ⥤ C i\ni : I\n⊢ pi' f ⋙ Pi.eval C i = f i")],
 'context_ids': tensor([[  101, 30522,  1045,  1024,  2828,  1059, 17110,  1046,  1024,  2828,
           1059, 11871,  1039,  1024,  1045,  1585,  2828,  1057, 11871,   100,
           1024,  1006,  1045,  1024,  1045,  1007,  1585,  4696,  1012,  1063,
           1058,   102]]),
 'context_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1]]),
 'url': ['https://github.com/leanprover-community/mathlib4'],
 'commit': ['3ce43c18f614b76e161f911b75a3e1ef641620ff'],
 'file_path': ['Mathlib/CategoryTheory/Pi/Basic.lean'],
 'full_n

In [30]:
model_pl.predict_step_outputs = []
model_pl.predict_step(batch, [None])

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00, 183.57it/s]


In [32]:
model_pl.validation_step(batch, 0)

100%|██████████| 1/1 [00:00<00:00, 184.20it/s]


AttributeError: 'NoneType' object has no attribute 'experiment'