#### Imports

In [None]:
import re
import copy
import itertools

import pandas as pd
import numpy as np

from Bio import SeqIO
from Bio.Seq import Seq

#### Functions

In [None]:
## Finds all PAMs on a strand of a sequence record
## strand specifies the strand (1,-1) that the search is being conducted on
## startcoord specifies the beginning coordinate of the seqrecord on the underlying reference genome
def find_strand_pams(seqrecord, strand, startcoord=0):
    if strand == 1:
        seq = str(seqrecord.seq)
    else:
        seq = str(seqrecord.seq.reverse_complement())

    pam_reg = re.compile("CC")
    pam_starts = [item.start(0) for item in re.finditer(pam_reg, str(seq))]

    pam_list = []

    if strand == 1:
        for item in pam_starts:
            if len(seq[item + 3 : item + 23]) == 20:
                start = startcoord + item + 3
                end = startcoord + item + 23
                sequence = seq[item + 3 : item + 23]
                pam_list.append([start, end, sequence, strand])
    else:
        for item in pam_starts:
            if len(seq[item + 3 : item + 23]) == 20:
                start = startcoord + len(seq) - item - 23
                end = startcoord + len(seq) - item - 3
                sequence = seq[item + 3 : item + 23]
                pam_list.append([start, end, sequence, strand])
    return pam_list


## Finds all PAMs in target sequence record
## startcoord specifies the beginning coordinate of the seqrecord on the underlying reference genome
## target_strand specifies the strand (1,-1) that the seq record is on
def find_pams(seqrecord, startcoord=0, target_strand=1):
    fwd_pams = find_strand_pams(seqrecord, 1, startcoord=startcoord)
    rev_pams = find_strand_pams(seqrecord, -1, startcoord=startcoord)
    pam_df = pd.DataFrame(
        fwd_pams + rev_pams, columns=["start", "end", "sequence", "ref_strand"]
    )
    if target_strand == 1:
        pam_df["target_strand"] = 1
    else:
        pam_df["target_strand"] = -pam_df["ref_strand"]
    return pam_df


## Uses a reference csv of bad seeds to eliminate bad seeds from the PAM DataFrame
def remove_bad_seeds(pam_df, bad_seed_path):
    bad_seed_df = pd.read_csv(bad_seed_path)
    bad_seed_list = bad_seed_df["seeds"].tolist()
    ## reverse complement to match target sequence
    bad_seed_list = [
        str(Seq(item.upper()).reverse_complement()) for item in bad_seed_list
    ]

    pam_df = pam_df[pam_df["sequence"].apply(lambda x: x[:5] not in bad_seed_list)]
    return pam_df


## Converts a string to an integer representation
def str_to_int(string):
    code = {"A": 0, "C": 1, "G": 2, "T": 3}
    conv_str = np.array(list(map(lambda x: code[x], string)))
    return conv_str


## Determines the maximum matching (minimum edit distance) of each target sequence
## to sequences in the reference
## subseq_range can be set to subset the sequences to a region of interest (i.e. the PAM adjacent region)
## remove_matching_starts can be set to eliminate sequences from consideration when they are at the same location
def compare_seqs(
    target_df, reference_df, subseq_range=None, remove_matching_starts=True
):
    target_arr = target_df["sequence"].values
    reference_arr = reference_df["sequence"].values
    target_int_arr = np.array(list(map(str_to_int, target_arr)), dtype="uint8")
    reference_int_arr = np.array(list(map(str_to_int, reference_arr)), dtype="uint8")

    if subseq_range != None:
        target_int_arr = target_int_arr[:, subseq_range]
        reference_int_arr = reference_int_arr[:, subseq_range]

    bool_arr = target_int_arr[:, np.newaxis, :] == reference_int_arr[np.newaxis, :, :]
    agreement_arr = np.sum(bool_arr, axis=2, dtype=int)

    if remove_matching_starts:
        matching_starts = np.where(
            target_df["start"].values[:, np.newaxis]
            == reference_df["start"].values[np.newaxis, :]
        )[1]
        agreement_arr[:, matching_starts] = 0
    most_agreement = np.max(agreement_arr, axis=1)
    return most_agreement


### Exhaustively generate mismatched versions of the query with num_mismatch substitutions starting from the PAM distal end
def generate_all_mismatchs(in_str, num_mismatch):
    flip_dict = {
        "A": ["T", "C", "G"],
        "T": ["A", "C", "G"],
        "C": ["T", "A", "G"],
        "G": ["T", "C", "A"],
    }
    in_str_len = len(in_str)
    prod = list(
        itertools.product(
            *[flip_dict[in_str[in_str_len - i - 1]] for i in range(num_mismatch)][::-1]
        )
    )
    new_strs = [in_str[: in_str_len - num_mismatch] + "".join(item) for item in prod]
    return new_strs


### Randomly (with replacement) generate mismatched versions of the query with num_mismatch substitutions
def generate_mismatch(in_str, num_mismatch):
    flip_dict = {
        "A": ["T", "C", "G"],
        "T": ["A", "C", "G"],
        "C": ["T", "A", "G"],
        "G": ["T", "C", "A"],
    }
    in_str_len = len(in_str)
    list_str = list(in_str)
    new_str = copy.copy(list_str)
    for i in range(num_mismatch):
        new_char = np.random.choice(flip_dict[list_str[in_str_len - i - 1]])
        new_str[in_str_len - i - 1] = new_char
    new_str = "".join(new_str)
    return new_str


