In [1]:
class Dict(dict):
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

args = Dict(
    seed = 998244353,
    fp_round = 20, fp_guard = 20, # precision
)

## Main

In [3]:
class Tester:
    def __init__(self, func, prog: list, inputs: list):
        self.func = func
        self.prog = prog
        self.inputs = inputs
    def test(self, engine_cls):
        print(f'<engine={engine_cls.__name__}>', flush = True)
        for i, input in enumerate(self.inputs):
            answer = self.func(input)
            print(f'[#{i}] input={input}, answer={answer}', flush = True)
            engine = engine_cls(prog = self.prog, input = input)
            engine.prepare()
            engine.execute()
            output = engine.readout()
            if output == answer:
                print(f'(AC) input={input}, output={output}', flush = True)
            else:
                print(f'(WA) input={input}, output={output}', flush = True)

from defaultlist import defaultlist
from collections import defaultdict

Tape = lambda: defaultdict(lambda: '0')

class DefaultDict(defaultdict):
    __getattr__ = defaultdict.__getitem__
    __setattr__ = defaultdict.__setitem__
    __delattr__ = defaultdict.__delitem__

def format_goto(inst: str, cond = None, absolute = None, return_list = False):
    seq = [(inst[0] + inst[1]) if cond is None else cond] + (([inst[2]] * eval(inst[3 :])) if absolute is None else [str(absolute + eval(inst[2 :]))]) + ['@']
    return seq if return_list else ''.join(seq)

def format_prog(raw_prog: list, absolute = False, return_list = False):
    prog = []
    for i, inst in enumerate(raw_prog):
        if len(inst) > 1 and inst[1] in '!?':
            prog.extend(format_goto(inst, absolute = i if absolute else None, return_list = True))
        else:
            prog.append(inst)
    return prog if return_list else ''.join(prog)

## Transformer

Implementation of the constructed Transformer

- Instructions: `TL`, `TR`, `T0`, `T1`, `T!`, `T?`, `#` (`T` $\in$ {`A`,`B`})
- Auxiliaries: `^`, `$`, `/`, `=`, `-`, `+`, `@`, `:`
- Outputs: `0`, `1`

In [4]:
import decimal
decimal.getcontext().prec = args.fp_round + args.fp_guard
FP_DTYPE = decimal.Decimal
FP_CONST_0 = FP_DTYPE(0)
FP_CONST_1 = FP_DTYPE(1)
FP_CONST_2 = FP_DTYPE(2)
FP_CONST_3 = FP_DTYPE(3)

from collections import defaultdict
from defaultlist import defaultlist

class DefaultDict(defaultdict):
    __getattr__ = defaultdict.__getitem__
    __setattr__ = defaultdict.__setitem__
    __delattr__ = defaultdict.__delitem__

def quantize(x, digits):
    return x.quantize(FP_DTYPE('.' + '0' * (digits - 1) + '1'))

def hardmax(sim):
    m = len(sim)
    sim = [quantize(sim[k], digits = args.fp_round) for k in range(m)]
    sim_max = max(sim)
    att = [FP_DTYPE(sim[k] == sim_max) for k in range(m)]
    att_cnt = sum(att)
    return [att[k] / att_cnt for k in range(m)]

def clip2(x: FP_DTYPE) -> FP_DTYPE:
    return FP_CONST_2 if x > 2 else x

def attend(qry_list, key_list, val_list, sim_fn = clip2):
    m = len(val_list[0])
    sim = [sim_fn(sum([qry_list[j][m - 1] * key_list[j][k] for j in range(len(qry_list))])) for k in range(m)]
    att = hardmax(sim)
    return [sum([att[k] * val[k] for k in range(m)]) for val in val_list]

def normalize(*val_list):
    if all(quantize(val, digits = args.fp_round) == 0 for val in val_list):
        return val_list
    norm = sum([val ** 2 for val in val_list]).sqrt()
    return [val / norm for val in val_list]

