In [None]:
# This notebook adapted from original version by L.Stuart

# Apply classification model

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib notebook

In [None]:
import sys
from catboost import CatBoostClassifier, Pool
from scipy.ndimage import binary_dilation
from sklearn import clone
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, BaggingClassifier
from sklearn.feature_selection import SequentialFeatureSelector
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split, StratifiedKFold, RepeatedStratifiedKFold
from sklearn.model_selection import cross_val_score, cross_val_predict, cross_validate
from sklearn.svm import SVC, LinearSVC
from sklearn.tree import DecisionTreeClassifier
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from metaspace.sm_annotation_utils import SMInstance
from metaspace.image_processing import clip_hotspots
import getpass
from metaspace import SMInstance
from datetime import datetime
from matplotlib.colors import Normalize, LogNorm

In [None]:
# Suppress warnings, because many models spam them during feature selection
# as some subsets of features just don't have enough information to make
# a good model.
import warnings
from sklearn.exceptions import ConvergenceWarning
warnings.simplefilter('ignore', ConvergenceWarning)

## Utility functions

In [None]:
def colorize_image_with_mask(image, mask):
    """Plotting function for combining a colorized ion image with a spot mask"""
    
    image = clip_hotspots(image)
    image /= np.max(image)
    
    on_spot_colorized = plt.cm.cividis(image)
    off_spot_colorized = plt.cm.magma(image)
    return np.where(mask[:,:,np.newaxis], on_spot_colorized, off_spot_colorized)

def make_a_panel(image, tic_image, mask):
    
    x = image.shape[1]
    y = image.shape[0]

    norm = Normalize() # This is a matplotlib tool to scale everything to the 0-1 range
    log_norm = LogNorm() # Same, but it does a log transform before scaling to the 0-1 range
    new_shape = (y, x*4 + 3, 4)  # The "4" is to fit that per-pixel RGBA dimension
    
    panel = np.zeros(shape=new_shape)
    panel[0:y, 0:x] = colorize_image_with_mask(image, mask)
    panel[0:y, x+1:2*x+1] = plt.cm.cividis(norm(clip_hotspots(image)))
    panel[0:y, 2*x+2:3*x+2] = plt.cm.cividis(log_norm(image + 1))
    panel[0:y, -x:] = plt.cm.cividis(Normalize()(clip_hotspots(tic_image)))    
    return panel    
    
    
def save_image_with_mask(image, mask, fname, tic_image):
    plt.imsave(fname, make_a_panel(image, tic_image, mask))

In [None]:
def crop_zeros(img):
    """Crop an image, removing all empty outer rows/columns"""
    cols = np.flatnonzero(np.count_nonzero(img, axis=0) != 0)
    rows = np.flatnonzero(np.count_nonzero(img, axis=1) != 0)
    top = rows[0]
    bottom = rows[-1] + 1
    left = cols[0]
    right = cols[-1] + 1

    return img[top:bottom, left:right]

In [None]:
def get_mispredictions(model, X, y):
    """
    Find which values would be mispredicted, returning two lists:
        * indexes of items that would be falsely predicted as positives
        * indexes of items that would be falsely predicted as negatives
        
    cross_val_predict uses a shuffled 5-fold test-train split so that each chunk of 
    20% of the input data gets its own model that was trained on the other 80%, 
    ensuring that the items being predicted aren't included in the training data.
    """
    preds = cross_val_predict(model, X, y)
    mispreds = preds != y
    fpos_idxs = np.flatnonzero(mispreds & ~y)
    fneg_idxs = np.flatnonzero(mispreds & y)
        
    return fpos_idxs, fneg_idxs

## Paths

In [None]:
p_root_dir = Path.cwd().parents[0]

p_analysis = p_root_dir  / "4_apply_classifier"
p_grids = p_root_dir / "2_grid_calibration/grid_masks"
p_wellmap = p_root_dir / "5_data/metadata/wellmap.csv"

# Paths for model appication
p_model = p_root_dir  /"3_train_classifier" / "model_evaluation/model.json"
p_apply = p_analysis / "model_application_best_replicates"
p_images = p_apply / "images.hdf5"
p_datasets = p_apply / "manual_dataset_qc.csv"

