# Apply classification model

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib notebook
# %matplotlib inline

In [2]:
import sys

from catboost import CatBoostClassifier, Pool
from scipy.ndimage import binary_dilation
from sklearn import clone
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, BaggingClassifier
from sklearn.feature_selection import SequentialFeatureSelector
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split, StratifiedKFold, RepeatedStratifiedKFold
from sklearn.model_selection import cross_val_score, cross_val_predict, cross_validate
from sklearn.svm import SVC, LinearSVC
from sklearn.tree import DecisionTreeClassifier
from pathlib import Path
import numpy as np
import pandas as pd
from definitions import ROOT_DIR
import matplotlib.pyplot as plt
import seaborn as sns
from metaspace.sm_annotation_utils import SMInstance
from metaspace.image_processing import clip_hotspots

import getpass
from metaspace import SMInstance
from datetime import datetime

In [3]:
# Suppress warnings, because many models spam them during feature selection
# as some subsets of features just don't have enough information to make
# a good model.
import warnings
from sklearn.exceptions import ConvergenceWarning
warnings.simplefilter('ignore', ConvergenceWarning)

## Utility functions

In [4]:
def colorize_image_with_mask(image, mask):
    """Plotting function for combining a colorized ion image with a spot mask"""
    
    image = clip_hotspots(image)
    image /= np.max(image)
    
    on_spot_colorized = plt.cm.cividis(image)
    off_spot_colorized = plt.cm.magma(image)
    return np.where(mask[:,:,np.newaxis], on_spot_colorized, off_spot_colorized)
    
def save_image_with_mask(image, mask, fname):
    plt.imsave(fname, colorize_image_with_mask(image, mask))

In [5]:
def crop_zeros(img):
    """Crop an image, removing all empty outer rows/columns"""
    cols = np.flatnonzero(np.count_nonzero(img, axis=0) != 0)
    rows = np.flatnonzero(np.count_nonzero(img, axis=1) != 0)
    top = rows[0]
    bottom = rows[-1] + 1
    left = cols[0]
    right = cols[-1] + 1

    return img[top:bottom, left:right]

In [6]:
def get_mispredictions(model, X, y):
    """
    Find which values would be mispredicted, returning two lists:
        * indexes of items that would be falsely predicted as positives
        * indexes of items that would be falsely predicted as negatives
        
    cross_val_predict uses a shuffled 5-fold test-train split so that each chunk of 
    20% of the input data gets its own model that was trained on the other 80%, 
    ensuring that the items being predicted aren't included in the training data.
    """
    preds = cross_val_predict(model, X, y)
    mispreds = preds != y
    fpos_idxs = np.flatnonzero(mispreds & ~y)
    fneg_idxs = np.flatnonzero(mispreds & y)
        
    return fpos_idxs, fneg_idxs

## Paths

In [7]:
p_root_dir = Path(ROOT_DIR)

p_analysis = p_root_dir  / "4_model_evaluation"
p_grids = p_root_dir / r"2_grid_calibration\grid_masks"
p_wellmap = p_root_dir / "5_data_analysis/wellmap.csv"

# Paths for evaluation
p_eval = p_analysis / "model_application"
p_metrics = p_eval / "metrics.csv"
p_images = p_eval / "images.hdf5"
p_model = p_analysis / 'model_evaluation' / 'model.json'

timestamp = datetime.now().strftime("%d-%b-%Y") 
p_predictions = p_root_dir / f"5_data_analysis/all_predictions_{timestamp}.csv"

# False positives/negatives - preview output from model prediction for molecules with known labels
# Note that all files in these directories are cleared before a prediction run
p_eval_fpos = p_eval / 'false_positives'
p_eval_fneg = p_eval / 'false_negatives'
p_eval_tpos = p_eval / 'true_positives'
p_eval_tneg = p_eval / 'true_negatives'
# Unknown positives/negatives - preview output from model prediction for molecules with no label
# Note that all files in these directories are cleared before a prediction run
p_eval_upos = p_eval / 'unknown_positives'
p_eval_uneg = p_eval / 'unknown_negatives'
# Manually labeled positives/negatives - Move preview files from any of the above directories into 
# these directories to add to the labelled data. Make sure to re-run the appropriate steps 
# in "Input data" to detect the changes
p_eval_lpos = p_eval / 'manual_label_positives'
p_eval_lneg = p_eval / 'manual_label_negatives'

