In [31]:
%matplotlib inline
import os 
from collections import deque
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import editdistance
import sys
import RNA
from typing import Dict, List, Tuple

import torch
from torch import nn
from tqdm import tqdm_notebook as tqdm
import torch.nn.functional as F
import torch.multiprocessing as mp
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_

# import path 
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from utils.sequence_utils import translate_one_hot_to_string,generate_random_mutant
from models.Theoretical_models import *
from models.Noise_wrapper import *
from exploration_strategies.CE import *
from utils.landscape_utils import *
from models.RNA_landscapes import *
from models.Multi_dimensional_model import *

from segment_tree import MinSegmentTree, SumSegmentTree

In [32]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
RAA="UGCA" #alphabet
length=40
wt=generate_random_sequences(length,1,alphabet=RAA)[0]
print(wt)
#make a simple folding landscape starting at wt
landscape1=RNA_landscape(wt)
noise_alpha=1
virtual_per_measure_ratio=15
temperature=0.1
# there are multiple abstract "noise models" you can use, or you can try to train your own model, using skM
noisy_landscape_CE=Noise_wrapper(landscape1,noise_alpha=noise_alpha)
noisy_landscape_RL=Noise_wrapper(landscape1,noise_alpha=noise_alpha)
noisy_landscape_RL_multiple=Noise_wrapper(landscape1,noise_alpha=noise_alpha)
#noisy_landscape=Gaussian_noise_landscape(base_landscape,noise_alpha=0.15)
#noisy_landscape=DF_noise_landscape(base_landscape,noise_alpha=0.5)
batch_size = 1000
initial_genotypes=list(set([wt]+[generate_random_mutant(wt,0.05,RAA) for i in range(batch_size)]))[:batch_size]

GCCGCGGUACAAUGAUUUCGGAGUGUGGCGCGUAACGCCC


## Rainbow DQN RL

In [33]:
import numpy as np
import random
import bisect
from utils.sequence_utils import translate_string_to_one_hot, translate_one_hot_to_string

def renormalize_moves(one_hot_input, rewards_output):
    """ensures that staying in place gives no reward"""
    zero_current_state = (one_hot_input - 1) * (-1)
    return np.multiply(rewards_output, zero_current_state)

def walk_away_renormalize_moves(one_hot_input, one_hot_wt, rewards_output):
    """ensures that moving toward wt is also not useful"""
    zero_current_state=(one_hot_input-1)*-1
    zero_wt=((one_hot_wt-1)*-1)
    zero_conservative_moves=np.multiply(zero_wt,zero_current_state)
    return np.multiply(rewards_output,zero_conservative_moves)

def get_all_singles_fitness(model,sequence,alphabet):
    prob_singles=np.zeros((len(alphabet),len(sequence)))
    for i in range(len(sequence)):
        for j in range(len(alphabet)):
            putative_seq=sequence[:i]+alphabet[j]+sequence[i+1:]
           # print (putative_seq)
            prob_singles[j][i]=model.get_fitness(putative_seq)
    return prob_singles

def get_all_mutants(sequence):
    mutants = []
    for i in range(sequence.shape[0]):
        for j in range(sequence.shape[1]):
            putative_seq = sequence.copy()
            putative_seq[:, j] = 0
            putative_seq[i, j] = 1
            mutants.append(putative_seq)
    return np.array(mutants)

def sample_greedy(matrix):
    i,j=matrix.shape
    max_arg=np.argmax(matrix)
    y=max_arg%j
    x=int(max_arg/j)
    output=np.zeros((i,j))
    output[x][y]=matrix[x][y]
    return output

def sample_multi_greedy(matrix):
    n = 5 # the number of base positions to greedily change
    max_args = np.argpartition(matrix.flatten(), -n)[-n:]
    i,j=matrix.shape
    output=np.zeros((i,j))
    for max_arg in max_args:
        y=max_arg%j
        x=int(max_arg/j)
        output[x][y]=matrix[x][y]
    return output

