In [2]:
import pandas as pd
from google.cloud import bigquery
from google.oauth2 import service_account
#from params.py import PROJECT_ID, TABLE_ID
import sys; sys.path.append('..')
from params import PROJECT_ID, TABLE_ID
from sklearn.model_selection import train_test_split



credentials = service_account.Credentials.from_service_account_file("../bq_keys.json")
client = bigquery.Client(project=PROJECT_ID, credentials=credentials)
table = client.get_table(TABLE_ID)


# Construct the SQL query to retrieve the table data
query = f'SELECT * FROM `{table}`'

# Submit the query and fetch the results
df = client.query(query).to_dataframe()

In [3]:
def balance_df(df):
    if len(df.groupby('target')['target'].count().unique()) > 1:
        print("Not balanced")
        print(df.groupby('target')['target'].count().unique())
        min_ = min(df.groupby('target')['target'].count().unique())
        balanced_df = pd.DataFrame()
        for letter in df['target'].unique():
            sub_df = df[df['target']==letter]
            selected_row = sub_df.sample(n=min_)
            balanced_df = pd.concat([balanced_df,selected_row])
    else:
        print("Balanced")
        print(df.groupby('target')['target'].count().unique())
        balanced_df = df
        
    print(balanced_df.shape)
    return balanced_df

In [4]:
def shuffle_targets(df):
    shuffled_df = pd.DataFrame()
    for letter in df['target'].unique():
        sub_df = df[df['target']==letter]
        sub_df = sub_df.sample(frac=1)
        shuffled_df = pd.concat([shuffled_df,sub_df])
        #shuffled_df = shuffled_df.reset_index(drop=True)

    return shuffled_df

In [16]:
def train_test_df(df, test_size=0.3, random_state=42):

    df = df.reset_index(drop=True)

    X_train_df = pd.DataFrame()
    X_test_df = pd.DataFrame()
    y_train_df = pd.DataFrame()
    y_test_df = pd.DataFrame()

    for letter in df['target'].unique():
        sub_df = df[df['target']==letter]
        X = sub_df.drop('target', axis=1)
        y = sub_df['target']
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
        X_train_df = pd.concat([X_train_df, X_train])
        X_test_df = pd.concat([X_test_df, X_test])
        y_train_df = pd.concat([y_train_df, y_train])
        y_test_df = pd.concat([y_test_df, y_test])

    return X_train_df, X_test_df, y_train_df, y_test_df

In [6]:
# Remove duplicates
df = df.drop_duplicates()
print(df.shape)

(37, 64)


In [7]:
# Remove the first three columns
if 'x_0' in df.columns:
    df = df.drop(['x_0','y_0','z_0'], axis=1)
    print(df.shape)

(37, 61)


In [8]:
# check balance and balance it 
balanced_df = balance_df(df)

Not balanced
[17 10]
(30, 61)


In [9]:
balanced_df

