In [158]:
# Imports
import numpy as np
from typing import List, Tuple
from numpy.typing import ArrayLike
from nn.io import read_text_file, read_fasta_file
import random

In [232]:
def sample_seqs(seqs: List[str], labels: List[bool]) -> Tuple[List[str], List[bool]]:
    """
    This function should sample the given sequences to account for class imbalance. 
    Consider this a sampling scheme with replacement.
    
    Args:
        seqs: List[str]
            List of all sequences.
        labels: List[bool]
            List of positive/negative labels

    Returns:
        sampled_seqs: List[str]
            List of sampled sequences which reflect a balanced class size
        sampled_labels: List[bool]
            List of labels for the sampled sequences
    """
    
    label_names, label_counts = np.unique(labels, return_counts=True)
    print(label_counts)
    print(label_names)
    
    label0 = [seq for seq, label in zip(seqs, labels) if label==label_names[0]]
    label1 = [seq for seq, label in zip(seqs, labels) if label==label_names[1]]
        
  
    
    if label_counts[0]==label_counts[1]:
        return(seqs, labels)
    
    
    elif label_counts[0]>label_counts[1]:
        
        #sample more of label_names[1] 
        sampled_label1=random.choices(label1, k=label_counts[0])
        #sampled_label1=random.choices(range(0, len(label1)), k=label_counts[0])

        sampled_seqs=label0 + sampled_label1
        sampled_labels=([label_names[0]] * label_counts[0]) + ([label_names[1]]* label_counts[0])
        
        return(sampled_seqs, sampled_labels)
        
        
    
    elif label_counts[0]<label_counts[1]:
        #sample more of label_names[0]
        
        sampled_label0=random.choices(label0, k=label_counts[1])
        #sampled_label1=random.choices(range(0, len(label1)), k=label_counts[0])

        sampled_seqs=sampled_label0 + label1
        sampled_labels=([label_names[0]] * label_counts[1]) + ([label_names[1]] * label_counts[1])

        return(sampled_seqs, sampled_labels)
        
        
    
    
    
    
    

def one_hot_encode_seqs(seq_arr: List[str]) -> ArrayLike:
    """
    This function generates a flattened one-hot encoding of a list of DNA sequences
    for use as input into a neural network.

    Args:
        seq_arr: List[str]
            List of sequences to encode.

    Returns:
        encodings: ArrayLike
            Array of encoded sequences, with each encoding 4x as long as the input sequence.
            For example, if we encode:
                A -> [1, 0, 0, 0]
                T -> [0, 1, 0, 0]
                C -> [0, 0, 1, 0]
                G -> [0, 0, 0, 1]
            Then, AGA -> [1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0].
    """
    encoded_allseq=[]
    
    nuc_dict={'A':[1,0,0,0], 
              'T':[0,1,0,0], 
              'C':[0,0,0,1], 
              'G':[0,0,0,1]}
    
    
    for i, seq in enumerate(seq_arr):
    
        encode_per_nuc=([nuc_dict[n] for n in seq.upper() if n in nuc_dict])
        #encode_seq= [i for per_nuc in encode_per_nuc for i in per_nuc]
        encoded_allseq.append(np.concatenate(encode_per_nuc))

    return(np.stack(encoded_allseq))

In [230]:
one_hot_encode_seqs(['AGACG', 'ACGCT', 'tggAc'])

array([[1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1],
       [1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0],
       [0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1]])

In [235]:
sampled_seqs, sampled_labels=sample_seqs(all_seq_list, labels)
print(len(sampled_seqs))
print(len(sampled_labels))

[3163  137]
['negative' 'positive']
6326
6326


In [33]:
# #read in fasta negatives
# neg_seq_list=[]
# with open('./data/yeast-upstream-1k-negative.fa', "r") as fasta:
#     for seq in fasta:
#         seq=seq.strip()
#         if not seq.startswith(">"):
#             neg_seq_list.append(seq)
            
            
# #read in txt file positives
# pos_seq_list=[]
# with open('./data/rap1-lieb-positives.txt', "rt") as txt:
#     for seq in txt:
#         pos_seq_list.append(seq.strip())

In [150]:
pos_seqs=read_text_file('./data/rap1-lieb-positives.txt')
neg_seqs=read_fasta_file('./data/yeast-upstream-1k-negative.fa')

min_seq_len=min(len(min(pos_seqs)), len(min(neg_seqs)))
min_seq_len


all_seqs=pos_seqs+neg_seqs



In [170]:
trimmed_all_seqs=[]
for seq in all_seqs:
    if len(seq)>min_seq_len:
        start=random.randint(0, len(seq)-min_seq_len)
        trimmed_all_seqs.append(seq[start:start+min_seq_len])
    else:
        trimmed_all_seqs.append(seq)
        
labels=(['positive'] * len(pos_seqs)) + (['negative'] * len(neg_seqs))        
        

In [171]:
sample_seqs(trimmed_all_seqs, labels)

[3163  137]
0


In [91]:
len(all_seq_list)

3300

In [92]:
len(pos_seqs)

137

In [83]:
len(neg_seqs)

3163