# MiGenPro Notebook Workflow
#### Set your configuration below and run each step as needed


In [None]:

import os
import json
import pandas as pd
from types import SimpleNamespace

from migenpro.ml.parameter_optimisation import ParameterOptimisation
from migenpro.querying.query_executor import QueryExecutor
from migenpro.querying.genome_annotation import GenomeAnnotationWorkflow
from migenpro.querying.query_executor import QueryExecutor
from migenpro.ml.machine_learning_main import FeatureMatrix, PhenotypeMatrix, MachineLearningModels
from migenpro.querying.query_parser import QueryParser
import os


# Configuration - Set your parameters here
config = {
    "output_dir": "./output/",  # Main output directory
    "debug": True,             # Enable debug mode

    "dataset_bin": "../binaries/datasets",        # Path to ncbi-datasets-cli
    "threads": 2,

    # Data formatting parameters
    "phenotype_query_file": "sparql_phenotype:motility.sparql",
    "sapp_jar": "./binaries/SAPP-2.0.jar",
    "phenotype_hdt_file": "../data/bacdive.hdt.gz",
    "abs_frequency": 200,
    "species_frequency": 10,

    # Genome querying parameters
    "genome_query_file": "sparql_genome:DomainCopyNumber.sparql",

    # Machine learning parameters
    "feature_matrix": "./output/feature_matrix.tsv",
    "phenotype_matrix_ml": "./output/phenotype_matrix.tsv",
    "load_model": None,
    "param_grids": "../tests/resources/param_grids.json",
    "dt_depth": 5,
    "rf_depth": 10,
    "gb_depth": 5,
    "num_trees": 100,
    "max_iter": 1000,
    "proportion_train": 0.8,
    "rf_min_leaf": 1,
    "rf_min_split": 2,
    "gb_min_samples": 2,
    "gb_learning_rate": 0.1,
    "sampling_type": "under_sample",
}


## ----------------------------
### Data Formatting Workflow
## ----------------------------

In [None]:
print("Starting data formatting workflow...")

# Create QueryExecutor instance
qe = QueryExecutor(
    config.get('phenotype_query_file'),
    config.get('sapp_jar'),
    debug=config["debug"]
)

# Execute query and process results
phenotype_file_path = os.path.join(config["output_dir"], "phenotype.tsv")
qe.execute_sapp_locally_file(
    hdt_file=config.get('phenotype_hdt_file'),
    output_file=phenotype_file_path
)

# Parse and filter results
phenotype_output = QueryParser(file_path=phenotype_file_path)

if config.get('rel_frequency'):
    phenotype_output.filter_by_relative_frequency(config.get('rel_frequency'))

if config.get('abs_frequency'):
    phenotype_output.filter_by_absolute_frequency(config.get('abs_frequency'))

phenotype_output.filter_by_species_frequency(config.get('species_frequency', 10))
phenotype_output.convert_to_phenotype_matrix()

# Save phenotype matrix
output_path = phenotype_file_path.replace("phenotype.tsv", "phenotype_matrix.tsv")
phenotype_output.write_phenotype_matrix_to_file(output_path)
print(f"Phenotype matrix created at: {output_path}")

## ----------------------------
### Annotation Workflow
## ----------------------------

In [None]:
os.makedirs(config["output_dir"], exist_ok=True)
print("Starting genome annotation workflow...")

workflow = GenomeAnnotationWorkflow(
    output_dir= config["output_dir"],
    threads= config.get('threads'),
    debug=config["debug"],
    ncbi_dataset_bin=config.get('dataset_bin')
)

phenotype_df = pd.read_csv(os.path.join(config.get('output_dir'),"phenotype_matrix.tsv"), index_col=0, sep="\t")
genome_identifiers = phenotype_df.index.to_list()

print(list(genome_identifiers)[0])
if config.get('dataset_bin') and os.path.exists(config.get('dataset_bin')):
    print("Downloading genomes using ncbi-datasets-cli")
    fasta_genome_paths = workflow.download_genomes_from_genome_identifier(genome_identifiers[:100])
else:
    if len(genome_identifiers) > 100:
        print("Warning: Downloading >100 genomes from NCBI may be slow")
    genome_identifiers = [
        gid.split('.')[0] for gid in genome_identifiers
    ]
    fasta_genome_paths = [
        f"http://www.ebi.ac.uk/ena/browser/api/fasta/{gid}?download=true&gzip=true"
        for gid in genome_identifiers
    ]


print(fasta_genome_paths[:2])

In [None]:
if len(genome_identifiers):
    genome_hdt_files = workflow.process_batch(fasta_genome_paths)
    print(f"Annotation complete. Results in: {config['output_dir']}")

## ----------------------------
### Genome Querying Workflow
## ----------------------------

In [None]:
print("Starting genome querying workflow...")

# Run genome querying
from migenpro.querying.query_executor import QueryExecutor

