In [None]:
# shows that when we train DPO (using data from sampling rather than beam search),
# that sampling will result in worse outcomes. It appears that much more unlikely sequences (lower logprob)
# are generated with sampling, and the distribution of logprobs is much more spread out
# possibly suggests that this makes the model more confident in only a small subset of paths..
# timing indicates that a massive increase occurs for DPO trained models when sampling (likely because of the longer sequences)
# beamsearch timings are similar

In [None]:
from refactor.dpo.model import DPOTrainModule

tac_gen_last = DPOTrainModule.load(
   'experiments/runs/lean_dojo_dpo/dpo_beamsearch_2023_12_05/10_09_00/checkpoints/last.ckpt', device='cuda', freeze=True
)

In [None]:
tac_gen_first = DPOTrainModule.load(
   'experiments/runs/lean_dojo_dpo/dpo_beamsearch_2023_12_05/10_09_00/checkpoints/epoch=0-step=10000-Pass@1_val=0.46.ckpt', device='cuda', freeze=True
)


In [None]:
from omegaconf import OmegaConf
conf = OmegaConf.create({'strategy': 'beam', 'length_penalty': 0.0})

In [None]:
tac_gen_first.gen_config = conf

In [None]:
tac_gen_first.gen_config

In [None]:
tac_gen_last.gen_config = conf

In [None]:
import time


In [None]:
t0 = time.monotonic()

first_samples = tac_gen_first.batch_generate(state=['E : Type u_3,\n' +
      '_inst_1 : normed_add_comm_group E,\n' +
      'f : ℝ → E,\n' +
      'a b c d : ℝ,\n' +
      'μ : measure ℝ,\n' +
      'h : [c, d] ⊆ [a, b],\n' +
      'hf : integrable_on f (Ι a b) μ,\n' +
      'hc : c = d\n' +
      '⊢ integrable_on f (Ι c d) μ'], retriever_args=None, num_samples=64)

first_time = time.monotonic() - t0


In [None]:
t0 = time.monotonic()

last_samples = tac_gen_last.batch_generate(state=['E : Type u_3,\n' +
      '_inst_1 : normed_add_comm_group E,\n' +
      'f : ℝ → E,\n' +
      'a b c d : ℝ,\n' +
      'μ : measure ℝ,\n' +
      'h : [c, d] ⊆ [a, b],\n' +
      'hf : integrable_on f (Ι a b) μ,\n' +
      'hc : c = d\n' +
      '⊢ integrable_on f (Ι c d) μ'], retriever_args=None, num_samples=64)

last_time = time.monotonic() - t0

In [None]:
sorted(first_samples[0], key=lambda x: x[1], reverse=True)

In [None]:
sorted(last_samples[0], key=lambda x: x[1], reverse=True)

In [None]:
import matplotlib.pyplot as plt

In [None]:
first_dist = [x[1] for x in first_samples[0]]


In [None]:
first_dist

In [None]:
last_dist = [x[1] for x in last_samples[0]]


In [None]:
plt.hist(first_dist, bins = 60)

In [None]:
plt.hist([x for x in last_dist if x > -200], bins=100)