In [2]:
%matplotlib inline
import os 
from collections import deque
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import editdistance
import sys
import RNA
from typing import Dict, List, Tuple

import torch
from torch import nn
from tqdm import tqdm_notebook as tqdm
import torch.nn.functional as F
import torch.multiprocessing as mp
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_

# import path 
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from utils.sequence_utils import translate_one_hot_to_string,generate_random_mutant
from models.Theoretical_models import *
from models.Noise_wrapper import *
from exploration_strategies.CE import *
from utils.landscape_utils import *
from models.RNA_landscapes import *
from models.Multi_dimensional_model import *

from segment_tree import MinSegmentTree, SumSegmentTree

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
RAA="UGCA" #alphabet
length=40
wt=generate_random_sequences(length,1,alphabet=RAA)[0]
print(wt)
#make a simple folding landscape starting at wt
landscape1=RNA_landscape(wt)
noise_alpha=1
virtual_per_measure_ratio=15
temperature=0.1
# there are multiple abstract "noise models" you can use, or you can try to train your own model, using skM
noisy_landscape_CE=Noise_wrapper(landscape1,noise_alpha=noise_alpha)
noisy_landscape_RL=Noise_wrapper(landscape1,noise_alpha=noise_alpha)
noisy_landscape_RL_multiple=Noise_wrapper(landscape1,noise_alpha=noise_alpha)
#noisy_landscape=Gaussian_noise_landscape(base_landscape,noise_alpha=0.15)
#noisy_landscape=DF_noise_landscape(base_landscape,noise_alpha=0.5)
batch_size = 1000
initial_genotypes=list(set([wt]+[generate_random_mutant(wt,0.05,RAA) for i in range(batch_size)]))[:batch_size]

CAGGCAGGUUGCACCGUAUGCGACGUUAAUGUAAGGACCA


In [4]:
import numpy as np
import random
import bisect
from utils.sequence_utils import translate_string_to_one_hot, translate_one_hot_to_string

def renormalize_moves(one_hot_input, rewards_output):
    """ensures that staying in place gives no reward"""
    zero_current_state = (one_hot_input - 1) * (-1)
    return np.multiply(rewards_output, zero_current_state)

def walk_away_renormalize_moves(one_hot_input, one_hot_wt, rewards_output):
    """ensures that moving toward wt is also not useful"""
    zero_current_state=(one_hot_input-1)*-1
    zero_wt=((one_hot_wt-1)*-1)
    zero_conservative_moves=np.multiply(zero_wt,zero_current_state)
    return np.multiply(rewards_output,zero_conservative_moves)

def get_all_singles_fitness(model,sequence,alphabet):
    prob_singles=np.zeros((len(alphabet),len(sequence)))
    for i in range(len(sequence)):
        for j in range(len(alphabet)):
            putative_seq=sequence[:i]+alphabet[j]+sequence[i+1:]
           # print (putative_seq)
            prob_singles[j][i]=model.get_fitness(putative_seq)
    return prob_singles

def get_all_mutants(sequence):
    mutants = []
    for i in range(sequence.shape[0]):
        for j in range(sequence.shape[1]):
            putative_seq = sequence.copy()
            putative_seq[:, j] = 0
            putative_seq[i, j] = 1
            mutants.append(putative_seq)
    return np.array(mutants)

def sample_greedy(matrix):
    i,j=matrix.shape
    max_arg=np.argmax(matrix)
    y=max_arg%j
    x=int(max_arg/j)
    output=np.zeros((i,j))
    output[x][y]=matrix[x][y]
    return output

def sample_multi_greedy(matrix):
    n = 5 # the number of base positions to greedily change
    max_args = np.argpartition(matrix.flatten(), -n)[-n:]
    i,j=matrix.shape
    output=np.zeros((i,j))
    for max_arg in max_args:
        y=max_arg%j
        x=int(max_arg/j)
        output[x][y]=matrix[x][y]
    return output

def sample_random(matrix):
    i,j=matrix.shape
    non_zero_moves=np.nonzero(matrix)
   # print (non_zero_moves)
    k=len(non_zero_moves)
    l=len(non_zero_moves[0])
    if k!=0 and l!=0:
        rand_arg=random.choice([[non_zero_moves[alph][pos] for alph in range(k)] for pos in range(l)])
    else:
        rand_arg=[random.randint(0,i-1),random.randint(0,j-1)]
    #print (rand_arg)
    y=rand_arg[1]
    x=rand_arg[0]
    output=np.zeros((i,j))
    output[x][y] = 1
    return output   

def action_to_scalar(matrix):
    matrix = matrix.ravel()
    for i in range(len(matrix)):
        if matrix[i] != 0:
            return i
    
def construct_mutant_from_sample(pwm_sample, one_hot_base):
    one_hot = np.zeros(one_hot_base.shape)
    one_hot += one_hot_base
    nonzero = np.nonzero(pwm_sample)
    nonzero = list(zip(nonzero[0], nonzero[1]))
    for nz in nonzero: # this can be problematic for non-positive fitnesses
        i, j = nz
        one_hot[:,j]=0
        one_hot[i,j]=1
    return one_hot

def best_predicted_new_gen(actor, genotypes, alphabet, pop_size):
    mutants = get_all_mutants(genotypes)
    one_hot_mutants = np.array([translate_string_to_one_hot(mutant, alphabet) for mutant in mutants])
    torch_one_hot_mutants = torch.from_numpy(np.expand_dims(one_hot_mutants, axis=0)).float()
    predictions = actor(torch_one_hot_mutants)
    predictions = predictions.detach().numpy()
    best_pred_ind = predictions.argsort()[-pop_size:]
    return mutants[best_pred_ind]

def make_one_hot_train_test(genotypes, model, alphabet):
    genotypes_one_hot = np.array([translate_string_to_one_hot(genotype, alphabet) for genotype in genotypes])
    genotype_fitnesses = []
    for genotype in genotypes:
        genotype_fitnesses.append(model.get_fitness(genotype))
    genotype_fitnesses = np.array(genotype_fitnesses)

    return genotypes_one_hot, genotype_fitnesses

## Test ability of model to predict

In [5]:
import keras 
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten, Conv1D, BatchNormalization, Flatten
from keras.regularizers import l1, l2
from sklearn.metrics import mean_squared_error
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import Lasso

model = Sequential()
model.add(Dense(160, activation='relu', input_shape=(160 + 160,)))
model.add(Dense(40, activation='relu', input_shape=(160,)))
model.add(Dense(1, activation='relu', input_shape=(40,)))

print(model.summary())

def construct_mutant_from_sample(pwm_sample, one_hot_base):
    one_hot = np.zeros(one_hot_base.shape)
    one_hot += one_hot_base
    nonzero = np.nonzero(pwm_sample)
    nonzero = list(zip(nonzero[0], nonzero[1]))
    for nz in nonzero: # this can be problematic for non-positive fitnesses
        i, j = nz
        one_hot[:,j]=0
        one_hot[i,j]=1
    return one_hot

def make_one_hot_train_test(genotypes, model, alphabet):
    genotypes_one_hot = np.array([translate_string_to_one_hot(genotype, alphabet) for genotype in genotypes])
    genotype_fitnesses = []
    for genotype in genotypes:
        genotype_fitnesses.append(model.get_fitness(genotype))
    genotype_fitnesses = np.array(genotype_fitnesses)

    return genotypes_one_hot, genotype_fitnesses

def build_genotype_data(samples):
    genotypes = [generate_random_mutant(wt, 0.05, RAA) for i in range(samples)]
    genotypes = np.array([translate_string_to_one_hot(mutant, RAA) for mutant in genotypes])
    actions = np.zeros((samples, 160))
    actions_ind = np.random.choice(160, size=samples)
    for i in range(samples):
        actions[i, actions_ind[i]] = 1
    actions = actions.reshape((samples, 4, 40))
    genotypes_after_action = [construct_mutant_from_sample(actions[i], genotypes[i]) 
                              for i in range(samples)]
    genotypes_after_action = [translate_one_hot_to_string(genotype, RAA) 
                              for genotype in genotypes_after_action]
    fitnesses = [noisy_landscape_RL.get_fitness(genotype) for genotype in genotypes_after_action]
    actions = actions.reshape((-1, 160))
    genotypes_one_hot = genotypes.reshape((-1, 160))
    combined = np.hstack((genotypes_one_hot, actions))
    return combined, fitnesses

train_data, train_labels = build_genotype_data(10000)
test_data, test_labels = build_genotype_data(5000)
model.compile(loss=keras.losses.mean_squared_error,
              optimizer=keras.optimizers.Adam(),
              metrics=['mse'])