def relu(x: FP_DTYPE) -> FP_DTYPE:
    return FP_CONST_0 if x < 0 else x

def calc_pos_enc_1(pos: int) -> FP_DTYPE:
    pos = FP_DTYPE(pos)
    pos_ = pos + FP_CONST_1
    return FP_CONST_1 - (pos * pos_ + FP_CONST_1) / ((pos ** 2 + FP_CONST_1) * (pos_ ** 2 + FP_CONST_1)).sqrt()

def calc_pos_enc_2_3(pos: int) -> FP_DTYPE:
    pos = FP_DTYPE(pos)
    return normalize(pos, FP_CONST_1)

def tfm_compute(seq: list, mem: DefaultDict) -> str:
    char = seq[-1]
    pos = len(seq) - 1
    mem.one[pos] = FP_CONST_1
    mem.neg_one[pos] = -FP_CONST_1
    mem.pos_enc_0[pos] = FP_CONST_1 / FP_DTYPE(pos + 1)
    mem.pos_enc_1[pos] = calc_pos_enc_1(pos + 1)
    mem.pos_enc_2[pos], mem.pos_enc_3[pos] = calc_pos_enc_2_3(pos + 1)
    mem.pos_enc_0_neg[pos] = -mem.pos_enc_0[pos]
    mem.is_start[pos] = FP_DTYPE(char == '^')
    mem.is_Al[pos] = FP_DTYPE(char == 'AL')
    mem.is_Ar[pos] = FP_DTYPE(char == 'AR')
    mem.is_A0[pos] = FP_DTYPE(char == 'A0')
    mem.is_A1[pos] = FP_DTYPE(char == 'A1')
    mem.is_Aif0[pos] = FP_DTYPE(char == 'A!')
    mem.is_Aif1[pos] = FP_DTYPE(char == 'A?')
    mem.is_Bl[pos] = FP_DTYPE(char == 'BL')
    mem.is_Br[pos] = FP_DTYPE(char == 'BR')
    mem.is_B0[pos] = FP_DTYPE(char == 'B0')
    mem.is_B1[pos] = FP_DTYPE(char == 'B1')
    mem.is_Bif0[pos] = FP_DTYPE(char == 'B!')
    mem.is_Bif1[pos] = FP_DTYPE(char == 'B?')
    mem.is_goto_prev[pos] = FP_DTYPE(char == '-')
    mem.is_goto_next[pos] = FP_DTYPE(char == '+')
    mem.is_goto_do[pos] = FP_DTYPE(char == '@')
    mem.is_halt[pos] = FP_DTYPE(char == '#')
    mem.is_delim[pos] = FP_DTYPE(char == '$')
    mem.is_goto_unsat[pos] = FP_DTYPE(char == '/')
    mem.is_goto_sat[pos] = FP_DTYPE(char == '=')
    mem.is_read_start[pos] = FP_DTYPE(char == ':')
    mem.is_read_0[pos] = FP_DTYPE(char == '0')
    mem.is_read_1[pos] = FP_DTYPE(char == '1')
    mem.one_disc[pos], = attend(qry_list = [mem.one], key_list = [mem.one], val_list = [mem.is_start])
    mem._after_delim[pos], = attend(qry_list = [mem.one], key_list = [mem.one], val_list = [mem.is_delim])
    mem.after_delim[pos], = normalize(mem._after_delim[pos])
    mem.is_A_write[pos] = relu(mem.is_A0[pos] + mem.is_A1[pos] + mem.after_delim[pos] - FP_CONST_1)
    mem.is_B_write[pos] = relu(mem.is_B0[pos] + mem.is_B1[pos] + mem.after_delim[pos] - FP_CONST_1)
    mem.A_move[pos] = relu(mem.is_Ar[pos] + mem.after_delim[pos] - FP_CONST_1) - relu(mem.is_Al[pos] + mem.after_delim[pos] - FP_CONST_1)
    mem.B_move[pos] = relu(mem.is_Br[pos] + mem.after_delim[pos] - FP_CONST_1) - relu(mem.is_Bl[pos] + mem.after_delim[pos] - FP_CONST_1)
    mem.A_write[pos] = relu(mem.is_A1[pos] + mem.after_delim[pos] - FP_CONST_1)
    mem.B_write[pos] = relu(mem.is_B1[pos] + mem.after_delim[pos] - FP_CONST_1)
    mem.A_cur_disc[pos], mem.B_cur_disc[pos] = attend(qry_list = [mem.one], key_list = [mem.one], val_list = [mem.A_move, mem.B_move])
    mem.A_cur_norm[pos], mem.A_cur_one_norm[pos] = normalize(mem.A_cur_disc[pos], mem.one_disc[pos])
    mem.B_cur_norm[pos], mem.B_cur_one_norm[pos] = normalize(mem.B_cur_disc[pos], mem.one_disc[pos])
    mem.A_retr[pos], mem.A_retr_cur_norm[pos], mem.A_retr_is_write[pos] = attend(qry_list = [mem.one, mem.A_cur_norm, mem.A_cur_one_norm, mem.pos_enc_1], key_list = [mem.is_A_write, mem.A_cur_norm, mem.A_cur_one_norm, mem.pos_enc_0_neg], val_list = [mem.A_write, mem.A_cur_norm, mem.is_A_write])
    mem.B_retr[pos], mem.B_retr_cur_norm[pos], mem.B_retr_is_write[pos] = attend(qry_list = [mem.one, mem.B_cur_norm, mem.B_cur_one_norm, mem.pos_enc_1], key_list = [mem.is_B_write, mem.B_cur_norm, mem.B_cur_one_norm, mem.pos_enc_0_neg], val_list = [mem.B_write, mem.B_cur_norm, mem.is_B_write])
    mem._A_not_found[pos], = normalize(mem.A_cur_norm[pos] - mem.A_retr_cur_norm[pos])
    mem.A_not_found[pos] = relu(mem._A_not_found[pos]) + relu(-mem._A_not_found[pos])
    mem._B_not_found[pos], = normalize(mem.B_cur_norm[pos] - mem.B_retr_cur_norm[pos])
    mem.B_not_found[pos] = relu(mem._B_not_found[pos]) + relu(-mem._B_not_found[pos])
    mem.A_val[pos] = relu(mem.A_retr[pos] - mem.A_not_found[pos] - FP_CONST_1 + mem.A_retr_is_write[pos])
    mem.B_val[pos] = relu(mem.B_retr[pos] - mem.B_not_found[pos] - FP_CONST_1 + mem.B_retr_is_write[pos])
    mem.prog_move[pos] = relu(mem.is_goto_next[pos] + mem.after_delim[pos] - FP_CONST_1) - relu(mem.is_goto_prev[pos] + mem.after_delim[pos] - FP_CONST_1)
    mem.prog_move_disc[pos], = attend(qry_list = [mem.one], key_list = [mem.one], val_list = [mem.prog_move])
    mem.prog_move_norm[pos], mem.prog_move_one_norm[pos] = normalize(mem.prog_move_disc[pos], mem.one_disc[pos])
    mem.is_inst[pos] = relu(mem.is_Al[pos] + mem.is_Ar[pos] + mem.is_A0[pos] + mem.is_A1[pos] + mem.is_Aif0[pos] + mem.is_Aif1[pos] + mem.is_Bl[pos] + mem.is_Br[pos] + mem.is_B0[pos] + mem.is_B1[pos] + mem.is_Bif0[pos] + mem.is_Bif1[pos] + mem.is_halt[pos] - mem.after_delim[pos])
    mem._prog_idx_disc[pos], = attend(qry_list = [mem.one], key_list = [mem.one], val_list = [mem.is_inst])
    mem.prog_idx_disc[pos] = mem._prog_idx_disc[pos] - mem.one_disc[pos]
    mem.prog_idx_norm[pos], mem.prog_idx_one_norm[pos] = normalize(mem.prog_idx_disc[pos], mem.one_disc[pos])
    mem.is_goto_cond[pos] = mem.is_Aif0[pos] + mem.is_Aif1[pos] + mem.is_Bif0[pos] + mem.is_Bif1[pos]
    mem.goto_one_disc[pos], = attend(qry_list = [mem.one], key_list = [mem.prog_idx_norm], val_list = [mem.is_goto_cond])
    mem._goto_idx_norm[pos], mem._goto_one_norm[pos] = normalize(mem.one[pos] - mem.goto_one_disc[pos], mem.goto_one_disc[pos])
    mem.goto_idx_norm[pos] = mem._goto_idx_norm[pos] + mem.is_goto_cond[pos]
    mem.goto_one_norm[pos] = mem._goto_one_norm[pos] - mem.is_goto_cond[pos]
    mem.rec_is_start[pos] = mem.is_start[pos] + relu(mem.is_Al[pos] + mem.is_Ar[pos] + mem.is_A0[pos] + mem.is_A1[pos] + mem.is_Bl[pos] + mem.is_Br[pos] + mem.is_B0[pos] + mem.is_B1[pos] + mem.is_goto_unsat[pos] + mem.is_goto_sat[pos] + mem.after_delim[pos] - FP_CONST_1)
    mem.prog_rec_one_disc[pos], = attend(qry_list = [mem.one], key_list = [mem.rec_is_start], val_list = [mem.is_start])
    mem.prog_rec_norm[pos], mem.prog_rec_one_norm[pos] = normalize(mem.one[pos], mem.prog_rec_one_disc[pos])
    mem.prog_rec_norm_neg[pos] = -mem.prog_rec_norm[pos]
    mem.prog_rec_one_norm_neg[pos] = -mem.prog_rec_one_norm[pos]
    mem.prog_tmp_one_disc[pos], = attend(qry_list = [mem.neg_one], key_list = [mem.prog_rec_one_disc], val_list = [mem.is_goto_sat])
    mem.rec_is_end[pos] = relu(mem.is_delim[pos] + mem.is_Al[pos] + mem.is_Ar[pos] + mem.is_A0[pos] + mem.is_A1[pos] + mem.is_Bl[pos] + mem.is_Br[pos] + mem.is_B0[pos] + mem.is_B1[pos] + mem.is_goto_do[pos] + mem.after_delim[pos] - FP_CONST_1)
    mem._prog_tmp_norm[pos], mem._prog_tmp_one_norm[pos] = normalize(mem.one[pos], mem.prog_tmp_one_disc[pos])
    mem.prog_tmp_norm[pos] = relu(mem._prog_tmp_norm[pos] - mem.rec_is_end[pos]) + mem.rec_is_end[pos]
    mem.prog_tmp_one_norm[pos] = relu(mem._prog_tmp_one_norm[pos] - mem.rec_is_end[pos])
    mem.prog_rec_diff[pos], = attend(qry_list = [mem.prog_rec_norm, mem.prog_rec_one_norm], key_list = [mem.pos_enc_2, mem.pos_enc_3], val_list = [mem.pos_enc_1])
    mem.rec_is_goto[pos] = relu(mem.is_goto_sat[pos] + mem.is_goto_prev[pos] + mem.is_goto_next[pos] + mem.after_delim[pos] - FP_CONST_1)
    mem.prog_rec_bias[pos] = FP_CONST_3 - relu(mem.prog_rec_diff[pos] / FP_CONST_2 - FP_CONST_1 + mem.rec_is_goto[pos])
    mem.prog_cur_move[pos] = relu(mem.is_Al[pos] + mem.is_Ar[pos] + mem.is_A0[pos] + mem.is_A1[pos] + mem.is_Bl[pos] + mem.is_Br[pos] + mem.is_B0[pos] + mem.is_B1[pos] + mem.is_goto_next[pos] + mem.is_goto_unsat[pos] + mem.after_delim[pos] - FP_CONST_1) - relu(mem.is_goto_prev[pos] + mem.after_delim[pos] - FP_CONST_1)
    mem.prog_cur_disc[pos], mem.prog_cur_one_disc[pos] = attend(qry_list = [mem.prog_rec_bias, mem.prog_rec_norm_neg, mem.prog_rec_one_norm_neg], key_list = [mem.one, mem.prog_rec_norm, mem.prog_rec_one_norm], val_list = [mem.prog_cur_move, mem.is_start])
    mem.prog_cur_norm[pos], mem.prog_cur_one_norm[pos] = normalize(mem.prog_cur_disc[pos], mem.prog_cur_one_disc[pos])
    mem.token_is_Al[pos], mem.token_is_Ar[pos], mem.token_is_A0[pos], mem.token_is_A1[pos], mem._prog_is_Aif0[pos], mem._prog_is_Aif1[pos], mem.token_is_Bl[pos], mem.token_is_Br[pos], mem.token_is_B0[pos], mem.token_is_B1[pos], mem._prog_is_Bif0[pos], mem._prog_is_Bif1[pos], mem.token_is_goto_prev[pos], mem.token_is_goto_next[pos], mem.token_is_goto_do[pos], mem.token_is_halt[pos] = attend(
        qry_list = [mem.prog_cur_norm, mem.prog_cur_one_norm, mem.prog_tmp_norm, mem.prog_tmp_one_norm, mem.pos_enc_1, mem.neg_one],
        key_list = [mem.prog_idx_norm, mem.prog_idx_one_norm, mem.goto_idx_norm, mem.goto_one_norm, mem.pos_enc_0, mem.one],
        val_list = [mem.is_Al, mem.is_Ar, mem.is_A0, mem.is_A1, mem.is_Aif0, mem.is_Aif1, mem.is_Bl, mem.is_Br, mem.is_B0, mem.is_B1, mem.is_Bif0, mem.is_Bif1, mem.is_goto_prev, mem.is_goto_next, mem.is_goto_do, mem.is_halt])
    mem.token_is_Aif0[pos] = relu(mem._prog_is_Aif0[pos] - mem.A_val[pos])
    mem.token_is_Aif1[pos] = relu(mem._prog_is_Aif1[pos] + mem.A_val[pos] - FP_CONST_1)
    mem.token_is_Bif0[pos] = relu(mem._prog_is_Bif0[pos] - mem.B_val[pos])
    mem.token_is_Bif1[pos] = relu(mem._prog_is_Bif1[pos] + mem.B_val[pos] - FP_CONST_1)
    mem.token_is_goto_sat[pos] = mem.token_is_Aif0[pos] + mem.token_is_Aif1[pos] + mem.token_is_Bif0[pos] + mem.token_is_Bif1[pos]
    mem.token_is_goto_unsat[pos] = mem._prog_is_Aif0[pos] + mem._prog_is_Aif1[pos] + mem._prog_is_Bif0[pos] + mem._prog_is_Bif1[pos] - mem.token_is_goto_sat[pos]
    mem.is_read_key[pos] = mem.is_read_start[pos] + mem.is_read_0[pos] + mem.is_read_1[pos]
    mem.read_cur0_shift[pos] = FP_CONST_2 * (mem.is_read_0[pos] + mem.is_read_1[pos])
    mem.read_cur1_shift[pos] = mem.is_read_start[pos] + FP_CONST_2 * (mem.is_read_0[pos] + mem.is_read_1[pos])
    mem.read_cur0_disc[pos], mem.read_cur1_disc[pos], mem.read_one_disc[pos] = attend(qry_list = [mem.one], key_list = [mem.is_read_key], val_list = [mem.read_cur0_shift, mem.read_cur1_shift, mem.is_read_start])
    mem.read_cur0_norm[pos], mem.read_cur0_one_norm[pos] = normalize(mem.read_cur0_disc[pos], mem.read_one_disc[pos])
    mem.read_cur1_norm[pos], mem.read_cur1_one_norm[pos] = normalize(mem.read_cur1_disc[pos], mem.read_one_disc[pos])
    mem.read_retr0[pos], mem.read_retr0_cur_norm[pos], mem.read_retr0_is_write[pos] = attend(qry_list = [mem.one, mem.read_cur0_norm, mem.read_cur0_one_norm, mem.pos_enc_1], key_list = [mem.is_A_write, mem.A_cur_norm, mem.A_cur_one_norm, mem.pos_enc_0_neg], val_list = [mem.A_write, mem.A_cur_norm, mem.is_A_write])
    mem.read_retr1[pos], mem.read_retr1_cur_norm[pos], mem.read_retr1_is_write[pos] = attend(qry_list = [mem.one, mem.read_cur1_norm, mem.read_cur1_one_norm, mem.pos_enc_1], key_list = [mem.is_A_write, mem.A_cur_norm, mem.A_cur_one_norm, mem.pos_enc_0_neg], val_list = [mem.A_write, mem.A_cur_norm, mem.is_A_write])
    mem._read_retr0_not_found[pos], = normalize(mem.read_cur0_norm[pos] - mem.read_retr0_cur_norm[pos])
    mem.read_retr0_not_found[pos] = relu(mem._read_retr0_not_found[pos]) + relu(-mem._read_retr0_not_found[pos])
    mem._read_retr1_not_found[pos], = normalize(mem.read_cur1_norm[pos] - mem.read_retr1_cur_norm[pos])
    mem.read_retr1_not_found[pos] = relu(mem._read_retr1_not_found[pos]) + relu(-mem._read_retr1_not_found[pos])
    mem.read_retr0_val[pos] = relu(mem.read_retr0[pos] - mem.read_retr0_not_found[pos] - FP_CONST_1 + mem.read_retr0_is_write[pos])
    mem.read_retr1_val[pos] = relu(mem.read_retr1[pos] - mem.read_retr1_not_found[pos] - FP_CONST_1 + mem.read_retr1_is_write[pos])
    mem.next_read_0[pos] = FP_CONST_2 * relu(mem.is_read_key[pos] + mem.read_retr0_val[pos] - mem.read_retr1_val[pos] - FP_CONST_1)
    mem.next_read_1[pos] = FP_CONST_2 * relu(mem.is_read_key[pos] + mem.read_retr0_val[pos] + mem.read_retr1_val[pos] - FP_CONST_2)
    mem.next_end[pos] = FP_CONST_2 * relu(mem.is_read_key[pos] - mem.read_retr0_val[pos])
    mem.logits[pos] = {'AL': mem.token_is_Al[pos], 'AR': mem.token_is_Ar[pos], 'A0': mem.token_is_A0[pos], 'A1': mem.token_is_A1[pos], 'BL': mem.token_is_Bl[pos], 'BR': mem.token_is_Br[pos], 'B0': mem.token_is_B0[pos], 'B1': mem.token_is_B1[pos], '/': mem.token_is_goto_unsat[pos], '=': mem.token_is_goto_sat[pos], '-': mem.token_is_goto_prev[pos], '+': mem.token_is_goto_next[pos], '@': mem.token_is_goto_do[pos], ':': mem.token_is_halt[pos], '0': mem.next_read_0[pos], '1': mem.next_read_1[pos], '$': mem.next_end[pos]}
    return max(mem.logits[pos].items(), key = lambda char_logit: char_logit[1])[0]

