# Examples on how to query the database to get an overview of studies and observations

In [None]:
from assrunit.visualizations import experimental_qualitative_overview,get_studies,get_observations, get_meta













## Select studies

In [None]:

get_studies(titles=['gaba', 'ana'], authors=['Jordan'])  #titles OR authors

# Equivalent PeeWee query:
# Studies.select.where(
#             Studies.title.contains('gaba') | 
#             Studies.title.contains('ana') | 
#             Studies.authors.contains('Jordan')
#         )


In [None]:

get_studies(ids = [1,2,7])

# Equivalent PeeWee query= Studies.select.where(Studies.ID << [1,2,7])


In [None]:

my_studies = get_studies(titles=['Modeling'], print_output=False)

# Equivalent PeeWee query:
# Studies.select.where(
#             Studies.title.contains('Modeling')
#           )


print ('\n Printing studies data my way ... \n\n')
for study in my_studies:
    print(f'{study.title}. {study.authors} {study.year} \n')
    

## Select observations

In [None]:

get_observations(power = [20], drive = [40])  #power AND drive

# Equivalent PeeWee query:
# StudiesExperiments.select.where(
#             StudiesExperiments.power << [20] &
#             StudiesExperiments.drive << [40]
#         )


In [None]:

get_observations(power = [20, 40], drive = [40]) #power AND drive

# Equivalent PeeWee query:
# StudiesExperiments.select.where(
#             StudiesExperiments.power << [20, 40] &
#             StudiesExperiments.drive << [40]
#         )


In [None]:

get_observations(drive = [40], study_id=[2]) # drive OR study_id

# Equivalent PeeWee query:
# StudiesExperiments.select.where(
#             StudiesExperiments.drive << [50] |
#             StudiesExperiments.study_id << [2]
#         )


In [None]:

experiments = get_observations(power = [40], drive = [40], print_output=False)

print ('\n Printing observations data my way ... \n\n')
for experiment in experiments:
    print(
        f'{experiment.value} has been obtained from {experiment.power}' + 
        f' Hz power at {experiment.drive} Hz drive and p_value {experiment.p_value}\n'     
    )
    

## Qualitative representation of experiments

In [None]:

ex = experimental_qualitative_overview(power=[20,30])

# Equivalent PeeWee query:
# StudiesExperiments.select().JOIN(Studies).WHERE(StudiesExperiments.power << [30])


In [None]:

ex = experimental_qualitative_overview(study_title=['gamma'], power=[30], drive=[20,40,30])

# Equivalent PeeWee query:
# StudiesExperiments.select().JOIN(Studies).WHERE(
#                                                    (Studies.title.contains('gamma')) 
#                                                                      OR
#                                                    (
#                                                       (StudiesExperiments.power << [30]) 
#                                                                      AND 
#                                                       (StudiesExperiments.drive << [20, 30, 40])
#                                                    )
#                                                )


In [None]:

ex = experimental_qualitative_overview(study_title=['gaba'], study_authors=['Jordan'], power=[30])

# Equivalent PeeWee query:
# StudiesExperiments.select().JOIN(Studies).WHERE(
#                                                  (
#                                                     Studies.title.contains('gaba')
#                                                                 OR
#                                                     Studies.authors.contains('Jordan')
#                                                   )
#                                                   
#                                                                 OR
#                                                     (StudiesExperiments.power << [30])
#                                                )


In [None]:

ex = experimental_qualitative_overview(study_authors=['Jordan'])

# Equivalent PeeWee query:
# StudiesExperiments.select().JOIN(Studies).WHERE(Studies.authors.contains('Jordan'))


## Studies Meta data

In [None]:

get_meta(study_id = 6)


## Plot Observations

In [None]:





import numpy as np
import seaborn as sns
import matplotlib.pylab as plt
plt.rc(['xtick', 'ytick'], labelsize=30)
plt.rc(['axes'], titlesize=40)



observations, row_labels, col_labels = experimental_qualitative_overview(plot_table=False)

numeric_observations = np.zeros(observations.shape)

# Replace qualitative values with numbers
for row_index, row in enumerate(observations):
    
    for col_index, value in enumerate(row):
        
        if value == 'higher':
            numeric_observations[row_index][col_index] = 1
        if value == 'lower':
            numeric_observations[row_index][col_index] = -1
        if value == 'equal':
            numeric_observations[row_index][col_index] = 0
        if value == 'Not tested':
            numeric_observations[row_index][col_index] = -2
            
 
# Plot heatmap

fig, ax = plt.subplots(figsize=(40,20)) 
sns.heatmap(
    numeric_observations, ax = ax, cmap="Blues", cbar=False, annot_kws={"size": 22}, 
    annot=observations, fmt = '', xticklabels=col_labels, yticklabels=row_labels, linewidth=1
)
ax.set_title('Qualitative observations of Schizophrenic patients VS Healthy controls')

plt.show()


## Solve: 

## Compare (HC, SZ) values for participants who is older than 10 for power = 40 and drive = 40.

In [None]:







from assrunit.db import *

#healthy controls ID
hc_id = [d.ID for d in Disorders.select(Disorders.ID).where(Disorders.name.contains('controls'))][0]

#Schizophrenia disorder ID
sz_id = [d.ID for d in Disorders.select(Disorders.ID).where(Disorders.name.contains('schizophrenia'))][0]

# Get participants filtered as requested
participants = SubjectsGroups.select().where(
                                    (SubjectsGroups.mean_age > 10) & 
                                    (SubjectsGroups.disorder_id==sz_id)
                                )

# obtain studies IDs form filtered participants 
studies_ids = [s.study_id for s in participants]

# Get observations
hc_observations = StudiesExperiments.select().where(
                                    (StudiesExperiments.power==40) & 
                                    (StudiesExperiments.drive==40) & 
                                    (StudiesExperiments.study_id << studies_ids) & 
                                    (StudiesExperiments.disorder_id == hc_id)
                                )

sz_observations = StudiesExperiments.select().where(
                                    (StudiesExperiments.power==40) & 
                                    (StudiesExperiments.drive==40) & 
                                    (StudiesExperiments.study_id << studies_ids) & 
                                    (StudiesExperiments.disorder_id == sz_id)
                                )

# Obtain values from obsevations
hc_values = [v.value for v in hc_observations]
sz_values = [v.value for v in sz_observations]

# Prepare annotation matrix
annotation_matrix = np.ndarray((2, len(hc_values)), dtype='U10')
annotation_matrix[0] = sz_values
annotation_matrix[1] = hc_values

# Report-Appendix1 explains values convention: 
#    Both columns equal 0, means the study mentioned that both values were equal. 
#    One column equals 0 and the second equals -1, means the study mentioned that one values were lower than the other

for index, value in enumerate(annotation_matrix[0]): 
    if annotation_matrix[0][index] == '0.0' and annotation_matrix[1][index] == '0.0':
        annotation_matrix[0][index] = 'equal'
        annotation_matrix[1][index] = 'equal'

annotation_matrix = np.where(annotation_matrix=='-1.0', 'lower', annotation_matrix)

#heatmap plot

fig, ax = plt.subplots(figsize=(40,20)) 
sns.heatmap(
    [sz_values, hc_values], ax = ax, cmap="Blues", cbar=False, annot_kws={"size": 22}, 
    annot=annotation_matrix, fmt = '', xticklabels=[], yticklabels=['Schizophrenia','HC'], linewidth=1
)
ax.set_title('Values for Healthy controls and Schizophrenic patients under Filters: 40Power40Drive and mean_age > 10')
plt.show()