In [None]:
import heapq
import os
import pickle
import sys
import time

import ray
import torch
from lean_dojo import (
    Pos,
    Dojo,
    Theorem,
    LeanGitRepo,
    ProofFinished,
    DojoInitError,
    DojoCrashError,
    DojoHardTimeoutError,
)
from lean_dojo.constants import LEAN3_DEPS_DIR, LEAN4_DEPS_DIR
from ray.util.actor_pool import ActorPool

from common import zip_strict
from generator.model import RetrievalAugmentedGenerator
from prover.new_search_tree import *

ckpt_path = 'gen.ckpt'

tac_gen = RetrievalAugmentedGenerator.load(
    ckpt_path, device=torch.device("cuda"), freeze=True
)


In [None]:
tac_gen

In [None]:
tokenized_state = tac_gen.tokenizer(state, return_tensors="pt").to(device)

In [None]:

state_ids = tokenized_state.input_ids.to(device)
state_mask = tokenized_state.attention_mask.to(device)

In [None]:
state_ids

In [None]:
state_mask

In [None]:
# Generate a single tactic.
tactic_ids = tac_gen.generator.generate(tokenized_state.input_ids, max_length=1024)

In [None]:
tactic_ids

In [None]:

tactic = tac_gen.tokenizer.decode(tactic_ids[0], skip_special_tokens=True)
print(tactic, end="\n\n")

In [None]:
# Generate multiple tactics via beam search.
tactic_candidates_ids = tac_gen.generator.generate(
    tokenized_state.input_ids,
    max_length=1024,
    num_beams=4,
    length_penalty=0.0,
    do_sample=False,
    num_return_sequences=4,
    early_stopping=False,
)

In [None]:
tactic_candidates_ids

In [None]:

tactic_candidates = tac_gen.tokenizer.batch_decode(
    tactic_candidates_ids, skip_special_tokens=True
)
for tac in tactic_candidates:
    print(tac)


In [None]:
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from trl.core import respond_to_batch


In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from trl import AutoModelForSeq2SeqLMWithValueHead

tokenizer = AutoTokenizer.from_pretrained("kaiyuy/leandojo-lean3-tacgen-byt5-small")       # Or "lean3" -> "lean4"
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained("kaiyuy/leandojo-lean3-tacgen-byt5-small")   # Or "lean3" -> "lean4"

In [None]:
state = "n : ℕ\n⊢ gcd n n = n"
device = 'cuda'

In [None]:
# initialize trainer
ppo_config = PPOConfig(
    batch_size=1,
)

In [None]:
model_ref = create_reference_model(model)

In [None]:
query_tensor = tokenizer.encode(state, return_tensors="pt").to(device)
model = model.to(device)

In [None]:
response_tensor = model.generate(query_tensor.to(device))

In [None]:
# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)

In [None]:
import torch

# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(100.0)]

In [None]:
len(response_tensor[0])

In [None]:

# train model for one step with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

train_stats['ppo/policy/advantages']