In [2]:
import numpy as np
import random
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

In [20]:
def generate_nb(state, E=1):
    '''
    Given the state, generate the NB distribution
    state: int
    E: int. Gene expression level (1,5,10,15,20)
    '''
    mean_list_true = np.array([60, 350, 300, 250, 200, 150, 120, 320, 250, 220, 60, 450, 400, 300, 250, 230, 100, 430, 350, 320, 60]) # mean
    variance_list_true = mean_list_true*1.2 
    beta_list_true = 1/(variance_list_true/mean_list_true-1)
    alpha_list_true = mean_list_true * beta_list_true
    return np.random.negative_binomial(alpha_list_true[state-1], beta_list_true[state-1]/(E+beta_list_true[state-1]))

In [260]:
def simulate_single():
    '''
    Simulate single RNA sequence contains 2 scenarios: both uORF and CDS, only CDS
    start codon: AUG
    stop codons: UAA, UGA, UAG
    '''
    stop_codon_list = ['UAA','UGA','UAG']

    # indexs for initiation
    state_1_2 = 0
    state_1_12 = 0
    state_11_12 = 0

    # first several bases without start codons
    curr_RNA = random.choices(['A', 'C', 'G', 'U'], weights=np.ones(4), k = 5)
    while 'AUG' in "".join(curr_RNA):
        curr_RNA = random.choices(['A', 'C', 'G', 'U'], weights=np.ones(4), k = 5)
    observed_counts = [generate_nb(1), generate_nb(1), generate_nb(1), generate_nb(1), generate_nb(1)]
    states_seq = [1, 1, 1, 1, 1]
    state = 1


    # 5'U
    while state == 1: 
      codons = random.choices(['A', 'C', 'G', 'U'], weights=np.ones(4), k = 3)
      if "".join(codons) == 'AUG':
        codons = random.choices(['A', 'C', 'G', 'U'], weights=np.ones(4), k = 3)
      curr_RNA += codons
      observed_counts.extend((generate_nb(1),generate_nb(1),generate_nb(1)))
      states_seq.extend((1,1,1))
      state = random.choices([1, 2, 12], weights=[0.2, 0.5, 0.3], k = 1)[0]

    if state == 2:
      state_1_2 = 1
    elif state == 12:
      state_1_12 = 1

    # uORF
    if state == 2:

      # initiation site
      curr_RNA += ['A', 'U', 'G'] # add start codon
      observed_counts.extend((generate_nb(2),generate_nb(3),generate_nb(4)))
      states_seq.extend((2,3,4))

      # elongation recycing: set shorter elongation compared to CDS 
      repeat = random.choices(np.arange(10, 31), weights=np.ones(21), k = 1)[0]
      for i in range(repeat):
        codons = random.sample(['A', 'C', 'G', 'U'], 3)
        if "".join(codons) in stop_codon_list:
          codons = random.sample(['A', 'C', 'G', 'U'], 3)
        curr_RNA += codons
        observed_counts.extend((generate_nb(5),generate_nb(6),generate_nb(7)))
        states_seq.extend((5,6,7))

      # termination site
      curr_RNA += random.choices([['U', 'A', 'A'], ['U', 'G', 'A'], ['U', 'A', 'G']], weights=[1, 1, 1], k = 1)[0]
      observed_counts.extend((generate_nb(8),generate_nb(9),generate_nb(10)))
      states_seq.extend((8,9,10))

      # 5'U2
      state = 11
      temp = random.choices(['A', 'C', 'G', 'U'], weights=np.ones(4), k = 10)
      while 'AUG' in "".join(temp):
          temp = random.choices(['A', 'C', 'G', 'U'], weights=np.ones(4), k = 10)
      curr_RNA += temp
      observed_counts.extend((generate_nb(11),generate_nb(11),generate_nb(11),generate_nb(11),generate_nb(11),generate_nb(11),generate_nb(11),generate_nb(11),generate_nb(11),generate_nb(11)))
      states_seq.extend((11,11,11,11,11,11,11,11,11,11))

      while state == 11:
        codons = random.choices(['A', 'C', 'G', 'U'], weights=np.ones(4), k = 3)
        if "".join(codons) == 'AUG':
          codons = random.choices(['A', 'C', 'G', 'U'], weights=np.ones(4), k = 3)
        curr_RNA += codons
        observed_counts.extend((generate_nb(11),generate_nb(11),generate_nb(11)))
        states_seq.extend((11,11,11))
        #if len(curr_RNA) >= 100:
           #return [curr_RNA, observed_counts, states_seq, [state_1_2, state_1_12, state_11_12]]
        state = random.choices([11, 12], weights=[0.2, 0.8], k = 1)[0]
        if state == 12:
          state_11_12 = 1

   # CDS
    if state == 12:

     # initiation site
      curr_RNA += ['A', 'U', 'G'] # add start codon
      observed_counts.extend((generate_nb(12),generate_nb(13),generate_nb(14)))
      states_seq.extend((12,13,14))

     # elongation recycing: set longer elongation compared to uORF 
      repeat = random.choices(np.arange(32, 53), weights=np.ones(21), k = 3)[0]
      for i in range(repeat):
        codons = random.choices(['A', 'C', 'G', 'U'], weights=np.ones(4), k = 3)
        if "".join(codons) in stop_codon_list:
          codons = random.choices(['A', 'C', 'G', 'U'], weights=np.ones(4), k = 3)
        curr_RNA += codons
        observed_counts.extend((generate_nb(15),generate_nb(16),generate_nb(17)))
        states_seq.extend((15,16,17))

     # termination site
      curr_RNA += random.choices([['U', 'A', 'A'], ['U', 'G', 'A'], ['U', 'A', 'G']], weights=[1, 1, 1], k = 1)[0]
      observed_counts.extend((generate_nb(18),generate_nb(19),generate_nb(20)))
      states_seq.extend((18,19,20))  

     # 3'U
      repeat = random.choices(np.arange(10, 31), weights=np.ones(21), k = 1)[0]
      for i in range(repeat):
        curr_RNA += random.choices(['A', 'C', 'G', 'U'], weights=np.ones(4), k = 3)
        observed_counts.extend((generate_nb(21),generate_nb(21),generate_nb(21)))
        states_seq.extend((21,21,21))

    return [curr_RNA, observed_counts, states_seq, [state_1_2, state_1_12, state_11_12]]