model.fit(train_data, train_labels,
          batch_size=128,
          epochs=10,
          verbose=1,
          validation_data=(test_data, test_labels))
preds = model.predict(test_data)
print('NN MSE', mean_squared_error(preds, test_labels))

# lasso regression benchmark 
lasso = Lasso()
lasso.fit(train_data, train_labels)
preds = lasso.predict(test_data)
print('Lasso MSE', mean_squared_error(preds, test_labels))

# random forest benchmark
rf = RandomForestRegressor()
rf.fit(train_data, train_labels)
preds = rf.predict(test_data)
print('Random Forest MSE', mean_squared_error(preds, test_labels))

Using TensorFlow backend.
W0714 19:53:02.173346 139657751041856 deprecation_wrapper.py:119] From /home/alexander/anaconda3/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0714 19:53:02.296686 139657751041856 deprecation_wrapper.py:119] From /home/alexander/anaconda3/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0714 19:53:02.358793 139657751041856 deprecation_wrapper.py:119] From /home/alexander/anaconda3/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.



_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_1 (Dense)              (None, 160)               51360     
_________________________________________________________________
dense_2 (Dense)              (None, 40)                6440      
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 41        
Total params: 57,841
Trainable params: 57,841
Non-trainable params: 0
_________________________________________________________________
None


W0714 19:53:08.868296 139657751041856 deprecation_wrapper.py:119] From /home/alexander/anaconda3/lib/python3.7/site-packages/keras/optimizers.py:790: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.

W0714 19:53:10.626376 139657751041856 deprecation_wrapper.py:119] From /home/alexander/anaconda3/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:986: The name tf.assign_add is deprecated. Please use tf.compat.v1.assign_add instead.

W0714 19:53:10.695435 139657751041856 deprecation_wrapper.py:119] From /home/alexander/anaconda3/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:973: The name tf.assign is deprecated. Please use tf.compat.v1.assign instead.



Train on 10000 samples, validate on 5000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
NN MSE 0.006337005695032697
Lasso MSE 0.00047121443167561677




Random Forest MSE 0.00018418122808724896


In [6]:
from collections import deque
# from utils.RL_utils import *
from utils.sequence_utils import translate_one_hot_to_string,generate_random_mutant
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Q_Network(nn.Module):
    def __init__(self, sequence_len, alphabet_len):
        super(Q_Network, self).__init__()
        self.sequence_len = sequence_len
        self.alphabet_len = alphabet_len
        self.linear1 = nn.Linear(2 * alphabet_len * sequence_len, alphabet_len * sequence_len)
        self.bn1 = nn.BatchNorm1d(alphabet_len * sequence_len)
        self.linear2 = nn.Linear(alphabet_len * sequence_len, sequence_len)
        self.bn2 = nn.BatchNorm1d(sequence_len)
        self.linear3 = nn.Linear(sequence_len, 1)
    
    def forward(self, x):
        x = self.bn1(F.relu(self.linear1(x)))
        x = self.bn2(F.relu(self.linear2(x)))
        x = F.relu(self.linear3(x))
        return x
    
def build_q_network(sequence_len, alphabet_len, device):
    model = Q_Network(sequence_len, alphabet_len).to(device)
    print(model)
    return model


class RL_agent_DQN():
    '''
    Based off https://colab.research.google.com/drive/1NsbSPn6jOcaJB_mp9TmkgQX7UrRIrTi0
    '''
    def __init__(self, start_sequence, alphabet, gamma=0.9, 
                 memory_size=1000, batch_size=1000, experiment_batch_size=1000,
                 device = "cpu", noise_alpha=1):
        self.alphabet = alphabet
        self.alphabet_size = len(alphabet)
        self.state = translate_string_to_one_hot(start_sequence, self.alphabet)
        self.seq_size = len(start_sequence)
        self.q_network = build_q_network(self.seq_size, len(self.alphabet), device)
        self.q_network.eval()
        self.start_sequence = translate_string_to_one_hot(start_sequence,self.alphabet)
        self.memory_size = memory_size
        self.gamma = gamma
        self.batch_size = batch_size
        self.experiment_batch_size = experiment_batch_size
        self.memory = []
        self.seen_sequences = []
        self.landscape = Noise_wrapper(RNA_landscape(start_sequence), 
                                       noise_alpha=noise_alpha, always_costly=True)
        self.best_fitness = 0

    def reset_position(self,sequence):
        self.state=translate_string_to_one_hot(sequence,self.alphabet)

    def get_position(self):
        return translate_one_hot_to_string(self.state,self.alphabet)

    def translate_pwm_to_sequence(self,input_seq_one_hot,output_pwm):
        diff=output_pwm-input_seq_one_hot
        most_likely=np.argmax(diff,axis=0)
        out_seq=""
        for m in most_likely:
            out_seq+=self.alphabet[m]
        return out_seq
    
    def sample(self):
        indices = np.random.choice(len(self.memory), self.batch_size)
        rewards, actions, states, next_states = zip(*[self.memory[ind] for ind in indices])
        return np.array(rewards), np.array(actions), np.array(states), np.array(next_states) 
    
    def calculate_next_q_values(self, state_v):
        dim = self.alphabet_size * self.seq_size
        states_repeated = state_v.repeat(1, dim).reshape(-1, dim)
        actions_repeated = torch.FloatTensor(np.identity(dim)).repeat(len(state_v), 1)
        next_states_actions = torch.cat((states_repeated, actions_repeated), 1)
        next_states_values = self.q_network(next_states_actions)
        next_states_values = next_states_values.reshape(len(state_v), -1)
        
        return next_states_values
    
    def q_network_loss(self, batch, device="cpu"):
        """
        Calculate MSE between actual state action values,
        and expected state action values from DQN
        """
        rewards, actions, states, next_states = batch
        
        state_action_v = torch.FloatTensor(np.hstack((states, actions)))
        rewards_v = torch.FloatTensor(rewards)
        next_states_v = torch.FloatTensor(next_states)
    
        state_action_values = self.q_network(state_action_v).view(-1)
        next_state_values = self.calculate_next_q_values(next_states_v)
        next_state_values = next_state_values.max(1)[0].detach()
        expected_state_action_values = next_state_values * self.gamma + rewards_v
        
        return nn.MSELoss()(state_action_values, expected_state_action_values)

    def train_actor(self, train_epochs=10):
        total_loss = 0.
        # train Q network on new samples 
        optimizer = optim.Adam(self.q_network.parameters())
        for epoch in range(train_epochs):
            batch = self.sample()
            optimizer.zero_grad()
            loss = self.q_network_loss(batch)
            loss.backward()
            clip_grad_norm_(self.q_network.parameters(), 1.0, norm_type=1)
            optimizer.step()
            total_loss += loss.item()
        return (total_loss / train_epochs)

    def pick_action(self, epsilon):
        state_tensor = torch.FloatTensor([self.state.ravel()])
        prediction = self.calculate_next_q_values(state_tensor).detach().numpy()
        prediction = prediction.reshape((len(self.alphabet), self.seq_size))
        # make action
        moves = renormalize_moves(self.state, prediction)
        p = random.random()
        action = sample_random(moves) if p < epsilon else sample_greedy(moves)
        # get next state (mutant)
        mutant = construct_mutant_from_sample(action, self.state)
        mutant_string = translate_one_hot_to_string(mutant, self.alphabet)
        self.state = mutant

        return action, mutant
    
    def run_RL(self, generations=10, train_epochs=10, epsilon_min=0.1):
        while self.landscape.cost < self.experiment_batch_size*generations:
            eps = max(epsilon_min, (0.5 - self.landscape.cost / (self.experiment_batch_size * generations)))
            b = 0
            new = []
            while(b < self.experiment_batch_size):
                state = self.state.copy() 
                action, new_state = agent.pick_action(eps)
                new_state_string = translate_one_hot_to_string(new_state, self.alphabet)
                reward = self.landscape.get_fitness(new_state_string)
                if not new_state_string in self.landscape.measured_sequences:
                    if reward > self.best_fitness:
                        state_tensor = torch.FloatTensor([self.state.ravel()])
                        prediction = self.calculate_next_q_values(state_tensor).detach().numpy()
                        prediction = prediction.reshape((len(self.alphabet), self.seq_size))
                        print(prediction)
                    self.best_fitness = max(self.best_fitness, reward)
                    self.memory.append((reward, action.ravel(), state.ravel(), new_state.ravel()))
                    b += 1
            
            self.memory = sorted(self.memory, key = lambda x: x[0])
            avg_loss = agent.train_actor(train_epochs)
            print (self.landscape.cost, self.memory[-1][0], avg_loss)
        
        best_string = translate_one_hot_to_string(
            self.memory[-1][2].reshape(self.alphabet_size, self.seq_size), self.alphabet)
        return (best_string, self.memory[-1][0])

