In [1]:
from train_semi_auto import BracketFlowModule
from sampling import semiauto_euler_sampling, SamplingTraceDatapoint

In [2]:
checkpoint_path = "checkpoints/bracket-flow/last-v48.ckpt"
model = BracketFlowModule.load_from_checkpoint(checkpoint_path)

In [3]:
steps = 2000
batch_size = 10
samples, trace = semiauto_euler_sampling(
    model,
    model.interpolant,
    steps=steps,
    mask=0,
    pad=3,
    batch_size=10,
    max_length=64,
    return_trace=True,
)

  with torch.cuda.amp.autocast(enabled=False):
  with torch.cuda.amp.autocast(enabled=False):


In [7]:
def process_trace(trace: list[SamplingTraceDatapoint]):
    event_type_mapping = dict(change="c", insertion="i")
    token_mapping = {0: "m", 1: "(", 2: ")"}

    def _process_datapoint(datapoint: SamplingTraceDatapoint):
        return dict(
            t=datapoint.t,
            a=event_type_mapping[datapoint.event_type],
            tk=token_mapping[datapoint.token],
            i=datapoint.position,
        )

    return [_process_datapoint(datapoint) for datapoint in trace]

In [8]:
from IPython.display import display, HTML
import json

# Select the index of the tensor to visualize
id = 0
assert 0 <= id < batch_size, "id must be in [0, batch_size)"

events = process_trace(trace[id])
print(events)
display(
    HTML(f"""
<style>
.token{{display:inline-block;padding:5px 10px;margin:2px;border:1px solid #aaa;border-radius:4px;
font-family:monospace;transition:background .5s}}
.hi-i{{background:lightgreen}}
.hi-c{{background:yellow}}
#s{{min-height:30px;margin-bottom:20px}}
</style>
<div id="s"></div>
<button id="r">Replay</button>
<script>
// Use events provided from Python
var ev = {json.dumps(events)}, s = [];
var render = function(hIdx, act){{
  var c = document.getElementById("s"); c.innerHTML="";
  s.forEach(function(t,k){{ 
    var sp = document.createElement("span");
    sp.className = "token" + (k===hIdx ? " hi-"+act : "");
    sp.textContent = t; c.appendChild(sp);
  }}); 
}},
run = function(){{
  s = []; render();
  ev.forEach(function(e){{
    setTimeout(function(){{
      if(e.a=="i"){{
        var p = e.i;
        s.splice(p, 0, e.tk);
        render(p, "i");
      }} else {{
        var p = (e.i===undefined || e.i<0 || e.i>=s.length) ? Math.floor(s.length/2) : e.i;
        s[p] = e.tk; render(p, "c");
      }}
      setTimeout(render,600);
    }}, e.t*5000);
  }});
}};
document.getElementById("r").onclick = run;
run();
</script>
""")
)

[{'t': 0.09800010919570923, 'a': 'i', 'tk': 'm', 'i': 0}, {'t': 0.13750000298023224, 'a': 'i', 'tk': 'm', 'i': 1}, {'t': 0.13999997079372406, 'a': 'i', 'tk': 'm', 'i': 2}, {'t': 0.1689995974302292, 'a': 'i', 'tk': 'm', 'i': 3}, {'t': 0.2329987734556198, 'a': 'i', 'tk': 'm', 'i': 4}, {'t': 0.2429986447095871, 'a': 'i', 'tk': 'm', 'i': 5}, {'t': 0.2569984793663025, 'a': 'c', 'tk': '(', 'i': 1}, {'t': 0.26299840211868286, 'a': 'i', 'tk': 'm', 'i': 6}, {'t': 0.2679983377456665, 'a': 'i', 'tk': 'm', 'i': 7}, {'t': 0.2779982089996338, 'a': 'i', 'tk': 'm', 'i': 8}, {'t': 0.29399800300598145, 'a': 'c', 'tk': '(', 'i': 2}, {'t': 0.29649797081947327, 'a': 'i', 'tk': 'm', 'i': 9}, {'t': 0.319497674703598, 'a': 'i', 'tk': 'm', 'i': 10}, {'t': 0.3339974880218506, 'a': 'i', 'tk': 'm', 'i': 11}, {'t': 0.34299737215042114, 'a': 'i', 'tk': 'm', 'i': 12}, {'t': 0.3479973077774048, 'a': 'i', 'tk': 'm', 'i': 13}, {'t': 0.35149726271629333, 'a': 'i', 'tk': 'm', 'i': 14}, {'t': 0.35549721121788025, 'a': 'i'