In [261]:
def simulate_main(length):
    '''
    Simulate multiple RNA sequence contains 3 scenarios: only uORF + only CDS + both
    length: the number of RNA sequences

    start codon: AUG
    stop codons: UAA, UGA, UAG
    ''' 

    RNA_data = []
    counts_data = []
    states_data = []
    transition_list = []

    for i in range(length):
        temp = simulate_both()
        
        # change to only uORF
        if i%3 == 0:
          if temp[3][0] == 1 and temp[3][2] == 1:
            RNA_seq = temp[0]
            observed_counts = temp[1]
            states_seq = temp[2]
           
            start_index = states_seq.index(12)
            stop_index = states_seq.index(20)
            final_index = states_seq.index(21)

            # replace original state sequence
            states_replace = [11 for i in states_seq[start_index:]]
            states_seq[start_index:] = states_replace

            # replace original count sequence
            count_replace = [generate_nb(11) for i in observed_counts[start_index:]]
            observed_counts[start_index:] = count_replace

            # replace original RNA sequence
            for j in range(start_index,stop_index-2,3):
              condons_replace = random.choices(['A','C', 'G', 'U'], weights=np.ones(4), k = 3)
              if "".join(condons_replace) == 'AUG':
                  condons_replace = random.choices(['A','C', 'G', 'U'], weights=np.ones(4), k = 3)
              RNA_seq[j:(j+3)] = condons_replace

            condons_replace1 = [random.choices(['A','C', 'G', 'U'], weights=np.ones(4), k = 1)[0] for r in RNA_seq[final_index:]]
            RNA_seq[final_index:] = condons_replace1

            temp[0] = RNA_seq
            temp[1] = observed_counts
            temp[2] = states_seq
            temp[3][2] = 0

            RNA_data.append(temp[0])
            counts_data.append(temp[1])
            states_data.append(temp[2])
            transition_list.append(temp[3])
        
        RNA_data.append(temp[0])
        counts_data.append(temp[1])
        states_data.append(temp[2])
        transition_list.append(temp[3])
  
    return [RNA_data, counts_data, states_data, transition_list]

