# Prepare sort-seq dataset for use in MAVE-NN

In [1]:
# Standard imports
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

# Insert mavenn at beginning of path
import sys
path_to_mavenn_local = '../../../../'
sys.path.insert(0,path_to_mavenn_local)

#Load mavenn and check path
import mavenn
print(mavenn.__path__)

# For testing
from mavenn.src.utils import vec_data_to_mat_data

['../../../../mavenn']


In [2]:
#experiment_name = 'full-500'
#experiment_name = 'full-150'
experiment_name = 'full-0'
#experiment_name = 'rnap-wt'


# Load raw data file
raw_df = pd.read_csv(mavenn.__path__[0] +
    '/examples/datasets/sort_seq/'+experiment_name+'/data.txt',
    index_col=[0],
    delim_whitespace=True)
print(len(raw_df))
raw_df.head()

23251


Unnamed: 0_level_0,ct_0,ct_1,ct_2,ct_3,ct_4
seq,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
AAAAAATCTGTGTTTGCTCACCCATAAGGCACCGCCGGCTTTACACTTTATGCTTCCGGCTTGTCTGTTGTGTGG,0.0,1.0,0.0,0.0,0.0
AAAAAATGCGAGGTAGCTCACTCATTAGGAGTCCCAGGCTTTACACTTTATGCTTCCGGCTCGTATGTTGTGTGG,0.0,0.0,0.0,1.0,0.0
AAAAAATGTCAGATTGCTCACTCATTAGGCACCCCGGGCTCTACACTTTTTGCTTCCGGATCGTATGTTGTGTGG,0.0,1.0,0.0,0.0,0.0
AAAAAATGTCAGTTAGCTGACTCATTAGGCACCCCTGGCTTTACGTTTTCTGCTTTCGGCTCGTATGTATGGTGG,1.0,0.0,0.0,0.0,0.0
AAAAAATGTGACTTAGCTCACTCATTAGGTACCCCAGGCCTTGCACTTTATGCTTCCGGCTCGTATGTTGTATGG,0.0,0.0,0.0,0.0,1.0


In [3]:
# # Refine contents of raw data file
# sequences = raw_df['seq'].values
# raw_df.columns = ['x','y','ct']
# raw_df['ct'] = raw_df['ct'].astype(int)
# raw_df.head()

raw_df.index.name='x'

In [4]:
# Pivot and set training/test data
# data_df = pd.pivot(raw_df, values='ct', index='x', columns='y').fillna(0).astype(int)
# data_df.columns.name = None

# # Get y_cols
# data_df.columns = [f'ct_{x}' for x in data_df.columns]
y_cols = list(raw_df.columns)

data_df = raw_df.astype(int).copy()
# Do all columns still sum to > 0?
print('rows summing to 0:', (data_df.values.sum(axis=1)==0).sum())
data_df.head()

rows summing to 0: 0


Unnamed: 0_level_0,ct_0,ct_1,ct_2,ct_3,ct_4
x,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
AAAAAATCTGTGTTTGCTCACCCATAAGGCACCGCCGGCTTTACACTTTATGCTTCCGGCTTGTCTGTTGTGTGG,0,1,0,0,0
AAAAAATGCGAGGTAGCTCACTCATTAGGAGTCCCAGGCTTTACACTTTATGCTTCCGGCTCGTATGTTGTGTGG,0,0,0,1,0
AAAAAATGTCAGATTGCTCACTCATTAGGCACCCCGGGCTCTACACTTTTTGCTTCCGGATCGTATGTTGTGTGG,0,1,0,0,0
AAAAAATGTCAGTTAGCTGACTCATTAGGCACCCCTGGCTTTACGTTTTCTGCTTTCGGCTCGTATGTATGGTGG,1,0,0,0,0
AAAAAATGTGACTTAGCTCACTCATTAGGTACCCCAGGCCTTGCACTTTATGCTTCCGGCTCGTATGTTGTATGG,0,0,0,0,1


In [5]:
N = len(data_df)
training_frac=.8
np.random.seed(0)
data_df['training_set'] = (np.random.rand(N) < training_frac)
data_df.reset_index(inplace=True)
data_df.head()

Unnamed: 0,x,ct_0,ct_1,ct_2,ct_3,ct_4,training_set
0,AAAAAATCTGTGTTTGCTCACCCATAAGGCACCGCCGGCTTTACAC...,0,1,0,0,0,True
1,AAAAAATGCGAGGTAGCTCACTCATTAGGAGTCCCAGGCTTTACAC...,0,0,0,1,0,True
2,AAAAAATGTCAGATTGCTCACTCATTAGGCACCCCGGGCTCTACAC...,0,1,0,0,0,True
3,AAAAAATGTCAGTTAGCTGACTCATTAGGCACCCCTGGCTTTACGT...,1,0,0,0,0,True
4,AAAAAATGTGACTTAGCTCACTCATTAGGTACCCCAGGCCTTGCAC...,0,0,0,0,1,True


In [6]:
# Remove entries where ct is 0
ix = data_df[y_cols].sum(axis=1) > 0
print(f'Dropping {sum(~ix)} columns with 0 counts.')
data_df = data_df[ix].reset_index(drop=True)
data_df.head()

Dropping 0 columns with 0 counts.