timestamp = datetime.now().strftime("%d-%b-%Y") 
p_predictions = p_apply / f"all_predictions_{timestamp}.csv"
p_predictions_curated = p_apply / f"all_predictions_curated_{timestamp}.csv"
p_metrics = p_apply / f"metrics_{timestamp}.csv"

# False positives/negatives - preview output from model prediction for molecules with known labels
# Note that all files in these directories are cleared before a prediction run
p_apply_fpos = p_apply / 'false_positives'
p_apply_fneg = p_apply / 'false_negatives'
p_apply_tpos = p_apply / 'true_positives'
p_apply_tneg = p_apply / 'true_negatives'
# Unknown positives/negatives - preview output from model prediction for molecules with no label
# Note that all files in these directories are cleared before a prediction run
p_apply_upos = p_apply / 'unknown_positives'
p_apply_uneg = p_apply / 'unknown_negatives'
# Manually labeled positives/negatives - Move preview files from any of the above directories into 
# these directories to add to the labelled data. Make sure to re-run the appropriate steps 
# in "Input data" to detect the changes
p_apply_lpos = p_apply / 'manual_label_positives'
p_apply_lneg = p_apply / 'manual_label_negatives'

# Directories for three-state positive/unsure/negative classification
p_tri_pos = p_apply / 'three-state' / 'positive'
p_tri_unk = p_apply / 'three-state' / 'unsure'
p_tri_neg = p_apply / 'three-state' / 'negative'

# METASPACE
database = ('Spotting_project_compounds-v9', 'feb2021')
fdr = 0.5

In [None]:
# Log into metaspace
sm = SMInstance(host='https://metaspace2020.eu')

if not sm.logged_in():
    # Using getpass here prevents the API key from being accidentally saved with this notebook.
    api_key = getpass.getpass(prompt='API key: ', stream=None)
    sm.login(api_key=api_key)

## Input data

In [None]:
# Get dataset ids of interest  from dataset qc spreadsheet (corresponding grids should exist)
datasets = pd.read_csv(p_datasets)
dataset_subset = datasets[datasets.Selection == 'Yes']

dataset_ids = dataset_subset['Clone ID']
dataset_paths = [list(p_grids.rglob(f"*{x}*.npy"))[0] for x in dataset_ids]
dataset_names = dataset_subset['Matrix short'] + '_' + dataset_subset['Polarity'] + '_' + dataset_subset['Slide code']
dataset_names.values

In [None]:
# Images from METASPACE
# Ignore any warnings about connection pools in this step

p_apply.mkdir(parents=True, exist_ok=True)

for i, ds_id in enumerate(dataset_ids):
    images = []
    print(f'Downloading images for {ds_id} ({i}/{len(dataset_ids)})')
    dataset = sm.dataset(id=ds_id)
    ds_tic_image = dataset.tic_image()
    for img in dataset.all_annotation_images(
        fdr=fdr, 
        database=database, 
        only_first_isotope=True, 
        scale_intensity=True, 
        hotspot_clipping=False
    ):
        # Exclude annotations with no first-isotopic-image
        if img[0] is not None:
            images.append({
                'dataset_id': ds_id,
                'formula': img.formula,
                'adduct': img.adduct,
                'neutral_loss': img.neutral_loss or '',
                'image': img[0],
                'tic_norm_image': np.nan_to_num(img[0] / ds_tic_image),  # nan_to_num replaces nan values with 0.0. This line will probably complain about division by zero but it can be ignored as it's fixed by the nan_to_num
            })
    images_df = pd.DataFrame(images)
    images_df.to_hdf(p_apply / f"images_{ds_id}.hdf5", key="df")
    print(f'Images for {ds_id} saved')

In [None]:
# Load pre-saved individual images_df
list_of_dfs = []
for fpath in p_apply.rglob("*.hdf5"):
    print(f"Loading {fpath.name}")
    f = pd.read_hdf(fpath)
    list_of_dfs.append(f)
    
images_df = pd.concat(list_of_dfs)

In [None]:
# Wellmap and grids
wellmap = pd.read_csv(p_wellmap)
grids = {
    ds_id: np.load(ds_p) 
    for ds_id, ds_p in zip(dataset_ids, dataset_paths)
}

In [None]:
# Merge images and metadata
merged_df = (images_df.merge(wellmap[['well', 'formula', 'name_short']], on=['formula'])
).reset_index()

