# Random Shaking Denoising using Farneback3D (Corbel)

In [None]:
local_debug = True

In [None]:
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False

if IN_COLAB:
    print("Running in Colab")
    !pip install cupy-cuda12x
    !pip install opticalflow3D
    !apt install libcudart11.0
    !apt install libcublas11
    !apt install libcufft10
    !apt install libcusparse11
    !apt install libnvrtc11.2
    from google.colab import drive
    drive.mount('/content/drive')
    !cp drive/Shareddrives/TomogramDenoising/tomograms/{vol_name}.tif .
else:
    print("Running in locahost")
    !cp ~/Downloads/{vol_name}.tif .

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure

In [None]:
import numpy as np
import logging

In [None]:
if local_debug:
    !ln -sf ../../information_theory/src/information_theory/ .
else:
    !pip install "information_theory @ git+https://github.com/vicente-gonzalez-ruiz/information_theory"
import information_theory  # pip install "information_theory @ git+https://github.com/vicente-gonzalez-ruiz/information_theory"

In [None]:
#import skimage.io

In [None]:
#import RSIVD

In [None]:
from collections import namedtuple
Args = namedtuple("args", ["input", "output"])
fn = "Corbel2301_block2_June2019_crop_ali_crop"
args = Args(fn, fn + "_denoised")

In [None]:
import mrcfile

In [None]:
%%bash -s "$args.input"
set -x
OUTPUT_FILENAME=$1
#rm -f $OUTPUT_FILENAME
if test ! -f $OUTPUT_FILENAME ; then
    FILEID="1Uqa6ywi8bllhyHxrODD5yjuesUkNO3O0"
    #wget --no-check-certificate 'https://docs.google.com/uc?export=download&id='$FILEID -O $OUTPUT_FILENAME #2> /dev/null
    #wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=FILEID' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=FILEID" -O FILENAME && rm -rf /tmp/cookies.txt
    gdown https://drive.google.com/uc?id=$FILEID
fi
set +x
# https://drive.google.com/file/d/1Uqa6ywi8bllhyHxrODD5yjuesUkNO3O0/view?usp=sharing

In [None]:
#noisy = opticalflow3D.helpers.load_image(args.input)
stack_MRC = mrcfile.open(args.input + ".mrc")
noisy = stack_MRC.data

In [None]:
noisy = (255*(noisy - np.min(noisy))/(np.max(noisy) - np.min(noisy))).astype(np.uint8)

In [None]:
noisy.shape

In [None]:
fig, axs = plt.subplots(1, 1, figsize=(16, 16))
axs.imshow(noisy[:, ::-1, :][50], cmap="gray")
axs.set_title(f"Noisy")
fig.tight_layout()
plt.show()

In [None]:
noisy = noisy[50:, 200:, 200:]

In [None]:
import logging
import threading
import time
import numpy as np
# pip install "motion_estimation @ git+https://github.com/vicente-gonzalez-ruiz/motion_estimation"
from motion_estimation._3D.farneback_opticalflow3d import Farneback_Estimator as _3D_OF_Estimation 
from motion_estimation._3D.project_opticalflow3d import Volume_Projection

PYRAMID_LEVELS = 3
WINDOW_SIDE = 5
ITERATIONS = 5
N_POLY = 11

