In [6]:
import os
import torch


# load the model ensemble
model_ensemble = []
model_path = os.path.join('models', 'gru-opt-cv10-sym')
for i in range(10):
    model = torch.jit.load(os.path.join(model_path, f'fold-{i:02d}-scripted.pt'), map_location='cpu')
    model.eval()
    model_ensemble.append(model)

In [7]:
import numpy as np

# generate all sequences by composition
n = 20
allow_symmetry = False
all_seq_by_frac = {k: set() for k in range(n+1)}
limit = 2**n
for i in range(limit):
    sequence = bin(i)[2:].zfill(n)
    mirror_sequence = sequence[::-1]
    if sequence <= mirror_sequence or allow_symmetry:
        all_seq_by_frac[sequence.count('1')].add(sequence)

# create a master list of all possible sequences
all_sequences = []
for k, v in all_seq_by_frac.items():
    all_sequences += v
print(f'generated {len(all_sequences)} sequences')        

possible_sequences = np.array(sorted(all_seq_by_frac[8]))  # without sort the order is NOT guaranteed

generated 524800 sequences


In [21]:
from deap import algorithms, base, creator, tools
import json
import message_utils, model_utils
import numpy as np
import random
from target_defs import archetype_predictions, archetype_sequences
import time


batch_prompt = "Here\n<result>\nNote"  # an empty/fake prompt to facilitate the message_utils
use_feasibility = False

arch_morphs = list(archetype_predictions.keys())
for morph in arch_morphs:

    print(f'computing results for {morph}...')
    target = archetype_predictions[morph]
    
    start_time = int(time.time())
    for ridx in range(5):
        
        individuals = [] 
        
        delta = 0.0
        
        def evalSeq(individual):
            sequence = ''.join([str(it) for it in individual]).replace('0', 'A').replace('1', 'B')
            out = model_utils.evaluate_sequences([sequence], target, model_ensemble)
            
            score = [float(it.split(':')[1]) for it in out.split('\n')][0]
            if not use_feasibility:  # only add this to score if not using feasibility
                score += delta + np.abs(np.sum(individual) - 8.0)
            
            individuals.append((individual, score, feasible(individual)))
            return score,

        def feasible(individual):
            """Feasibility function for the individual. Returns True if feasible False
            otherwise."""
            if np.sum(individual) == 8:
                return True
            return False

        def distance(individual):
            """A distance function to the feasibility region."""
            return (np.sum(individual) - 8.0)**2
        
        n_init = 5
        n_batch = 5
        use_seed = True
        
        fake_payload = [{"role": "user", "content": [{"type": "text", "text": "N/A"}]}]
        rng = np.random.RandomState(ridx)
        init_idx = rng.choice(np.arange(len(possible_sequences)), n_init, replace=False)
        init_bitstr = [possible_sequences[it] for it in init_idx]
        if use_seed:
            init_bitstr[0] = archetype_sequences[morph].replace('A', '0').replace('B', '1')
        init_pop = [[int(x) for x in it] for it in init_bitstr]
        
        random.seed(ridx)
        
        creator.create("FitnessMin", base.Fitness, weights=(-1.0,))
        creator.create("Individual", list, fitness=creator.FitnessMin)
        
        toolbox = base.Toolbox()
        toolbox.register("attr_bool", random.randint, 0, 1)
        toolbox.register("individual", tools.initRepeat, creator.Individual, 
            toolbox.attr_bool, 20)
        toolbox.register("population", tools.initRepeat, list, toolbox.individual)
        toolbox.register("evaluate", evalSeq)
        
        if use_feasibility:
            toolbox.decorate("evaluate", tools.DeltaPenalty(feasible, delta, distance))
        
        toolbox.register("mate", tools.cxTwoPoint)
        toolbox.register("mutate", tools.mutFlipBit, indpb=0.10)
        toolbox.register("select", tools.selTournament, tournsize=3)
        
        pop = toolbox.population(n=(n_batch - n_init))
        pop += [creator.Individual(p) for p in init_pop]
        if use_feasibility:
            not_feasible_init = [not feasible(it) for it in pop]
        else:
            not_feasible_init = 0
        
        hof = tools.HallOfFame(1000)
        stats = tools.Statistics(lambda ind: ind.fitness.values)
        stats.register("avg", np.mean)
        stats.register("std", np.std)
        stats.register("min", np.min)
        stats.register("max", np.max)
        
        cxpb, mutpb, ngen = 0.5, 0.2, 75
        pop, log = algorithms.eaSimple(pop, toolbox, cxpb=cxpb, mutpb=mutpb, ngen=ngen, 
                                       stats=stats, halloffame=hof, verbose=True)
        
        print(sum([it['nevals'] for it in log]))
        
        ind_by_gen = []
        k = 0
        for gen in log:
            n = gen['nevals']
            if k == 0:
                n -= np.sum(not_feasible_init)
            ind_by_gen.append(individuals[k:(k+n)])
            k += n
        
        params = {'n_batch': n_batch,
                  'n_init': n_init,
                  'target': target.tolist(),
                  'morph': morph,
                  'use_seed': use_seed,
                  'cxpb': cxpb,
                  'mutpb': mutpb,
                  'ngen': ngen,
                  'delta': delta}
        
        fake_payload = [{"role": "user", "content": [{"type": "text", "text": "N/A"}]}]
        old_sequences = []
        for gen in ind_by_gen:
            sequences = [''.join([str(x) for x in it[0]]).replace('0', 'A').replace('1', 'B') for it in gen]
            sequences = [it for it in sequences if it not in old_sequences]
            if len(sequences) == 0:
                continue
            out = model_utils.evaluate_sequences(sequences, target, model_ensemble)
            fake_payload.append(message_utils.build_user_message(batch_prompt, out))
            old_sequences += sequences
        
        if len(old_sequences) < 50:
            raise RuntimeError('Failed to generate 50 unique sequences')
            
        param_hash = message_utils.hash_dict(params)
        buffer = {'params': params, 'messages': fake_payload}
        suffix = str(ridx)
        seed_hash = 'seeded' if use_seed else 'unseeded'
        logdir = f'data/llm-logs/{seed_hash}/evolutionary/{morph}/'
        logfile = os.path.join(logdir, f'deap-{param_hash}-{start_time}{suffix}.json')
        if not os.path.isdir(logdir):
            os.mkdir(logdir)
        with open(logfile, 'w') as fid:
            json.dump(buffer, fid)

