In [2]:
# Switch to sqlalchemy to accomodate dataframe.to_sql function
# not perfect with error handling and managing connection when close
import pandas as pd
import os
import numpy as np
import pymysql
from sqlalchemy import create_engine

In [30]:
# load data from database
# from pandas import DataFrame

class db_manager(object):
    # initiate and connect if schema_name is provided
    def __init__(self,schema_name = 'null'):
        self.connected = False
        self.__schema_name = ''
        self.__db = ''
        self.__cursor = ''
        self.__engine = ''
        if schema_name != 'null':
            self.connect(schema_name)
        
            
    # close the connection manually
    def close(self):
        if self.connected:
#                 self.__db.close()
            self.__engine.dispose()
            self.connected = False
                
    # destructor function, do db.close automatically
    # return exception 3 if unsuccessful
    def __del__(self):
        self.close()   
        self.connected = False
        print('db close in destructor')
        
    # connect to the specified schema
    # return exception 3 if unsuccessful
    def connect(self,schema_name = ''):
        if schema_name == '':
            if self.__schema_name=='':
                print('Specify name of the schema/db first')
                return
            else:
                print('Use the previous schema_name %s'%self.__schema_name)
        else:
            if self.__schema_name == schema_name:
                    print('Seem to be the same engine')
            self.__schema_name = schema_name
            
        if self.connected:
            self.close()
        try:
#             self.__db = pymysql.connect(host = "rm-j6cluj4576jdi6n6oo.mysql.rds.aliyuncs.com",\
#                                  database = schema_name, user='cognitiveleap', password= 'QWE@123456')
            #self.__cursor = self.__db.cursor()
            self.__engine = create_engine("mysql+pymysql://cognitiveleap:QWE@123456@rm-j6cluj4576jdi6n6oo.mysql.rds.aliyuncs.com/%s"%self.__schema_name)
        except Exception as e:
            raise Exception (3)
        self.connected = True
    
    # get a sql select query str
    def get_sql_str_select(self,table_name, field_names = ['*'],where_ind_name = None, where_ind_value = None):
        if (where_ind_name!=None) & (where_ind_value!=None):
            sql_str = 'SELECT %s FROM %s.%s WHERE %s = %s'%(','.join([str(x) for x in field_names]),\
                                                    self.__schema_name,\
                                                        table_name,\
                                                         where_ind_name,\
                                                         where_ind_value)
        else:
            sql_str = 'SELECT %s FROM %s.%s'%(','.join([str(x) for x in field_names]),\
                                                    self.__schema_name,\
                                                        table_name)
        return sql_str
        
    # fetch specific table from the schema and return data in dataframe
    # where_ind_name, where_ind_value optional, should be string. Equivalent to WHERE clause in sql
    # field names optional, should be string or list of strings
    def fetch_table(self,table_name, field_names = ['*'],where_ind_name = None, where_ind_value = None, primary_key = 'Id'):
        if self.connected:
            fetch_data = pd.read_sql_query(sql=\
                                           self.get_sql_str_select(table_name,field_names,where_ind_name,where_ind_value),\
                                           con=self.__engine,\
                                          index_col = primary_key)
            return fetch_data
        else:
            print ('db not connected yet. Do connect first')
        
    # do sql query fetch, and return results in np.array     
    def sql_query_fetch_list(self,sql):
        if self.connected:
#             self.__cursor.execute(sql)
#             results = np.asarray(self.__cursor.fetchall())
            cur = self.__engine.execute(sql)
            results = np.asarray(cur.fetchall())
            cur.close()
            return results
        else:
            print ('db not connected yet. Do connect first')
    
    # do sql query fetch and return results in pd.dataframe
    def sql_query_fetch_df(self,sql):
        if not self.connected:
            print ('db not connected yet. Do connect first')
            return
        results = pd.read_sql_query(sql,self.__engine,index_col = 'id')
        return results
    
    # do sql without fetch (e.g. delete, truncate)
    def sql_nofetch(self,sql):
        if not self.connected:
            print ('db not connected yet. Do connect first')
            return
        cur = self.__engine.execute(sql)
        cur.close()
    
    # check if ids in certain field existed in the db
    # return list if ids is a list
    def check_existed(self,table_name,field_name,ids):
        if not self.connected:
            print ('db not connected yet. Do connect first')
            return
        existed = []
        existed_count = []
        for ind in ids:
            try:
                sql = """SELECT COUNT(*) FROM %s.%s WHERE %s = %s"""%(self.__schema_name,\
                                                                      table_name,\
                                                                     field_name,\
                                                                     ind)