Unnamed: 0,x_1,y_1,z_1,x_2,y_2,z_2,x_3,y_3,z_3,x_4,...,x_18,y_18,z_18,x_19,y_19,z_19,x_20,y_20,z_20,target
7,0.062639,-0.037009,-0.016209,0.099324,-0.119881,-0.014263,0.121819,-0.188735,-0.012125,0.145586,...,-0.048469,-0.289281,-0.015422,-0.051347,-0.335995,-0.025008,-0.047849,-0.381171,-0.031341,a
6,-0.049541,-0.030856,-0.019681,-0.08199,-0.123745,-0.025648,-0.093822,-0.211419,-0.029608,-0.110827,...,0.050224,-0.286058,-0.032939,0.055183,-0.33404,-0.038313,0.057958,-0.378785,-0.041894,a
1,-0.049619,-0.039207,-0.016461,-0.079816,-0.132695,-0.02094,-0.092047,-0.216808,-0.024711,-0.10886,...,0.05229,-0.301836,-0.02995,0.056986,-0.350153,-0.033708,0.057934,-0.39344,-0.036319,a
5,0.067253,-0.027447,-0.019123,0.111159,-0.102696,-0.016308,0.138633,-0.170112,-0.012628,0.164115,...,-0.037555,-0.292538,0.000143,-0.039772,-0.338443,-0.00832,-0.036571,-0.383866,-0.014294,a
2,-0.050531,-0.03617,-0.018123,-0.080477,-0.132437,-0.022754,-0.093391,-0.218379,-0.026873,-0.110968,...,0.051522,-0.301766,-0.028751,0.056262,-0.350737,-0.03295,0.057547,-0.396092,-0.035761,a
16,-0.049085,-0.037036,-0.017655,-0.080145,-0.127733,-0.022473,-0.092397,-0.212739,-0.026175,-0.110506,...,0.048849,-0.296028,-0.028402,0.052984,-0.343024,-0.033156,0.053729,-0.384991,-0.036517,a
11,-0.049097,-0.037633,-0.017112,-0.079656,-0.127592,-0.021847,-0.091187,-0.211819,-0.025449,-0.110208,...,0.048195,-0.29274,-0.026551,0.052002,-0.338992,-0.031152,0.052806,-0.380997,-0.034599,a
0,-0.049028,-0.035093,-0.017194,-0.079343,-0.131463,-0.021312,-0.092285,-0.217714,-0.024704,-0.110099,...,0.051761,-0.303027,-0.030095,0.056748,-0.351968,-0.033856,0.05814,-0.39666,-0.036296,a
3,0.071293,-0.033885,-0.014923,0.111072,-0.111365,-0.009179,0.137204,-0.178955,-0.003783,0.161801,...,-0.041646,-0.294241,0.002959,-0.043938,-0.339732,-0.008451,-0.039347,-0.385925,-0.016816,a
9,0.071637,-0.052767,-0.017243,0.111176,-0.136743,-0.011655,0.136585,-0.206969,-0.006931,0.160339,...,-0.054779,-0.319011,0.015786,-0.055563,-0.365024,0.006382,-0.047483,-0.405366,-0.00088,a


In [10]:
# shuffle par target
shuffled_df = shuffle_targets(balanced_df)



In [11]:
shuffled_df

Unnamed: 0,x_1,y_1,z_1,x_2,y_2,z_2,x_3,y_3,z_3,x_4,...,x_18,y_18,z_18,x_19,y_19,z_19,x_20,y_20,z_20,target
11,-0.049097,-0.037633,-0.017112,-0.079656,-0.127592,-0.021847,-0.091187,-0.211819,-0.025449,-0.110208,...,0.048195,-0.29274,-0.026551,0.052002,-0.338992,-0.031152,0.052806,-0.380997,-0.034599,a
16,-0.049085,-0.037036,-0.017655,-0.080145,-0.127733,-0.022473,-0.092397,-0.212739,-0.026175,-0.110506,...,0.048849,-0.296028,-0.028402,0.052984,-0.343024,-0.033156,0.053729,-0.384991,-0.036517,a
3,0.071293,-0.033885,-0.014923,0.111072,-0.111365,-0.009179,0.137204,-0.178955,-0.003783,0.161801,...,-0.041646,-0.294241,0.002959,-0.043938,-0.339732,-0.008451,-0.039347,-0.385925,-0.016816,a
0,-0.049028,-0.035093,-0.017194,-0.079343,-0.131463,-0.021312,-0.092285,-0.217714,-0.024704,-0.110099,...,0.051761,-0.303027,-0.030095,0.056748,-0.351968,-0.033856,0.05814,-0.39666,-0.036296,a
9,0.071637,-0.052767,-0.017243,0.111176,-0.136743,-0.011655,0.136585,-0.206969,-0.006931,0.160339,...,-0.054779,-0.319011,0.015786,-0.055563,-0.365024,0.006382,-0.047483,-0.405366,-0.00088,a
5,0.067253,-0.027447,-0.019123,0.111159,-0.102696,-0.016308,0.138633,-0.170112,-0.012628,0.164115,...,-0.037555,-0.292538,0.000143,-0.039772,-0.338443,-0.00832,-0.036571,-0.383866,-0.014294,a
7,0.062639,-0.037009,-0.016209,0.099324,-0.119881,-0.014263,0.121819,-0.188735,-0.012125,0.145586,...,-0.048469,-0.289281,-0.015422,-0.051347,-0.335995,-0.025008,-0.047849,-0.381171,-0.031341,a
1,-0.049619,-0.039207,-0.016461,-0.079816,-0.132695,-0.02094,-0.092047,-0.216808,-0.024711,-0.10886,...,0.05229,-0.301836,-0.02995,0.056986,-0.350153,-0.033708,0.057934,-0.39344,-0.036319,a
6,-0.049541,-0.030856,-0.019681,-0.08199,-0.123745,-0.025648,-0.093822,-0.211419,-0.029608,-0.110827,...,0.050224,-0.286058,-0.032939,0.055183,-0.33404,-0.038313,0.057958,-0.378785,-0.041894,a
2,-0.050531,-0.03617,-0.018123,-0.080477,-0.132437,-0.022754,-0.093391,-0.218379,-0.026873,-0.110968,...,0.051522,-0.301766,-0.028751,0.056262,-0.350737,-0.03295,0.057547,-0.396092,-0.035761,a


