# Apply classification model

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib notebook
# %matplotlib inline

In [None]:
import sys

from catboost import CatBoostClassifier, Pool
from scipy.ndimage import binary_dilation
from sklearn import clone
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, BaggingClassifier
from sklearn.feature_selection import SequentialFeatureSelector
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split, StratifiedKFold, RepeatedStratifiedKFold
from sklearn.model_selection import cross_val_score, cross_val_predict, cross_validate
from sklearn.svm import SVC, LinearSVC
from sklearn.tree import DecisionTreeClassifier
from pathlib import Path
import numpy as np
import pandas as pd
from definitions import ROOT_DIR
import matplotlib.pyplot as plt
import seaborn as sns
from metaspace.sm_annotation_utils import SMInstance
from metaspace.image_processing import clip_hotspots

import getpass
from metaspace import SMInstance
from datetime import datetime

from matplotlib.colors import Normalize, LogNorm

In [None]:
# Suppress warnings, because many models spam them during feature selection
# as some subsets of features just don't have enough information to make
# a good model.
import warnings
from sklearn.exceptions import ConvergenceWarning
warnings.simplefilter('ignore', ConvergenceWarning)

## Utility functions

In [None]:
def colorize_image_with_mask(image, mask):
    """Plotting function for combining a colorized ion image with a spot mask"""
    
    image = clip_hotspots(image)
    image /= np.max(image)
    
    on_spot_colorized = plt.cm.cividis(image)
    off_spot_colorized = plt.cm.magma(image)
    return np.where(mask[:,:,np.newaxis], on_spot_colorized, off_spot_colorized)
    
def save_image_with_mask(image, mask, fname):
    plt.imsave(fname, colorize_image_with_mask(image, mask))

In [None]:
def crop_zeros(img):
    """Crop an image, removing all empty outer rows/columns"""
    cols = np.flatnonzero(np.count_nonzero(img, axis=0) != 0)
    rows = np.flatnonzero(np.count_nonzero(img, axis=1) != 0)
    top = rows[0]
    bottom = rows[-1] + 1
    left = cols[0]
    right = cols[-1] + 1

    return img[top:bottom, left:right]

In [None]:
def get_mispredictions(model, X, y):
    """
    Find which values would be mispredicted, returning two lists:
        * indexes of items that would be falsely predicted as positives
        * indexes of items that would be falsely predicted as negatives
        
    cross_val_predict uses a shuffled 5-fold test-train split so that each chunk of 
    20% of the input data gets its own model that was trained on the other 80%, 
    ensuring that the items being predicted aren't included in the training data.
    """
    preds = cross_val_predict(model, X, y)
    mispreds = preds != y
    fpos_idxs = np.flatnonzero(mispreds & ~y)
    fneg_idxs = np.flatnonzero(mispreds & y)
        
    return fpos_idxs, fneg_idxs

## Paths

In [None]:
p_root_dir = Path(ROOT_DIR)
p_proj_root_dir = p_root_dir.parents[0]
p_matrix_root_dir = p_proj_root_dir / "matrix_comparison"
p_matrix_data = p_matrix_root_dir / "5_data_analysis"

p_data = p_root_dir / "5_data_analysis"
p_analysis = p_root_dir  / "4_model_evaluation"
p_grids = p_root_dir / r"2_grid_calibration/grid_masks/20labs_all"
p_wellmap = p_matrix_data / "wellmap.csv"

# Paths for evaluation
p_eval = p_analysis / "model_application"
p_pickles = p_eval / "pickles"
p_metrics = p_eval / "metrics"
p_apply = p_analysis / "model_application_best_replicates"
p_images = p_eval / "images.hdf5"
p_model = p_matrix_root_dir / "3_model_evaluation\model_evaluation\model.json"
p_datasets = p_root_dir / "5_data_analysis/Datasets_modified.csv"
p_metadata = p_matrix_data / "Datasets_14Jul2022.csv"
p_chem_class = p_matrix_data / "custom_classification_v2.csv"