# Directories for three-state positive/unsure/negative classification
p_tri_pos = p_eval / 'three-state' / 'positive'
p_tri_unk = p_eval / 'three-state' / 'unsure'
p_tri_neg = p_eval / 'three-state' / 'negative'

# METASPACE
database = ('Spotting_project_compounds-v9', 'feb2021')
fdr = 0.5

In [12]:
# Log into metaspace
sm = SMInstance(host='https://metaspace2020.eu')

if not sm.logged_in():
    # Using getpass here prevents the API key from being accidentally saved with this notebook.
    api_key = getpass.getpass(prompt='API key: ', stream=None)
    sm.login(api_key=api_key)

API key: ········


## Input data

In [8]:
# Get dataset IDs based on grid files 
dataset_ids = [x.stem[-20:] for x in p_grids.rglob("*.npy")]
dataset_paths = [x for x in p_grids.rglob("*.npy")]
dataset_names = [x.stem for x in p_grids.rglob("*.npy")]

In [14]:
# Images from METASPACE
# Ignore any warnings about connection pools in this step

p_eval.mkdir(parents=True, exist_ok=True)

# images = []
for i, ds_id in enumerate(dataset_ids):
    images = []
    print(f'Downloading images for {ds_id} ({i}/{len(dataset_ids)})')
    dataset = sm.dataset(id=ds_id)
    ds_tic_image = dataset.tic_image()
    for img in dataset.all_annotation_images(
        fdr=fdr, 
        database=database, 
        only_first_isotope=True, 
        scale_intensity=True, 
        hotspot_clipping=False
    ):
        # Exclude annotations with no first-isotopic-image
        if img[0] is not None:
            images.append({
                'dataset_id': ds_id,
                'formula': img.formula,
                'adduct': img.adduct,
                'neutral_loss': img.neutral_loss or '',
                'image': img[0],
                'tic_norm_image': np.nan_to_num(img[0] / ds_tic_image),  # nan_to_num replaces nan values with 0.0. This line will probably complain about division by zero but it can be ignored as it's fixed by the nan_to_num
            })
    images_df = pd.DataFrame(images)
    images_df.to_hdf(p_eval / f"images_{ds_id}.hdf5", key="df")
    print(f'Images for {ds_id} saved')
            
# images_df = pd.DataFrame(images)

Downloading images for 2021-06-18_10h37m54s (0/34)