class Random_Shaking_Denoising_old(_3D_OF_Estimation, Volume_Projection):
    def __init__(
        self,
        logging_level=logging.INFO,
        #estimator="opticalflow3d"
    ):
        #self.estimator = estimator
        _3D_OF_Estimation.__init__(self, logging_level)
        Volume_Projection.__init__(self, logging_level)
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging_level)

        if self.logger.getEffectiveLevel() <= logging.INFO:
            self.max = 0
            self.min = 0
        print(f"{'iter':>5s}", end='')
        print(f"{'min_shaking':>15s}", end='')
        print(f"{'max_shaking':>15s}", end='')
        print(f"{'min_flow':>15s}", end='')
        print(f"{'avg_abs_flow':>15s}", end='')
        print(f"{'max_flow':>15s}", end='')
        print(f"{'time':>15s}", end='')
        print()

        self.stop_event = threading.Event()
        self.logger_daemon = threading.Thread(target=self.show_log)
        self.logger_daemon.daemon = True
        self.time_0 = time.perf_counter()
        self.logger_daemon.start()

    def show_log(self):
        #while not self.stop_event.is_set():
        while self.stop_event.wait():
            time_1 = time.perf_counter()
            running_time = time_1 - self.time_0
            print(f"{self.iter:>5d}", end='')
            print(f"{np.min(self.displacements):>15.2f}", end='')
            print(f"{np.max(self.displacements):>15.2f}", end='')
            print(f"{np.min(self.flow):>15.2f}", end='')
            print(f"{np.average(np.abs(self.flow)):>15.2f}", end='')
            print(f"{np.max(self.flow):>15.2f}", end='')
            print(f"{running_time:>15.2f}", end='')
            print()
            self.stop_event.clear()
            self.time_0 = time.perf_counter()

    def shake_vector(self, x, mean=0.0, std_dev=1.0):
        y = np.arange(len(x))
        self.displacements = np.random.normal(mean, std_dev, len(x))
        return np.stack((y + self.displacements, x), axis=1)

    def shake_volume(self, volume, mean=0.0, std_dev=1.0):
        shaked_volume = np.empty_like(volume)

        # Shaking in Z
        values = np.arange(volume.shape[0]).astype(np.int16)
        for y in range(volume.shape[1]):
            for x in range(volume.shape[2]):
                pairs = self.shake_vector(x=values, mean=mean, std_dev=std_dev).astype(np.int16)
                pairs = pairs[pairs[:, 0].argsort()]
                shaked_volume[values, y, x] = volume[pairs[:, 1], y , x]
        volume = shaked_volume
    
        # Shaking in Y
        values = np.arange(volume.shape[1]).astype(np.int16)
        for z in range(volume.shape[0]):
            for x in range(volume.shape[2]):
                pairs = self.shake_vector(values, mean=mean, std_dev=std_dev).astype(np.int16)
                pairs = pairs[pairs[:, 0].argsort()]
                shaked_volume[z, values, x] = volume[z, pairs[:, 1], x]
        volume = shaked_volume

        # Shaking in X
        values = np.arange(volume.shape[2]).astype(np.int16)
        for z in range(volume.shape[0]):
            for y in range(volume.shape[1]):
                pairs = self.shake_vector(values, mean=mean, std_dev=std_dev).astype(np.int16)
                pairs = pairs[pairs[:, 0].argsort()]
                shaked_volume[z, y, values] = volume[z, y, pairs[:, 1]]
                
        return shaked_volume

    def project_volume_reference_to_target(self, reference, target, pyramid_levels, window_side, iterations, N_poly, block_size, overlap, threads_per_block):
        self.flow = self.pyramid_get_flow(
            target=target,
            reference=reference,
            flow=None,
            pyramid_levels=pyramid_levels,
            window_side=window_side,
            iterations=iterations,
            N_poly=N_poly,
            block_size=block_size,
            overlap=overlap,
            threads_per_block=threads_per_block)
        projection = self.remap(reference, self.flow)
        return projection

    def filter_volume(
        self,
        noisy_volume,
        N_iters=25,
        mean=0.0,
        std_dev=1.0,
        pyramid_levels=PYRAMID_LEVELS,
        window_side=WINDOW_SIDE,
        iterations=ITERATIONS,
        N_poly=N_POLY,
        block_size=(256, 256, 256),
        overlap=(64, 64, 64),
        threads_per_block=(8, 8, 8)
    ):
        acc_volume = np.zeros_like(noisy_volume, dtype=np.float32)
        acc_volume[...] = noisy_volume
        for i in range(N_iters):
            self.iter = i
            denoised_volume = acc_volume/(i+1)
            shaked_noisy_volume = self.shake_volume(noisy_volume, mean=mean, std_dev=std_dev)
            shaked_and_compensated_noisy_volume = self.project_volume_reference_to_target(
                reference=denoised_volume,
                target=shaked_noisy_volume,
                pyramid_levels=pyramid_levels,
                window_side=window_side,
                iterations=iterations,
                N_poly=N_poly,
                block_size=block_size,
                overlap=overlap,
                threads_per_block=threads_per_block)
            acc_volume += shaked_and_compensated_noisy_volume
            self.stop_event.set()
        denoised_volume = acc_volume/(N_iters + 1)

        return denoised_volume