timestamp = datetime.now().strftime("%d-%b-%Y") 
p_predictions = p_eval / f"all_predictions_{timestamp}.csv"
p_predictions_curated = p_apply / f"all_predictions_curated_{timestamp}.csv"

# False positives/negatives - preview output from model prediction for molecules with known labels
# Note that all files in these directories are cleared before a prediction run
p_eval_fpos = p_eval / 'false_positives'
p_eval_fneg = p_eval / 'false_negatives'
p_eval_tpos = p_eval / 'true_positives'
p_eval_tneg = p_eval / 'true_negatives'
# Unknown positives/negatives - preview output from model prediction for molecules with no label
# Note that all files in these directories are cleared before a prediction run
p_eval_upos = p_eval / 'unknown_positives'
p_eval_uneg = p_eval / 'unknown_negatives'
# Manually labeled positives/negatives - Move preview files from any of the above directories into 
# these directories to add to the labelled data. Make sure to re-run the appropriate steps 
# in "Input data" to detect the changes
p_eval_lpos = p_eval / 'manual_label_positives'
p_eval_lneg = p_eval / 'manual_label_negatives'
# Manually labeled positives/negatives - Move preview files from any of the above directories into 
# these directories to add to the labelled data. Make sure to re-run the appropriate steps 
# in "Input data" to detect the changes
p_apply_lpos = p_apply / 'manual_label_positives'
p_apply_lneg = p_apply / 'manual_label_negatives'
# Directories for three-state positive/unsure/negative classification
p_tri_pos = p_eval / 'three-state' / 'positive'
p_tri_unk = p_eval / 'three-state' / 'unsure'
p_tri_neg = p_eval / 'three-state' / 'negative'

# METASPACE
database = ('Spotting_project_compounds-v9', 'feb2021')
fdr = 0.5

print(timestamp)

In [None]:
# Log into metaspace
sm = SMInstance(host='https://metaspace2020.eu')

if not sm.logged_in():
    # Using getpass here prevents the API key from being accidentally saved with this notebook.
    api_key = getpass.getpass(prompt='API key: ', stream=None)
    sm.login(api_key=api_key)

## Input data

In [None]:
# Get dataset IDs based on grid files 
datasets = pd.read_csv(p_datasets)
dataset_stems = [x.stem[-20:] for x in p_grids.glob("*.npy")]
dataset_paths = [x for x in p_grids.glob("*.npy")]
dataset_names = [x.stem for x in p_grids.glob("*.npy")]
dataset_ids = datasets['Clone ID']
dataset_new_ids = datasets['20 Labs ID']

In [None]:
#Check if we need to download additional data (assumes that correctly named pickle files are correct!)
to_download = []
pickles = [x.stem[-20:] for x in p_pickles.glob("*.pkl")]
for i, ds_id in enumerate(dataset_new_ids):
    if ds_id not in pickles:
        print(ds_id)
        to_download.append(ds_id)

In [None]:
# Images from METASPACE
# Ignore any warnings about connection pools in this step

p_eval.mkdir(parents=True, exist_ok=True)
#pickles = [x.stem[-20:] for x in p_eval.glob("*.pkl")]

# images = []
for i, ds_id in enumerate(dataset_new_ids):
    if ds_id in to_download:
        images = []
        print(f'Downloading images for {ds_id} ({i}/{len(dataset_ids)-1})')
        dataset = sm.dataset(id=ds_id)
        ds_tic_image = dataset.tic_image()
        for img in dataset.all_annotation_images(
            fdr=fdr, 
            database=database, 
            only_first_isotope=True, 
            scale_intensity=True, 
            hotspot_clipping=False
        ):
            # Exclude annotations with no first-isotopic-image
            if img[0] is not None:
                images.append({
                    'dataset_id': ds_id,
                    'formula': img.formula,
                    'adduct': img.adduct,
                    'neutral_loss': img.neutral_loss or '',
                    'image': img[0],
                    'tic_norm_image': np.nan_to_num(img[0] / ds_tic_image),  # nan_to_num replaces nan values with 0.0. This line will probably complain about division by zero but it can be ignored as it's fixed by the nan_to_num
                })
        images_df = pd.DataFrame(images)
        images_df.to_pickle(p_pickles / f"images_{ds_id}.pkl")
        print(f'Images for {ds_id} saved')
    # del images; images = []
            