In [17]:
X_train_df, X_test_df, y_train_df, y_test_df = train_test_df(shuffled_df, test_size=0.3, random_state=42)




        



In [22]:
X_train_df

Unnamed: 0,x_1,y_1,z_1,x_2,y_2,z_2,x_3,y_3,z_3,x_4,...,z_17,x_18,y_18,z_18,x_19,y_19,z_19,x_20,y_20,z_20
0,-0.049097,-0.037633,-0.017112,-0.079656,-0.127592,-0.021847,-0.091187,-0.211819,-0.025449,-0.110208,...,-0.016976,0.048195,-0.29274,-0.026551,0.052002,-0.338992,-0.031152,0.052806,-0.380997,-0.034599
7,-0.049619,-0.039207,-0.016461,-0.079816,-0.132695,-0.02094,-0.092047,-0.216808,-0.024711,-0.10886,...,-0.019911,0.05229,-0.301836,-0.02995,0.056986,-0.350153,-0.033708,0.057934,-0.39344,-0.036319
2,0.071293,-0.033885,-0.014923,0.111072,-0.111365,-0.009179,0.137204,-0.178955,-0.003783,0.161801,...,0.013775,-0.041646,-0.294241,0.002959,-0.043938,-0.339732,-0.008451,-0.039347,-0.385925,-0.016816
9,-0.050531,-0.03617,-0.018123,-0.080477,-0.132437,-0.022754,-0.093391,-0.218379,-0.026873,-0.110968,...,-0.018001,0.051522,-0.301766,-0.028751,0.056262,-0.350737,-0.03295,0.057547,-0.396092,-0.035761
4,0.071637,-0.052767,-0.017243,0.111176,-0.136743,-0.011655,0.136585,-0.206969,-0.006931,0.160339,...,0.022535,-0.054779,-0.319011,0.015786,-0.055563,-0.365024,0.006382,-0.047483,-0.405366,-0.00088
3,-0.049028,-0.035093,-0.017194,-0.079343,-0.131463,-0.021312,-0.092285,-0.217714,-0.024704,-0.110099,...,-0.019461,0.051761,-0.303027,-0.030095,0.056748,-0.351968,-0.033856,0.05814,-0.39666,-0.036296
6,0.062639,-0.037009,-0.016209,0.099324,-0.119881,-0.014263,0.121819,-0.188735,-0.012125,0.145586,...,-0.002306,-0.048469,-0.289281,-0.015422,-0.051347,-0.335995,-0.025008,-0.047849,-0.381171,-0.031341
10,-0.049728,-0.036259,-0.016343,-0.082994,-0.134079,-0.020282,-0.09857,-0.221583,-0.023712,-0.118385,...,-0.021382,0.045353,-0.318519,-0.031348,0.048905,-0.369389,-0.03553,0.048849,-0.414946,-0.0388
17,-0.048415,-0.030434,-0.015897,-0.080123,-0.124027,-0.019937,-0.093251,-0.209721,-0.022818,-0.109656,...,-0.022406,0.049852,-0.300091,-0.033648,0.052944,-0.351798,-0.038543,0.052676,-0.399404,-0.042013
12,-0.050639,-0.033898,-0.016353,-0.083746,-0.128384,-0.0203,-0.098241,-0.216281,-0.02358,-0.117459,...,-0.021458,0.046082,-0.31387,-0.031981,0.050448,-0.364233,-0.036625,0.051364,-0.409605,-0.040137