In [None]:
import threading
import time
import numpy as np
# pip install "motion_estimation @ git+https://github.com/vicente-gonzalez-ruiz/motion_estimation"
from motion_estimation._3D.farneback import OF_Estimation as _3D_OF_Estimation 
from motion_estimation._3D.project import Projection

import logging
import inspect

PYRAMID_LEVELS = 3
WINDOW_SIDE = 5
ITERATIONS = 5
N_POLY = 11

class Random_Shaking_Denoising(_3D_OF_Estimation, Volume_Projection):
    def __init__(
        self,
        logging_level=logging.INFO
        #estimator="opticalflow3d"
    ):
        #self.estimator = estimator
        _3D_OF_Estimation.__init__(self, logging_level)
        Volume_Projection.__init__(self, logging_level)
        #self.logger = logging.getLogger(__name__)
        #self.logger.setLevel(logging_level)
        self.logging_level = logging_level

        if self.logging_level <= logging.INFO:
            print(f"\nFunction: {inspect.currentframe().f_code.co_name}")
            '''
            args, _, _, values = inspect.getargvalues(inspect.currentframe())
            for arg in args:
                if isinstance(values[arg], np.ndarray):
                    print(f"{arg}.shape: {values[arg].shape}", end=' ')
                    print(f"{np.min(values[arg])} {np.average(values[arg])} {np.max(values[arg])}")
                else:
                    print(f"{arg}: {values[arg]}")
            '''

        if self.logging_level <= logging.INFO:
            self.max = 0
            self.min = 0
        print(f"{'iter':>5s}", end='')
        print(f"{'min_shaking':>15s}", end='')
        print(f"{'max_shaking':>15s}", end='')
        print(f"{'min_flow':>15s}", end='')
        print(f"{'avg_abs_flow':>15s}", end='')
        print(f"{'max_flow':>15s}", end='')
        print(f"{'time':>15s}", end='')
        print()

        self.stop_event = threading.Event()
        self.logger_daemon = threading.Thread(target=self.show_log)
        self.logger_daemon.daemon = True
        self.time_0 = time.perf_counter()
        self.logger_daemon.start()

    def show_log(self):
        #while not self.stop_event.is_set():
        while self.stop_event.wait():
            time_1 = time.perf_counter()
            running_time = time_1 - self.time_0
            print(f"{self.iter:>5d}", end='')
            print(f"{np.min(self.displacements):>15.2f}", end='')
            print(f"{np.max(self.displacements):>15.2f}", end='')
            print(f"{np.min(self.flow):>15.2f}", end='')
            print(f"{np.average(np.abs(self.flow)):>15.2f}", end='')
            print(f"{np.max(self.flow):>15.2f}", end='')
            print(f"{running_time:>15.2f}", end='')
            print()
            self.stop_event.clear()
            self.time_0 = time.perf_counter()

    def shake_vector(self, x, mean=0.0, std_dev=1.0):
        y = np.arange(len(x))
        self.displacements = np.random.normal(mean, std_dev, len(x))
        return np.stack((y + self.displacements, x), axis=1)

    def shake_volume(self, volume, mean=0.0, std_dev=1.0):

        if self.logging_level <= logging.INFO:
            print(f"\nFunction: {inspect.currentframe().f_code.co_name}")
            args, _, _, values = inspect.getargvalues(inspect.currentframe())
            for arg in args:
                if isinstance(values[arg], np.ndarray):
                    print(f"{arg}.shape: {values[arg].shape}", end=' ')
                    print(f"{np.min(values[arg])} {np.average(values[arg])} {np.max(values[arg])}")
                else:
                    print(f"{arg}: {values[arg]}")

        shaked_volume = np.empty_like(volume)

        # Shaking in Z
        values = np.arange(volume.shape[0]).astype(np.int16)
        for y in range(volume.shape[1]):
            for x in range(volume.shape[2]):
                pairs = self.shake_vector(x=values, mean=mean, std_dev=std_dev).astype(np.int16)
                pairs = pairs[pairs[:, 0].argsort()]
                shaked_volume[values, y, x] = volume[pairs[:, 1], y , x]
        volume = shaked_volume
    
        # Shaking in Y
        values = np.arange(volume.shape[1]).astype(np.int16)
        for z in range(volume.shape[0]):
            for x in range(volume.shape[2]):
                pairs = self.shake_vector(values, mean=mean, std_dev=std_dev).astype(np.int16)
                pairs = pairs[pairs[:, 0].argsort()]
                shaked_volume[z, values, x] = volume[z, pairs[:, 1], x]
        volume = shaked_volume

        # Shaking in X
        values = np.arange(volume.shape[2]).astype(np.int16)
        for z in range(volume.shape[0]):
            for y in range(volume.shape[1]):
                pairs = self.shake_vector(values, mean=mean, std_dev=std_dev).astype(np.int16)
                pairs = pairs[pairs[:, 0].argsort()]
                shaked_volume[z, y, values] = volume[z, y, pairs[:, 1]]
                
        return shaked_volume

    def project_volume_reference_to_target(self, reference, target, pyramid_levels, window_side, iterations, N_poly, block_size, overlap, threads_per_block):

        if self.logging_level <= logging.INFO:
            print(f"\nFunction: {inspect.currentframe().f_code.co_name}")
            args, _, _, values = inspect.getargvalues(inspect.currentframe())
            for arg in args:
                if isinstance(values[arg], np.ndarray):
                    print(f"{arg}.shape: {values[arg].shape}", end=' ')
                    print(f"{np.min(values[arg])} {np.average(values[arg])} {np.max(values[arg])}")
                else:
                    print(f"{arg}: {values[arg]}")

        self.flow = self.pyramid_get_flow(
            target=target,
            reference=reference,
            flow=None,
            pyramid_levels=pyramid_levels,
            window_side=window_side,
            iterations=iterations,
            N_poly=N_poly)
        projection = self.remap(reference, self.flow)
        return projection

    def filter_volume(
        self,
        noisy_volume,
        N_iters=25,
        mean=0.0,
        std_dev=1.0,
        pyramid_levels=PYRAMID_LEVELS,
        window_side=WINDOW_SIDE,
        iterations=ITERATIONS,
        N_poly=N_POLY,
        block_size=(256, 256, 256),
        overlap=(8, 8, 8),
        threads_per_block=(8, 8, 8)
    ):

        if self.logging_level <= logging.INFO:
            print(f"\nFunction: {inspect.currentframe().f_code.co_name}")
            args, _, _, values = inspect.getargvalues(inspect.currentframe())
            for arg in args:
                if isinstance(values[arg], np.ndarray):
                    print(f"{arg}.shape: {values[arg].shape}", end=' ')
                    print(f"{np.min(values[arg])} {np.average(values[arg])} {np.max(values[arg])}")
                else:
                    print(f"{arg}: {values[arg]}")

        acc_volume = np.zeros_like(noisy_volume, dtype=np.float32)
        acc_volume[...] = noisy_volume
        for i in range(N_iters):
            self.iter = i
            denoised_volume = acc_volume/(i+1)
            shaked_noisy_volume = self.shake_volume(noisy_volume, mean=mean, std_dev=std_dev)
            shaked_and_compensated_noisy_volume = self.project_volume_reference_to_target(
                reference=denoised_volume,
                target=shaked_noisy_volume,
                pyramid_levels=pyramid_levels,
                window_side=window_side,
                iterations=iterations,
                N_poly=N_poly,
                block_size=block_size,
                overlap=overlap,
                threads_per_block=threads_per_block)
            acc_volume += shaked_and_compensated_noisy_volume
            self.stop_event.set()
        denoised_volume = acc_volume/(N_iters + 1)

        return denoised_volume

