In [None]:
import os
os.chdir('/home/mhoerold/entrack')

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import yaml
import nibabel as nib
import copy
import tensorflow as tf
import pydoc

from src.test_retest.mri.supervised_features import SliceClassification
from src.data.streaming.base import Group

## Load model

In [None]:
smt_label = "mci_ad_direct_conv_clf"
model_dir = os.path.join("data", smt_label)
config_path = os.path.join("configs", "single_clf", "clf_direct_conversion.yaml")
with open(config_path, 'r') as f:
    model_config = yaml.load(f)

# change streamer to stream conversion pairs
model_config["params"]["streamer"]["class"] = "src.data.streaming.mri_streaming.MRIConversionSingleStream"
model_config["params"]["streamer"]["class"] = pydoc.locate(model_config["params"]["streamer"]["class"])

obj = SliceClassification(**model_config["params"])

est = tf.estimator.Estimator(
    model_fn=obj.model_fn,
    model_dir=model_dir,
    params=model_config["params"]["params"]
)

## Prepare t0 and t1 stream

In [None]:
def fids_to_groups(fids):
    return [Group([fid]) for fid in fids]
    
streamer = obj.streamer
t0_input_fn = streamer.get_input_fn_for_groups(fids_to_groups(streamer.t0_fids))
t1_input_fn = streamer.get_input_fn_for_groups(fids_to_groups(streamer.t1_fids))

## Make some predictions

In [None]:
def make_predictions(input_fn):
    y = []
    image_label = []
    
    preds = est.predict(input_fn)
    for pred in preds:
        y.append(pred["classes"])
        image_label.append(pred["image_label"][0].decode('utf-8'))
        
    return y, image_label

In [None]:
t0_y, t0_labels = make_predictions(t0_input_fn)
t1_y, t1_labels = make_predictions(t1_input_fn)

In [None]:
correct = 0
for t0, lab0, t1, lab1 in zip(t0_y, t0_labels, t1_y, t1_labels):
    print("{} {}".format(lab0, lab1))
    assert streamer.get_patient_id(lab0) == streamer.get_patient_id(lab1)
    converts = streamer.get_diagnose(lab0) != streamer.get_diagnose(lab1)
    pred = t0 != t1
    print("Expected {}, got {}".format(converts, pred))
    if pred == converts:
        correct += 1
        
correct / len(t0_y)