In [92]:
# benchmark
batch_size = 1000
generations = 20
agent = RL_agent_DQN(wt, alphabet=RAA, gamma=0, memory_size=10000, device=device)
agent.run_RL(generations=generations, train_epochs=40, epsilon_min=0.2)

Q_Network(
  (linear1): Linear(in_features=320, out_features=160, bias=True)
  (bn1): BatchNorm1d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear2): Linear(in_features=160, out_features=40, bias=True)
  (bn2): BatchNorm1d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear3): Linear(in_features=40, out_features=1, bias=True)
)
[[0.09849628 0.0980573  0.10156537 0.10032272 0.09564329 0.09486055
  0.09921077 0.09721259 0.09928502 0.10053869 0.09756131 0.09804366
  0.09946192 0.09666634 0.10001573 0.09647773 0.09691897 0.09808229
  0.09518256 0.09561601 0.09828803 0.09979536 0.09995455 0.09739996
  0.10088108 0.09770073 0.09435686 0.09720631 0.09816606 0.09897108
  0.09884693 0.09624495 0.09786817 0.09708573 0.09798492 0.09966846
  0.0986013  0.09680902 0.09925673 0.0969498 ]
 [0.09828604 0.09926087 0.0973086  0.09588106 0.09817992 0.09603836
  0.10054604 0.09753614 0.09649698 0.09743056 0.09737143 0.0984416
  0.0955975  0.09508903 

[[0.08599368 0.08575141 0.08544821 0.08705402 0.0833724  0.08486178
  0.08582024 0.08529055 0.08463743 0.08469973 0.0853029  0.08531225
  0.08566035 0.08361115 0.08661414 0.085941   0.08540709 0.08599079
  0.08512361 0.08589385 0.08626834 0.08844391 0.08745922 0.08513679
  0.08781096 0.08519876 0.08374922 0.08453084 0.08532017 0.08556687
  0.08442335 0.08920017 0.08376519 0.08740889 0.08549121 0.08793779
  0.08633094 0.08557513 0.08884195 0.08441208]
 [0.08568334 0.08493818 0.08612244 0.08635601 0.08743615 0.08403195
  0.08550511 0.08894008 0.08530397 0.08448776 0.08529653 0.08510256
  0.08536816 0.08540368 0.08535302 0.08656752 0.08808789 0.0864891
  0.08715577 0.08680648 0.0861918  0.08513808 0.08446181 0.08593494
  0.08484889 0.08649675 0.08451816 0.08618023 0.08439279 0.08367777
  0.08559471 0.08591068 0.085557   0.08634594 0.08527707 0.08437828
  0.08407574 0.08735018 0.08347726 0.08609459]
 [0.08863545 0.08640711 0.08612807 0.08818457 0.08818594 0.0873469
  0.08708733 0.08600853 

3088 0.2517647013944738 0.00038697067720931954
[[0.12687457 0.131329   0.14519061 0.13223119 0.12954338 0.13953322
  0.13226731 0.13645865 0.13563749 0.13777176 0.13345781 0.13599974
  0.13091075 0.14282592 0.12315944 0.12373605 0.12631385 0.14382392
  0.12803912 0.12507921 0.12844358 0.14274748 0.13832818 0.14924866
  0.14045253 0.12864271 0.12740864 0.13930438 0.1390669  0.12827355
  0.13834722 0.13300559 0.12598252 0.13190019 0.14211878 0.13681534
  0.12957522 0.1335283  0.13698827 0.12473497]
 [0.13494413 0.1456582  0.13973314 0.14221698 0.13641301 0.14316109
  0.1477069  0.15049164 0.14033997 0.15841115 0.13835104 0.14328767
  0.13157596 0.13593782 0.1300586  0.13524228 0.14503644 0.13729185
  0.14449178 0.13556051 0.13952917 0.14215624 0.13389587 0.15065517
  0.14343591 0.14274319 0.1371997  0.13785094 0.13590246 0.14490753
  0.13549446 0.14507395 0.13854015 0.13726164 0.13686983 0.14200974
  0.12661636 0.13199444 0.12882327 0.13101262]
 [0.13891788 0.12318632 0.13845183 0.141106

4263 0.3752941131591797 0.00037022692995378745
4716 0.3752941131591797 0.0003261293164541712
[[0.25082734 0.257146   0.2711579  0.25002658 0.24404663 0.26409522
  0.25316125 0.24544078 0.25855103 0.24647987 0.23735908 0.25718126
  0.23425066 0.24793226 0.23292619 0.23850802 0.2575135  0.2683945
  0.2498194  0.25039768 0.24829134 0.25194907 0.26135918 0.25108963
  0.251179   0.24192491 0.24832076 0.25642005 0.26444438 0.25326264
  0.2627028  0.23892656 0.23869944 0.25672618 0.2505945  0.2589484
  0.22207892 0.24159858 0.25688466 0.24697623]
 [0.26548842 0.2775159  0.27298    0.26047978 0.27127278 0.26019564
  0.27698356 0.28498274 0.26962924 0.2853014  0.25393426 0.27150768
  0.26269278 0.26301154 0.24352765 0.25596884 0.27821314 0.25970596
  0.27403858 0.26219186 0.2720495  0.2615572  0.25172338 0.29212454
  0.2739051  0.27967325 0.26612943 0.27410147 0.24963239 0.2532789
  0.2653725  0.26578024 0.2659829  0.25245324 0.2758364  0.27597907
  0.2525725  0.23724943 0.2392785  0.25240773]


5210 0.3964705972110524 0.0003304491281596711
5620 0.3964705972110524 0.0003337086676765466
5964 0.3964705972110524 0.0002847883901267778
6353 0.3964705972110524 0.0003538203163770959
6703 0.3964705972110524 0.0003608990031352732
[[0.3219269  0.32225755 0.30837664 0.32019693 0.318805   0.32391638
  0.32696456 0.3044304  0.32313788 0.31703869 0.30431628 0.3237266
  0.2972765  0.34047565 0.31244928 0.30468613 0.33037487 0.3218754
  0.32086012 0.33504978 0.32401785 0.31560677 0.31649527 0.32016847
  0.32588905 0.31318092 0.3237852  0.3232655  0.32122785 0.32479525
  0.32482773 0.32428625 0.30941856 0.33037093 0.3192626  0.32884032
  0.3071769  0.31734252 0.3351801  0.33181834]
 [0.33079213 0.36675546 0.36231074 0.35121733 0.30963883 0.3221323
  0.3731946  0.34933954 0.34106404 0.35689822 0.33267665 0.3591297
  0.33258015 0.31571272 0.32716614 0.34246212 0.35720715 0.3380249
  0.34251663 0.34497988 0.35463473 0.33401555 0.33248627 0.36091265
  0.3410445  0.3584005  0.31349432 0.33754238 0.

7382 0.43647057028377756 0.00038914132164791226
[[0.33266976 0.31631804 0.29887655 0.3206292  0.3062063  0.3250556
  0.3303041  0.31602278 0.32347652 0.32680985 0.3125226  0.32848716
  0.30697355 0.33742085 0.32378018 0.30596572 0.3362371  0.33519807
  0.33292863 0.3351625  0.33804613 0.32358778 0.3304815  0.31199813
  0.33247885 0.30686745 0.32430473 0.3256581  0.32575607 0.33117497
  0.3311129  0.32747716 0.3264378  0.34409466 0.31256896 0.3276029
  0.3003737  0.32273012 0.32729062 0.3369061 ]
 [0.33764932 0.38055348 0.37775487 0.3679203  0.30687785 0.32919434
  0.38158894 0.3523625  0.35132572 0.3591417  0.34513772 0.37766537
  0.34427655 0.32080132 0.333638   0.35748506 0.3701707  0.34458232
  0.37092242 0.36002937 0.36177906 0.34947845 0.3319211  0.37087137
  0.34957042 0.35873845 0.3284326  0.35947245 0.33612287 0.3550872
  0.36259627 0.33490887 0.32691985 0.33535716 0.38315323 0.36912075
  0.31806564 0.32833675 0.31920537 0.34801328]
 [0.33937436 0.31553707 0.29795676 0.32680815

8753 0.5047059003044577 0.0003954143154260237
9010 0.5047059003044577 0.0003248438868467929
9310 0.5047059003044577 0.00029600235320685896
[[0.45810083 0.4045487  0.3846604  0.39064983 0.43439457 0.4253162
  0.41735193 0.40245146 0.40744722 0.4240185  0.42706782 0.40886083
  0.4117223  0.4355599  0.41795048 0.43005654 0.4366906  0.44195503
  0.45521644 0.4567748  0.4456064  0.44080815 0.44692823 0.40880036
  0.4239996  0.40608233 0.4350908  0.43723252 0.4397817  0.44456962
  0.43576017 0.44585946 0.43564165 0.44425324 0.41084036 0.39209035
  0.4366552  0.4303943  0.44376054 0.45167968]
 [0.4739953  0.49093106 0.507313   0.50827974 0.40517727 0.4155694
  0.5084779  0.48130682 0.47758898 0.44278404 0.449099   0.5020021
  0.4202037  0.4471557  0.4066666  0.44135553 0.49229982 0.46791473
  0.48081878 0.46687308 0.46337947 0.45406792 0.4408954  0.4919589
  0.48486003 0.49611512 0.441717   0.49452937 0.45081988 0.47362316
  0.46597457 0.4116796  0.42595825 0.43386987 0.51470214 0.49430788
  

10877 0.5952940996955423 0.0004474141951504862
11184 0.5952940996955423 0.00032706506754038857
11462 0.5952940996955423 0.0004417754891619552
11702 0.5952940996955423 0.0003059738581214333
11930 0.5952940996955423 0.0002621670813823584
12119 0.5952940996955423 0.000293583155144006
12249 0.5952940996955423 0.00037506581575144084
12356 0.5952940996955423 0.00029172879221732727
12468 0.5952940996955423 0.00034513502650952433
12606 0.5952940996955423 0.00033839273328339914
12799 0.5952940996955423 0.000316920924524311
12942 0.5952940996955423 0.0002597338005216443
13087 0.5952940996955423 0.00025845355921774173
13242 0.5952940996955423 0.0002603638145956211
13348 0.5952940996955423 0.00022966175893088803
13513 0.5952940996955423 0.0003111951937171398
13684 0.5952940996955423 0.00022840354740765177
13813 0.5952940996955423 0.00014381993332790444
13969 0.5952940996955423 0.00028932570912729714
14084 0.5952940996955423 0.0002269593325763708
14208 0.5952940996955423 0.00018212925715488382
1433

IndexError: tuple index out of range

In [89]:
batch_size = 1000
generations = 20
agent = RL_agent_DQN(wt, alphabet=RAA, gamma=0.8, memory_size=10000, device=device)
agent.run_RL(generations=generations, train_epochs=40)

Q_Network(
  (linear1): Linear(in_features=320, out_features=160, bias=True)
  (bn1): BatchNorm1d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear2): Linear(in_features=160, out_features=40, bias=True)
  (bn2): BatchNorm1d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear3): Linear(in_features=40, out_features=1, bias=True)
)
[[0.10886748 0.1082293  0.10756049 0.11009314 0.10837949 0.10887124
  0.10647732 0.10948335 0.10864244 0.10753663 0.10887988 0.10718796
  0.1111569  0.107638   0.10908968 0.10950546 0.10755908 0.10818593
  0.10794933 0.10994601 0.1091218  0.11104104 0.10967084 0.11119866
  0.10673211 0.10675288 0.11108978 0.10861164 0.10942063 0.10810453
  0.1074435  0.10767917 0.10711201 0.1080021  0.10951816 0.11034293
  0.10781036 0.10862392 0.11116369 0.1081534 ]
 [0.10908154 0.11072846 0.1102531  0.10831727 0.10673152 0.10987413
  0.10925047 0.10944591 0.10806155 0.10778305 0.10817356 0.1091143
  0.11066602 0.10635109 

[[0.11340502 0.11348581 0.11434987 0.11294103 0.11511695 0.11529327
  0.1156567  0.11461997 0.11411908 0.10990737 0.11392678 0.11232558
  0.11595234 0.11343956 0.11422257 0.1128429  0.11198209 0.11477475
  0.1143647  0.11492611 0.11670383 0.11639124 0.11363545 0.11423774
  0.11398153 0.11497438 0.11729731 0.11179688 0.11351047 0.11440536
  0.11315665 0.11374545 0.11305211 0.11368693 0.11426721 0.11591186
  0.11345571 0.11408823 0.11598316 0.11607098]
 [0.11551539 0.11412752 0.11666215 0.11477638 0.11531121 0.11463407
  0.11485516 0.11490667 0.11176835 0.1152928  0.11208297 0.11074246
  0.11817511 0.11380734 0.11192022 0.11487076 0.11527473 0.11388115
  0.11546792 0.11405826 0.11668836 0.11492103 0.11443335 0.11267626
  0.11487715 0.11503382 0.11160627 0.11267374 0.115697   0.11171456
  0.11313654 0.11404003 0.11364217 0.11289286 0.1128913  0.11426701
  0.11525901 0.11354117 0.11254192 0.11200604]
 [0.11520814 0.11430109 0.11293518 0.1128066  0.11462593 0.11439352
  0.11614902 0.1142773

907 0.21529410867130055 0.0018026919555268251
[[0.65793884 0.62797964 0.6518051  0.64086884 0.6722742  0.6371562
  0.648887   0.64269936 0.6432517  0.64395326 0.64836925 0.63551974
  0.6388084  0.64912003 0.64218813 0.6395246  0.64490134 0.6601647
  0.65848386 0.6567338  0.65792006 0.65735763 0.6406536  0.6434386
  0.66783416 0.6389699  0.6501233  0.66932744 0.65161586 0.63847786
  0.63504183 0.6499111  0.6499929  0.64803165 0.6589226  0.6512177
  0.6603044  0.6422274  0.6703437  0.64307255]
 [0.6730438  0.674522   0.64707655 0.6474208  0.6629781  0.64640707
  0.63623625 0.6300835  0.6435849  0.6702617  0.64121896 0.66575485
  0.66578877 0.6447683  0.65718526 0.6455508  0.66061527 0.66832227
  0.6794742  0.6643744  0.6602896  0.66582495 0.6689794  0.67508143
  0.6684359  0.66974354 0.65257853 0.6647485  0.6712675  0.6695724
  0.64113873 0.6555214  0.6412674  0.65042615 0.6308719  0.6600257
  0.6719901  0.640152   0.6521617  0.6604679 ]
 [0.63713557 0.6634259  0.6541968  0.66514874 0.65

[[1.2677    1.2869525 1.2897506 1.3097279 1.3093905 1.2599891 1.3338655
  1.280807  1.2776287 1.2911614 1.2965305 1.259557  1.2733746 1.3074458
  1.3043227 1.3111697 1.25302   1.3267637 1.325682  1.337512  1.3181124
  1.2947356 1.2730861 1.3066877 1.3132188 1.2965603 1.280112  1.3250375
  1.2885034 1.3049271 1.3150635 1.2912624 1.338009  1.2894957 1.3049693
  1.3225665 1.3082445 1.3054101 1.3389473 1.3051653]
 [1.3526284 1.3396944 1.3007162 1.3475454 1.3606286 1.3383148 1.3076415
  1.3018446 1.336555  1.397331  1.2941239 1.2821473 1.3341041 1.3038718
  1.3341062 1.3076451 1.3454884 1.3603058 1.3566867 1.327408  1.3397362
  1.326723  1.363493  1.3257273 1.324801  1.3427479 1.3301241 1.3128283
  1.3426973 1.3488438 1.2720351 1.3479507 1.2983924 1.3378172 1.31986
  1.3176775 1.3396041 1.2873137 1.3109138 1.2999259]
 [1.2997155 1.301337  1.3124433 1.3097546 1.3156197 1.3236731 1.3440373
  1.3383462 1.3466563 1.2807771 1.3504677 1.3674942 1.3079865 1.3700684
  1.3080996 1.3545644 1.3175118 

4119 0.4105882532456342 0.0007677281100768596
[[1.709213  1.6905217 1.6954904 1.6875913 1.7080753 1.6644454 1.7402642
  1.6926138 1.6892381 1.677932  1.6740673 1.6928661 1.713814  1.689424
  1.6761832 1.6791444 1.6877725 1.7343228 1.7145207 1.6956358 1.7065821
  1.6990123 1.6499698 1.7367713 1.7552357 1.7231197 1.6995275 1.7623465
  1.6750538 1.6927965 1.7329524 1.6703322 1.6920276 1.712641  1.7103662
  1.6947467 1.7409434 1.6810126 1.7241771 1.7386508]
 [1.774164  1.7573888 1.736675  1.7624843 1.7779756 1.7466583 1.7422068
  1.7207413 1.725004  1.7882233 1.7298019 1.6866777 1.7411587 1.6981056
  1.7016807 1.7448926 1.7430739 1.7493336 1.736171  1.756105  1.7613299
  1.7409105 1.790082  1.7727885 1.7742167 1.7737317 1.7704327 1.7475986
  1.7627933 1.7570915 1.7450836 1.7693496 1.7786222 1.6968286 1.7298038
  1.7424121 1.7343767 1.7470784 1.7528849 1.7153549]
 [1.7240994 1.7639108 1.7698951 1.7116253 1.7106206 1.7391913 1.7457135
  1.7628341 1.7615659 1.6907713 1.7748408 1.7786388 1.725

5778 0.4823529411764706 0.0007364778881310486
[[2.0326555 2.0210357 2.0328743 2.0221148 1.9723017 1.9694409 2.0561593
  1.9995694 2.0354097 1.98598   2.038094  1.9614973 2.0721729 2.0383854
  2.0463002 2.0184977 2.0112793 2.0454667 2.0199733 2.0159142 2.021616
  2.0165892 1.9776754 2.0018318 2.0134873 1.9634085 1.9781632 2.0567772
  2.0073035 2.0079143 2.042093  2.021     1.9834759 2.0166621 2.0323896
  2.0080664 2.0796242 2.0690494 2.0405273 2.0775194]
 [2.09553   2.068849  2.0461638 2.0995686 2.1132376 2.04935   2.0624943
  2.042703  2.0536053 2.1120348 2.0623918 1.9975989 2.0597062 2.0660312
  1.9942806 2.0478086 2.0531452 2.0844636 2.0640433 2.1014228 2.1270647
  2.0997806 2.120849  2.1162632 2.0473833 2.1175458 2.1228724 2.126288
  2.114874  2.0490942 2.078122  2.095052  2.1370082 1.9969962 2.028467
  2.099844  2.0573697 2.0595398 2.0957801 2.0811317]
 [2.0117733 2.0956278 2.1225789 2.035212  2.0108473 2.1309578 2.1282248
  2.1243553 2.0887365 2.0174513 2.0937102 2.1439521 2.07808

9985 0.5141176560345818 0.00028460795838327615
10083 0.5141176560345818 0.00027961137457168663
10191 0.5141176560345818 0.00020519223598967074
10312 0.5141176560345818 0.0002119232432960416
10406 0.5141176560345818 0.00027992292434646514
10497 0.5141176560345818 0.00015760004389449024
10613 0.5141176560345818 0.00019957026142947143
10725 0.5141176560345818 0.0001977562213141937
10833 0.5141176560345818 0.00020469144510570914
10911 0.5141176560345818 0.0002044718819888658
11010 0.5141176560345818 0.0002513963392630103
11091 0.5141176560345818 0.0001718883146168082
11214 0.5141176560345818 0.00022614170775341335
11299 0.5141176560345818 0.0001785913176718168
11414 0.5141176560345818 0.00020722248464153382
11488 0.5141176560345818 0.00019931404385715724
11579 0.5141176560345818 0.0001271977391297696
11681 0.5141176560345818 0.0002044340294560243
11779 0.5141176560345818 0.0001532514772407012
11873 0.5141176560345818 0.00015921536469249987
11953 0.5141176560345818 0.00011898616194230271
12

IndexError: tuple index out of range

In [8]:
batch_size = 1000
generations = 20
agent = RL_agent_DQN(wt, alphabet=RAA, gamma=0.5, memory_size=10000, device=device)
agent.run_RL(generations=generations, train_epochs=40)

Q_Network(
  (linear1): Linear(in_features=320, out_features=160, bias=True)
  (bn1): BatchNorm1d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear2): Linear(in_features=160, out_features=40, bias=True)
  (bn2): BatchNorm1d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear3): Linear(in_features=40, out_features=1, bias=True)
)
[[0.14735202 0.14921135 0.14877278 0.15073323 0.14842862 0.14758505
  0.14793728 0.14747332 0.1491582  0.14850947 0.1487302  0.14969522
  0.1469557  0.14886816 0.14839628 0.1474535  0.14841598 0.14668922
  0.14813066 0.14698865 0.15097365 0.14952713 0.14764981 0.14940028
  0.14866854 0.14341514 0.14928868 0.14674093 0.14796102 0.15126503
  0.15080267 0.1477365  0.14508921 0.14854127 0.14991939 0.14751215
  0.15200609 0.14912866 0.14955488 0.1495286 ]
 [0.14785463 0.15096432 0.14933215 0.14617406 0.14969236 0.14941216
  0.14844482 0.14862289 0.148649   0.14767459 0.14917839 0.14840712
  0.15053472 0.1448522 

[[0.13757142 0.14019836 0.1381578  0.13760681 0.13525166 0.13974859
  0.1378876  0.13689135 0.1378916  0.1374065  0.13650586 0.13884488
  0.1384271  0.1364093  0.13721292 0.13748392 0.13768815 0.1367936
  0.13740417 0.13682377 0.13746418 0.13825989 0.13676442 0.13843091
  0.13673504 0.13604048 0.1366764  0.13636966 0.13864619 0.13836864
  0.13965277 0.13699995 0.13794181 0.13689154 0.13890156 0.13605364
  0.1405822  0.13927618 0.14033705 0.13750368]
 [0.1361753  0.13987507 0.13493547 0.1376552  0.13706592 0.1383407
  0.13784103 0.14057492 0.13788438 0.13671312 0.13903274 0.13731161
  0.13886134 0.13621654 0.13979822 0.13651945 0.1385349  0.13631868
  0.13662615 0.13746107 0.1384034  0.13829473 0.13723221 0.13706084
  0.1394131  0.13956372 0.13866565 0.13853396 0.13769966 0.1371633
  0.13690448 0.13859086 0.13604276 0.13788056 0.13747397 0.1401807
  0.1377308  0.13952737 0.13871764 0.1369336 ]
 [0.13529056 0.13859792 0.13823102 0.13538122 0.13806091 0.13895175
  0.14182341 0.13740681 0.

[[0.23167658 0.20907512 0.21244209 0.21001823 0.2193685  0.21279316
  0.20715508 0.21825464 0.2039446  0.21694675 0.2100511  0.21709628
  0.22989674 0.22275212 0.2141869  0.22398175 0.20165098 0.20313787
  0.22799155 0.2012151  0.20721143 0.20926788 0.20786509 0.20567
  0.20711568 0.2164681  0.22343662 0.22485952 0.20045194 0.22369082
  0.23606119 0.20075612 0.20858786 0.21156654 0.21434768 0.21703465
  0.22131196 0.20555136 0.22516653 0.22356054]
 [0.21520741 0.22385448 0.2140106  0.21132877 0.20876215 0.21577898
  0.2247061  0.2231542  0.22256136 0.22786583 0.22759753 0.22828254
  0.21802846 0.21608704 0.23211162 0.21301599 0.23410024 0.22143207
  0.22472885 0.21684697 0.21866113 0.19953227 0.21468309 0.20754892
  0.23687041 0.2257664  0.22138584 0.21995671 0.21566737 0.2276322
  0.22486983 0.22629607 0.23290926 0.21473512 0.23175535 0.21869028
  0.2183412  0.22735037 0.20587616 0.21317147]
 [0.20377643 0.21288894 0.21753876 0.23713456 0.22237791 0.23078027
  0.22673222 0.21814503 0.

[[0.5661403  0.53132457 0.52033854 0.535601   0.5143013  0.49404106
  0.543797   0.54430914 0.53175974 0.5023456  0.53045183 0.5318737
  0.52015173 0.5185964  0.51726717 0.5132133  0.5181257  0.52147615
  0.5353756  0.51499176 0.53690803 0.5369457  0.5279419  0.52692807
  0.5497961  0.52008057 0.5028001  0.54290843 0.51924634 0.5113325
  0.5116751  0.51735675 0.51678085 0.5207578  0.5271826  0.49486443
  0.5184455  0.5488732  0.54503655 0.5394151 ]
 [0.5218215  0.54960203 0.55778563 0.527377   0.54493034 0.53010833
  0.5581174  0.52028394 0.5318129  0.540661   0.5655278  0.5213053
  0.5478542  0.56883967 0.60119325 0.5801784  0.5757824  0.5424763
  0.53063023 0.55210674 0.5041187  0.5420054  0.5393088  0.55255455
  0.54407203 0.5651257  0.57817936 0.5368575  0.5882269  0.56888646
  0.5676597  0.5392226  0.5825399  0.5664829  0.5726303  0.54872763
  0.5556841  0.53218985 0.53924024 0.5294741 ]
 [0.5397942  0.5478771  0.5680159  0.5861608  0.548779   0.5759006
  0.54451525 0.55498505 0.5

[[0.5987724  0.56514806 0.55387795 0.5727761  0.55347025 0.5350853
  0.56879485 0.58508086 0.57439864 0.54515016 0.5729501  0.56466913
  0.55495095 0.56100583 0.5682025  0.55982995 0.5672352  0.5722114
  0.5762081  0.5585511  0.5658238  0.5777831  0.5593958  0.56639206
  0.59008193 0.5600636  0.5517881  0.568149   0.5608719  0.553712
  0.55531144 0.55445796 0.5534879  0.567959   0.5652427  0.54368067
  0.5617682  0.59561574 0.5856981  0.5736246 ]
 [0.56935656 0.5895728  0.59567016 0.56854725 0.5830738  0.573037
  0.58950263 0.555884   0.57249093 0.56258714 0.6058749  0.5640092
  0.5850334  0.60805535 0.6224835  0.61185324 0.61337984 0.57409835
  0.56709826 0.5823262  0.54981613 0.57098734 0.5750408  0.5913638
  0.5839735  0.60282356 0.6085088  0.56263494 0.6222979  0.6020255
  0.59934485 0.5786289  0.6140591  0.59836185 0.5975481  0.58137816
  0.58630204 0.5761547  0.58135355 0.5750948 ]
 [0.57274044 0.5946244  0.5914676  0.61617506 0.58232296 0.60437584
  0.5891924  0.5884801  0.60625

4631 0.5247058644014246 0.0005823361316288356
5054 0.5247058644014246 0.0004735983930004295
[[1.0407147  0.96679044 0.96732706 0.9431033  0.96125925 0.9433973
  0.9836408  0.9586685  0.95613617 0.96075165 0.95550424 0.9604473
  0.9731346  0.9552313  0.94627416 0.922901   0.94570374 1.0063658
  0.9938721  0.96394134 0.97213507 0.9509333  0.96608007 0.94467205
  1.0043898  0.9725112  0.9347869  0.9760246  0.95400405 0.9304049
  0.9493973  1.0110513  0.9700142  0.9407582  0.94031256 0.9432101
  0.96741575 1.0190498  1.0180671  1.026077  ]
 [1.0053494  0.98191935 0.954188   0.95028174 0.95567906 0.9694499
  1.0432111  0.93529785 0.9712378  0.94784975 1.0590434  0.9560071
  0.9549024  1.0388916  1.0697448  1.0738646  1.0372281  1.018541
  0.9927042  0.994447   0.9603274  0.9785975  0.9632872  0.9737245
  0.9619049  1.070768   1.0527964  0.9683876  1.0738568  1.0609611
  1.0514264  0.99159676 1.0509814  1.0512648  1.0617334  1.0555162
  1.0689416  0.9975453  1.0121337  0.9883218 ]
 [1.002920

6050 0.5670588325051701 0.0003029861120012356
[[1.1080836 1.0829036 1.0812021 1.0457423 1.071649  1.0623246 1.0695705
  1.0565909 1.0715461 1.0701922 1.0498084 1.0754789 1.0726867 1.0503252
  1.0381728 1.0220793 1.0888143 1.1237658 1.1303331 1.1331644 1.1115752
  1.0831094 1.0857024 1.053863  1.1034789 1.0695113 1.045423  1.0873048
  1.0441555 1.0459177 1.0226388 1.1136835 1.0461102 1.0550181 1.0487258
  1.0390607 1.0786268 1.1104089 1.1432568 1.1394411]
 [1.1497235 1.0912713 1.0583802 1.0510755 1.0547161 1.0616131 1.1588322
  1.0454576 1.0677917 1.0424453 1.1851661 1.0639331 1.0696179 1.183824
  1.1887972 1.192212  1.1753262 1.141694  1.1182932 1.1134851 1.0997148
  1.0958283 1.0627488 1.0507945 1.0559709 1.1955833 1.1893579 1.0707606
  1.2019438 1.1891953 1.1759312 1.0861884 1.1750318 1.1734835 1.2052827
  1.1762856 1.1680535 1.1168725 1.1201701 1.1290164]
 [1.1073861 1.1834561 1.1885849 1.1784768 1.1992006 1.183935  1.0699155
  1.2014942 1.1835495 1.1968156 1.051024  1.1898494 1.165

11148 0.5858823439654182 0.00010155298559766379
11182 0.5858823439654182 6.505770763851615e-05
11264 0.5858823439654182 0.0001579546152470357
11315 0.5858823439654182 9.578949945989734e-05
11369 0.5858823439654182 6.759993368632422e-05
11395 0.5858823439654182 8.375181755582162e-05
11460 0.5858823439654182 0.00010411785017367947
11486 0.5858823439654182 0.00011475498390609574
11535 0.5858823439654182 0.00011971223809723597
11588 0.5858823439654182 0.00011391811358407722
11611 0.5858823439654182 0.00011041017389743502
11635 0.5858823439654182 6.108835989380169e-05
11677 0.5858823439654182 8.988256270185957e-05
11699 0.5858823439654182 0.00010869311931855918
11716 0.5858823439654182 6.316390142728779e-05
11731 0.5858823439654182 0.000152768465909503
11750 0.5858823439654182 6.939155637155636e-05
11811 0.5858823439654182 0.0001381176869472256
11847 0.5858823439654182 0.00010239277653454338
11902 0.5858823439654182 6.251134821013692e-05
11936 0.5858823439654182 0.0001316716526162054
11962 

17251 0.5858823439654182 3.595099561835013e-05
17267 0.5858823439654182 3.638822539642206e-05
17292 0.5858823439654182 3.734490910574095e-05
17358 0.5858823439654182 3.683835010406256e-05
17409 0.5858823439654182 3.623251552085094e-05
17423 0.5858823439654182 3.753838071816062e-05
17456 0.5858823439654182 3.642708870756906e-05
17472 0.5858823439654182 3.7300646113180846e-05
17481 0.5858823439654182 3.7797445679643714e-05
17505 0.5858823439654182 3.866619728114529e-05
17515 0.5858823439654182 3.651593511904139e-05
17543 0.5858823439654182 3.7453607239967825e-05
17574 0.5858823439654182 3.604703438782053e-05
17592 0.5858823439654182 3.635208403238721e-05
17606 0.5858823439654182 3.779268427592797e-05
17631 0.5858823439654182 3.6988410460025986e-05
17651 0.5858823439654182 3.596211413423589e-05
17659 0.5858823439654182 3.5873213221293554e-05
17668 0.5858823439654182 3.5829850116897435e-05
17735 0.5858823439654182 3.796458341298603e-05
17781 0.5858823439654182 3.716292979447644e-05
17813 0

('GCCCCCGCCCGCCGGGGGUCACCCCGGGGGGCGGGGGCGA', 0.5858823439654182)

In [95]:
batch_size = 1000
generations = 20
agent = RL_agent_DQN(wt, alphabet=RAA, gamma=0.8, memory_size=10000, device=device)
agent.run_RL(generations=generations, train_epochs=40, epsilon_min=0.2)

Q_Network(
  (linear1): Linear(in_features=320, out_features=160, bias=True)
  (bn1): BatchNorm1d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear2): Linear(in_features=160, out_features=40, bias=True)
  (bn2): BatchNorm1d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear3): Linear(in_features=40, out_features=1, bias=True)
)
[[0.04411029 0.04457659 0.04416889 0.04366009 0.04530866 0.04545609
  0.04315792 0.04468696 0.0454899  0.04390033 0.0455474  0.04321003
  0.04528911 0.04315555 0.04648664 0.04387166 0.04403219 0.04273931
  0.04135934 0.046217   0.04358989 0.04244828 0.04555681 0.04594287
  0.04445239 0.04573286 0.04504759 0.04524492 0.04446078 0.04526094
  0.04536694 0.04401646 0.0452577  0.0443987  0.04472012 0.04260904
  0.04370743 0.04369771 0.04347298 0.04435972]
 [0.04532193 0.04438268 0.04338095 0.04432257 0.04468241 0.04289825
  0.04318272 0.04557195 0.04361975 0.04393742 0.04373531 0.04686577
  0.04358487 0.04407912

[[0.04648048 0.04720763 0.04772173 0.04684468 0.04927927 0.05061628
  0.04810569 0.0481355  0.04842921 0.04751866 0.04934393 0.04811952
  0.04920854 0.04711197 0.05044554 0.04876403 0.04751249 0.04563624
  0.04748053 0.05019916 0.04636972 0.04844607 0.04972139 0.04899543
  0.04770904 0.04786824 0.0493934  0.04749478 0.04884774 0.04923942
  0.05035321 0.04849773 0.04742054 0.04752643 0.04790811 0.04617831
  0.04807077 0.04609822 0.04839518 0.04782272]
 [0.04763775 0.04699089 0.04842916 0.04723283 0.0466551  0.04627895
  0.0483825  0.04841419 0.04831254 0.04837641 0.04935526 0.04824306
  0.04628921 0.0482379  0.0456062  0.04721131 0.04853883 0.04857325
  0.04621228 0.04787657 0.0486586  0.04892348 0.04649453 0.04729834
  0.04912991 0.04849939 0.05041817 0.04832759 0.04685383 0.04648029
  0.0478062  0.04835465 0.04654623 0.05022266 0.04756201 0.04921216
  0.04868393 0.04819804 0.04672492 0.04805734]
 [0.04846421 0.04688146 0.04800907 0.04823876 0.04497555 0.04792868
  0.04727015 0.0475725

908 0.21176470588235294 0.0015866621819441208
1624 0.21176470588235294 0.0007282373582711444
[[1.0193946  1.0319134  0.99092466 0.9882752  1.0197338  1.0011703
  0.95465255 0.96964157 1.0392693  0.95119643 0.9791161  0.9591
  0.98679024 1.0117412  0.9977269  0.96276814 1.0110565  0.9826301
  0.9999847  0.98857415 1.0033519  0.9946439  1.0071219  0.98714024
  1.0221919  0.9803995  1.010869   0.97954553 0.99566096 0.9721862
  1.0251998  0.98261595 0.9738212  1.0363551  0.9980835  1.0019999
  0.9701896  0.996087   1.0151594  0.9619797 ]
 [1.0153959  1.0228354  1.0187258  1.0190483  0.9855235  1.0477217
  1.0216067  1.0467199  1.0160522  1.0648441  1.0261118  1.0300982
  1.0108693  1.0093075  1.0427248  1.0472065  1.020764   0.99368656
  1.0224183  1.0583118  1.0226227  1.0487174  0.99799746 1.003344
  0.9978563  1.0116524  1.0097544  1.0065866  1.0057918  1.0284126
  1.0382074  0.9935026  0.98298866 1.0050981  1.0074693  1.0671617
  1.0134357  1.0525773  0.98474187 1.0049448 ]
 [1.0218712

2244 0.28823529411764703 0.0006979635829338804
[[1.271355  1.2748914 1.264628  1.2093782 1.2248813 1.2237451 1.2164321
  1.2255183 1.2609751 1.1875999 1.2147217 1.187618  1.2229815 1.2453408
  1.232827  1.1992688 1.246754  1.2472456 1.2661164 1.2429076 1.2765098
  1.19948   1.2331612 1.2319002 1.2719724 1.2345438 1.2321697 1.215628
  1.227917  1.2040656 1.2905817 1.2346215 1.204646  1.2718079 1.2283664
  1.2182038 1.2132305 1.241289  1.2248815 1.2233164]
 [1.2768586 1.2440042 1.2969465 1.2912772 1.2618647 1.2815087 1.2754388
  1.276881  1.2419841 1.3064058 1.3076407 1.2792596 1.2857022 1.2983937
  1.2992189 1.2919725 1.2793148 1.2661586 1.2813958 1.301949  1.2761397
  1.3279164 1.2423224 1.2318559 1.2284746 1.2349226 1.2813642 1.2355568
  1.23638   1.3018169 1.2937897 1.2529984 1.2326813 1.248377  1.2672601
  1.318892  1.2819041 1.2778633 1.2395328 1.2573022]
 [1.2523121 1.2477381 1.2092211 1.274952  1.2698019 1.2653353 1.2841979
  1.271212  1.2808381 1.2527517 1.2344341 1.2719188 1.21

3948 0.40588235294117647 0.000825471403368283
[[1.8731582 1.805605  1.8082666 1.7927619 1.7914026 1.7770673 1.833289
  1.8054463 1.8284172 1.7869904 1.7992309 1.7517562 1.7679852 1.8025153
  1.8008605 1.775695  1.8342391 1.8426776 1.8749952 1.8123076 1.836623
  1.7890722 1.7739896 1.774298  1.8343138 1.8176191 1.8115792 1.771332
  1.7670625 1.8011793 1.8153423 1.863714  1.8030113 1.833131  1.7710052
  1.814868  1.8289583 1.8513006 1.8238465 1.8272691]
 [1.8027463 1.8411882 1.9025213 1.8741571 1.8661109 1.846719  1.8537792
  1.8653568 1.8637519 1.8924593 1.9184741 1.8735865 1.9132276 1.8728101
  1.8767633 1.8973972 1.8553818 1.8657489 1.908236  1.8738776 1.8883419
  1.8932878 1.8277977 1.7920916 1.8252373 1.784512  1.8414905 1.7845381
  1.8228046 1.8656849 1.8835943 1.8527858 1.7398549 1.803012  1.8573357
  1.8910729 1.8703567 1.8592588 1.8405081 1.8027574]
 [1.8713726 1.9124657 1.786961  1.8415369 1.8768325 1.8628201 1.8562071
  1.8381344 1.8435532 1.7819312 1.7670034 1.8307328 1.84042

6086 0.49882354736328127 0.00045710437771049326
6334 0.49882354736328127 0.0005452376026369166
[[2.1796587 2.1635408 2.1228194 2.09585   2.205381  2.1107085 2.142476
  2.1693485 2.1375196 2.0976224 2.1096811 2.0966096 2.0975618 2.104663
  2.0345986 2.1465232 2.1320882 2.161766  2.1921391 2.1976435 2.1709528
  2.1616063 2.1250353 2.1131964 2.1220808 2.0726728 2.112135  2.0764935
  2.1284032 2.1823452 2.181083  2.1345286 2.10438   2.1643891 2.1456041
  2.148061  2.1868615 2.2126632 2.185653  2.198542 ]
 [2.1911807 2.1557808 2.2384906 2.2722762 2.205863  2.1087503 2.1761246
  2.231794  2.2373388 2.2534425 2.2420492 2.2218475 2.2391698 2.2257524
  2.2816586 2.2044516 2.1585546 2.2091846 2.1964147 2.1916404 2.214494
  2.208638  2.1100411 2.13096   2.0946379 2.0409667 2.11568   2.0787234
  2.1115894 2.1715512 2.224743  2.2457128 2.1182976 2.1657605 2.2106164
  2.234882  2.1902957 2.1671572 2.1908088 2.1902566]
 [2.2262614 2.2416415 2.1462142 2.106243  2.1813185 2.2248926 2.231319
  2.1674266

8942 0.5599999820484834 0.0004546081800071988
9153 0.5599999820484834 0.0005142606030858587
9335 0.5599999820484834 0.0004283723043045029
9482 0.5599999820484834 0.00039947588666109366
9634 0.5599999820484834 0.000448013239656575
9819 0.5599999820484834 0.0004840663052164018
9977 0.5599999820484834 0.0004112386486667674
10165 0.5599999820484834 0.0004216093941067811
10350 0.5599999820484834 0.000420387131816824
10520 0.5599999820484834 0.00037537124881055207
10739 0.5599999820484834 0.000419453038557549
10904 0.5599999820484834 0.00038931094713916535
11092 0.5599999820484834 0.00031686831898696254
11277 0.5599999820484834 0.00031644196460547393
11468 0.5599999820484834 0.0004111582871701103
11649 0.5599999820484834 0.0003745518970390549
11883 0.5599999820484834 0.00044645986054092646
12057 0.5599999820484834 0.0004036308586364612
12218 0.5599999820484834 0.0003904027082171524
12408 0.5599999820484834 0.0004966883909219178
12638 0.5599999820484834 0.0003284110560343834
12840 0.559999982

('GCGGGCCGGGGGGGGCCGUAGGCCCCCCCCGGCCCGGAAU', 0.5599999820484834)

In [98]:
batch_size = 1000
generations = 20
agent = RL_agent_DQN(wt, alphabet=RAA, gamma=0.9, memory_size=10000, device=device)
agent.run_RL(generations=generations, train_epochs=40, epsilon_min=0.2)

Q_Network(
  (linear1): Linear(in_features=320, out_features=160, bias=True)
  (bn1): BatchNorm1d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear2): Linear(in_features=160, out_features=40, bias=True)
  (bn2): BatchNorm1d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear3): Linear(in_features=40, out_features=1, bias=True)
)
[[0.09011858 0.08911628 0.08818015 0.08731744 0.0885158  0.08561381
  0.08771721 0.08607282 0.08683001 0.08800399 0.08710308 0.08721759
  0.08684646 0.08627951 0.08667099 0.08822922 0.08850691 0.08800587
  0.08866739 0.08786146 0.08764639 0.08604883 0.08863344 0.08868428
  0.08567745 0.08813688 0.08755046 0.08790256 0.0863748  0.08782889
  0.08939397 0.08685362 0.08772202 0.08933338 0.08729862 0.08806799
  0.08684321 0.08864309 0.08756786 0.08796078]
 [0.08964147 0.08862317 0.08523662 0.08785106 0.08531138 0.08866674
  0.08889156 0.08956146 0.08763635 0.08634494 0.08878747 0.08699483
  0.08573465 0.08840132

[[0.09127852 0.09171647 0.09180443 0.09257285 0.09236062 0.09478053
  0.09391802 0.0902044  0.09183078 0.09324512 0.09142454 0.08911175
  0.09244213 0.09114654 0.09468263 0.09125858 0.0908707  0.09521677
  0.09480374 0.09428667 0.08941168 0.08947717 0.09271494 0.0923089
  0.09311461 0.09437818 0.09547638 0.0897334  0.09203114 0.09283774
  0.09646374 0.09277332 0.09468607 0.09225109 0.09504288 0.09337446
  0.09201381 0.09111067 0.09635884 0.09224555]
 [0.09706237 0.09128821 0.09392506 0.09363927 0.08893854 0.09247713
  0.09520275 0.09027994 0.09148595 0.08929574 0.09162077 0.09464076
  0.08923014 0.09242853 0.09213421 0.09444877 0.09330337 0.09038904
  0.09266429 0.09167816 0.09535472 0.08933228 0.09386561 0.09318266
  0.09135047 0.09345129 0.0900932  0.09227142 0.09303822 0.09228127
  0.09066796 0.09248479 0.09230172 0.09463345 0.09175315 0.09185679
  0.09224024 0.09397871 0.09247837 0.09545069]
 [0.09159286 0.09166729 0.09433231 0.09151268 0.0879123  0.09462917
  0.09346266 0.09341143

929 0.1964705972110524 0.0015984674828359857
1647 0.1964705972110524 0.000920925107493531
[[1.758633  1.7110314 1.7668954 1.749196  1.6938018 1.7124294 1.7621971
  1.7142819 1.711146  1.735474  1.7396709 1.738087  1.7625574 1.7654563
  1.7240481 1.7531701 1.7673931 1.7024401 1.7113091 1.734536  1.7437189
  1.724228  1.7361511 1.7381378 1.7249063 1.7251521 1.7390541 1.7409037
  1.7125713 1.7191104 1.7477441 1.7399508 1.7662336 1.7357959 1.7504239
  1.7485346 1.733569  1.7115486 1.7607714 1.7648691]
 [1.7770107 1.7857257 1.6723121 1.778286  1.7767985 1.7630569 1.745224
  1.7383785 1.7737066 1.773421  1.7128575 1.8173083 1.7560586 1.7791046
  1.7517678 1.768656  1.7699114 1.7589276 1.804192  1.7320849 1.7654823
  1.7837654 1.7573953 1.7202576 1.750765  1.7530578 1.7577295 1.7795948
  1.775751  1.752872  1.7772585 1.7206813 1.7413937 1.7575349 1.7820914
  1.7937887 1.7738672 1.7585284 1.7380892 1.7341024]
 [1.6971325 1.7223636 1.7495857 1.7348553 1.7700864 1.7752302 1.7261206
  1.7649914 1

2891 0.2529411764705882 0.0022007855906849725
3500 0.2529411764705882 0.0013616791227832436
4051 0.2529411764705882 0.0011507421382702886
4608 0.2529411764705882 0.0011461411602795124
[[2.335156  2.353938  2.3750167 2.4014387 2.3128629 2.3569064 2.3533196
  2.3537283 2.3172536 2.2992697 2.30905   2.3808289 2.3692265 2.3007545
  2.3658772 2.3083644 2.3432655 2.3306007 2.272254  2.3548694 2.3280134
  2.3293562 2.3365889 2.3526077 2.3109946 2.3325067 2.339108  2.3752184
  2.352827  2.3652549 2.333394  2.315144  2.31668   2.368719  2.3418226
  2.36054   2.2893538 2.337037  2.362131  2.375187 ]
 [2.4085255 2.379086  2.3053355 2.3723044 2.3968248 2.3383164 2.3988338
  2.3636718 2.3887658 2.3871584 2.373632  2.3933463 2.370028  2.3964262
  2.3708448 2.4048114 2.3845701 2.3896399 2.3994694 2.3493881 2.3620315
  2.383707  2.3727407 2.3616505 2.3586736 2.3778863 2.3697376 2.3375163
  2.3800902 2.3645778 2.4097033 2.3988824 2.3065329 2.3755274 2.3751988
  2.3884873 2.388136  2.3749223 2.391873  2

5046 0.34705882352941175 0.0010560053706285544
5529 0.34705882352941175 0.000993012747494504
5986 0.34705882352941175 0.0008976863959105685
[[2.703847  2.6744947 2.668744  2.7051098 2.6340868 2.6453035 2.682036
  2.7140799 2.6454544 2.6782317 2.6057823 2.6688309 2.70938   2.6720352
  2.6850305 2.6332526 2.7014465 2.6888108 2.6523447 2.6810513 2.7039728
  2.7068682 2.639412  2.678547  2.6811228 2.692008  2.664733  2.6978428
  2.6930106 2.7128468 2.687522  2.698657  2.635005  2.664174  2.7055233
  2.6931548 2.595463  2.628058  2.7144961 2.7188046]
 [2.722836  2.7342854 2.7241974 2.7279525 2.7293553 2.738585  2.7272096
  2.6979332 2.7656684 2.7389598 2.7212934 2.7271428 2.7091198 2.7256582
  2.757554  2.7481165 2.7239127 2.734407  2.7263525 2.7261782 2.703783
  2.7126021 2.7332652 2.698431  2.717904  2.699744  2.7278051 2.675199
  2.7500606 2.704671  2.720467  2.728132  2.6897805 2.7055128 2.697626
  2.727298  2.7281442 2.706737  2.7341628 2.6765032]
 [2.7065606 2.7111976 2.7319608 2.7021

7177 0.45999998204848347 0.0009265978864277713
7469 0.45999998204848347 0.0008447989821434021
7795 0.45999998204848347 0.0008205781457945704
[[3.6517978 3.5682034 3.5702147 3.6802776 3.581441  3.5296152 3.6026826
  3.6597314 3.5880132 3.5619216 3.6043265 3.6768315 3.6631877 3.538494
  3.5329041 3.5555246 3.6312103 3.6346383 3.5591054 3.5842457 3.627867
  3.677808  3.5524223 3.5726757 3.5932302 3.584341  3.5809498 3.6294193
  3.6257033 3.5827816 3.6549153 3.5862527 3.6181002 3.608848  3.6551712
  3.6242056 3.5501912 3.5758862 3.640982  3.6271446]
 [3.6828828 3.6841657 3.6475773 3.662856  3.70517   3.7049038 3.6608047
  3.6413386 3.6856782 3.6954973 3.651517  3.6920338 3.6546392 3.7064533
  3.697166  3.687614  3.6832979 3.6769655 3.6948824 3.657124  3.6509688
  3.6780245 3.691679  3.5875034 3.621473  3.591554  3.6618786 3.603088
  3.6976895 3.6113772 3.6410794 3.6818018 3.6071694 3.6401188 3.6790094
  3.7105618 3.6369033 3.6558776 3.6642592 3.5904512]
 [3.6715689 3.6324825 3.6983027 3.68

11061 0.5176470588235295 0.0009936598333297297
11316 0.5176470588235295 0.0008205029153032228
11568 0.5176470588235295 0.000861290299508255
11869 0.5176470588235295 0.000786099533434026
12134 0.5176470588235295 0.000805779336951673
12358 0.5176470588235295 0.0007951204621349462
12569 0.5176470588235295 0.0007211751653812826
12810 0.5176470588235295 0.0008493719884427265
13021 0.5176470588235295 0.0008875985120539553
13215 0.5176470588235295 0.0009347571205580607
13418 0.5176470588235295 0.0007286922089406289
13622 0.5176470588235295 0.0008589908575231675
13835 0.5176470588235295 0.0008046297385590151
14093 0.5176470588235295 0.0008892029174603522
14334 0.5176470588235295 0.0008027724863495678
14584 0.5176470588235295 0.0007382695461274124
14813 0.5176470588235295 0.0008513906097505242
15057 0.5176470588235295 0.0007962061710713897
15233 0.5176470588235295 0.0008322358997247647
15463 0.5176470588235295 0.0007232286741782445
15705 0.5176470588235295 0.0007334349218581337
15971 0.51764705

('GGCCGGCCGGGGGGGGCCCGAGGCCCCCGCGGCCGGCCGA', 0.5176470588235295)