# images_df = pd.DataFrame(images)

In [None]:
# Wellmap and grids
wellmap = pd.read_csv(p_wellmap)
grids = {
    ds_stem: np.load(ds_p) 
    for ds_stem, ds_p in zip(dataset_stems, dataset_paths)
}

In [None]:
#Sanity check - do we have all the data that we have grids for?
sorted(set(grids)) == sorted(set(dataset_stems))

## Calculate metrics (or load pre-calculated)

In [None]:
# Calculate metrics
def calc_far_bg(mask, bg):
    """Gets mask for background pixels that are at least 4 radii away from the spot"""
    # 3 iterations = (1+3=)4x the spot radius
    expanded_spot = binary_dilation(mask, crop_zeros(mask), iterations=3)
    return bg & ~expanded_spot

def occ(px):
    """Calculates non-zero % of the given array"""
    return np.count_nonzero(px) / px.size

def calculate_metrics(merged_df, grids, dataset_ids, dataset_new_ids, path):
    
    lasterror = ""
    progress = 0
    metrics = []
    
    for row in merged_df.itertuples():
        
        progress = progress+1
        if progress % 1000 == 0:
            print(progress)
        
        if row.dataset_id in list(dataset_new_ids):
            oid = list(dataset_ids[dataset_new_ids==row.dataset_id].values)[0]
            grid = grids[oid]
        else:
            grid= grids[row.dataset_id]

        mask = grid == row.well
        bg = grid == 0
        
        #Catch missing wells
        try:
            far_bg = calc_far_bg(mask, bg)
        except:
            error = f"Missing well: {row.dataset_id} #{row.well}"
            if error != lasterror:
                print(error)
                lasterror = error
            continue

        in_mask = row.image[mask]   
        in_bg = row.image[bg]
        in_far_bg = row.image[far_bg]   
        in_other_spots = row.image[~bg & ~mask]

        # tic image
        in_mask_tic_norm = row.tic_norm_image[mask]
        in_bg_tic_norm = row.tic_norm_image[bg]
        in_far_bg_tic_norm = row.tic_norm_image[far_bg]
        in_other_spots_tic_norm = row.tic_norm_image[~bg & ~mask]

        # Calculate threshold (0.01 * 99th percentile) 
        # (note the image is already hotspot-removed, so the max is the 99th percentile)
        threshold = np.max(row.image) * 0.01
        metrics.append({
            'row_id': row[0],   # with .itertuples(), item[0] is the index
            'dataset_id' : row.dataset_id,
            'name_short' : row.name_short,
            'formula' : row.formula,
            'adduct' : row.adduct,
            'neutral_loss' : row.neutral_loss,
            'well' : row.well,
            # Original metrics
            # NOTE: The constant in the denominator of `on_off_ratio` was changed to
            # 0.001 as it seemed to produce slightly better results
            'occupancy_ratio': (occ(in_mask) * 100) / (occ(in_bg) * 100 + 1),
            'on_off_ratio': (np.mean(in_mask)) / (np.mean(in_bg) + 0.001),

            # Single-spot occupancy %
            'spot_occupancy': occ(in_mask),
            'spot_occupancy_thresholded': occ(in_mask > threshold),
            # Other occupancy metrics
            'image_occupancy': occ(row.image),
            'other_spots_occupancy': occ(in_other_spots),
            'bg_occupancy': occ(in_bg),
            'far_bg_occupancy': occ(in_bg),
            'occupancy_vs_far_bg_ratio' : (occ(in_mask) * 100) / (occ(in_far_bg) * 100 + 1),

            # How many spots have a non-zero pixel
            'in_n_spots': len(np.unique(grid[(grid != 0) & (row.image > threshold)])),

            # Intensity ratios
            'spot_intensity' : np.mean(in_mask),
            'spot_intensity_bgr_corrected' : np.mean(in_mask) - np.mean(in_far_bg),
            'spot_intensity_sum' : np.sum(in_mask),
            'spot_intensity_std' : np.std(in_mask),
            'other_spot_intensity': np.mean(in_other_spots),
            'bg_intensity' : np.mean(in_bg),
            'far_bg_intensity' : np.mean(in_far_bg),
            'intensity_vs_far_bg_ratio': np.mean(in_mask) / (np.mean(in_far_bg) + 0.001),
            'intensity_vs_other_spots_ratio': np.mean(in_mask) / (np.mean(in_other_spots) + 0.001),
           
            # Intensity ratios for TIC normalised
            'spot_intensity_tic_norm': np.mean(in_mask_tic_norm),
            'spot_intensity_bgr_corrected_tic_norm' : np.mean(in_mask_tic_norm) - np.mean(in_far_bg_tic_norm),
            'spot_intensity_sum_tic_norm' : np.sum(in_mask_tic_norm),
            'spot_intensity_std_tic_norm' : np.std(in_mask_tic_norm),
            'other_spot_intensity_tic_norm': np.mean(in_other_spots_tic_norm),
            'bg_intensity_tic' : np.mean(in_bg_tic_norm),
            'far_bg_intensity_tic' : np.mean(in_far_bg_tic_norm),
            'intensity_vs_far_bg_ratio_tic': np.mean(in_mask_tic_norm) / (np.mean(in_far_bg_tic_norm) + 0.001),
            'intensity_vs_other_spots_ratio_tic': np.mean(in_mask_tic_norm) / (np.mean(in_other_spots_tic_norm) + 0.001),
        })

    metrics_df = pd.DataFrame(metrics).set_index('row_id')
    metrics_df.to_csv(path)