In [None]:
#farneback = opticalflow3D.Farneback3D(iters=5, num_levels=3, scale=0.5, spatial_size=7, presmoothing=3, filter_type="gaussian", filter_size=7,); RS_sigma = 1.0
#farneback = opticalflow3D.Farneback3D(iters=5, num_levels=2, scale=0.5, spatial_size=5, sigma_k=1.0, filter_type="gaussian", filter_size=9, presmoothing=None, device_id=0); RS_sigma = 1.25; N_iters=100
denoiser = Random_Shaking_Denoising(logging_level=logging.DEBUG)

In [None]:
RS_sigma = 3.0
N_iters = 2
denoised = denoiser.filter_volume(noisy, std_dev=RS_sigma, N_iters=N_iters)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(16, 32))
axs[0].imshow(noisy[25], cmap="gray")
axs[0].set_title(f"Noisy")
axs[1].imshow(denoised[25], cmap="gray")
axs[1].set_title(f"Denoised (DQI={information_theory.information.compute_quality_index(noisy[16], denoised[16])})")
fig.tight_layout()
plt.show()

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(16, 32))
axs[0].imshow(noisy[1][300:,300:], cmap="gray")
axs[0].set_title(f"Noisy")
axs[1].imshow(denoised[1][300:,300:], cmap="gray")
axs[1].set_title(f"Denoised (DQI={information_theory.information.compute_quality_index(noisy[16][300:,300:], denoised[16][300:,300:])}")
fig.tight_layout()
plt.show()

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(16, 32))
axs[0].imshow(noisy[:, 100], cmap="gray")
axs[0].set_title(f"Noisy")
axs[1].imshow(denoised[:, 100], cmap="gray")
axs[1].set_title(f"Denoised (DQI={information_theory.information.compute_quality_index(noisy[:, 100], denoised[:, 100])})")
fig.tight_layout()
plt.show()

