In [None]:
import numpy as np
import os
import torch


# generate all sequences by composition
n = 20
allow_symmetry = True
all_seq_by_frac = {k: set() for k in range(n+1)}
limit = 2**n
for i in range(limit):
    sequence = bin(i)[2:].zfill(n)
    mirror_sequence = sequence[::-1]
    if sequence <= mirror_sequence or allow_symmetry:
        all_seq_by_frac[sequence.count('1')].add(sequence)

# create a master list of all possible sequences
all_sequences = []
for k, v in all_seq_by_frac.items():
    all_sequences += v
print(f'generated {len(all_sequences)} sequences')        

possible_sequences = np.array(list(all_seq_by_frac[8]))

# load the model ensemble
model_ensemble = []
model_path = os.path.join('models', 'gru-opt-cv10-sym')
for i in range(10):
    model = torch.jit.load(os.path.join(model_path, f'fold-{i:02d}-scripted.pt'), map_location='cpu')
    model.eval()
    model_ensemble.append(model)

In [None]:
from sklearn import ensemble, model_selection


class EnsembleRFModel(object):

    def __init__(self, n_models=3, constructor=ensemble.RandomForestRegressor, model_params={}, symmetrize=False):
        self.n_models = n_models
        self.model_params = model_params
        self.symmetrize = symmetrize
        self.models = [constructor(**self.model_params) for _ in range(self.n_models)]

    def fit(self, x, y):
        cv = model_selection.KFold(n_splits=self.n_models, shuffle=True, random_state=0)
        for i, (fold_train_idx, fold_val_idx) in enumerate(cv.split(x)):
            if self.symmetrize:
                this_x = x[fold_train_idx]
                this_x = np.vstack([this_x, [it[::-1] for it in this_x]])
                this_y = y[fold_train_idx]
                this_y = np.hstack([this_y, this_y])
                _ = self.models[i].fit(this_x, this_y)
            else:
                _ = self.models[i].fit(x[fold_train_idx], y[fold_train_idx])
        return self

    def predict(self, x, return_std=False):
        
        if self.symmetrize:
            all_yf = np.vstack([it.predict(x).reshape(1, -1) for it in self.models])
            all_yr = np.vstack([it.predict([it[::-1] for it in x]).reshape(1, -1) for it in self.models])
            all_y = 0.5 * (all_yf + all_yr)
        else:
            all_y = np.vstack([it.predict(x).reshape(1, -1) for it in self.models])

        if return_std:
            return np.mean(all_y, axis=0), np.std(all_y, axis=0)
        else:
            return np.mean(all_y, axis=0)

In [None]:
import itertools
import tqdm.notebook


def swap_monomers(seq):
    seq = seq.replace('A', 'C')
    seq = seq.replace('B', 'A')
    seq = seq.replace('C', 'B')
    return seq


def make_base(degree, verbose=False):
    # construct unique numbers of A/B monomers
    base = []
    for i in range(degree+1):
        for j in range(i):
            seq = ''.join(['A' for _ in range(i-j)] + ['B' for _ in range(j)])
            base += [''.join(x) for x in itertools.permutations(seq)]
            base += [''.join(x) for x in itertools.permutations(swap_monomers(seq))]
    base = sorted(set(base))
    pruned_base = []
    for b in base:
        if b not in pruned_base and b[::-1] not in pruned_base:
            pruned_base.append(b)
    base = pruned_base
    if verbose:
        print(f'Finding {len(base)} patterns:', base)

    return base


def featurize(chain_sequences, base, symmetric=False, verbose=False):
    x = np.zeros([len(chain_sequences), len(base)])
    pbar = tqdm.notebook.tqdm(enumerate(chain_sequences),
                              total=len(chain_sequences),
                              disable=(not verbose))
    for i, chain in pbar:
        if 'A' in str(chain):
            seq = chain
        else:
            seq = ''.join(['A' if x == 0 else 'B' for x in chain])
        x[i] = np.array([seq.count(b) for b in base])
        if symmetric:
            x[i] += np.array([seq[::-1].count(b) for b in base])

    return x


