In [None]:
import os
import pickle
from glob import glob
import requests
from tempfile import TemporaryDirectory

import matplotlib.pyplot as plt
import numpy as np
from showit import image
from skimage import img_as_float32
import seaborn as sns
from skimage.feature import peak_local_max
from sympy import Point, Line, Segment
import pandas as pd

from starfish import ImageStack
from starfish.types import Indices
from starfish.image import Filter
from starfish.types._spot_attributes import SpotAttributes
from starfish.spots import SpotFinder

%matplotlib inline
%load_ext autoreload
%autoreload 2

aws_data_path = 's3://czi.starfish.data.public/browse/raw/20180912/osmFISH/'
_im_path = os.path.join(aws_data_path, 'images')
_res_path = os.path.join(aws_data_path, 'results')
fov_num = 33
im_path = 'images'
res_path = 'results'

## Load pysmFISH results

In [None]:
!aws s3 cp $_im_path/ ./images --exclude "*" --include "*$fov_num*" --recursive 2>&1 > /dev/null

In [None]:
!aws s3 cp $_res_path/ ./results --exclude "*" --include "*$fov_num*" --recursive 2>&1 > /dev/null

In [None]:
def load_image_stack(fov_num):
    ims = glob(os.path.join(im_path, '*.npy'))
    im = np.load([i for i in ims if str(fov_num) in i][0])
    stack = np.zeros((1, 1, 45, 2048, 2048))
    stack[0,0,:,:,:] = img_as_float32(im)
    stack = img_as_float32(stack)

    return ImageStack.from_numpy_array(stack)

def load_results(fov_num):
    pkls = glob(os.path.join(res_path, '*.pkl'))
    pkl = [p for p in pkls if str(fov_num) in p][0]
    with open(pkl, 'rb') as f:
        res = pickle.load(f)

    for k, v in res.items():
        if type(v) is np.float64 or type(v) is np.int64 or type(v) is np.int:
            print(k, v)

    return res

def selected_peaks(res, redo_flag = False):

    if not redo_flag:
        sp = pd.DataFrame({'y':res['selected_peaks'][:,0],
                           'x':res['selected_peaks'][:,1],
                           'selected_peaks_int': res['selected_peaks_int']
                          })
    else:
        p = peaks(res)
        coords = p[p.thr_array==res['selected_thr']].peaks_coords
        coords = coords.values[0]
        sp = pd.DataFrame({'x':coords[:,0], 'y':coords[:,1]})

    return sp

def peaks(res):
    p = pd.DataFrame({'thr_array':res['thr_array'],
              'peaks_coords':res['peaks_coords'],
              'total_peaks':res['total_peaks']
             })
    return p

res = load_results(fov_num)
sp = selected_peaks(res, redo_flag=False)
p = peaks(res)

In [None]:
psymFISH_thresh = res['selected_thr']

In [None]:
stack = load_image_stack(fov_num)

In [None]:
stack.show_stack({Indices.ROUND: 0}, rescale=True)

# Re-produce pysmFISH Results

## Filtering code

In [None]:
from starfish.image import Filter

ghp = Filter.GaussianHighPass(sigma=(1,8,8), is_volume=True)
lp = Filter.Laplace(sigma=(0.2, 0.5, 0.5), is_volume=True)

stack_hp = ghp.run(stack, in_place=False)
stack_hp_lap = lp.run(stack_hp, in_place=False)

In [None]:
mp = stack_hp_lap.max_proj(Indices.Z)[0,0,:,:]

plt.figure(figsize=(10,10))
plt.imshow(mp, cmap = 'gray', vmin=np.percentile(mp, 98), vmax=np.percentile(mp, 99.9))
plt.title('Filtered max projection')
plt.axis('off');

### Spot Finding

In [None]:
min_distance = 6
stringency = 0
min_obj_area = 6
max_obj_area = 600

# TODO this will go away once ImageStack.max_proj returns an ImageStack
stack = ImageStack.from_numpy_array(np.expand_dims(np.expand_dims(np.expand_dims(mp, 0), 0), 0))

lmp = SpotFinder.LocalMaxPeakFinder(
    min_distance=min_distance,
    stringency=stringency,
    min_obj_area=min_obj_area,
    max_obj_area=max_obj_area
)
lmp_res = lmp.run(stack)

### Spot finding QA

In [None]:
lmp_res

In [None]:
plt.hist(lmp_res.data[:,0,0], bins=20)
sns.despine(offset=2)
plt.yscale('log')
plt.xlabel('Intensity')
plt.ylabel('Number of spots');

In [None]:
mp = stack_hp_lap.max_proj(Indices.Z)[0,0,:,:]

plt.figure(figsize=(10,10))
plt.imshow(mp, cmap = 'gray', vmin=np.percentile(mp, 98), vmax=np.percentile(mp, 99.9))
plt.plot(lmp_res.x, lmp_res.y, 'or')
plt.axis('off');

## Compare to pySMFISH peak calls

In [None]:
num_spots_simone = len(sp)
num_spots_starfish = len(lmp_res)

plt.figure(figsize=(10,10))
plt.plot(sp.x, -sp.y, 'o')
sns.despine(offset=20)
plt.plot(lmp_res.x, -lmp_res.y, 'x')

plt.legend(['Benchmark: {} spots'.format(num_spots_simone),
            'Starfish: {} spots'.format(num_spots_starfish)])
plt.title('osmFISH spot calls');

print("Starfish finds {} fewer spots".format(num_spots_simone-num_spots_starfish))