In [None]:
test

Unnamed: 0,x_1,y_1,z_1,x_2,y_2,z_2,x_3,y_3,z_3,x_4,...,x_18,y_18,z_18,x_19,y_19,z_19,x_20,y_20,z_20,target
8,-0.049085,-0.037036,-0.017655,-0.080145,-0.127733,-0.022473,-0.092397,-0.212739,-0.026175,-0.110506,...,0.048849,-0.296028,-0.028402,0.052984,-0.343024,-0.033156,0.053729,-0.384991,-0.036517,a
1,0.063205,-0.055324,-0.000398,0.087114,-0.142139,0.007835,0.101945,-0.205468,0.012348,0.120142,...,-0.05696,-0.274834,-0.030179,-0.059204,-0.319377,-0.042986,-0.053639,-0.365472,-0.052119,a
5,-0.049097,-0.037633,-0.017112,-0.079656,-0.127592,-0.021847,-0.091187,-0.211819,-0.025449,-0.110208,...,0.048195,-0.29274,-0.026551,0.052002,-0.338992,-0.031152,0.052806,-0.380997,-0.034599,a
18,-0.050879,-0.035351,-0.017704,-0.082635,-0.131142,-0.021942,-0.0962,-0.218409,-0.025283,-0.11449,...,0.047041,-0.312352,-0.030783,0.051316,-0.36386,-0.036123,0.051744,-0.410618,-0.040095,c
11,-0.049728,-0.036259,-0.016343,-0.082994,-0.134079,-0.020282,-0.09857,-0.221583,-0.023712,-0.118385,...,0.045353,-0.318519,-0.031348,0.048905,-0.369389,-0.03553,0.048849,-0.414946,-0.0388,c
15,-0.048609,-0.030919,-0.016804,-0.079674,-0.125278,-0.020944,-0.09279,-0.211202,-0.023822,-0.108702,...,0.049955,-0.300477,-0.031811,0.053471,-0.351542,-0.036818,0.053519,-0.398499,-0.04034,c
28,-0.048397,-0.031743,-0.012817,-0.078327,-0.122976,-0.01404,-0.091781,-0.20347,-0.014933,-0.109049,...,0.037428,-0.309549,-0.024357,0.039342,-0.361975,-0.028158,0.037908,-0.409833,-0.030862,b
21,-0.049043,-0.035603,-0.011989,-0.080057,-0.126099,-0.012767,-0.096054,-0.20406,-0.013981,-0.115474,...,0.034121,-0.313649,-0.020422,0.035086,-0.362652,-0.02325,0.032787,-0.407238,-0.025407,b
25,-0.050319,-0.032419,-0.013835,-0.080634,-0.129581,-0.015496,-0.092086,-0.216501,-0.016835,-0.10609,...,0.041028,-0.311462,-0.028637,0.0438,-0.364134,-0.03301,0.043082,-0.412032,-0.036068,b
