<a href="https://colab.research.google.com/github/rpadmanabhan/mdp-rl-project/blob/main/mdp_rl_explore.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install stable-baselines3[extra]

Collecting stable-baselines3[extra]
  Downloading stable_baselines3-1.5.0-py3-none-any.whl (177 kB)
[K     |████████████████████████████████| 177 kB 5.3 MB/s 
[?25hCollecting gym==0.21
  Downloading gym-0.21.0.tar.gz (1.5 MB)
[K     |████████████████████████████████| 1.5 MB 33.3 MB/s 
Collecting ale-py~=0.7.4
  Downloading ale_py-0.7.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.6 MB)
[K     |████████████████████████████████| 1.6 MB 34.8 MB/s 
Collecting autorom[accept-rom-license]~=0.4.2
  Downloading AutoROM-0.4.2-py3-none-any.whl (16 kB)
Collecting AutoROM.accept-rom-license
  Downloading AutoROM.accept-rom-license-0.4.2.tar.gz (9.8 kB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Building wheels for collected packages: gym, AutoROM.accept-rom-license
  Building wheel for gym (setup.py) ... [?25l[?25hdone
  Created wheel for gym: filename=gym-0.

In [None]:
from dataclasses import dataclass, field
from typing import Dict, List

In [None]:
import gym
from gym import spaces

from stable_baselines3.common.env_checker import check_env
from stable_baselines3 import DQN, PPO, A2C
from stable_baselines3.common.cmd_util import make_vec_env


import numpy as np

In [None]:
@dataclass(frozen = True)
class RegionOfInterest:
    chrom: str
    start: int
    end: int
    info: Dict = field(default_factory = dict)

@dataclass(frozen = True)
class SinglePrimerOligo:
    seq: str
    chrom: str
    loc5: int
    loc3: int
    strand: int

    def ret_coverage_range(self):
        if self.strand == 0:
            return (self.loc3 + 1, self.loc3 + 150)
        else:
            return (self.loc3 - 150, self.loc3 - 1)

    def covered_bases_for_roi(self, roi: RegionOfInterest):
        oligo_cov_start, oligo_cov_end = self.ret_coverage_range()
        return (max(oligo_cov_start, roi.start), min(oligo_cov_end, roi.end))

    def oligo_score(self):
        return 1


@dataclass(frozen = True)
class ProbeOligo:
    seq: str
    hyb_extension: int = 50

    def ret_coverage_range(self, ref_seq):
        start = ref_seq.find(seq)
        assert start != -1
        # return coverage [start, end) range in reference sequence
        return start - self.hyb_extension, start + len(seq) + self.hyb_extension

class RefSequences:
    sequences: Dict[str, str]

    def __postinit__(self):
        self.coverage_profile = {seq_name: np.zeros(len(seq)) for seq_name, seq in sequences.items()}

    def update_coverage_profile(self, probe: ProbeOligo):
        """ Update base level coverage for all reference sequences
        """
        for seq_name,seq in sequences.items():
            start, end = probe.ret_coverage_range(seq)
            for idx in range(start, end):
                self.coverage_profile[seq_name][idx] += 1

    def ret_num_bases_covered(self):
        """ Provide a count for the number of bases covered for each reference sequence
        """
        num_bases_covered = []
        for seq_name, seq in sequences.items():
            num_bases_covered.append(np.count_nonzero(self.coverage_profile[seq_name]))
        return num_bases_covered

    def ret_mean_bases_covered(self):
        """Return mean bases covered across all reference sequences
        """
        return statistics.mean(self.ret_num_bases_covered())


class OligoEnv(gym.Env):
    def __init__(self, *args, roi: RegionOfInterest, oligos: List[SinglePrimerOligo], **kwargs):
        super(OligoEnv, self).__init__()

        # initialize ROI
        self.roi = roi

        # all possible oligos we have
        self.oligos = oligos

        # where is the agent ?

        # action space
        self.action_space = spaces.Discrete(len(self.oligos))

        # observation space - state of the agent
        self.observation_space = spaces.MultiBinary(len(self.oligos))

        # some reward related constants
        self._reward_oligo_penalty = -10
        self._reward_action_masking = -5

        # variables to reset for state space
        self._actions_taken = None

    def _check_termination(self, observation):
        return np.all(observation == 1)

    def _compute_reward(self, action):
        oligo = self.oligos[action]
        covered_bases_range = oligo.covered_bases_for_roi(self.roi)
        total_covered_bases = covered_bases_range[1] - covered_bases_range[0]

        return total_covered_bases + self._reward_oligo_penalty

    def step(self, action):
        """
        """

        reward = self._reward_action_masking
        if action not in self._actions_taken:
            self._actions_taken.add(action)
            reward = self._compute_reward(action)

        observation = np.array([1 if idx in self._actions_taken else 0 for idx in range(0, len(self.oligos))], dtype = np.int8)
        done = bool(self._check_termination(observation))
        info = {}
        return observation, reward, done, info

    def reset(self):

        # no oligos picked
        self._actions_taken = set({})
        init_observation = np.zeros(len(self.oligos), dtype = np.int8)

        return init_observation

    def render(self, mode = "console"):
        pass

    def close (self):
        pass

In [None]:
oligos = [
          SinglePrimerOligo(seq = "ACCG", chrom = "1", loc5 = 10, loc3 = 20, strand = 0),
          SinglePrimerOligo(seq = "ACCT", chrom = "1", loc5 = 2, loc3 = 14, strand = 0),
          SinglePrimerOligo(seq = "ACGG", chrom = "1", loc5 = 18, loc3 = 25, strand = 0)
]
roi = RegionOfInterest(chrom = "1", start = 15, end = 300)


In [None]:
env = OligoEnv(roi = roi, oligos = oligos)

In [None]:
check_env(env)


In [None]:
test_space = spaces.MultiBinary(5)
test_space.sample().shape

(5,)

In [None]:
np.zeros(5, dtype = np.int8).shape

(5,)

In [None]:
np.all(np.array([0, 1, 0, 1], dtype = np.int8).shape == 0)

False

In [None]:
test_space.contains(np.zeros(5, dtype = np.int8))

True

In [None]:
env_wrap = make_vec_env(lambda: env, n_envs = 1)
model = DQN("MlpPolicy", env, verbose = 1).learn(100)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.75     |
|    ep_rew_mean      | 408      |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 4        |
|    fps              | 5358     |
|    time_elapsed     | 0        |
|    total_timesteps  | 19       |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 6.62     |
|    ep_rew_mean      | 399      |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 8        |
|    fps              | 4210     |
|    time_elapsed     | 0        |
|    total_timesteps  | 53       |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 6.5      |
|    ep_rew_mean      | 400    

In [None]:
obs = env.reset()
n_steps = 20
for step in range(n_steps):
    action, _ = model.predict(obs, deterministic=True)
    print(step, action)
    print(env.step(action))

0 1
(array([0, 1, 0], dtype=int8), 139, False, {})
1 1
(array([0, 1, 0], dtype=int8), -5, False, {})
2 1
(array([0, 1, 0], dtype=int8), -5, False, {})
3 1
(array([0, 1, 0], dtype=int8), -5, False, {})
4 1
(array([0, 1, 0], dtype=int8), -5, False, {})
5 1
(array([0, 1, 0], dtype=int8), -5, False, {})
6 1
(array([0, 1, 0], dtype=int8), -5, False, {})
7 1
(array([0, 1, 0], dtype=int8), -5, False, {})
8 1
(array([0, 1, 0], dtype=int8), -5, False, {})
9 1
(array([0, 1, 0], dtype=int8), -5, False, {})
10 1
(array([0, 1, 0], dtype=int8), -5, False, {})
11 1
(array([0, 1, 0], dtype=int8), -5, False, {})
12 1
(array([0, 1, 0], dtype=int8), -5, False, {})
13 1
(array([0, 1, 0], dtype=int8), -5, False, {})
14 1
(array([0, 1, 0], dtype=int8), -5, False, {})
15 1
(array([0, 1, 0], dtype=int8), -5, False, {})
16 1
(array([0, 1, 0], dtype=int8), -5, False, {})
17 1
(array([0, 1, 0], dtype=int8), -5, False, {})
18 1
(array([0, 1, 0], dtype=int8), -5, False, {})
19 1
(array([0, 1, 0], dtype=int8), -5, 