class Tfm:
    def __init__(self, prog: list, input: str, verbose: bool = True):
        self.prog = prog
        self.input = input
        self.verbose = verbose

    def prepare(self):
        prompt = ['^'] + format_prog(self.prog, absolute = False, return_list = True) + ['$']
        input = self.input
        if input:
            i = len(input) * 2
            prompt.extend(['AR'] * i)
            for char in reversed(input):
                i += 1
                prompt.append('AL')
                if char == '1':
                    i += 1
                    prompt.append('A1')
                i += 2
                prompt.extend(['AL', 'A1'])
            prompt.append('=')
            prompt.extend(['-'] * i)
            prompt.append('@')
        self.prompt = prompt

    def execute(self):
        self.seq = []
        self.mem = DefaultDict(lambda: defaultlist(FP_DTYPE))
        for char in self.prompt[: -1]:
            self.seq.append(char)
            tfm_compute(seq = self.seq, mem = self.mem)
        self.seq.append(self.prompt[-1])
        if self.verbose:
            print(''.join(self.prompt), end = '', flush = True)
        while True:
            self.seq.append(tfm_compute(seq = self.seq, mem = self.mem))
            if self.verbose:
                print(self.seq[-1], end = '', flush = True)
            if self.seq[-1] == '$':
                break
        if self.verbose:
            print('', flush = True)

    def readout(self) -> str:
        seq = ''.join(self.seq)
        return seq[seq.rindex(':') + 1 : -1]