base = make_base(degree=4)
possible_tokens = featurize(possible_sequences, base, symmetric=True)

In [None]:
batch_prompt = "Here\n<result>\nNote"  # an empty/fake prompt to facilitate the message_utils
AB = {'A': 0, 'B': 1}
all_arrs = np.array([[int(x) for x in s] for s in possible_sequences])

In [None]:
import json
import message_utils, model_utils
from target_defs import archetype_predictions, archetype_sequences
import time
import tqdm


arch_morphs = list(archetype_predictions.keys())
for morph in arch_morphs:
    
    target = archetype_predictions[morph]
    
    rollouts = []
    
    n_batch = 5
    n_init = 5
    n_total = 10
    xi = 0
    
    use_tokens = False
    symmetrize = False
    use_seed = False
    
    params = {'n_batch': n_batch,
            'n_init': n_init,
            'xi': xi,
            'use_tokens': use_tokens,
            'symmetrize': symmetrize,
            'target': target.tolist(),
            'morph': morph,
            'use_seed': use_seed}
    
    start_time = int(time.time())
    for ridx in range(5):
        fake_payload = [{"role": "user", "content": [{"type": "text", "text": "N/A"}]}]
        
        # rf = ensemble.RandomForestRegressor(random_state=0)
        rf = EnsembleRFModel(model_params={'random_state': ridx, 'n_estimators': 24, 'max_depth': 6}, symmetrize=(symmetrize and not use_tokens))
        
        rng = np.random.RandomState(ridx)
        init_idx = rng.choice(np.arange(len(possible_sequences)), n_init, replace=False)
        init_bitstr = [possible_sequences[it] for it in init_idx]
        if use_seed:
            init_bitstr[0] = archetype_sequences[morph]
        init_sequences = [it.replace('0', 'A').replace('1', 'B') for it in init_bitstr]
        
        if use_tokens:
            x_so_far = featurize(init_sequences, base, symmetric=True)
        else:
            x_so_far = np.array([[int(AB[x]) for x in s] for s in init_sequences])
        
        out = model_utils.evaluate_sequences(init_sequences, target, model_ensemble)
        fake_payload.append(message_utils.build_user_message(batch_prompt, out))
        y_so_far = np.array([float(it.split(':')[1]) for it in out.split('\n')])
        
        _ = rf.fit(x_so_far, y_so_far)
        
        is_avail = np.ones(len(possible_sequences)).astype(bool)
        is_avail[init_idx] = False
        
        for _ in tqdm.tqdm(range(n_total-n_init//n_batch)):
            
            if use_tokens:
                mu, sigma = rf.predict(possible_tokens, return_std=True)
            else:
                mu, sigma = rf.predict(all_arrs, return_std=True)
    
            o = np.hstack([np.argsort(mu - xi * sigma)])
            next_sequences = []
            k = 0
            while len(next_sequences) < n_batch:
                if is_avail[o[k]]:
                    next_bitstr = possible_sequences[o[k]]
                    seq = next_bitstr.replace('0', 'A').replace('1', 'B')
                    next_sequences.append(seq)
                    is_avail[o[k]] = False
                k += 1
            
            if use_tokens:
                next_x = featurize(next_sequences, base, symmetric=True)
            else:
                next_x = np.array([[int(AB[x]) for x in s] for s in next_sequences])
            x_so_far = np.vstack([x_so_far, next_x])
            
            out = model_utils.evaluate_sequences(next_sequences, target, model_ensemble)
            fake_payload.append(message_utils.build_user_message(batch_prompt, out))
            y_so_far = np.hstack([y_so_far, np.array([float(it.split(':')[1]) for it in out.split('\n')])])
            
            _ = rf.fit(x_so_far, y_so_far)
        
        buffer = {'params': params, 'messages': fake_payload}
        suffix = str(ridx)
        logfile = f'logs/active-learning-{start_time}{suffix}.json'
        with open(logfile, 'w') as fid:
            json.dump(buffer, fid)

        rollouts.append(fake_payload)