In [39]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
from sklearn import preprocessing

In [40]:
df_taxa = pd.read_csv('../data/taxa_400.csv')
print(df_taxa.shape)

(9174, 8)


In [41]:
def create_split_csv(data, labels, fpath_split):
    label_encoder = preprocessing.LabelEncoder()
    label_nos = label_encoder.fit_transform(labels)
    data_rows = []
    for i, (row_id, sequence) in enumerate(data):
        data_dict = {}
        data_dict['id'] = row_id
        data_dict['sequence'] = sequence
        data_dict['class_name'] = labels[i]
        data_dict['label'] = label_nos[i]
        data_rows.append(data_dict)
    
    df_split = pd.DataFrame(data_rows)
    df_split.to_csv(fpath_split, index=False)

In [50]:
def get_xy_from_df(df_taxa, level):
    ids = df_taxa['id'].values
    seqs = df_taxa['sequence'].values

    X = np.vstack([ids, seqs]).T
    y = df_taxa[level].values
    
    return X, y

def split_data(df_taxa, level='class', n_samples_per_class=6000):
    df_group = df_taxa.sample(frac=1).groupby(by=level)
    
    X = []
    y = []
    for name, group in df_group:
        gX, gy = get_xy_from_df(group, level)
        gX = gX[:int(n_samples_per_class)]
        gy = gy[:int(n_samples_per_class)]
        X.append(gX)
        y.append(gy)
    
    X = np.vstack(X)
    y = np.concatenate(y)
    
    print(np.unique(y, return_counts=True))
    (train_data, testval_data, train_labels, testval_labels) = \
        train_test_split(X, y, test_size=0.4, stratify=y, shuffle=True)
    
    (test_data, val_data, test_labels, val_labels) = \
        train_test_split(testval_data, testval_labels, test_size=0.5, stratify=testval_labels, shuffle=True)
    
    create_split_csv(train_data, train_labels, '../data/hierarchy/{}/train.csv'.format(level))
    create_split_csv(test_data, test_labels, '../data/hierarchy/{}/test.csv'.format(level))
    create_split_csv(val_data, val_labels, '../data/hierarchy/{}/val.csv'.format(level))

def group_labels(df_taxa, label_names):
    df_group = df_taxa
    grouped_label_name = 'Other'
    
    for lname in label_names:
        df_group = df_group.replace(lname, grouped_label_name)
    return df_group

In [51]:
split_data(df_taxa, level='phylum', n_samples_per_class=10000)

(array(['Actinobacteria', 'Firmicutes', 'Proteobacteria'], dtype=object), array([2211, 1880, 5083]))


In [52]:
split_data(df_taxa, level='class', n_samples_per_class=10000)

(array(['Actinobacteria', 'Alphaproteobacteria', 'Betaproteobacteria',
       'Clostridia', 'Gammaproteobacteria'], dtype=object), array([2211, 1617, 1746, 1880, 1720]))


In [53]:
label_names = ['Coriobacteriales', 'Rhodocyclales', 'Xanthomonadales', 'Alteromonadales', 
               'Caulobacterales', 'Rhodobacterales', 'Rhodospirillales', 'Vibrionales', 
               'Oceanospirillales', 'Acidimicrobiales']
df_group = group_labels(df_taxa, label_names)

split_data(df_group, level='order', n_samples_per_class=10000)

(array(['Actinomycetales', 'Bifidobacteriales', 'Burkholderiales',
       'Clostridiales', 'Enterobacteriales', 'Nitrosomonadales', 'Other',
       'Pseudomonadales', 'Rhizobiales', 'Sphingomonadales'], dtype=object), array([1351,  493, 1095, 1880,  324,  413, 1706,  788,  762,  362]))
