In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import h5py
%matplotlib inline
import os


In [10]:
gRNA = {'AAVS1_s14': 'GGGGCCACTAGGGACAGGATTGG',
        'CTLA4_s9': 'GGACTGAGGGCCATGGACACGGG',
        'TRAC_s1': 'GTCAGGGTTCTGGATATCTGTGG',
        'LAG3_s9': 'GAAGGCTGAGATCCTGGAGGGGG',
        'CXCR4_s8': 'GTCCCCTGAGCCCATTTCCTCGG',
        'CCR5_s8': 'GGACAGTAAGAAGGAAAAACAGG'}
SITE_LIST = ['CCR5_s8', 'LAG3_s9', 'TRAC_s1', 'CTLA4_s9', 'AAVS1_s14']
DATA_PATH = 'data/combined_scaled_log1p_df_pivot_activity_20241126.csv'
OUTPUT_PATH = 'data/randomized_change_seq_large_data_log21p_activity_11262024_site.hdf5'


In [11]:
dt = h5py.special_dtype(vlen=str)
np.random.seed(42)

def genome_onehot_encoding_pair(seq1, seq2):
    seq1 = list(seq1.upper())
    seq2 = list(seq2.upper())
    D1 = {'A': [1,0,0,0], 'C': [0,1,0,0], 'G': [0,0,1,0], 'T': [0,0,0,1], 'N': [1/4,1/4,1/4,1/4], '-': [0,0,0,0],
         'M': [0.5,0.5,0,0], 'R': [0.5,0,0.5,0], 'W':[0.5,0,0,0.5], 'S': [0,0.5,0.5,0], 'Y': [0,0.5,0,0.5],'K': [0,0,0.5,0.5],
         'V': [1/3,1/3,1/3,0], 'H': [1/3,1/3,0,1/3], 'D': [1/3,0,1/3,1/3], 'B': [0,1/3,1/3,1/3]}
    D2 = {'A': [0.97,0.01,0.01,0.01], 'C': [0.01,0.97,0.01,0.01], 'G': [0.01,0.01,0.97,0.01], 'T': [0.01,0.01,0.01,0.97], 'N': [1/4,1/4,1/4,1/4], '-': [0,0,0,0],
         'M': [0.5,0.5,0,0], 'R': [0.5,0,0.5,0], 'W':[0.5,0,0,0.5], 'S': [0,0.5,0.5,0], 'Y': [0,0.5,0,0.5],'K': [0,0,0.5,0.5],
         'V': [1/3,1/3,1/3,0], 'H': [1/3,1/3,0,1/3], 'D': [1/3,0,1/3,1/3], 'B': [0,1/3,1/3,1/3]}

    mat1 = np.array([D1[i] for i in seq1])
    mat2 = np.array([D1[i] for i in seq2])
    return np.hstack((mat1, mat2)).astype(np.float32)


def seq_to_num(seq):
    D = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
    return [D[i]  for i in seq]

df = pd.read_csv(DATA_PATH, header = 0)
df['log2_relative_activity_percent'] = np.log2(df['relative_activity_percent'])
df['log1p_relative_activity_percent'] = np.log2(1 + df['relative_activity_percent'])
df['control_counts'] = np.minimum(df['control_counts_rep1'], df['control_counts_rep2'])
df['Cas9_counts'] = np.minimum(df['Cas9_counts_rep1'], df['Cas9_counts_rep2'])
print(df.head(10))

print(df.head(10))
df = df[df['MM'] <= 6]
print(len(df))

