In [1]:
# generate all sequences by composition
n = 20
allow_symmetry = True
all_seq_by_frac = {k: set() for k in range(n+1)}
limit = 2**n
for i in range(limit):
    sequence = bin(i)[2:].zfill(n)
    mirror_sequence = sequence[::-1]
    if sequence <= mirror_sequence or allow_symmetry:
        all_seq_by_frac[sequence.count('1')].add(sequence)

# create a master list of all possible sequences
all_sequences = []
for k, v in all_seq_by_frac.items():
    all_sequences += v
print(f'generated {len(all_sequences)} sequences')
possible_sequences = all_seq_by_frac[8]
print(f'choosing from {len(possible_sequences)} sequences')

generated 1048576 sequences
choosing from 125970 sequences


In [2]:
import os
import torch

# load the model ensemble
model_ensemble = []
model_path = os.path.join('models', 'gru-opt-cv10-sym')
for i in range(10):
    model = torch.jit.load(os.path.join(model_path, f'fold-{i:02d}-scripted.pt'), map_location='cpu')
    model.eval()
    model_ensemble.append(model)

In [3]:
import json
import model_utils, message_utils
import numpy as np
from target_defs import archetype_predictions
import time


n_batch = 5
n_iter = 10
batch_prompt = "Here\n<result>\nNote"  # an empty/fake prompt to facilitate the message_utils
  
arch_morphs = list(archetype_predictions.keys())
for morph in arch_morphs:

    target = archetype_predictions[morph]

    params = {'n_batch': n_batch,
            'target': target.tolist(),
            'morph': morph,
              }

    start_time = time.time()
    for ridx in range(1):
        
        fake_payload = [{"role": "user", "content": [{"type": "text", "text": "N/A"}]}]
    
        these_bitstr = np.random.choice(list(possible_sequences), n_batch*n_iter, replace=False)
        these_sequences = [it.replace('0', 'A').replace('1', 'B') for it in these_bitstr]
        
        seq_by_iter = np.array_split(these_sequences, n_iter)
        
        for i in range(n_iter):
            out = model_utils.evaluate_sequences(seq_by_iter[i], target, model_ensemble)
            fake_payload.append(message_utils.build_user_message(batch_prompt, out))
        
        buffer = {'params': params, 'messages': fake_payload}
        suffix = str(ridx)
        logfile = f'logs/random-{start_time}{suffix}.json'
        with open(logfile, 'w') as fid:
            json.dump(buffer, fid)