# True Class Ratio

## Overview

### References

- This overview is mostly a rewording of info from [Neo4j docs on Class Imbalance](https://neo4j.com/docs/graph-data-science/current/machine-learning/linkprediction-pipelines/theory/)
- https://en.wikipedia.org/wiki/Precision_and_recall
- https://stats.stackexchange.com/a/302087


### Bias in dataset

- There are likely several sampling biases present in our dataset, i.e. impact on humans, sampling preferences for certain geographies, preferences to sample certain taxons, etc.
- To deal with this, we can consider resampling methods: oversample the minority classes or undersample the majority classes
- Oversampling methods:
    - Network Imputation: substituting missing data with values according to some criteria (density, degree distribution)
    - Synthetic Minority Over-sampling Technique (SMOTE)
    - GANs for generating synthetic samples 
- Undersampling methods:
    - Random walks with restarts to undersample majority classes and create uniform distribution


### Negative samples in training and precision-recall in evaluation

- Link prediction on graphs usually has a *class imbalance problem* because there are substantially less relationships or edges (postive examples) than missing edges (negative examples)
- To account for this, a training loader exposes the model to positive and negative samples at 1-1 ratio by default
- Moreover, the AUC-PR (area under precision-recall curve) metric evaluates precision according to the positive and negative classes (see [img](https://en.wikipedia.org/wiki/Precision_and_recall))
- If we only care about identifying all the minority class records, we could emphasize recall, but would be accepting more false positives. Optimizing only precision would allow our model to underdetect minority classes since the easiest way to have a high precision is to be overcautious in declaring the minority class.
- To evaluate our models ability to generalize to both major and minor, we can evaluate using Area under the Precision-Recall curve and run validation on datasets with the true distribution as well as a dataset with minor classes only
- During training we can tune the negative sample ratio (`negativeSampleRatio`)
- During validation we can tune the weight of false positives when computing precision (`negativeClassWeight`)


### Approximating true class ratio using existing bias

- Given the ability for our training loader to control the positive-negative ratio, we may not need to resolve the bias in the dataset but can tune the positive-negative sample ratio our model is trained on as well as tune the precision of AUC-PR eval metric by re-weighting false positives
- Unfortunately, we don't know the desired ratio of total probability mass of negative versus positive examples. That is, under perfect sampling, we don't know what the expected number of virus-host associations nor the total number of viruses.
- This notebook attempts to approximate this value using humans since they are a majority class. We do this by computing the true-class ratio: `(q - r) / r` where `q` is the number of possible undirected relationships and `r` is the number of actual undirected relationships.
- There are likely many assumptions with this approach, but it seems worthwhile to try it out. After training a model, I will begin interpreting the top-k predictions per palmprint using visualizations

### Tuning negative sampling ratio and negative weight class

- The recommended value for negativeSamplingRatio is the true class ratio of the graph, in other words, not applying undersampling
- We can also tune negativeClassWeight. To be consistent with traditional evaluation, one should choose parameters so that `negativeSamplingRatio * negativeClassWeight = 1.0`.
- Alternatively, one can aim for the ratio of total probability weight between the classes to be close to the true class ratio, i.e.`negativeSamplingRatio * negativeClassWeight ~= true class ratio`. The reported metric (AUCPR) then better reflects the expected precision on unseen highly imbalanced data. With this type of evaluation one has to adjust expectations as the metric value then becomes much smaller.
- Increasing negative samples will increase training time, while increasing negative class weight improves eval accuracy. Both of these parameters can be adjusted in tandem to trade off evaluation accuracy with speed.

In [1]:
# Notebook config
import sys
if '../' not in sys.path:
    sys.path.append("../")
%load_ext dotenv
%dotenv

# Actual imports
from queries import gds_queries

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# count viral palmprints using potential taxon values
total_viral_palmprints_count = gds_queries.run_query('''
    MATCH (p:Palmprint)-[r:HAS_POTENTIAL_TAXON]->(t:Taxon)
    WHERE t.taxKingdom = 'Viruses'
    RETURN COUNT(p) as count
''')['count'][0]
print(total_viral_palmprints_count)

937845


In [5]:
total_viral_sotu_count = gds_queries.run_query('''
    MATCH (p:SOTU)-[r:HAS_POTENTIAL_TAXON]->(t:Taxon)
    WHERE t.taxKingdom = 'Viruses'
    RETURN COUNT(p) as count
''')['count'][0]
print(total_viral_sotu_count)

431864


In [16]:
# count number of viral palmprint associations to humans (taxId: 9606)
human_assoc_viral_palmprints_count = gds_queries.run_query('''
    MATCH (s:Palmprint)<-[:HAS_PALMPRINT]-(:SRA)-[:HAS_HOST]->(t:Taxon)
    WHERE ((t)-[:HAS_PARENT*]->(:Taxon {taxId: '9606'})
    OR t.taxId = '9606')
    AND s.taxKingdom = 'Viruses'
    RETURN COUNT(s) as count
''')['count'][0]
print(human_assoc_viral_palmprints_count)

399996


In [6]:
# count number of viral palmprint associations to humans (taxId: 9606)
human_assoc_viral_sotu_count = gds_queries.run_query('''
    MATCH (s:SOTU)<-[:HAS_PALMPRINT]-(:SRA)-[:HAS_HOST]->(t:Taxon)
    WHERE ((t)-[:HAS_PARENT*]->(:Taxon {taxId: '9606'})
    OR t.taxId = '9606')
    AND s.taxKingdom = 'Viruses'
    RETURN COUNT(s) as count
''')['count'][0]
print(human_assoc_viral_sotu_count)

257779


In [18]:
# true_class_ratio = (q - r) / r
true_class_ratio = (total_viral_palmprints_count - human_assoc_viral_palmprints_count) / human_assoc_viral_palmprints_count
print(true_class_ratio)

1.3446359463594635


In [8]:
# true_class_ratio = (q - r) / r
true_class_ratio_sotu = (total_viral_sotu_count - human_assoc_viral_sotu_count) / human_assoc_viral_sotu_count
print(true_class_ratio_sotu)

0.6753265394000287


## Tuning experiments

Check AUC-PR in last line in evals.csv under `validation`, and `bestParameters` for selected model.

In [38]:
def log_run_artifacts(run_id):
    sample_ratio = 0.1
    dir_prefix = f'/mnt/graphdata/results/link_prediction/{sample_ratio}/{run_id}'

    for artifact_file in ['config.txt', 'eval.csv', 'dataset.txt', 'model.txt', 'pipeline.txt']:
        print(f"\n{artifact_file}\n")
        try:
            with open(f'{dir_prefix}/{artifact_file}', 'r') as f:
                print(f.read())
        except FileNotFoundError:
            print(f"File {artifact_file} not found")

### Recommended settings

In [39]:
# Run with negative sample rate = 1.34, negative class weight = 1, dataset sample ratio = 0.1
recommended_settings_run_id = '1685655386'
log_run_artifacts(recommended_settings_run_id)



config.txt

File config.txt not found

eval.csv

0
"{'modelCandidates': [{'metrics': {'AUCPR': {'validation': {'avg': 0.2855628751695121, 'min': 0.2841136484768501, 'max': 0.2866516512101441}, 'train': {'avg': 0.2855612414795886, 'min': 0.2854407327576802, 'max': 0.28572274813292087}}}, 'parameters': {'maxEpochs': 100, 'minEpochs': 1, 'classWeights': [], 'penalty': 1.0, 'patience': 2, 'methodName': 'MultilayerPerceptron', 'focusWeight': 0.0, 'hiddenLayerSizes': [64, 16, 4], 'batchSize': 100, 'tolerance': 0.001, 'learningRate': 0.001}}, {'metrics': {'AUCPR': {'validation': {'avg': 0.7318097434484135, 'min': 0.7309953050924338, 'max': 0.7329281574676022}, 'train': {'avg': 0.7318139688407923, 'min': 0.7316862942169846, 'max': 0.7319015512869126}}}, 'parameters': {'maxEpochs': 100, 'minEpochs': 1, 'classWeights': [], 'penalty': 0.16145378017027948, 'patience': 1, 'methodName': 'LogisticRegression', 'focusWeight': 0.0, 'batchSize': 100, 'tolerance': 0.001, 'learningRate': 0.001}}, {'metric

## Traditional settings

Consistent with traditional evaluation

In [31]:
# negative_sample_ratio = true_class_ratio = 1.34
# negative_sample_ratio * negative_class_weight = 1
# negative_class_weight = 1 / negative_sample_ratio
negative_class_weight_traditional = 1 / true_class_ratio
print(negative_class_weight_traditional)

0.7436957212898044


In [40]:
# Run with negative sample rate = 1.34, negative class weight = 0.74, dataset sample ratio = 0.1

traditional_settings_run_id = '1685651665'
log_run_artifacts(traditional_settings_run_id)


config.txt

{'PROJECTION_NAME': 'palmprint-host-dataset', 'PIPELINE_NAME': 'lp-pipeline', 'MODEL_NAME': 'lp-model', 'RANDOM_SEED': 42, 'SAMPLING_RATIO': 0.1, 'TEST_FRACTION': 0.3, 'TRAIN_FRACTION': 0.6, 'VALIDATION_FOLDS': 10, 'NEGATIVE_SAMPLING_RATIO': 1.34, 'NEGATIVE_CLASS_WEIGHT': 0.75, 'PREDICTION_THRESHOLD': 0.7}

eval.csv

0
"{'modelCandidates': [{'metrics': {'AUCPR': {'validation': {'avg': 0.3423030834666072, 'min': 0.34055368465933195, 'max': 0.34362130890285475}, 'train': {'avg': 0.3423019294745503, 'min': 0.34215593788262316, 'max': 0.3424967171626045}}}, 'parameters': {'maxEpochs': 100, 'minEpochs': 1, 'classWeights': [], 'penalty': 1.0, 'patience': 2, 'methodName': 'MultilayerPerceptron', 'focusWeight': 0.0, 'hiddenLayerSizes': [64, 16, 4], 'batchSize': 100, 'tolerance': 0.001, 'learningRate': 0.001}}, {'metrics': {'AUCPR': {'validation': {'avg': 0.771651091647092, 'min': 0.7708950411730429, 'max': 0.7726628620271948}, 'train': {'avg': 0.7716554055781006, 'min': 0.77153958

## Alternative settings

The reported metric (AUCPR) better reflects the expected precision on unseen highly imbalanced data. With this type of evaluation one has to adjust expectations as the metric value becomes much smaller.

In [None]:
# negative_sample_ratio * negative_class_weight = true_class_ratio = 1.34
# WANT: max(negative_class_weight) and min(negative_sample_ratio)
# x * y = 1.34, max(x) and min(y)
negative_class_weight_traditional = 1 / true_class_ratio
print(negative_class_weight_traditional)

In [None]:
alternative_settings_run_id = ''
log_run_artifacts(alternative_settings_run_id)

TODO:
- [ ] Plot evals for different settings