merged_df['score'] = np.nan
merged_df['filename'] = [f'{row.formula}_{row.adduct}_{row.neutral_loss}_{row.well}_{row.dataset_id}.png' for row in merged_df.itertuples()]
merged_df['row_id'] = [f'{row.formula}_{row.adduct}_{row.neutral_loss}_{row.well}_{row.dataset_id}' for row in merged_df.itertuples()]  # You may want to customize this and add any other fields you feel are necessary to uniquely identify a scored image+well
merged_df = merged_df.set_index('row_id')

## Calculate metrics (or load pre-calculated)

In [None]:
# Calculate metrics
def calc_far_bg(mask, bg):
    """Gets mask for background pixels that are at least 4 radii away from the spot"""
    # 3 iterations = (1+3=)4x the spot radius
    expanded_spot = binary_dilation(mask, crop_zeros(mask), iterations=3)
    return bg & ~expanded_spot

def occ(px):
    """Calculates non-zero % of the given array"""
    return np.count_nonzero(px) / px.size


metrics = []
for row in merged_df.itertuples():
    grid = grids[row.dataset_id]
    
    mask = grid == row.well
    bg = grid == 0
    far_bg = calc_far_bg(mask, bg)
        
    in_mask = row.image[mask]   
    in_bg = row.image[bg]
    in_far_bg = row.image[far_bg]   
    in_other_spots = row.image[~bg & ~mask]
    
    # tic image
    in_mask_tic_norm = row.tic_norm_image[mask]
    in_far_bg_tic = row.tic_norm_image[far_bg]
    in_other_spots_tic = row.tic_norm_image[~bg & ~mask]
    
    # Calculate threshold (0.01 * 99th percentile) 
    # (note the image is already hotspot-removed, so the max is the 99th percentile)
    threshold = np.max(row.image) * 0.01

    metrics.append({
        'row_id': row[0],   # with .itertuples(), item[0] is the index
        # Original metrics
        # NOTE: The constant in the denominator of `on_off_ratio` was changed to
        # 0.001 as it seemed to produce slightly better results
        'occupancy_ratio': (occ(in_mask) * 100) / (occ(in_bg) * 100 + 1),
        'on_off_ratio': (np.mean(in_mask)) / (np.mean(in_bg) + 0.001),
        
        # Single-spot occupancy %
        'spot_occupancy': occ(in_mask),
        'spot_occupancy_thresholded': occ(in_mask > threshold),
        # Other occupancy metrics
        'image_occupancy': occ(row.image),
        'other_spots_occupancy': occ(in_other_spots),
        'bg_occupancy': occ(in_bg),
        'far_bg_occupancy': occ(in_bg),
        'occupancy_vs_far_bg_ratio' : (occ(in_mask) * 100) / (occ(in_far_bg) * 100 + 1),
        
        # How many spots have a non-zero pixel
        'in_n_spots': len(np.unique(grid[(grid != 0) & (row.image > threshold)])),
        
        # Intensity ratios
        'spot_intensity' : np.mean(in_mask),
        'spot_intensity_tic_norm': np.mean(in_mask_tic_norm),
        'spot_intensity_bgr_corrected' : np.mean(in_mask) - np.mean(in_far_bg),
        'spot_intensity_sum' : np.sum(in_mask),
        'spot_intensity_std' : np.std(in_mask),
        'other_spot_intensity': np.mean(in_other_spots),
        'bg_intensity' : np.mean(in_bg),
        'far_bg_intensity' : np.mean(in_far_bg),
        #Intensity ratios
        'intensity_vs_far_bg_ratio': np.mean(in_mask) / (np.mean(in_far_bg) + 0.001),
        'intensity_vs_other_spots_ratio': np.mean(in_mask) / (np.mean(in_other_spots) + 0.001),
    })

metrics_df = pd.DataFrame(metrics).set_index('row_id')
metrics_df.to_csv(p_metrics)
metrics_df.head()

In [None]:
# # Or import pre-calculated metrics
# p_metrics = p_apply / "metrics_08-Dec-2021.csv"
# metrics_df = pd.read_csv(p_metrics, index_col=0)

In [None]:
metrics_df = metrics_df.merge(merged_df[['score']], left_index=True, right_index=True, how='left')
metrics_df.head()

## Perform classification

### Load pertrained model

