## True Class Ratio

### References:
- This is notebook is mostly a rewording of info from [Neo4j docs on Class Imbalance](https://neo4j.com/docs/graph-data-science/current/machine-learning/linkprediction-pipelines/theory/)


### Negative samples in training and precision-recall in evaluation

- Link prediction on graphs usually has a *class imbalance problem* because there are substantially less relationships or edges (postive examples) than missing edges (negative examples)
- To account for this, a training loader exposes the model to positive and negative samples at 1-1 ratio by default
- Moreover, the AUC-PR (area under precision-recall curve) metric evaluates precision according to the positive and negative classes (see [img](https://en.wikipedia.org/wiki/Precision_and_recall))
- During training we can tune the negative sample ratio (`negativeSampleRatio`)
- During validation we can tune the weight of false positives when computing precision (`negativeClassWeight`)
- Note: increasing negative samples will increase training time, while increasing class weigh improves eval accuracy. Both of these parameters can be adjusted in tandem to trade off evaluation accuracy with speed.

### Bias in dataset

- There are likely several sampling biases present in our dataset, i.e. impact on humans, sampling preferences for certain geographies, preferences to sample certain taxons, etc.
- To deal with this, we can consider resampling methods: oversample the minority classes or undersample the majority classes
- Oversampling methods:
    - Network Imputation: substituting missing data with values according to some criteria (density, degree distribution)
    - Synthetic Minority Over-sampling Technique (SMOTE)
    - GANs for generating synthetic samples 
- Undersampling methods:
    - Random walks with restarts to undersample majority classes and create uniform distribution

### Tuning true class ratio using existing bias

- Given the ability for our training loader to control the positive-negative ratio, we may not need to resolve the bias in the dataset but can tune the positive-negative sample ratio our model is trained on as well as tune the precision of AUC-PR eval metric by re-weighting false positives
- Unfortunately, we don't know the desired ratio of total probability mass of negative versus positive examples. That is, under perfect sampling, we don't know what the expected number of virus-host associations nor the total number of viruses.
- This notebook attempts to approximate this value using humans since they are a majority class. We do this by computing the true-class ratio: `(q - r) / r` where `q` is the number of possible undirected relationships and `r` is the number of actual undirected relationships.
- There are likely many assumptions with this approach, but it seems worthwhile to try it out. After training a model, I will begin interpreting the top-k predictions per palmprint using visualizations


In [19]:
# Notebook config
import sys
if '../' not in sys.path:
    sys.path.append("../")
%load_ext dotenv
%dotenv

# Actual imports
from queries import gds_queries

The dotenv extension is already loaded. To reload it, use:
  %reload_ext dotenv


In [15]:
# count viral palmprints using potential taxon values
total_viral_palmprints_count = gds_queries.run_query('''
    MATCH (p:Palmprint)-[r:HAS_POTENTIAL_TAXON]->(t:Taxon)
    WHERE t.taxKingdom = 'Viruses'
    RETURN COUNT(p) as count
''')['count'][0]
print(total_viral_palmprints_count)

937845


In [16]:
# count number of viral palmprint associations to humans (taxId: 9606)
human_assoc_viral_palmprints_count = gds_queries.run_query('''
    MATCH (s:Palmprint)<-[:HAS_PALMPRINT]-(:SRA)-[:HAS_HOST]->(t:Taxon)
    WHERE ((t)-[:HAS_PARENT*]->(:Taxon {taxId: '9606'})
    OR t.taxId = '9606')
    AND s.taxKingdom = 'Viruses'
    RETURN COUNT(s) as count
''')['count'][0]
print(human_assoc_viral_palmprints_count)

399996


In [18]:
true_class_ratio = (total_viral_palmprints_count - human_assoc_viral_palmprints_count) / human_assoc_viral_palmprints_count
print(true_class_ratio)

1.3446359463594635


In [29]:
run_id_with_class_ratio = '1685646457'
sample_ratio = 0.1
dir_prefix = f'/mnt/graphdata/results/link_prediction/{sample_ratio}/{run_id_with_class_ratio}'

for artifact_file in ['eval.csv', 'dataset.txt', 'model.txt', 'pipeline.txt']:
    with open(f'{dir_prefix}/{artifact_file}', 'r') as f:
        print(f"\n{artifact_file}\n")
        print(f.read())



eval.csv

0
"{'modelCandidates': [{'metrics': {'AUCPR': {'validation': {'avg': 0.2855628751695121, 'min': 0.2841136484768501, 'max': 0.2866516512101441}, 'train': {'avg': 0.2855612414795886, 'min': 0.2854407327576802, 'max': 0.28572274813292087}}}, 'parameters': {'maxEpochs': 100, 'minEpochs': 1, 'classWeights': [], 'penalty': 1.0, 'patience': 2, 'methodName': 'MultilayerPerceptron', 'focusWeight': 0.0, 'hiddenLayerSizes': [64, 16, 4], 'batchSize': 100, 'tolerance': 0.001, 'learningRate': 0.001}}, {'metrics': {'AUCPR': {'validation': {'avg': 0.7318097434484135, 'min': 0.7309953050924338, 'max': 0.7329281574676022}, 'train': {'avg': 0.7318139688407923, 'min': 0.7316862942169846, 'max': 0.7319015512869126}}}, 'parameters': {'maxEpochs': 100, 'minEpochs': 1, 'classWeights': [], 'penalty': 0.16145378017027948, 'patience': 1, 'methodName': 'LogisticRegression', 'focusWeight': 0.0, 'batchSize': 100, 'tolerance': 0.001, 'learningRate': 0.001}}, {'metrics': {'AUCPR': {'validation': {'avg': 0.