In [193]:
def simulate_neither(length):
  '''
  Simulate multiple RNA sequence without translated regions
  length: the number of RNA sequences
  ''' 
  RNA_data = []
  counts_data = []
  states_data = []
  transition_list = []
  #RNA_data = random.choices(['A', 'C', 'G', 'U'], weights=np.ones(4), k = 5)
  #while 'AUG' in "".join(RNA_data):
    #RNA_data = random.choices(['A', 'C', 'G', 'U'], weights=np.ones(4), k = 5)
  #counts_data = [generate_nb(1), generate_nb(1), generate_nb(1), generate_nb(1), generate_nb(1)]
  #states_data = [1, 1, 1, 1, 1]
  #transition_list = []

  for i in range(length):
    RNA_seq = []
    observed_counts = []
    states_seq = []
    repeat = random.choices(np.arange(30, 81), weights=np.ones(51), k = 1)[0]
    for j in range(repeat):
        codons = random.choices(['A', 'C', 'G', 'U'], weights=np.ones(4), k = 3)
        RNA_seq.extend(codons)
        observed_counts.extend((generate_nb(1),generate_nb(1),generate_nb(1)))
        states_seq.extend((1,1,1))

    RNA_data.append(RNA_seq)
    counts_data.append(observed_counts)
    states_data.append(states_seq)
    transition_list.append([0,0,0])
    
  return [RNA_data, counts_data, states_data, transition_list]


In [None]:
# generate datasets
random.seed(0)
part1 = simulate_main(2700) 
part2 = simulate_neither(300)  
RNA_data = part1[0] + part2[0]
observed_data = part1[1] + part2[1]
states_seq = part1[2] + part2[2]
transition_list = part1[3] + part2[3]

In [None]:
# save as file: RNA_data, counts_data, states_data and transition_list
def saveListToFile(listname, pathtosave):
    file1 = open(pathtosave,"w") 
    for i in listname:
        file1.writelines("{}\n".format(i))    
    file1.close() 

listname = [RNA_data, observed_data, states_seq, transition_list]
saveListToFile(listname, path)

In [None]:
# visualization
# both
# x_tick = list(map(lambda num: "" if num % 10 != 0 else num, range(len(observed_data[2]))))
# plt.figure(figsize=(300, 100), dpi=10)
# plt.ylabel('Number of RPF counts')
# plt.grid()
# plt.bar(np.arange(len(observed_data[2])), observed_data[2], width=0.5)
# plt.ylim(0,500)
# plt.xlim(0,350)
# plt.xticks(range(len(x_tick)), x_tick, size=200)
# plt.yticks(size=200)
# plt.show()
# # only main coding region
# plt.bar(np.arange(len(observed_data[0])), observed_data[0])
# plt.ylim(0,500)
# plt.xlim(0,350)
# plt.show()
# # only uORF
# plt.bar(np.arange(len(observed_data[5])), observed_data[5])
# plt.ylim(0,500)
# plt.xlim(0,350)
# plt.show()
# # none
# plt.bar(np.arange(len(observed_data[2500])), observed_data[2500])
# plt.ylim(0,500)
# plt.xlim(0,350)
# plt.show()