<h2>Encrypted DNA ancestry using Concrete ML by Horaizon27 team</h2>


<h3>Imports</h3>

In [24]:
import os
import shutil
import time
import yaml
import pandas as pd
import numpy as np
from scipy import stats

from src.utils import read_vcf, save_dict, vcf_to_npy, read_genetic_map
from src.laidataset import LAIDataset
from src.model import Gnomix
from concrete_models import ConcreteGnomix
from training_utils import get_data

<h3>Constants</h3>

In [7]:
QUERY_FILE = "data/ALL.chr22.phase3_shapeit2_mvncall_integrated_v5b.20130502.genotypes.vcf.gz"
GENETIC_MAP_FILE = "data/allchrs.b37.gmap"
REFERENCE_FILE = "data/reference_1000g.vcf"
SAMPLE_MAP_FILE = "data/1000g.smap"
SINGLE_ANCESTRY_SAMPLES_FILE = "data/samples_1000g.tsv"
WORKING_DIR = "tmp"
TRAINING_DATA_DIR = os.path.join(WORKING_DIR, "training_data")
CHM = "22"

<h3>Data preparation</h3>

Almost all necessary files are stored in **/data** directory. 
However, some steps still need to be completed:
1) Download **query file** and put it in **/data** directory
2) Create **reference file** and put it in **/data** directory
3) Load config file
4) Generate trainning data


Downloading query file

In [None]:
import urllib.request
QUERY_FILE_URL = "https://ftp.1000genomes.ebi.ac.uk/vol1/ftp/release/20130502/ALL.chr22.phase3_shapeit2_mvncall_integrated_v5b.20130502.genotypes.vcf.gz"
urllib.request.urlretrieve(QUERY_FILE_URL, QUERY_FILE)

Creating reference file

In [None]:
sample_map = pd.read_csv(SAMPLE_MAP_FILE, sep="\t")
np.savetxt(
    SINGLE_ANCESTRY_SAMPLES_FILE, list(sample_map["#Sample"]), delimiter="\t", fmt="%s"
)
cmd = "bcftools view -S {} -o {} {}".format(SINGLE_ANCESTRY_SAMPLES_FILE, REFERENCE_FILE, QUERY_FILE)
os.system(cmd)

Loading models config

In [8]:
with open("config.yaml", "r") as file:
    config = yaml.load(file, Loader=yaml.UnsafeLoader)

Generating training data

In [9]:
def generate_training_data(
        training_data_dir, reference, genetic_map, sample_map, chm, config, force_regeneration=False
):
    """ Training data generation
    Imported almost as is from Gnomix repo (simulate_splits function in gnomix.py)
    """

    r_admixed = config["simulation"]["r_admixed"]
    print (
        "Generating training data in {} with r_admixed: {}".format(training_data_dir, r_admixed)
    )
    
    if os.path.exists(training_data_dir):
        if force_regeneration:
            shutil.rmtree(training_data_dir)
        else:
            print ("Training data already exists", training_data_dir)
            return
    os.makedirs(training_data_dir)


    laidataset = LAIDataset(chm, reference, genetic_map, seed=config["seed"])
    laidataset.buildDataset(sample_map)

    sample_map_path = os.path.join(training_data_dir, "sample_maps")
    os.makedirs(sample_map_path)

    # split sample map and write it.
    splits = config["simulation"]["splits"]["ratios"]
    if len(laidataset) <= 25:
        if splits.get("val"):
            print("WARNING: Too few samples to run validation.")
            del config["simulation"]["splits"]["ratios"]["val"]
    laidataset.create_splits(splits, sample_map_path)

    save_dict(laidataset.metadata(), os.path.join(training_data_dir, "metadata.pkl"))

    # get num_outs
    split_generations = config["simulation"]["splits"]["gens"]
    
    num_outs = {}
    min_splits = {"train1": 400, "train2": 75, "val": 25}
    for split in splits:
        total_sim = max(
            len(laidataset.return_split(split)) * r_admixed,
            min_splits[split]
        )
        num_outs[split] = int(total_sim / len(split_generations[split]))


    for split in splits:
        split_path = os.path.join(training_data_dir, split)
        if not os.path.exists(split_path):
            os.makedirs(split_path)
        for gen in split_generations[split]:
            laidataset.simulate(
                num_outs[split],
                split=split,
                gen=gen,
                outdir=os.path.join(split_path, "gen_{}".format(gen)),
                return_out=False
            )

    print ("Generated {} splits: {}".format(len(splits), splits))
    return