#                 self.__cursor.execute(sql)
#                 count = np.asarray(self.__cursor.fetchall())[0][0]
                cur = self.__engine.execute(sql)
                count = np.asarray(cur.fetchall())[0][0]
                cur.close()
            except Exception as e:
                raise Exception (4)
            existed.append(count!=0)
            existed_count.append(count)
        return np.asarray(existed),np.asarray(existed_count)
    
    # insert data frame into table
    # assume table existed, and index columns are the same as dataframe field name
    # specify which column is the actual index, e.g. CaseId, and whether to delete the row when the index already existed
    def insert_table(self,df, table_name,index_name='CaseId',del_row_if_exist = True):
        if not self.connected:
            print ('db not connected yet. Do connect first')
            return
        ind_list = np.asarray(pd.unique(df[index_name]))
        exist_list,exist_count_list = self.check_existed(table_name,index_name,ind_list)
        if del_row_if_exist:
            for ind,exist_count in zip(ind_list[exist_list],exist_count_list[exist_list]):
                sql = """DELETE FROM %s.%s WHERE %s = %s"""%(self.__schema_name,\
                                                                      table_name,\
                                                                     index_name,\
                                                                     ind)
#                 self.__cursor.execute(sql)
                cur = self.__engine.execute(sql)
                print('delete %d rows for %s %d'%(exist_count,index_name,ind))
                cur.close()
        else:
            df = df[not exist_list]
            
        df.to_sql(name = table_name, con = self.__engine, \
                  schema=self.__schema_name, if_exists='append', index=False)
        print('%d lines insertion done'%df.shape[0])
       
#     # return db and cursor, not recommended
#     def get_db_cursor(self):
#         return self.__db,self.__cursor
    def get_engine(self):
        return self.__engine
       
    

# Examples 

## copy tables between two db 

In [9]:
from loadData import db_manager
db_rnd = db_manager('rnd_test')
db_web = db_manager('webtest')
CaseIds = [40]
for cid in CaseIds:
    df = db_rnd.fetch_table(table_name = 'trial_data',where_ind_name = 'CasdId', where_ind_value = cid )
    db_web.insert_table(df, 'trial_data', 'CasdId')
    print(str(cid)+' done')

db_rnd.close()
db_web.close()

db close in destructor
db close in destructor
delete 540 rows for CasdId 40
40 done


## load data with fetch table

In [31]:
mydb = db_manager('rnd_test')
# print(mydb.connected)
hmd_data = mydb.fetch_table(table_name='hmd_data',\
                        where_ind_name = 'CasdId',\
                        where_ind_value = 68, \
                        field_names = ['Id','PosX', 'PosY', 'PosZ', 'RotX', 'RotY', 'RotZ'],\
                        primary_key='Id')
print(hmd_data.head())
head_features_new = mydb.fetch_table(table_name='head_features_new',primary_key='id')
print(head_features_new.head())

             PosX     PosY     PosZ      RotX     RotY     RotZ
Id                                                             
2573001  0.069218  1.07777 -1.00768  0.994100  355.827  355.629
2573002  0.069047  1.07770 -1.00766  0.983457  355.815  355.726
2573003  0.069009  1.07770 -1.00769  0.980946  355.803  355.752
2573004  0.068926  1.07773 -1.00790  0.970070  355.760  355.803
2573005  0.068846  1.07774 -1.00779  0.967915  355.729  355.802
    CaseId  LargeMovements  SmallMovements  PathLength  PercentageActivePos  \
id                                                                            
1        2               0              59           6             0.206108   
2        3              13             962          25             0.557630   
3        4               0             303           9             0.372984   
4        5               3             498          13             0.493705   
5        6              32            1303          42             0.747891   

## Insertion

In [32]:
# some dummy operation, do insertion
mydb.insert_table(table_name='head_features_new',df=head_features_new[head_features_new['CaseId']==68],index_name='CaseId',del_row_if_exist=True)