# Create QueryExecutor instance
qe = QueryExecutor(
    config.get('genome_query_file'),
    config.get('sapp_jar'),
    debug=config["debug"]
)

# Create genome directory
genome_dir = os.path.join(config["output_dir"], "genomes")
os.makedirs(genome_dir, exist_ok=True)

# Execute queries
individual_genome_feature_paths = qe.execute_sapp_locally_directory(genome_dir)

# Create feature matrix
feature_matrix_path = os.path.join(config["output_dir"], "feature_matrix.tsv")
qe.summarise_feature_importance_files(individual_genome_feature_paths, feature_matrix_path)
print(f"Feature matrix created at: {feature_matrix_path}")


## ----------------------------
### Machine Learning Workflow
## ----------------------------

In [None]:
print("Starting machine learning workflow...")

# Run ML workflow

# Load matrices
fm = FeatureMatrix(config.get('feature_matrix'))
pm = PhenotypeMatrix(config.get('phenotype_matrix_ml'))
fm.load_matrix()
pm.load_matrix()

# Get common genomes
intersect_genomes = pm.get_intersected_genomes(fm.file_df)
fm_subset = fm.create_subset(intersect_genomes)
pm_subset = pm.create_subset(intersect_genomes)

# Setup ML models
ml_models = ParameterOptimisation(
    dt_depth=config.get('dt_depth', 5),
    rf_depth=config.get('rf_depth', 10),
    gb_depth=config.get('gb_depth', 5),
    num_trees=config.get('num_trees', 100),
    max_iter=config.get('max_iter', 1000),
    output=config.get('ml_output', "./output/ml_results"),
    proportion_train=config.get('proportion_train', 0.8),
    rf_min_leaf=config.get('rf_min_leaf', 1),
    rf_min_split=config.get('rf_min_split', 2),
    gb_min_samples=config.get('gb_min_samples', 2),
    gb_learning_rate=config.get('gb_learning_rate', 0.1),
    debug=config["debug"]
)

# Configure datasets
ml_models.set_datasets(
    observed_values=fm_subset,
    observed_results=pm_subset,
    sampling_type=config.get('sampling_type', 'under_sample'),
    threads=config.get('threads', 4)
)

# Parameter tuning
if config.get('param_grids'):
    with open(config.get('param_grids'), "r") as f:
        param_grids = json.load(f)
    optimized_params = ml_models.perform_halving_grid_search_search(param_grids=param_grids)
    ml_models = MachineLearningModels(parameter_dictionary=optimized_params)

# Training/prediction
if config.get('train', True):
    ml_models.train_models()
    ml_models.save_models()
elif config.get('load_model'):
    ml_models.load_model(config.get('load_model'))

if config.get('predict'):
    ml_models.predict_models_test()
    ml_models.predict_models_train()

print(f"ML results saved to: {ml_models.output}")

## ----------------------------
### Feature Importance Analysis
## ----------------------------

In [None]:
print("Starting feature importance analysis...")

# Run feature importance
from migenpro.post_analysis.ml_model_analysis import LoadedMachineLearningModel, ModelAnalysis

# Load matrices
fm = FeatureMatrix(config.get('feature_matrix'))
pm = PhenotypeMatrix(config.get('phenotype_matrix_ml'))
fm.load_matrix()
pm.load_matrix()

# Get common genomes
intersect_genomes = pm.get_intersected_genomes(fm.file_df)
fm_subset = fm.create_subset(intersect_genomes)
pm_subset = pm.create_subset(intersect_genomes)

# Analyze each model
model_dir = config.get('models_dir', "./output/ml_results/models")
model_files = [
    os.path.join(model_dir, f)
    for f in os.listdir(model_dir)
    if f.endswith('.pkl')
]

for model_path in model_files:
    model = LoadedMachineLearningModel(model_path)
    print(f"Analyzing model: {model.model_name}")

    analysis = ModelAnalysis(
        "target_phenotype",
        model,
        fm_subset,
        pm_subset
    )

    if model.gini:
        analysis.plot_gini_feature_importance(
            save_path=os.path.join(config["output_dir"], f"gini_{model.model_name}.png")
        )

    if model.permutation:
        analysis.permutation_feature_importance(
            save_path=os.path.join(config["output_dir"], f"permutation_{model.model_name}.png")
        )

print("Feature importance analysis complete")

## ----------------------------
### ML Results summary
## ----------------------------

In [None]:
from migenpro.post_analysis.ml_summarise import MachineLearningData, SummaryGraphs

machine_learning_output_data = MachineLearningData(config["output_dir"])
for scenario in ["test", "train"]:
    summary_graphs = SummaryGraphs(machine_learning_output_data, config['output_dir'], debug=config['debug'])
    summary_graphs.analyse_classifiers(scenario=scenario)
    summary_graphs.make_method_summary_graphs()
    summary_graphs.output_scores_to_table()

print(f"Summarised results are located in {config['output_dir']}.")