In [None]:
figure(figsize=(32, 32))
plt.subplot(1, 3, 1)
plt.title("original")
imgplot = plt.imshow(noisy[7][::-1, :], cmap="gray")
plt.subplot(1, 3, 2)
plt.title("$\sigma_\mathrm{RS}=$"+f"{RS_sigma}")
plt.imshow(denoised[7][::-1, :], cmap="gray")
plt.subplot(1, 3, 3)
plt.title("difference")
plt.imshow(noisy[7][::-1, :] - denoised[7][::-1, :], cmap="gray")

In [None]:
with mrcfile.new(f"{args.output}_{RS_sigma}_{N_iters}.mrc", overwrite=True) as mrc:
            mrc.set_data(denoised.astype(np.float32))
            mrc.data
#skimage.io.imsave(f"{args.output}_{RS_sigma}_{N_iters}.tif", denoised, imagej=True)
f"{args.output}_{RS_sigma}_{N_iters}.mrc"

In [None]:
input()

In [None]:
farneback = opticalflow3D.Farneback3D(iters=5,
                                      num_levels=3,
                                      scale=0.5,
                                      spatial_size=5,
                                      presmoothing=4,
                                      filter_type="box",
                                      filter_size=5,
                                     )

In [None]:
RS_sigma = 1.0
denoised_vol = RSIVD.filter(farneback, block_size, noisy_vol, RS_sigma=RS_sigma, N_iters=25)

In [None]:
figure(figsize=(32, 32))
plt.subplot(1, 3, 1)
plt.title("original")
imgplot = plt.imshow(noisy_vol[75][::-1, :], cmap="gray")
plt.subplot(1, 3, 2)
plt.title("$\sigma_\mathrm{RS}=$"+f"{RS_sigma}")
plt.imshow(denoised_vol[75][::-1, :], cmap="gray")
plt.subplot(1, 3, 3)
plt.title("difference")
plt.imshow(noisy_vol[75][::-1, :] - denoised_vol[75][::-1, :], cmap="gray")

In [None]:
skimage.io.imsave(f"{vol_name}_denoised_{RS_sigma}.tif", denoised_vol, imagej=True)

In [None]:
RS_sigma = 2.0
denoised_vol = RSIVD.filter(farneback, block_size, noisy_vol, RS_sigma=RS_sigma, N_iters=25)

In [None]:
figure(figsize=(32, 32))
plt.subplot(1, 3, 1)
plt.title("original")
imgplot = plt.imshow(noisy_vol[75][::-1, :], cmap="gray")
plt.subplot(1, 3, 2)
plt.title("$\sigma_\mathrm{RS}=$"+f"{RS_sigma}")
plt.imshow(denoised_vol[75][::-1, :], cmap="gray")
plt.subplot(1, 3, 3)
plt.title("difference")
plt.imshow(noisy_vol[75][::-1, :] - denoised_vol[75][::-1, :], cmap="gray")