## Test 1: Parity

In [5]:
import random
random.seed(args.seed)

def decide_parity(input):
    return str(1 - sum(map(int, input)) % 2)

print('prog:', format_prog(PROG_PARITY := [
    'A!+4', 'AR', 'AR', 'A?-2', 'A1', 'AR', 'A1', 'A?+3', # 0--7: find the end of input and initialize the parity cell
    'AL', 'AL', 'AL', 'AL', 'AL', 'A?+2', '#', # 8--14: halt if there is no remaining cell
    'AR', 'AR', 'A0', 'AR', 'A!-11', # 15--19: if the current parity is 0, just move on
    'A0', 'AL', 'AL', 'A?+3', 'A1', 'A?+2', 'A0', 'AL', 'A?-17', # 20--28: if the current parity is 1, clean the parity cell and flip the preceding cell
], absolute = True))

tester = Tester(func = decide_parity, prog = PROG_PARITY, inputs = [''] + [
    ''.join(random.choices('01', k = 10))  if i % 2 else (''.join(random.choices('01', k = 9)) + '0')
    for i in range(10)
])
tester.test(Tfm)

prog: A!4@ARARA?1@A1ARA1A?10@ALALALALALA?15@#ARARA0ARA!8@A0ALALA?26@A1A?27@A0ALA?11@
<engine=Tfm>
[#0] input=, answer=1
^A!++++@ARARA?--@A1ARA1A?+++@ALALALALALA?++@#ARARA0ARA!-----------@A0ALALA?+++@A1A?++@A0ALA?-----------------@$=++++@A1ARA1=+++@ALALAL/:1$
(AC) input=, output=1
[#1] input=1011110010, answer=1
^A!++++@ARARA?--@A1ARA1A?+++@ALALALALALA?++@#ARARA0ARA!-----------@A0ALALA?+++@A1A?++@A0ALA?-----------------@$ARARARARARARARARARARARARARARARARARARARARALALA1ALA1ALA1ALALA1ALALA1ALA1ALA1ALA1ALA1ALA1ALA1ALA1ALA1ALALA1ALA1ALA1=--------------------------------------------------------@/ARAR=--@ARAR=--@ARAR=--@ARAR=--@ARAR=--@ARAR=--@ARAR=--@ARAR=--@ARAR=--@ARAR/A1ARA1=+++@ALALAL=++@ARARA0AR/A0ALAL/A1=++@AL=-----------------@ALAL=++@ARARA0AR/A0ALAL=+++@A0AL=-----------------@ALAL=++@ARARA0AR=-----------@ALALALALAL=++@ARARA0AR=-----------@ALALALALAL=++@ARARA0AR=-----------@ALALALALAL=++@ARARA0AR/A0ALAL=+++@A0AL=-----------------@ALAL=++@ARARA0AR=-----------@ALALALALAL=++@ARARA0AR/A0ALA

## Test 2: Dyck

In [6]:
import random
random.seed(args.seed)

def decide_dyck(input):
    s = 0
    for char in input:
        if char == '0':
            s += 1
        else:
            s -= 1
            if s < 0:
                return '0'
    return str(int(s == 0))

print('prog:', format_prog(PROG_DYCK := [
    'A?+14', 'A0', 'AL', 'A0', 'AL', 'A?-4', 'AR', 'AR', 'A1', 'AR', 'BL', 'B?+2', 'A1', '#', # 0--13: no more char, so return 1 if the counter is 0
    'AR', 'A?+4', 'B1', 'BR', 'B!+3', 'BL', 'B!+4', 'B0', 'AR', 'B!-23', # 14--23: increment the counter for (, or decrement it for )
    'AL', 'AR', 'AR', 'A?-2', 'A0', 'AL', 'A0', 'AL', 'A?-4', 'AR', 'AR', 'A1', '#', # 24--36: invalid, so clean up and return 0
], absolute = True))

def rand_parens(size: int, valid: bool = True, par: str = '()', sep = ''): # return list if sep is None
    if valid:
        size += size % 2
        seq = []
        cnt = 0
        for i in range(size):
            if cnt == 0 or (cnt < size - i and random.randint(0, 1)):
                if cnt == 0:
                    seq.append([])
                seq[-1].append(par[0])
                cnt += 1
            else:
                seq[-1].append(par[1])
                cnt -= 1
        seq = [''.join(subseq) for subseq in seq]
        if sep is not None:
            seq = sep.join(seq)
    else:
        assert sep == ''
        seq = rand_parens(size = (size + 1) // 2 - 1, valid = True, par = par) + par[1] + par[0] + rand_parens(size = (size + 1) // 2 - 1, valid = True, par = par)
    return seq

tester = Tester(func = decide_dyck, prog = PROG_DYCK, inputs = ['', '0', '1', '00', '11', '010', '101', '0010110', '0010111'] + [
    rand_parens(size = 10, valid = bool(i % 2), par = '01') for i in range(10)
])
tester.test(Tfm)

prog: A?14@A0ALA0ALA?1@ARARA1ARBLB?13@A1#ARA?19@B1BRB!21@BLB!24@B0ARB!0@ALARARA?25@A0ALA0ALA?28@ARARA1#
<engine=Tfm>
[#0] input=, answer=1
^A?++++++++++++++@A0ALA0ALA?----@ARARA1ARBLB?++@A1#ARA?++++@B1BRB!+++@BLB!++++@B0ARB!-----------------------@ALARARA?--@A0ALA0ALA?----@ARARA1#$/A0ALA0AL/ARARA1ARBL/A1:1$
(AC) input=, output=1
[#1] input=0, answer=0
^A?++++++++++++++@A0ALA0ALA?----@ARARA1ARBLB?++@A1#ARA?++++@B1BRB!+++@BLB!++++@B0ARB!-----------------------@ALARARA?--@A0ALA0ALA?----@ARARA1#$ARARALALA1=-----@=++++++++++++++@AR/B1BR=+++@B0AR=-----------------------@/A0ALA0AL=----@A0ALA0AL/ARARA1ARBL=++@:0$
(AC) input=0, output=0
[#2] input=1, answer=0
^A?++++++++++++++@A0ALA0ALA?----@ARARA1ARBLB?++@A1#ARA?++++@B1BRB!+++@BLB!++++@B0ARB!-----------------------@ALARARA?--@A0ALA0ALA?----@ARARA1#$ARARALA1ALA1=------@=++++++++++++++@AR=++++@BL=++++@ALARAR/A0ALA0AL=----@A0ALA0AL/ARARA1:0$
(AC) input=1, output=0
[#3] input=00, answer=0
^A?++++++++++++++@A0ALA0ALA?----@ARARA1ARBLB?++@A1#ARA?++++