your performance may suffer as PyTables will pickle object types that it cannot
map directly to c-types [inferred_type->mixed,key->block0_values] [items->Index(['dataset_id', 'formula', 'adduct', 'neutral_loss', 'image',
       'tic_norm_image'],
      dtype='object')]

  pytables.to_hdf(


Images for 2021-06-18_10h37m54s saved
Downloading images for 2021-06-21_12h29m16s (1/34)
Images for 2021-06-21_12h29m16s saved
Downloading images for 2021-09-13_21h08m40s (2/34)
Images for 2021-09-13_21h08m40s saved
Downloading images for 2021-09-13_21h02m20s (3/34)
Images for 2021-09-13_21h02m20s saved
Downloading images for 2021-06-18_10h41m59s (4/34)
Images for 2021-06-18_10h41m59s saved
Downloading images for 2021-06-21_12h32m53s (5/34)
Images for 2021-06-21_12h32m53s saved
Downloading images for 2021-06-18_10h45m03s (6/34)
Images for 2021-06-18_10h45m03s saved
Downloading images for 2021-06-21_12h34m54s (7/34)
Images for 2021-06-21_12h34m54s saved
Downloading images for 2021-07-10_00h16m39s (8/34)
Images for 2021-07-10_00h16m39s saved
Downloading images for 2021-07-10_00h13m11s (9/34)
Images for 2021-07-10_00h13m11s saved
Downloading images for 2021-06-18_10h49m47s (10/34)
Images for 2021-06-18_10h49m47s saved
Downloading images for 2021-06-21_12h38m55s (11/34)
Images for 2021-06-

In [9]:
# Load pre-saved individual images_df
list_of_dfs = []
for fpath in p_eval.rglob("*.hdf5"):
    print(f"Loading {fpath.name}")
    f = pd.read_hdf(fpath)
    list_of_dfs.append(f)
    
images_df = pd.concat(list_of_dfs)

Loading images_2021-06-18_10h37m54s.hdf5
Loading images_2021-06-18_10h41m59s.hdf5
Loading images_2021-06-18_10h45m03s.hdf5
Loading images_2021-06-18_10h49m47s.hdf5
Loading images_2021-06-18_10h52m02s.hdf5
Loading images_2021-06-18_10h54m38s.hdf5
Loading images_2021-06-18_10h58m25s.hdf5
Loading images_2021-06-18_10h59m48s.hdf5
Loading images_2021-06-18_11h04m19s.hdf5
Loading images_2021-06-18_11h09m13s.hdf5
Loading images_2021-06-21_12h29m16s.hdf5
Loading images_2021-06-21_12h32m53s.hdf5
Loading images_2021-06-21_12h34m54s.hdf5
Loading images_2021-06-21_12h38m55s.hdf5
Loading images_2021-06-21_12h41m08s.hdf5
Loading images_2021-06-21_12h45m04s.hdf5
Loading images_2021-06-21_12h48m21s.hdf5
Loading images_2021-06-21_12h50m44s.hdf5
Loading images_2021-06-21_12h54m04s.hdf5
Loading images_2021-06-21_12h59m59s.hdf5
Loading images_2021-06-21_15h10m30s.hdf5
Loading images_2021-06-23_23h19m02s.hdf5
Loading images_2021-07-10_00h13m11s.hdf5
Loading images_2021-07-10_00h16m39s.hdf5
Loading images_2

In [None]:
# list(zip(dataset_ids, dataset_paths))[0]

In [10]:
# Wellmap and grids
wellmap = pd.read_csv(p_wellmap)
grids = {
    ds_id: np.load(ds_p) 
    for ds_id, ds_p in zip(dataset_ids, dataset_paths)
}

In [None]:
# # Image labels from Quality_Labels.csv files
# labeled_anns = []
# for i in p_labelled_set.rglob("*Quality_Labels.csv"):
#     data = pd.read_csv(i)
#     data = data.loc[:, ['dataset_id', 'formula', 'adduct', 'neutral_loss', 'score', 'well']]
#     data.neutral_loss.fillna('', inplace=True)
#     labeled_anns.append(data)

# labeled_anns_df = pd.concat(labeled_anns)

### Import image labels from the manual_label directories

If you use these directories for labelling, re-run every cell from this point onwards

In [11]:
# Image labels from the "manual_label" directories
manual_labels = []
for score, labels_path in [(1, p_eval_lpos), (0, p_eval_lneg)]:
    labels_path.mkdir(parents=True, exist_ok=True)
    for f in labels_path.glob('*.png'):
        manual_labels.append({
            'filename': f.name,
            'manual_score': score,
        })
if manual_labels:
    manual_labels_df = pd.DataFrame(manual_labels)
else:
    manual_labels_df = pd.DataFrame({'filename': pd.Series(dtype=str), 'manual_score': pd.Series(dtype='i')})

In [12]:
# Combine them for easier access
merged_df = (
    images_df
    # Add `how=left` when merging with wellmap to include non-spotted formulas
    .merge(wellmap[['well', 'formula', 'name_short']], on=['formula'])
#     .merge(labeled_anns_df, on=['dataset_id', 'formula', 'adduct', 'neutral_loss', 'well'], how='left')
).reset_index()

merged_df['score'] = np.nan

merged_df['filename'] = [f'{row.dataset_id}_{row.formula}_{row.adduct}_{row.neutral_loss}_{row.well}.png' for row in merged_df.itertuples()]
merged_df = merged_df.merge(manual_labels_df, on='filename', how='left')

# Merge the "manual_score" column into "score"
merged_df['score'] = merged_df.manual_score.fillna(merged_df.score) # manual labels overwrite csv labels
del merged_df['manual_score']

merged_df['row_id'] = [f'{row.dataset_id}_{row.formula}_{row.adduct}_{row.neutral_loss}_{row.well}' for row in merged_df.itertuples()]  # You may want to customize this and add any other fields you feel are necessary to uniquely identify a scored image+well
merged_df = merged_df.set_index('row_id')

## Calculate metrics (or load pre-calculated)

In [13]:
#Calculate metrics
def calc_far_bg(mask, bg):
    """Gets mask for background pixels that are at least 4 radii away from the spot"""
    # 3 iterations = (1+3=)4x the spot radius
    expanded_spot = binary_dilation(mask, crop_zeros(mask), iterations=3)
    return bg & ~expanded_spot

def occ(px):
    """Calculates non-zero % of the given array"""
    return np.count_nonzero(px) / px.size


metrics = []
for row in merged_df.itertuples():
    grid = grids[row.dataset_id]
    
    mask = grid == row.well
    bg = grid == 0
    far_bg = calc_far_bg(mask, bg)
        
    in_mask = row.image[mask]
    in_mask_tic_norm = row.tic_norm_image[mask]
    in_bg = row.image[bg]
    in_far_bg = row.image[far_bg]
    in_other_spots = row.image[~bg & ~mask]
    
    # Calculate threshold (0.01 * 99th percentile) 
    # (note the image is already hotspot-removed, so the max is the 99th percentile)
    threshold = np.max(row.image) * 0.01

    metrics.append({
        'row_id': row[0],   # with .itertuples(), item[0] is the index
        # Original metrics
        # NOTE: The constant in the denominator of `on_off_ratio` was changed to
        # 0.001 as it seemed to produce slightly better results
        'occupancy_ratio': (occ(in_mask) * 100) / (occ(in_bg) * 100 + 1),
        'on_off_ratio': (np.mean(in_mask)) / (np.mean(in_bg) + 0.001),
        
        # Single-spot occupancy %
        'spot_occupancy': occ(in_mask),
        'spot_occupancy_thresholded': occ(in_mask > threshold),
        # Other occupancy metrics
        'image_occupancy': occ(row.image),
        'other_spots_occupancy': occ(in_other_spots),
        'bg_occupancy': occ(in_bg),
        'far_bg_occupancy': occ(in_bg),
        'occupancy_vs_far_bg_ratio' : (occ(in_mask) * 100) / (occ(in_far_bg) * 100 + 1),
        
        # How many spots have a non-zero pixel
        'in_n_spots': len(np.unique(grid[(grid != 0) & (row.image > threshold)])),
        
        # Intensity metrics
        'spot_intensity' : np.mean(in_mask),
        'spot_intensity_tic_norm': np.mean(in_mask_tic_norm),
        'spot_intensity_bgr_corrected' : np.mean(in_mask) - np.mean(in_far_bg),
        'spot_intensity_sum' : np.sum(in_mask),
        'spot_intensity_std' : np.std(in_mask),
        'other_spot_intensity': np.mean(in_other_spots),
        'bg_intensity' : np.mean(in_bg),
        'far_bg_intensity' : np.mean(in_far_bg),
         #Intensity ratios
        'intensity_vs_other_spots_ratio': np.mean(in_mask) / (np.mean(in_other_spots) + 0.001),
        'intensity_vs_far_bg_ratio': np.mean(in_mask) / (np.mean(in_far_bg) + 0.001),
    })

metrics_df = pd.DataFrame(metrics).set_index('row_id')
metrics_df.to_csv(p_metrics)
metrics_df.head()

Unnamed: 0_level_0,occupancy_ratio,on_off_ratio,spot_occupancy,spot_occupancy_thresholded,image_occupancy,other_spots_occupancy,bg_occupancy,far_bg_occupancy,occupancy_vs_far_bg_ratio,in_n_spots,spot_intensity,spot_intensity_tic_norm,spot_intensity_bgr_corrected,spot_intensity_sum,spot_intensity_std,other_spot_intensity,bg_intensity,far_bg_intensity,intensity_vs_other_spots_ratio,intensity_vs_far_bg_ratio
row_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
2021-06-18_10h37m54s_C4H6O5_-H__59,68.115942,237378900.0,0.681159,0.637681,0.002498,0.003392,0.0,0.0,68.115942,2,237378.921875,0.041146,237378.921875,16379150.0,289332.53125,23.36445,0.0,0.0,10159.398482,237378900.0
2021-06-18_10h37m54s_C4H6O5_-H_+H2_59,0.0,0.0,0.0,0.0,0.003129,0.008788,0.0,0.0,0.0,3,0.0,0.0,0.0,0.0,0.0,165.844162,0.0,0.0,0.0,0.0
2021-06-18_10h37m54s_C4H6O5_-H_-H2O_59,1.049917,3.872822,1.0,0.84058,0.892704,0.802421,0.942457,0.942457,1.05133,170,8511.894531,0.001464,6315.216309,587320.8,7861.288574,1859.231934,2197.852051,2196.678223,4.578175,3.874892
2021-06-18_10h37m54s_C4H6O5_-H_-CH2O2_59,0.99545,1.495195,1.0,0.884058,0.993852,0.992522,0.994571,0.994571,0.995353,189,1390.449829,0.000237,460.323975,95941.04,655.091492,1065.456055,929.944519,930.125854,1.305027,1.494903
2021-06-18_10h37m54s_C4H6O5_-H_-CO2_59,0.990267,1.027741,1.0,1.0,0.999808,0.999769,0.999829,0.999829,0.99027,189,10424.539062,0.001764,273.745117,719293.2,1288.304321,9895.628906,10143.155273,10150.793945,1.053449,1.026968


In [None]:
# Or import pre-calculated metrics
metrics_df = pd.read_csv(p_metrics, index_col=0)

In [14]:
metrics_df = metrics_df.merge(merged_df[['score']], left_index=True, right_index=True, how='left')

# metrics_df.drop(columns = ['occupancy_ratio', 'on_off_ratio', 
#        'other_spots_occupancy', 'bg_occupancy', 'far_bg_occupancy', 
#         'in_n_spots', 'intensity_vs_other_spots_ratio', 'spot_intensity'
#        ], inplace=True)
metrics_df.head()

Unnamed: 0_level_0,occupancy_ratio,on_off_ratio,spot_occupancy,spot_occupancy_thresholded,image_occupancy,other_spots_occupancy,bg_occupancy,far_bg_occupancy,occupancy_vs_far_bg_ratio,in_n_spots,...,spot_intensity_tic_norm,spot_intensity_bgr_corrected,spot_intensity_sum,spot_intensity_std,other_spot_intensity,bg_intensity,far_bg_intensity,intensity_vs_other_spots_ratio,intensity_vs_far_bg_ratio,score
row_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2021-06-18_10h37m54s_C4H6O5_-H__59,68.115942,237378900.0,0.681159,0.637681,0.002498,0.003392,0.0,0.0,68.115942,2,...,0.041146,237378.921875,16379150.0,289332.53125,23.36445,0.0,0.0,10159.398482,237378900.0,
2021-06-18_10h37m54s_C4H6O5_-H_+H2_59,0.0,0.0,0.0,0.0,0.003129,0.008788,0.0,0.0,0.0,3,...,0.0,0.0,0.0,0.0,165.844162,0.0,0.0,0.0,0.0,
2021-06-18_10h37m54s_C4H6O5_-H_-H2O_59,1.049917,3.872822,1.0,0.84058,0.892704,0.802421,0.942457,0.942457,1.05133,170,...,0.001464,6315.216309,587320.8,7861.288574,1859.231934,2197.852051,2196.678223,4.578175,3.874892,
2021-06-18_10h37m54s_C4H6O5_-H_-CH2O2_59,0.99545,1.495195,1.0,0.884058,0.993852,0.992522,0.994571,0.994571,0.995353,189,...,0.000237,460.323975,95941.04,655.091492,1065.456055,929.944519,930.125854,1.305027,1.494903,
2021-06-18_10h37m54s_C4H6O5_-H_-CO2_59,0.990267,1.027741,1.0,1.0,0.999808,0.999769,0.999829,0.999829,0.99027,189,...,0.001764,273.745117,719293.2,1288.304321,9895.628906,10143.155273,10150.793945,1.053449,1.026968,


## Evaluate models

This section uses the calculated metrics and labeled data to train a set of models 
and find which features are best for predicting the labels. 
It uses two strategies for evaluation:

* Hold-out validation - this splits the labeled data into 80% for training, 20% for testing
* Cross-Validation - this uses the full labeled data, but trains 5 different models, each
    with a different combinations of inputs in the 80% training set, so that each input 
    can be tested by a model that didn't use that input as part of the training.
    This approach reports a much more numerically stable accuracy value it can use 
    the full input set for evaluation.
    However, it shouldn't be used for fine-tuning the model hyperparameters 
    (the input variables when constructing the model), as this can lead to overfitting.
    
   
The output is a DataFrame `eval_results_df` that shows for each model/# of features:
* Which combination of features worked best
* The accuracy/F1 scores
* The # of false positives & false negatives

In [None]:
# # Prepare input data
# input_df = metrics_df[~metrics_df.score.isna()]  # Exclude unlabeled rows
# input_df = input_df.sample(frac=1.0)  # Shuffle rows
# X = input_df.drop(columns=['score'])
# y = input_df.score.astype('i').values

# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y)

In [None]:
# # Models to try
# models_to_eval = [
#     CatBoostClassifier(verbose=False),
# #     LinearSVC(class_weight='balanced'),
# #     DecisionTreeClassifier(max_depth=3),
# #     BaggingClassifier(LinearSVC(), n_estimators=3, bootstrap_features=True),
# ]
# max_features_to_consider = 4

# eval_results = []

# for model in models_to_eval:
#     model_name = str(model)
#     for n_features in range(1, max_features_to_consider + 1):
#         print(model_name, n_features)
#         # SequentialFeatureSelector finds the set of N features that give the best scores
#         sfs = SequentialFeatureSelector(model, n_features_to_select=n_features, n_jobs=-1)
#         sfs.fit(X_train, y_train)
#         best_features = X.columns[sfs.support_]
        
#         # Evaluate using cross-validation
#         X_subset = X[best_features].values
#         fpos_idxs, fneg_idxs = get_mispredictions(model, X_subset, y)
#         # Use a repeating cross-validator so that results are averaged over ~50 runs
#         cv = RepeatedStratifiedKFold()
#         cv_scores = cross_validate(model, X_subset, y, cv=cv, scoring=['accuracy','f1'])
#         cv_accuracy = np.mean(cv_scores['test_accuracy'])
#         cv_f1 = np.mean(cv_scores['test_f1'])
        
#         # Evaluate using hold-out validation
#         trained_subset_model = clone(model).fit(X_train[best_features].values, y_train)
#         holdout_accuracy = trained_subset_model.score(X_test[best_features].values, y_test)
#         holdout_f1 = f1_score(y_test, trained_subset_model.predict(X_test[best_features].values))
        
#         eval_results.append({
#             'model': model_name,
#             'n_features': n_features,
#             'features': ', '.join(best_features),
#             'cv_accuracy': cv_accuracy,
#             'cv_f1': cv_f1,
#             'holdout_accuracy': holdout_accuracy,
#             'holdout_f1': holdout_f1,
#             'n_fpos': len(fpos_idxs),
#             'n_fneg': len(fneg_idxs),
#             # Uncomment to include the idxs of false positives/negatives to see which
#             # inputs are repeatedly mispredicted regardless of the model
#             # 'fpos_idxs': fpos_idxs,
#             # 'fneg_idxs': fneg_idxs,
#         })
        
# eval_results_df = pd.DataFrame(eval_results)

# eval_results_df

In [None]:
# # Show behavior of accuracy as number of features increases
# sns.lineplot(data=eval_results_df, x='n_features', y='cv_accuracy', hue='model')

In [None]:
# print(eval_results_df.features.iloc[0])
# print(eval_results_df.features.iloc[1])
# print(eval_results_df.features.iloc[2])
# print(eval_results_df.features.iloc[3])

## Examine results for a specific model

In [15]:
model = CatBoostClassifier(verbose=False)
features = ['spot_intensity_tic_norm', 'spot_occupancy', 'occupancy_vs_far_bg_ratio', 'intensity_vs_far_bg_ratio']

### Option A: Train a new model

In [None]:
# # Train the model on labeled data
# train_df = metrics_df[~metrics_df.score.isna()]  # Exclude unlabeled rows
# train_df = input_df.sample(frac=1.0)  # Shuffle rows
# X_df = train_df.drop(columns=['score'])[features]
# y = train_df.score.astype('i').values
# trained_model = clone(model).fit(X_df.values, y)

# # Make predictions for unlabeled data
# unlabeled_df = metrics_df[metrics_df.score.isna()][features]
# unlabeled_predictions_df = pd.DataFrame({
#     'pred_val': trained_model.predict_proba(unlabeled_df.values)[:, 1]
# }, index=unlabeled_df.index)

# # Make cross-validated predictions for labeled data
# labeled_predictions_df = pd.DataFrame({
#     'pred_val': cross_val_predict(model, X_df.values, y, method='predict_proba')[:, 1]
# }, index=X_df.index)

# # Combine predictions
# predictions_df = pd.concat([unlabeled_predictions_df, labeled_predictions_df])

### Option B: Load an existing model
Uses a saved model from the last step of this file

NOTE: This approach doesn't use cross-validated predictions for the labelled training data,
so it shouldn't be used for analyzing the model or refining the training set.

In [16]:
model.load_model(p_model, format='json')

# Make predictions for all data
predictions_df = pd.DataFrame({
    'pred_val': model.predict_proba(metrics_df[features].values)[:, 1]
}, index=metrics_df.index)

### Both options: Assign labels to predictions

In [17]:
# Make combined DF
output_df = merged_df.join(metrics_df.drop(columns='score')).join(predictions_df)

# Add two-state and three-state classes
output_df['pred_twostate'] = np.where(output_df.pred_val < 0.5, 0, 1)
unsure_range = [0.2, 0.8] # Lowest & highest values to include in the "unsure" class
# This assigns 0 = negative, 1 = unsure, 2 = positive
output_df['pred_threestate'] = np.digitize(output_df.pred_val, unsure_range)

### Write predictions CSV files

In [18]:
csv_df = output_df.drop(columns=['image', 'filename']) # Skip unwanted columns
csv_df.to_csv(p_predictions)

for dataset_id, results_df in csv_df.groupby('dataset_id'):
    output_path = p_eval / f'{dataset_id}_predictions.csv'
    results_df.to_csv(output_path)

### Write image files into false positives, false negatives, etc.

In [19]:
# Clean output directories
for output_path in [
    p_eval_fpos, p_eval_fneg, p_eval_tpos, p_eval_tneg, p_eval_upos, p_eval_uneg, 
    p_tri_pos, p_tri_unk, p_tri_neg
]:
    output_path.mkdir(parents=True, exist_ok=True)
    for f in output_path.glob('*.png'):
        f.unlink()  # Delete existing files

# Write images with two-state classification
for row in output_df.itertuples():
    mask = grids[row.dataset_id] == row.well
    
    # Figure out which directory to use
    if row.score == 0:
        twostate_path = [p_eval_tneg, p_eval_fpos][row.pred_twostate]
    elif row.score == 1:
        twostate_path = [p_eval_fneg, p_eval_tpos][row.pred_twostate]
    else:
        twostate_path = [p_eval_uneg, p_eval_upos][row.pred_twostate]
    
    save_image_with_mask(row.image, mask, twostate_path / row.filename)
    
# Write images with three-state classification
for row in output_df.itertuples():
    mask = grids[row.dataset_id] == row.well
    
    threestate_path = [p_tri_neg, p_tri_unk, p_tri_pos][row.pred_threestate]
    
    save_image_with_mask(row.image, mask, threestate_path / row.filename)