computing results for wormlike micelle...
gen	nevals	avg   	std    	min	max  
0  	5     	7.3178	6.14828	0  	17.97
1  	3     	4.7972	5.20781	0.189	14.894
2  	4     	8.9352	8.96595	0.189	22.637
3  	3     	1.411 	2.444  	0.189	6.299 
4  	3     	1.3958	2.4136 	0.189	6.223 
5  	1     	1.076 	1.774  	0.189	4.624 
6  	2     	1.1764	1.25086	0.189	3.163 
7  	2     	0.189 	0      	0.189	0.189 
8  	5     	2.2372	2.53086	0.189	5.84  
9  	4     	2.3022	2.81015	0.189	7.203 
10 	3     	1.423 	1.54875	0.189	3.809 
11 	3     	2.5506	4.7232 	0.189	11.997
12 	5     	5.1836	4.84464	0.189	11.406
13 	4     	1.298 	1.49416	0.189	3.946 
14 	2     	0.5466	0.7152 	0.189	1.977 
15 	1     	1.8186	3.2592 	0.189	8.337 
16 	2     	0.189 	0      	0.189	0.189 
17 	4     	0.214 	0.05   	0.189	0.314 
18 	4     	1.595 	2.812  	0.189	7.219 
19 	2     	0.189 	0      	0.189	0.189 
20 	2     	0.5306	0.6832 	0.189	1.897 
21 	4     	0.189 	0      	0.189	0.189 
22 	4     	0.8102	1.2424 	0.189	3.295 
23 	2     	0.189 	0      	0.