Unnamed: 0,x,ct_0,ct_1,ct_2,ct_3,ct_4,training_set
0,AAAAAATCTGTGTTTGCTCACCCATAAGGCACCGCCGGCTTTACAC...,0,1,0,0,0,True
1,AAAAAATGCGAGGTAGCTCACTCATTAGGAGTCCCAGGCTTTACAC...,0,0,0,1,0,True
2,AAAAAATGTCAGATTGCTCACTCATTAGGCACCCCGGGCTCTACAC...,0,1,0,0,0,True
3,AAAAAATGTCAGTTAGCTGACTCATTAGGCACCCCTGGCTTTACGT...,1,0,0,0,0,True
4,AAAAAATGTGACTTAGCTCACTCATTAGGTACCCCAGGCCTTGCAC...,0,0,0,0,1,True


In [7]:
# Assign to trianing and test sets
N = len(data_df)
training_frac=.8
np.random.seed(0)
r = np.random.rand(N)
test_frac = .2
val_frac = .2
ix_train = (test_frac + val_frac <= r)
ix_val = (test_frac <= r) & (r < test_frac + val_frac)
ix_test = (r < test_frac)
data_df['set'] = ''
data_df.loc[ix_train, 'set'] = 'training'
data_df.loc[ix_val, 'set'] = 'validation'
data_df.loc[ix_test, 'set'] = 'test'
assert all([len(x)>0 for x in data_df['set']])

# Shuffle data for extra safety
data_df = data_df.sample(frac=1).reset_index(drop=True)

# Order columns
data_df = data_df[['set'] + y_cols + ['x']]
data_df.head(20)

Unnamed: 0,set,ct_0,ct_1,ct_2,ct_3,ct_4,x
0,training,0,2,0,0,0,AATTATTGTGACTTAGTTCACCCACTAGGCTCCACAGGCTTTTCAA...
1,training,0,0,0,0,4,AATTAATGTGAGTGAGCTCACTCATTCGGCACCTCAGGCTTTACAC...
2,validation,0,0,1,0,0,ATTTAATGTGAGTTAGCTCACTCATTCGGCACCGCAGGCTTTACAC...
3,validation,0,0,0,1,0,AATTAATGTGAGTTACTTCACTCATTAGGCGCACAAGGCTTGGCAC...
4,training,0,0,1,0,0,AATTCATGTTAATTATCTCAATCATTAGGCACCCCGGGATTTACAC...
5,training,0,0,2,0,0,AATTAATGTGAGTTAGCTCACTCAATAGGCACCCCAGGCTTTAAAC...
6,test,0,0,0,1,0,AATTAATGTGAGTTAGCTCACTCATTAGGCATCCTAGGCTTGCCAC...
7,validation,0,0,0,1,0,AATTAATGTGATTTAGCTCACTCATTAGGCACCCCAGGCTGTACAC...
8,training,0,1,0,0,0,AATTAATGTGAGTTATCTCAATCATCAGGCACCTCAGGCTTTCCAC...
9,training,0,0,0,1,0,AATTAATGTGAGTTACCTCACTCCTTAGGTACCCCAGGCTTTACAC...


In [8]:
# Show size of compressed dataset file
file_name = 'sortseq_'+experiment_name+'_data.csv.gz'
data_df.to_csv(file_name, compression='gzip', index=False)
print('df (zipped):')
!du -mh $file_name
!mv $file_name ../.

df (zipped):
324K	sortseq_full-500_data.csv.gz


In [9]:
mavenn.load_example_dataset()

Please enter a dataset name. Valid choices are:
"ace2rbd"
"gb1"
"mpsa"
"mpsa_replicate"
"sortseq"
"sortseq_full-150"
"sortseq_full-500"
"sortseq_rnap-wt"


In [10]:
# # Test loading
# loaded_df = mavenn.load_example_dataset('sortseq')
# loaded_df.head()

In [11]:
mavenn.load_example_dataset('sortseq_'+experiment_name)

Unnamed: 0,set,ct_0,ct_1,ct_2,ct_3,ct_4,x
0,training,0,2,0,0,0,AATTATTGTGACTTAGTTCACCCACTAGGCTCCACAGGCTTTTCAA...
1,training,0,0,0,0,4,AATTAATGTGAGTGAGCTCACTCATTCGGCACCTCAGGCTTTACAC...
2,validation,0,0,1,0,0,ATTTAATGTGAGTTAGCTCACTCATTCGGCACCGCAGGCTTTACAC...
3,validation,0,0,0,1,0,AATTAATGTGAGTTACTTCACTCATTAGGCGCACAAGGCTTGGCAC...
4,training,0,0,1,0,0,AATTCATGTTAATTATCTCAATCATTAGGCACCCCGGGATTTACAC...
...,...,...,...,...,...,...,...
23246,test,0,0,0,1,0,AATTAATGTGAGTTAGCTCACTCATTAGGCACCTCGGGCTTTACCC...
23247,training,0,0,0,0,1,AATTAATATGAGTTAGCTCACTCATTTGGCACCCCAGGCTTTACAC...
23248,test,1,0,0,0,0,AATTTATGTGTGTTAGCTCACTCATTAGGCACCCCAGGCTTTACAC...
23249,validation,0,0,0,1,0,AACTAATGTGTGTTCGCTGACTCATTAGGCACCCCAGGCTTTACAC...