def sample_random(matrix):
    i,j=matrix.shape
    non_zero_moves=np.nonzero(matrix)
   # print (non_zero_moves)
    k=len(non_zero_moves)
    l=len(non_zero_moves[0])
    if k!=0 and l!=0:
        rand_arg=random.choice([[non_zero_moves[alph][pos] for alph in range(k)] for pos in range(l)])
    else:
        rand_arg=[random.randint(0,i-1),random.randint(0,j-1)]
    #print (rand_arg)
    y=rand_arg[1]
    x=rand_arg[0]
    output=np.zeros((i,j))
    output[x][y] = 1
    return output   

def action_to_scalar(matrix):
    matrix = matrix.ravel()
    for i in range(len(matrix)):
        if matrix[i] != 0:
            return i
    
def construct_mutant_from_sample(pwm_sample, one_hot_base):
    one_hot = np.zeros(one_hot_base.shape)
    one_hot += one_hot_base
    nonzero = np.nonzero(pwm_sample)
    nonzero = list(zip(nonzero[0], nonzero[1]))
    for nz in nonzero: # this can be problematic for non-positive fitnesses
        i, j = nz
        one_hot[:,j]=0
        one_hot[i,j]=1
    return one_hot

def best_predicted_new_gen(actor, genotypes, alphabet, pop_size):
    mutants = get_all_mutants(genotypes)
    one_hot_mutants = np.array([translate_string_to_one_hot(mutant, alphabet) for mutant in mutants])
    torch_one_hot_mutants = torch.from_numpy(np.expand_dims(one_hot_mutants, axis=0)).float()
    predictions = actor(torch_one_hot_mutants)
    predictions = predictions.detach().numpy()
    best_pred_ind = predictions.argsort()[-pop_size:]
    return mutants[best_pred_ind]

def make_one_hot_train_test(genotypes, model, alphabet):
    genotypes_one_hot = np.array([translate_string_to_one_hot(genotype, alphabet) for genotype in genotypes])
    genotype_fitnesses = []
    for genotype in genotypes:
        genotype_fitnesses.append(model.get_fitness(genotype))
    genotype_fitnesses = np.array(genotype_fitnesses)

    return genotypes_one_hot, genotype_fitnesses

In [75]:
from segment_tree import MinSegmentTree, SumSegmentTree

class ReplayBuffer:
    """A simple numpy replay buffer."""

    def __init__(
        self, 
        obs_dim, 
        size, 
        batch_size = 128, 
        n_step = 1, 
        gamma = 0.99
    ):
        self.obs_buf = np.zeros((size,) + obs_dim, dtype=np.float32)
        self.next_obs_buf = np.zeros((size,) + obs_dim, dtype=np.float32)
        self.acts_buf = np.zeros([size], dtype=np.float32)
        self.rews_buf = np.zeros([size], dtype=np.float32)
        self.max_size, self.batch_size = size, batch_size
        self.ptr, self.size, = 0, 0
        
        # for N-step Learning
        self.n_step_buffer = deque(maxlen=n_step)
        self.n_step = n_step
        self.gamma = gamma

    def store(
        self, 
        obs, 
        act, 
        rew, 
        next_obs
    ):
        transition = (obs, act, rew, next_obs)
        self.n_step_buffer.append(transition)

        # single step transition is not ready
        if len(self.n_step_buffer) < self.n_step:
            return False
        
        # make a n-step transition
        rew, next_obs = self._get_n_step_info(
            self.n_step_buffer, self.gamma
        )
        obs, action = self.n_step_buffer[0][:2]
        
        self.obs_buf[self.ptr] = obs
        self.next_obs_buf[self.ptr] = next_obs
        self.acts_buf[self.ptr] = act
        self.rews_buf[self.ptr] = rew
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)
        
        return True

    def sample_batch(self):
        idxs = np.random.choice(self.size, size=self.batch_size)

        return dict(
            obs=self.obs_buf[idxs],
            next_obs=self.next_obs_buf[idxs],
            acts=self.acts_buf[idxs],
            rews=self.rews_buf[idxs],
            indices=idxs,
        )
    
    def sample_batch_from_idxs(self, idxs):
        # for N-step Learning
        return dict(
            obs=self.obs_buf[idxs],
            next_obs=self.next_obs_buf[idxs],
            acts=self.acts_buf[idxs],
            rews=self.rews_buf[idxs]
        )
    
    def _get_n_step_info(
        self, n_step_buffer, gamma):
        """Return n step rew, next_obs."""
        # info of the last transition
        rew, next_obs, done = n_step_buffer[-1][-3:]
        for transition in reversed(list(n_step_buffer)[:-1]):
            r, next_obs = transition[-2:]
            rew = r + gamma * rew

        return rew, next_obs

    def __len__(self):
        return self.size
    