In [None]:
# Load pre-saved individual images_df and generate metrics
start = 0 #Start and end points allow running a subset of pickles (for example when adding more data to the project)
count = 0
end = 200
for fpath in p_eval.rglob("*.pkl"):
    if fpath.stem[-20:] in pickles and count >= start and count <= end:
        print(f"Loading {fpath.name}")
        try:
            f = pd.read_pickle(fpath)
        except:
            print(f"Failed to load {fpath.name}")
            continue
        merged_df = f.merge(wellmap[['well', 'formula', 'name_short']], on=['formula']).reset_index()
        merged_df['row_id'] = [f'{row.dataset_id}_{row.formula}_{row.adduct}_{row.neutral_loss}_{row.well}' for row in merged_df.itertuples()]
        merged_df = merged_df.set_index('row_id')
        print(merged_df['dataset_id'].unique())
        calculate_metrics(merged_df, grids, dataset_ids, dataset_new_ids, p_metrics / f"Metrics_{timestamp}_{count}.csv")
            
    count = count+1

In [None]:
# Load multiple metrics files and join
metrics_list = []
for fpath in p_metrics.rglob("*.csv"):
    #if fpath.stem[-20:] in list(dataset_new_ids):
    #print(f"Loading {fpath.name}")
    try:
        f = pd.read_csv(fpath, index_col=0)
        metrics_list.append(f)
    except:
        print(f"Failed to load {fpath.name}")
metrics_df = pd.concat(metrics_list)#.reset_index()
# metrics_df = metrics_df.set_index('row_id')
# metrics_df = metrics_df.drop(columns=['index'])
metrics_df['score'] = np.nan

## Examine results for a specific model

In [None]:
model = CatBoostClassifier(verbose=False)
features = ['spot_intensity_tic_norm', 'spot_occupancy', 'occupancy_vs_far_bg_ratio', 'intensity_vs_far_bg_ratio', 'intensity_vs_other_spots_ratio']

model.load_model(p_model, format='json')

# Make predictions for all data
predictions_df = pd.DataFrame({
    'pred_val': model.predict_proba(metrics_df[features].values)[:, 1]
}, index=metrics_df.index)

### Both options: Assign labels to predictions

In [None]:
 # Make combined DF
output_df = metrics_df.join(predictions_df)
output_df = output_df[output_df['pred_val'] > -1] #Drops nan
#output_df = metrics_df.join(predictions_df)