for site in SITE_LIST:
    np.random.seed(42)
    output_path = OUTPUT_PATH.replace('site', site)
    df_site = df[df['site'] == site.replace('_','.')]
    mismatches = df_site['MM'].tolist()
    seq_list = df_site['control_target_seq'].tolist()
    # Remove indels for now
    off_target_seq = df_site['control_target_seq'].to_numpy()
    ind = [i for i in range(len(off_target_seq)) if len(off_target_seq[i]) == 23]
    # print(len(df_site))
    df_site = df_site.iloc[ind]
    print(len(df_site))
    if os.path.exists(output_path):
        os.remove(output_path)
    f = h5py.File(output_path, 'a')
    for count in range(0,7):
        df_small = df_site[df_site['MM'] == count]
        df_small = df_small.reset_index()
        batch_size = 10000000
        for start in range(0, len(df_small), batch_size):
            X = np.array([genome_onehot_encoding_pair(gRNA[df_small['site'][i].replace('.','_')], df_small['control_target_seq'][i]) for i in range(start, min(start + batch_size, len(df_small)))])
            y = df_small['log1p_relative_activity_percent'].to_numpy()[start:min(start + batch_size, len(df_small))]
            seq = df_small['control_target_seq'].to_numpy()[start:min(start + batch_size, len(df_small))]
            control_counts = df_small['control_counts'].to_numpy()[start:min(start + batch_size, len(df_small))]
            cas9_counts = df_small['Cas9_counts'].to_numpy()[start:min(start + batch_size, len(df_small))]

            if start == 0:
                g = f.create_group(str(count))
                g.create_dataset('X', data = X, chunks=True, maxshape=(None,23,8))
                g.create_dataset('y', data = y, chunks=True, maxshape=(None,))
                g.create_dataset('seq', data = seq, dtype = dt, chunks=True, maxshape=(None,))
                g.create_dataset('control_counts', data = control_counts, chunks=True, maxshape=(None,))
                g.create_dataset('cas9_counts', data = cas9_counts, chunks=True, maxshape=(None,))
            else:
                g = f[str(count)]
                g['X'].resize((g['X'].shape[0] + X.shape[0]), axis=0)
                g['X'][-X.shape[0]:] = X

                g['y'].resize((g['y'].shape[0] + y.shape[0]), axis=0)
                g['y'][-y.shape[0]:] = y

                g['seq'].resize((g['seq'].shape[0] + y.shape[0]), axis=0)
                g['seq'][-seq.shape[0]:] = seq

                g['control_counts'].resize((g['control_counts'].shape[0] + y.shape[0]), axis=0)
                g['control_counts'][-seq.shape[0]:] = control_counts

                g['cas9_counts'].resize((g['cas9_counts'].shape[0] + y.shape[0]), axis=0)
                g['cas9_counts'][-seq.shape[0]:] = cas9_counts


            print(count, X.shape, y.shape, seq.shape, control_counts.shape, cas9_counts.shape)
            del X,y,seq
        print(count, len(g['X']), len(g['y']), len(g['seq']), len(g['control_counts']))
    f.close()
del df

  result = getattr(ufunc, method)(*inputs, **kwargs)


   Unnamed: 0       site  MM       control_target_seq  Cas9_counts_rep1  \
0        8000  AAVS1.s14   3  AGAGCCACTAGGAACAGGATCGG             736.0   
1       79968  AAVS1.s14   3  TCGACCACTAGGGACAGGATGGG             130.0   
2        1606  AAVS1.s14   2  GGGACCACTAGGGACAGGATGGT              88.0   
3      384446  AAVS1.s14   4  TCCGCCACTAGGGACAGCATTGG              81.0   
4        7583  AAVS1.s14   2  TGGGCCACTAAGGACAGGATAGG             145.0   
5       84604  AAVS1.s14   4  ACCGCCACTAGGTACAGGATAGG             147.0   
6         394  AAVS1.s14   2  AGGGCCACTAGGGACAGGTTCGG             124.0   
7       11393  AAVS1.s14   3  CCGGCCACTAGGGACAGAATAGG              90.0   
8      120737  AAVS1.s14   4  CTGGCCGCTAGGTACAGGATTGG              46.0   
9         387  AAVS1.s14   2  AGGGCCACTAGGGACAGGCTGGG              98.0   

   Cas9_counts_rep2  Cas9_counts_scaled_rep1  Cas9_counts_scaled_rep2  \
0             516.0                61.769835                65.073526   
1              93.0         