class PrioritizedReplayBuffer(ReplayBuffer):
    """Prioritized Replay buffer.
    
    Attributes:
        max_priority (float): max priority
        tree_ptr (int): next index of tree
        alpha (float): alpha parameter for prioritized replay buffer
        sum_tree (SumSegmentTree): sum tree for prior
        min_tree (MinSegmentTree): min tree for min prior to get max weight
        
    """
    
    def __init__(
        self, 
        obs_dim,
        size, 
        batch_size = 128, 
        alpha = 0.6
    ):
        """Initialization."""
        assert alpha >= 0
        
        super(PrioritizedReplayBuffer, self).__init__(obs_dim, size, batch_size)
        self.max_priority, self.tree_ptr = 1.0, 0
        self.alpha = alpha
        
        # capacity must be positive and a power of 2.
        tree_capacity = 1
        while tree_capacity < self.max_size:
            tree_capacity *= 2

        self.sum_tree = SumSegmentTree(tree_capacity)
        self.min_tree = MinSegmentTree(tree_capacity)
        
    def store(self, obs, act, rew, next_obs):
        """Store experience and priority."""
        transition = super().store(obs, act, rew, next_obs)
        
        self.sum_tree[self.tree_ptr] = self.max_priority ** self.alpha
        self.min_tree[self.tree_ptr] = self.max_priority ** self.alpha
        self.tree_ptr = (self.tree_ptr + 1) % self.max_size
        
        return transition

    def sample_batch(self, beta = 0.4):
        """Sample a batch of experiences."""
        indices = self._sample_proportional()
        
        obs = self.obs_buf[indices]
        next_obs = self.next_obs_buf[indices]
        acts = self.acts_buf[indices]
        rews = self.rews_buf[indices]
        weights = np.array([self._calculate_weight(i, beta) for i in indices])
        
        return dict(
            obs=obs,
            next_obs=next_obs,
            acts=acts,
            rews=rews,
            weights=weights,
            indices=indices,
        )
        
    def update_priorities(self, indices, priorities):
        """Update priorities of sampled transitions."""
        assert len(indices) == len(priorities)

        for idx, priority in zip(indices, priorities):
            assert priority > 0
            assert 0 <= idx < len(self)

            self.sum_tree[idx] = priority ** self.alpha
            self.min_tree[idx] = priority ** self.alpha

            self.max_priority = max(self.max_priority, priority)
            
    def _sample_proportional(self):
        """Sample indices based on proportions."""
        indices = []
        p_total = self.sum_tree.sum(0, len(self) - 1)
        segment = p_total / self.batch_size
        
        for i in range(self.batch_size):
            a = segment * i
            b = segment * (i + 1)
            upperbound = random.uniform(a, b)
            idx = self.sum_tree.retrieve(upperbound)
            indices.append(idx)
            
        return indices
    
    def _calculate_weight(self, idx, beta):
        """Calculate the weight of the experience at idx."""
        # get max weight
        p_min = self.min_tree.min() / self.sum_tree.sum()
        max_weight = (p_min * len(self)) ** (-beta)
        
        # calculate weights
        p_sample = self.sum_tree[idx] / self.sum_tree.sum()
        weight = (p_sample * len(self)) ** (-beta)
        weight = weight / max_weight
        
        return weight