In [10]:
generate_training_data(
    TRAINING_DATA_DIR, REFERENCE_FILE, GENETIC_MAP_FILE, SAMPLE_MAP_FILE, CHM, config, force_regeneration=True
)

Generating training data in tmp/training_data with r_admixed: 0.6
Reading vcf file...
Getting genetic map info...
Getting sample map info...
Building founders...
Splitting sample map...
Generated 3 splits: {'train1': 0.8, 'train2': 0.15, 'val': 0.05}


<h2>Models</h2>

We are going to train two models: Gnomix (Non-FHE) and ConcreteGnomix (FHE)<br>
- Both models will be trained using the same **config.yaml** file that was downloaded from Gnomix repository
- Both models will use the same generated training data

More information about models you can find in **README.md**

<h3>Models training</h3>

In [19]:
def train_model(config, training_data_path, genetic_map_path, chm, model_type="concrete", evaluate=False):
    """ Model training
    Creates and trains model depending on mode and config data
    'default' mode: original Gnomix model with LogisticRegression and XGBClassifier
    'concrete' mode: similar to Gnomix model with concrete versions of LogisticRegression and XGBClassifier
    """

    window_size_cM=config["model"].get("window_size_cM")
    smooth_window_size=config["model"].get("smooth_size")
    n_cores=config["model"].get("n_cores", None)
    retrain_base=config["model"].get("retrain_base")
    calibrate=config["model"].get("calibrate")
    context_ratio=config["model"].get("context_ratio")
    generations = config["simulation"]["splits"]["gens"]

    print("Reading training data...")
    data, meta = get_data(training_data_path, generations, window_size_cM, model_type)

    if model_type == "concrete":
        print("Training Concrete model...")
        model = ConcreteGnomix(
            C=meta["C"], M=meta["M"], A=meta["A"], S=smooth_window_size,
            snp_pos=meta["snp_pos"], snp_ref=meta["snp_ref"], snp_alt=meta["snp_alt"],
            population_order=meta["pop_order"], calibrate=calibrate,
            n_jobs=n_cores, context_ratio=context_ratio, seed=config["seed"],
        )
        model.train(data=data, retrain_base=retrain_base, evaluate=evaluate, compile=False)

    elif model_type == "gnomix":
        print("Training Gnomix model...")
        model = Gnomix(
            C=meta["C"], M=meta["M"], A=meta["A"], S=smooth_window_size,
            snp_pos=meta["snp_pos"], snp_ref=meta["snp_ref"], snp_alt=meta["snp_alt"],
            population_order=meta["pop_order"], calibrate=calibrate,
            n_jobs=n_cores, context_ratio=context_ratio, seed=config["seed"],
        )
        model.train(data=data, retrain_base=retrain_base, evaluate=evaluate)

    else:
        raise Exception("Unknown model type: {}".format(model_type))

    # write gentic map df to model variable
    model.write_gen_map_df(
        read_genetic_map(genetic_map_path, chm)
    )

    return model

Gnomix model training

In [12]:
gnomix_model = train_model(config, TRAINING_DATA_DIR, GENETIC_MAP_FILE, CHM, model_type="gnomix", evaluate=True)

Reading training data...
Training Gnomix model...
Training base models...
Training smoother...
Evaluating model...
training accuracy
val accuracy
Re-training base models...