### Generate a mismatch DataFrame with PAMs containing k[i] mismatches, taking n_samples for each k[i]
### If k[i]<max_exhaustive, then this is done exhaustively
def generate_mismatch_df(pam_df, k=[1, 2, 4, 8, 10], n_samples=50, max_exhaustive=5):
    mismatch_df = []
    for i, row in pam_df.iterrows():
        seq = row["sequence"]
        start = row["start"]
        end = row["end"]
        ref_strand = row["ref_strand"]
        target_strand = row["target_strand"]
        mismatch_df.append([start, end, seq, ref_strand, target_strand, 0])
        for k in [1, 2, 4, 8, 10]:
            if k <= max_exhaustive:
                mismatch_list = generate_all_mismatchs(seq, k)
            else:
                mismatch_list = list(
                    set([generate_mismatch(seq, k) for i in range(n_samples)])
                )
            mismatch_df += [
                [start, end, item, ref_strand, target_strand, k]
                for item in mismatch_list
            ]

    mismatch_df = pd.DataFrame(
        mismatch_df,
        columns=[
            "start",
            "end",
            "sequence",
            "ref_strand",
            "target_strand",
            "num_mismatch",
        ],
    )
    return mismatch_df

#### Generate target sequences

In [None]:
genome = SeqIO.read("./DE161_reference.gb", "gb")

ref_start = 807758
ref_end = 808585
target = genome[ref_start:ref_end]

In [None]:
target_pam_df = find_pams(target, startcoord=ref_start, target_strand=-1)
genome_pam_df = find_pams(genome)

target_pam_df = remove_bad_seeds(target_pam_df, "./bad_seed_list.csv")

In [None]:
genome_pam_df[:10]

In [None]:
PAM_adj_size = 10
most_agreement = compare_seqs(target_pam_df, genome_pam_df, range(0, PAM_adj_size))
past_threshold = most_agreement < PAM_adj_size
target_pam_df_nooff = target_pam_df[past_threshold]

In [None]:
print("Percent Past Threshold: " + str(np.sum(past_threshold) / len(past_threshold)))

In [None]:
mismatch_df = generate_mismatch_df(target_pam_df_nooff, n_samples=20)

In [None]:
mismatch_df

In [None]:
promoter_guides_df = mismatch_df[(mismatch_df["start"] > 808496)].reset_index(drop=True)
most_agreement = compare_seqs(promoter_guides_df, genome_pam_df, range(0, PAM_adj_size))
past_threshold = most_agreement < PAM_adj_size
promoter_guides_df_nooff = promoter_guides_df[past_threshold]

In [None]:
print("Percent Past Threshold: " + str(np.sum(past_threshold) / len(past_threshold)))

In [None]:
promoter_guides_df_nooff

In [None]:
site_1_df = promoter_guides_df_nooff[
    promoter_guides_df_nooff["start"] == 808513
].reset_index(drop=True)
site_1_subsample = site_1_df.groupby("num_mismatch").sample(1).reset_index(drop=True)

In [None]:
site_1_subsample

In [None]:
guides_df = mismatch_df[mismatch_df["target_strand"] == 1].reset_index(drop=True)
most_agreement = compare_seqs(guides_df, genome_pam_df, range(0, PAM_adj_size))
past_threshold = most_agreement < PAM_adj_size
guides_df_nooff = guides_df[past_threshold]

In [None]:
print("Percent Past Threshold: " + str(np.sum(past_threshold) / len(past_threshold)))

In [None]:
extra_guides = guides_df[guides_df["num_mismatch"] == 0][2:4]

In [None]:
site_2_df = guides_df[guides_df["start"] == 808416].reset_index(drop=True)
site_2_subsample = site_2_df.groupby("num_mismatch").sample(1).reset_index(drop=True)

In [None]:
site_2_subsample

In [None]:
selected_guides_df = pd.concat([site_1_subsample, site_2_subsample]).reset_index(
    drop=True
)

In [None]:
selected_guides_df

In [None]:
### Convert a target sequence to a spacer sequence
def target_to_spacer(target_str):
    target = Seq(target_str.upper())
    spacer = target.reverse_complement()
    spacer = str(spacer)
    return spacer


### Add bsai sites and primer sequences to the side for cloning
def add_bsaI_sites(spacer):  ##make more general later
    site_1 = "AGGCACTTGCTCGTACGACGGAAGACATTAGT"
    site_2 = "GTTTTCGTCTTCTTAAGGTGCCGGGCCCACAT"
    output_seq = site_1 + spacer + site_2
    return output_seq


### Convert a target sequence to a spacer with bsai sites and basic PCR primers
def target_to_padded_spacer(target_str):
    spacer = target_to_spacer(target_str)
    padded_spacer = add_bsaI_sites(spacer)
    return padded_spacer

In [None]:
selected_guides_df["sequence_to_order"] = selected_guides_df["sequence"].apply(
    target_to_padded_spacer
)

In [None]:
selected_guides_df.to_csv("DE161_extra_guide_df.csv")

In [None]:
extra_guides["sequence_to_order"] = extra_guides["sequence"].apply(
    target_to_padded_spacer
)

In [None]:
extra_guides.to_csv("DE161_extra_guide_df.csv")