In [1]:
import numpy as np

def make_outliers(mean=1000, err=10, size=50):
    # Create 1 stack of pixel with 10% outliers. 
    
    std = np.sqrt(mean)
    n_outliers = np.random.randint(0, int(size/10))
    size -= n_outliers
    random_err = err * np.random.normal(loc=mean, scale=std, size=size) * np.random.choice((-1, 1), size)
    data = mean + random_err

    outlier_int = 50 * mean
    outlier_errs =  outlier_int * np.random.rand(n_outliers) * np.random.choice((-1, 1), n_outliers)
    
    data = np.concatenate((data, outlier_errs))
    np.random.shuffle(data)

    return data

def cube_outliers(size=(128, 128, 50)):
    # Make an image series with outliers with respect to the 3rd dimension. 
    cube = np.empty(size)
    for r in range(size[0]):
        for c in range(size[1]):
            cube[r,c,:] = make_outliers(size=size[-1])
    
    return cube

In [2]:
def sigma_clip(datacube):
    rejMask = np.zeros(datacube.shape, dtype=np.bool)
    n = 1
    n_outliers = []
    while n > 0:
        rejMask0 = rejMask.copy()
        med = np.nanmedian(datacube, axis=-1)
        sigma = np.nanstd(datacube, axis=-1)
        rejMask = (np.abs(datacube - med[...,np.newaxis]) > 5*sigma[...,np.newaxis])
        n = rejMask.sum()
        n_outliers.append(n)
#         print(n)
        rejMask = rejMask0 | rejMask
        datacube[rejMask] = np.nan
    return rejMask, n_outliers

In [3]:
images = cube_outliers(size=(1024, 1024, 100))

In [5]:
%time rej_mask1, n_outs1 = sigma_clip(images.copy())

  if __name__ == '__main__':


Wall time: 26.3 s


In [7]:
print(n_outs1, sum(n_outs1))

[22440, 43, 0] 22483


In [8]:
rej_mask2 = np.empty(images.shape)

In [10]:
%%time
images2 = images.copy()
nout_rows = []
for r in range(images.shape[0]):
    rej_mask2[r, ...], nout_r = sigma_clip(images2[r,...]) 
    nout_rows.append(nout_r)

  if __name__ == '__main__':


Wall time: 16.4 s


In [11]:
nloops = max([len(nouts) for nouts in nout_rows])
nout_sum = sum([sum(nouts) for nouts in nout_rows])
nloops_per_row = np.array([len(nouts) for nouts in nout_rows])

In [12]:
from collections import Counter
Counter(nloops_per_row)

Counter({2: 982, 3: 42})

In [None]:
np.array_equal(rej_mask1, rej_mask)