Concrete model training

In [20]:
concrete_model = train_model(config, TRAINING_DATA_DIR, GENETIC_MAP_FILE, CHM, model_type="concrete", evaluate=True)

Reading training data...
Training Concrete model...
Training base model...
Training smoother...
Evaluating model...
training accuracy
val accuracy
Re-training base models...
Base model compile
Smooth model compile


<h2>Inference time comparison</h2>

For comparison, we will try three options:

1) Default Gnomix model (non-FHE)
2) ConcreteGnomix model, which uses **FHE simulation** at both stages (sim-FHE)
3) ConcreteGnomix model, which uses **FHE** only at the first stage (half-FHE)

In [14]:
def get_inference(model, query_vcf_data, fhe_data=None):
    
    # preparing data
    X_query, _, _ = vcf_to_npy(
        query_vcf_data, model.snp_pos, model.snp_ref, return_idx=True, verbose=False
    )
    samples = query_vcf_data["samples"]

    predictions_start = time.time()
    # making predictions
    if fhe_data is None:
        B_query = model.base.predict_proba(X_query)
        y_proba_query = model.smooth.predict_proba(B_query)
    else: 
        B_query = model.base.predict_proba(X_query, fhe=fhe_data["base"])
        y_proba_query = model.smooth.predict_proba(B_query, fhe=fhe_data["smooth"])
    y_pred_query = np.argmax(y_proba_query, axis=-1)

    # getting final prediction
    ind_idx = np.arange(0, len(y_pred_query), 2) 
    final_prediction = stats.mode(y_pred_query[ind_idx,:], axis=1).mode

    predictions_end = time.time()
    avg_time = (predictions_end - predictions_start) / len(samples)
    print("Average inference time per sample: {}".format(avg_time))

    predictions = {
        sample_name: int(prediction)
        for sample_name, prediction in zip(samples, final_prediction)
    }

    return {"predictions": predictions, "avg_inference_time": avg_time}

Preparing test queries

In [15]:
QUERIES_NUM = 10

query_samples = pd.read_csv(SAMPLE_MAP_FILE, sep="\t").sample(QUERIES_NUM).set_index("#Sample").to_dict()['Panel']
query_vcf_data = read_vcf(QUERY_FILE, chm=CHM, fields="*", samples=query_samples.keys())

Gnomix model (Non-FHE) 

In [16]:
gnomix_inference_result = get_inference(gnomix_model, query_vcf_data)

Average inference time per sample: 0.9127118825912476


Concrete model in simulation mode (sim-FHE)

In [17]:
fhe_data = {"base": "simulate", "smooth": "simulate"}
concrete_sim_inference_result = get_inference(concrete_model, query_vcf_data, fhe_data=fhe_data)

Average inference time per sample: 19.036097359657287


Concrete model with FHE only at first stage   (half-FHE)

In [18]:
fhe_data = {"base": "execute", "smooth": "disable"}
concrete_half_inference_result = get_inference(concrete_model, query_vcf_data, fhe_data=fhe_data)

Average inference time per sample: 826.38837920489296635


<h2>Results</h2>

In [23]:
results = pd.DataFrame(
    [
        [
            gnomix_model.accuracies["smooth_val_acc_bal"],
            concrete_model.accuracies["smooth_val_acc_bal"],
            concrete_model.accuracies["smooth_val_acc_bal"]
        ],
        [
            gnomix_inference_result['avg_inference_time'],
            concrete_sim_inference_result['avg_inference_time'],
            concrete_half_inference_result['avg_inference_time']
        ]
    ], 
    index=["Accuracy", "Inference time"], 
    columns=["Non-FHE", "Sim-FHE", "Half-FHE"]
)
results

Unnamed: 0,Non-FHE,Sim-FHE,Half-FHE
Accuracy,97.75,97.26,97.26
Inference time,0.912712,19.036097,826.388379