# Add two-state and three-state classes
output_df['pred_twostate'] = np.where(output_df.pred_val < 0.5, 0, 1)
unsure_range = [0.2, 0.8] # Lowest & highest values to include in the "unsure" class
# This assigns 0 = negative, 1 = unsure, 2 = positive
output_df['pred_threestate'] = np.digitize(output_df.pred_val, unsure_range)

### Write predictions CSV files

In [None]:
#Manually merging results from partially overlapping images. 
#Simple way by reassigning all ids to the primary image, then removing any duplicates (from the overlap)

csv_df = output_df
csv_df = csv_df.replace('2022-02-18_23h32m33s','2022-02-18_23h32m31s')
csv_df = csv_df.replace('2022-02-18_23h32m37s','2022-02-18_23h32m35s')
csv_df = csv_df.replace('2022-02-18_23h32m46s','2022-02-18_23h32m44s')

csv_df = csv_df.drop_duplicates(subset=['dataset_id','well','formula','adduct','neutral_loss'])

#Trim overlapping annotations in unmerged dataset pairs, keep highest pred_val

datasets = pd.read_csv(p_metadata)
datasets_info = datasets.groupby('Dataset ID').first()[['Polarity', 'Participant lab', 'Slide code', 'All', 'EMBL', 'Interlab', 'Technology', 'Matrix short']] # 'Participant lab', 'Technology'
datasets_info['sample_name'] = datasets_info['Slide code'] + ': ' + datasets_info['Technology'] + ': ' + datasets_info['Matrix short']
df = pd.merge(csv_df, datasets_info[['sample_name', 'Polarity']], left_on='dataset_id', right_on='Dataset ID', how='left')
df.sort_values(by='pred_val', ascending=False)
df = df.drop_duplicates(subset=['sample_name','well','formula','adduct','neutral_loss', 'Polarity'], keep='first')
df = df.drop(columns=['sample_name', 'Polarity']).sort_index()

csv_df = df

In [None]:
#csv_df = output_df.drop(columns=['image', 'filename']) # Skip unwanted columns
csv_df.to_csv(p_predictions)

for dataset_id, results_df in csv_df.groupby('dataset_id'):
    output_path = p_eval / f'{dataset_id}_predictions.csv'
    results_df.to_csv(output_path)

### Write image files into false positives, false negatives, etc.

In [None]:
# Clean output directories
for output_path in [
#     p_eval_fpos, p_eval_fneg, p_eval_tpos, p_eval_tneg, 
    p_eval_upos, p_eval_uneg, 
    p_tri_pos, p_tri_unk, p_tri_neg
]:
    output_path.mkdir(parents=True, exist_ok=True)
    for f in output_path.glob('*.png'):
        f.unlink()  # Delete existing files

# Write images with two-state classification
for row in output_df.itertuples():
    
    if row.dataset_id in list(dataset_new_ids):
        oid = list(dataset_ids[dataset_new_ids==row.dataset_id].values)[0]
        grid = grids[oid]
    else:
        grid = grids[row.dataset_id]
    
    try:
        mask = grid == row.well
    except:
        print("Error")
    # Figure out which directory to use
#     if row.score == 0:
#         twostate_path = [p_eval_tneg, p_eval_fpos][row.pred_twostate]
#     elif row.score == 1:
#         twostate_path = [p_eval_fneg, p_eval_tpos][row.pred_twostate]
#     else:
    twostate_path = [p_eval_uneg, p_eval_upos][row.pred_twostate]
    
    save_image_with_mask(row.image, mask, twostate_path / row.filename)
    
# Write images with three-state classification
for row in output_df.itertuples():
    
    if row.dataset_id in list(dataset_new_ids):
        oid = list(dataset_ids[dataset_new_ids==row.dataset_id].values)[0]
        grid = grids[oid]
    else:
        grid = grids[row.dataset_id]
    
    try:
        mask = grid == row.well
        threestate_path = [p_tri_neg, p_tri_unk, p_tri_pos][row.pred_threestate]
        save_image_with_mask(row.image, mask, threestate_path / row.filename)
    except:
        print("Error")