In [None]:
model = CatBoostClassifier(verbose=False)
features = ['spot_intensity_tic_norm', 'spot_occupancy', 'occupancy_vs_far_bg_ratio', 'intensity_vs_far_bg_ratio', 'intensity_vs_other_spots_ratio']

model.load_model(p_model, format='json')

# Make predictions for all data
predictions_df = pd.DataFrame({
    'pred_val': model.predict_proba(metrics_df[features].values)[:, 1]
}, index=metrics_df.index)

### Assign labels to predictions

In [None]:
# Make combined DF
output_df = merged_df.join(metrics_df.drop(columns='score')).join(predictions_df)

# Add two-state and three-state classes
output_df['pred_twostate'] = np.where(output_df.pred_val < 0.5, 0, 1)
unsure_range = [0.2, 0.8] # Lowest & highest values to include in the "unsure" class
# This assigns 0 = negative, 1 = unsure, 2 = positive
output_df['pred_threestate'] = np.digitize(output_df.pred_val, unsure_range)

### Write predictions CSV files

In [None]:
csv_df = output_df.drop(columns=['image', 'tic_norm_image']) # Skip unwanted columns
csv_df.to_csv(p_predictions)

for dataset_id, results_df in csv_df.groupby('dataset_id'):
    output_path = p_apply / f'{dataset_id}_predictions.csv'
    results_df.to_csv(output_path)

### Write image files into false positives, false negatives, etc.

In [None]:
# Clean output directories
for output_path in [
    p_apply_fpos, p_apply_fneg, p_apply_tpos, p_apply_tneg, p_apply_upos, p_apply_uneg, 
    p_tri_pos, p_tri_unk, p_tri_neg
]:
    output_path.mkdir(parents=True, exist_ok=True)
    for f in output_path.glob('*.png'):
        f.unlink()  # Delete existing files

# Write images with two-state classification
for row in output_df.itertuples():
    mask = grids[row.dataset_id] == row.well
    
    # Figure out which directory to use
    if row.score == 0:
        twostate_path = [p_apply_tneg, p_apply_fpos][row.pred_twostate]
    elif row.score == 1:
        twostate_path = [p_apply_fneg, p_apply_tpos][row.pred_twostate]
    else:
        twostate_path = [p_apply_uneg, p_apply_upos][row.pred_twostate]
        continue
    
    save_image_with_mask(row.image, mask, twostate_path / row.filename, row.tic_norm_image)
    
# Write images with three-state classification
for row in output_df.itertuples():
    mask = grids[row.dataset_id] == row.well
    
    threestate_path = [p_tri_neg, p_tri_unk, p_tri_pos][row.pred_threestate]
    
    save_image_with_mask(row.image, mask, threestate_path / row.filename, row.tic_norm_image)

### Curate predictions and make changes using change_to folders

In [None]:
# Get images for which prediction needs to be changed
p_new_pos = p_apply / 'change_to_pos'
p_new_unsure = p_apply / 'change_to_neg'
p_new_neg = p_apply / 'change_to_unsure'

changed_labels = []
for score, labels_path in [(0.81, p_new_pos), (0.49, p_new_unsure), (0.01, p_new_neg)]:
    labels_path.mkdir(parents=True, exist_ok=True)
    for f in labels_path.glob('*.png'):
        changed_labels.append({
            'filename': f.name,
            'pred_val_override': score,
        })
if changed_labels:
    pred_override = pd.DataFrame(changed_labels)
else:
    pred_override = pd.DataFrame({'filename': pd.Series(dtype=str), 'pred_val_override': pd.Series(dtype='i')})

In [None]:
# Update prediction values
csv_df_curated = csv_df.merge(pred_override, on='filename', how='left') 
csv_df_curated['pred_val'] = csv_df_curated.pred_val_override.fillna(csv_df_curated.pred_val) 

# Add two-state and three-state classes
csv_df_curated['pred_twostate'] = np.where(csv_df_curated.pred_val < 0.5, 0, 1)
unsure_range = [0.2, 0.8] # Lowest & highest values to include in the "unsure" class
# This assigns 0 = negative, 1 = unsure, 2 = positive
csv_df_curated['pred_threestate'] = np.digitize(csv_df_curated.pred_val, unsure_range)

csv_df_curated = csv_df_curated.drop(columns=['pred_val_override'])

# Save as separate file 
csv_df_curated.to_csv(p_predictions_curated)