In [106]:
class NoisyLinear(nn.Module):
    """Noisy linear module for NoisyNet.
    
    Attributes:
        in_features (int): input size of linear module
        out_features (int): output size of linear module
        std_init (float): initial std value
        weight_mu (nn.Parameter): mean value weight parameter
        weight_sigma (nn.Parameter): std value weight parameter
        bias_mu (nn.Parameter): mean value bias parameter
        bias_sigma (nn.Parameter): std value bias parameter
        
    """

    def __init__(self, in_features: int, out_features: int, std_init: float = 0.5):
        """Initialization."""
        super(NoisyLinear, self).__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        self.std_init = std_init

        self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features))
        self.weight_sigma = nn.Parameter(
            torch.Tensor(out_features, in_features)
        )
        self.register_buffer(
            "weight_epsilon", torch.Tensor(out_features, in_features)
        )

        self.bias_mu = nn.Parameter(torch.Tensor(out_features))
        self.bias_sigma = nn.Parameter(torch.Tensor(out_features))
        self.register_buffer("bias_epsilon", torch.Tensor(out_features))

        self.reset_parameters()
        self.reset_noise()

    def reset_parameters(self):
        """Reset trainable network parameters (factorized gaussian noise)."""
        mu_range = 1 / np.sqrt(self.in_features)
        self.weight_mu.data.uniform_(-mu_range, mu_range)
        self.weight_sigma.data.fill_(
            self.std_init / np.sqrt(self.in_features)
        )
        self.bias_mu.data.uniform_(-mu_range, mu_range)
        self.bias_sigma.data.fill_(
            self.std_init / np.sqrt(self.out_features)
        )

    def reset_noise(self):
        """Make new noise."""
        epsilon_in = self.scale_noise(self.in_features)
        epsilon_out = self.scale_noise(self.out_features)

        # outer product
        self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in))
        self.bias_epsilon.copy_(epsilon_out)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward method implementation.
        
        We don't use separate statements on train / eval mode.
        It doesn't show remarkable difference of performance.
        """
        return F.linear(
            x,
            self.weight_mu + self.weight_sigma * self.weight_epsilon,
            self.bias_mu + self.bias_sigma * self.bias_epsilon,
        )
    
    @staticmethod
    def scale_noise(size: int) -> torch.Tensor:
        """Set scale to make noise (factorized gaussian noise)."""
        x = torch.FloatTensor(np.random.normal(loc=0.0, scale=1.0, size=size))

        return x.sign().mul(x.abs().sqrt())
    
class Network(nn.Module):
    def __init__(
        self, 
        sequence_len, 
        alphabet_len,
        atom_size, 
        support
    ):
        """Initialization."""
        super(Network, self).__init__()
        
        self.support = support
        self.dim = sequence_len*alphabet_len
        self.atom_size = atom_size

        # set common feature layer
        num_moves = sequence_len*alphabet_len
        self.feature_layer = nn.Sequential(
            nn.Linear(self.dim, self.dim), 
            nn.ReLU(),
        )
        
        # set advantage layer
        self.advantage_hidden_layer = NoisyLinear(self.dim, self.dim)
        self.advantage_layer = NoisyLinear(self.dim, self.dim*atom_size)

        # set value layer
        self.value_hidden_layer = NoisyLinear(self.dim, self.dim)
        self.value_layer = NoisyLinear(self.dim, atom_size)

    def forward(self, x):
        """Forward method implementation."""
        x = x.view(-1, self.dim)
        dist = self.dist(x)
        q = torch.sum(dist * self.support, dim=2)
        
        return q
    
    def dist(self, x):
        """Get distribution for atoms."""
        x = x.view(-1, self.dim)
        feature = self.feature_layer(x)
        adv_hid = F.relu(self.advantage_hidden_layer(feature))
        val_hid = F.relu(self.value_hidden_layer(feature))
        
        advantage = self.advantage_layer(adv_hid).view(
            -1, self.dim, self.atom_size
        )
        value = self.value_layer(val_hid).view(-1, 1, self.atom_size)
        q_atoms = value + advantage - advantage.mean(dim=1, keepdim=True)
        
        dist = F.softmax(q_atoms, dim=-1)
        
        return dist
    
    def reset_noise(self):
        """Reset all noisy layers."""
        self.advantage_hidden_layer.reset_noise()
        self.advantage_layer.reset_noise()
        self.value_hidden_layer.reset_noise()
        self.value_layer.reset_noise()
    
def build_network(sequence_len, alphabet_len, atom_size, support, device):
    model = Network(sequence_len, alphabet_len, atom_size, support).to(device)
    print(model)
    return model

class RL_agent_Rainbow_DQN():
    '''
    Based off https://github.com/Curt-Park/rainbow-is-all-you-need/blob/master/08.rainbow.ipynb
    '''
    def __init__(self, 
                 start_sequence, 
                 alphabet, 
                 alpha = 0.6,
                 beta = 0.4,
                 gamma = 0.9, 
                 prior_eps = 1e-6,
                 memory_size = 100000, 
                 batch_size = 128, 
                 v_min = 0.0,
                 v_max = 10.0,
                 atom_size = 51,
                 n_step = 3,
                 device = "cpu", 
                 noise_alpha=1):
        self.alphabet = alphabet
        self.state = translate_string_to_one_hot(start_sequence, self.alphabet)
        self.seq_size = len(start_sequence)
        self.device = device
        # neural networks and their parameters 
        self.v_min = v_min
        self.v_max = v_max
        self.atom_size = atom_size
        self.support = torch.linspace(
            self.v_min, self.v_max, self.atom_size
        ).to(device)
        self.net = build_network(self.seq_size, len(self.alphabet), atom_size, 
                                 self.support, device)
        self.target_net = build_network(self.seq_size, len(self.alphabet), atom_size, 
                                        self.support, device)
        self.target_net.load_state_dict(self.net.state_dict())
        self.net.eval()
        self.target_net.eval()
        # other params
        self.start_sequence = translate_string_to_one_hot(start_sequence,self.alphabet)
        self.memory_size = memory_size
        self.gamma = gamma
        self.batch_size = batch_size
        # 1-Step Learning
        obs_dim = (len(self.alphabet), self.seq_size)
        self.beta = beta
        self.prior_eps = prior_eps
        self.memory = PrioritizedReplayBuffer(
            obs_dim, memory_size, self.batch_size, alpha=alpha
        )
        # N-Step Learning 
        self.n_step = n_step
        self.use_n_step = True if n_step > 1 else False
        if self.use_n_step:
            self.n_step = n_step
            self.memory_n = ReplayBuffer(
                obs_dim, memory_size, self.batch_size, n_step=n_step, gamma=gamma
            )

        self.seen_sequences = []
        self.landscape = Noise_wrapper(RNA_landscape(start_sequence), noise_alpha=noise_alpha)
        self.best_fitness = 0

    def reset_position(self,sequence):
        self.state=translate_string_to_one_hot(sequence,self.alphabet)

    def get_position(self):
        return translate_one_hot_to_string(self.state,self.alphabet)

    def translate_pwm_to_sequence(self,input_seq_one_hot,output_pwm):
        diff=output_pwm-input_seq_one_hot
        most_likely=np.argmax(diff,axis=0)
        out_seq=""
        for m in most_likely:
            out_seq+=self.alphabet[m]
        return out_seq
    
    def q_network_loss(self, samples, gamma):
        """
        Calculate MSE between actual state action values,
        and expected state action values from DQN
        """
        state = torch.FloatTensor(samples["obs"]).to(device)
        next_state = torch.FloatTensor(samples["next_obs"]).to(device)
        action = torch.LongTensor(samples["acts"]).to(device)
        reward = torch.FloatTensor(samples["rews"].reshape(-1, 1)).to(device)

        # Categorical DQN algorithm
        delta_z = float(self.v_max - self.v_min) / (self.atom_size - 1)

        with torch.no_grad():
            # Double DQN
            next_action = self.net(next_state).argmax(1)
            next_dist = self.target_net.dist(next_state)
            next_dist = next_dist[range(self.batch_size), next_action]

            t_z = reward + gamma * self.support
            t_z = t_z.clamp(min=self.v_min, max=self.v_max)
            b = (t_z - self.v_min) / delta_z
            l = b.floor().long()
            u = b.ceil().long()
            offset = torch.linspace(
                    0, (self.batch_size - 1) * self.atom_size, self.batch_size
                ).long().unsqueeze(1).expand(self.batch_size, self.atom_size)

            proj_dist = torch.zeros(next_dist.size(), device=self.device)
            proj_dist.view(-1).index_add_(
                0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1)
            )
            proj_dist.view(-1).index_add_(
                0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1)
            )

        dist = self.net.dist(state)
        log_p = torch.log(dist[range(self.batch_size), action])
        elementwise_loss = -(proj_dist * log_p).sum(1)

        return elementwise_loss
    
    def train_actor(self, train_epochs=10):
        total_loss = 0.
        # train Q network on new samples 
        optimizer = optim.Adam(self.net.parameters())
        for epoch in range(train_epochs):
            # sample for 1-step learning loss first 
            samples = self.memory.sample_batch(self.beta)
            weights = torch.FloatTensor(
                samples["weights"].reshape(-1, 1)
            ).to(self.device)
            indices = samples["indices"]
            elementwise_loss = self.q_network_loss(samples, self.gamma)
            loss = torch.mean(elementwise_loss * weights)
            # add in N-step learning loss if it is being used 
            if self.use_n_step:
                gamma = self.gamma ** self.n_step
                samples = self.memory_n.sample_batch_from_idxs(indices)
                elementwise_loss_n_loss = self.q_network_loss(samples, gamma)
                elementwise_loss += elementwise_loss_n_loss
                loss = torch.mean(elementwise_loss * weights)
            # train model 
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # PER: update priorities
            loss_for_prior = elementwise_loss.detach().cpu().numpy()
            new_priorities = loss_for_prior + self.prior_eps
            self.memory.update_priorities(indices, new_priorities)
            total_loss += loss.item()
            self.net.reset_noise()
            self.target_net.reset_noise()
        return (total_loss / train_epochs)

    def pick_action(self):
        # get action from current state, use noisy network instead of epsilon-greedy for exploration
        action = self.net(torch.FloatTensor(self.state)).argmax()
        action_matrix = np.zeros(len(self.alphabet) * self.seq_size)
        action_matrix[action] = 1
        action_matrix = action_matrix.reshape((len(self.alphabet), self.seq_size))
        # get new state after action is performed 
        mutant = construct_mutant_from_sample(action_matrix, self.state)
        mutant_string = translate_one_hot_to_string(mutant, self.alphabet)
        self.state = mutant

        return action, mutant
    
    def run_RL(self, generations=10, train_epochs=10):
        while self.landscape.cost < self.batch_size*generations:
            eps = max(0.05, (0.5 - self.landscape.cost / (self.batch_size * generations)))
            b = 0
            new = []
            while(b < self.batch_size):
                state = self.state.copy() 
                action, new_state = agent.pick_action()
                new_state_string = translate_one_hot_to_string(new_state, self.alphabet)
                reward = self.landscape.get_fitness(new_state_string)
                if not new_state_string in self.landscape.measured_sequences:
                    if reward > self.best_fitness:
                        print(self.net(torch.tensor(self.state).float()).detach().numpy()[0])
                        print(self.target_net(torch.tensor(self.state).float()).detach().numpy()[0])
                    self.best_fitness = max(self.best_fitness, reward)
                    # add N-step transition
                    if self.use_n_step:
                        is_n_step_stored = self.memory_n.store(state, action, reward, new_state) 
                    # add a single step transition
                    if not self.use_n_step or is_n_step_stored:
                        self.memory.store(state, action, reward, new_state)
                    b += 1
            # reset target network 
            self.target_net.load_state_dict(self.net.state_dict())
            
            avg_loss = agent.train_actor(train_epochs)
            print (self.landscape.cost, self.best_fitness, avg_loss)

In [None]:
batch_size = 1000
generations = 20
agent = RL_agent_Rainbow_DQN(wt, alphabet=RAA, gamma=0.9, atom_size=10,
                             batch_size=128, memory_size=10000, device=device)
agent.run_RL(train_epochs=40)