delete 1 rows for CaseId 68
1 lines insertion done


## Other stuff

In [33]:
mydb.check_existed(field_name='CaseId',ids=[1,2,3,40,100],table_name='head_features_new')

(array([ True,  True,  True,  True, False], dtype=bool),
 array([1, 1, 1, 1, 0]))

In [34]:
mydb.sql_query_fetch_df(sql='SELECT * FROM rnd_test.head_features_new WHERE CaseId = 40')

Unnamed: 0_level_0,CaseId,LargeMovements,SmallMovements,PathLength,PercentageActivePos,PercentageActiveRot,PercentageDistracted,TimeDistracted,IsMoveRangeLarge
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
60,40,1,967,18,0.717659,0.231359,0.071846,47,0


In [24]:
mydb.connected #check connectivity
mydb.close() # close engine
mydb.check_existed(field_name='CaseId',ids=[1,2,3,40,100],table_name='head_features_new') # try a command
mydb.connect()
print(mydb.check_existed(field_name='CaseId',ids=[1,2,3,40,100],table_name='head_features_new')) # try again
mydb.close()

db not connected yet. Do connect first
Use the previous schema_name rnd_test
(array([ True,  True,  True,  True, False], dtype=bool), array([1, 1, 1, 1, 0]))


# Old stuff 

In [2]:
# load csv from folders
# obsolete
def loadFiles(root="data/TAIWAN_RAW_DATA/ADHD"):
    data_rt = [] # realtime.csv
    data_trial = [] # trialdata.csv
    data_id = [] # caseid/subjectid
    RealTime = "A2RealTime_"
    TrialData = "A2TrialData_"
    folder_list = os.listdir(root) # list of subfolders in the root
    for folders in folder_list:
        folders_path = os.path.join(root,folders)
        if folders.find("pass") != -1:
            continue
            
        try:
            data_rt.append(pd.read_csv(os.path.join
                                   (folders_path,
                                   RealTime+folders[3:]+".csv")))
            data_trial.append(pd.read_csv(os.path.join
                                      (folders_path,
                                       TrialData+folders[3:]+".csv")))
            data_id.append(int(folders.split('_')[1]))
        except:
            print(os.path.join(folders_path,TrialData+folders[3:]+".csv"))
            
    return data_rt,data_trial,data_id,folder_list

In [134]:
def extractUsersInTestLoc(testLoc):
    sql = "SELECT Id FROM vrclassroom.case WHERE TestLocationId = %d ORDER BY Id"%testLoc
    cursor.execute(sql)
    caseIds = [x[0] for x in cursor.fetchall()]
#     caseIds = list(cursor)
    diagnoses = []
    for caseid in caseIds:
        sql = "SELECT ADHDDiagnose FROM vrclassroom.patient WHERE Id = %d"%caseid
        cursor.execute(sql)
        if cursor.rowcount >1:
            print("more than one entry in patient table found, only fetch the first one")
        diagnoses.append(cursor.fetchone()[0])
    return caseIds,diagnoses

def extractRealTime(caseIds,max_num =0):
    result = []
    for caseid in caseIds:
        sql = "SELECT Block_num, Stim, DistractorPosX, DistractorPosY, DistractorPosZ, DistractorId, CasdId \
        FROM realtime_data WHERE CasdId = %d ORDER BY TimeLog, TimeLogMillisecond"%caseid
#         cursor.execute(sql)
        if max_num != 0:
            sql = (sql+ " LIMIT 0,%d")%max_num
            this_case = pd.read_sql_query(sql,db)
            result.append(this_case)
        else:
            result.append(pd.read_sql_query(sql,db))
    return result


    

In [89]:
head_features_new[head_features_new['CaseId']==68]

Unnamed: 0,CaseId,LargeMovements,SmallMovements,PathLength,PercentageActivePos,PercentageActiveRot,PercentageDistracted,TimeDistracted,IsMoveRangeLarge,id
57,68,0,34,5,0.140901,0.012026,0.008821,3,0,61


In [11]:
connectdb()
testLoc = 2
caseIds,diagnoses = extractUsersInTestLoc(testLoc)
rt = extractRealTime(caseIds,100)
disconnect()

NameError: name 'connectdb' is not defined