# Noise2Void - 2D Example for SEM data

__Note:__ This notebook expects a trained model and will only work if you have executed the `01_training.ipynb` beforehand.

In [27]:
# We import all our dependencies.
from n2v.models import N2V
import numpy as np
from matplotlib import pyplot as plt
from tifffile import imread
from csbdeep.io import save_tiff_imagej_compatible
import os
import glob

## Load the Network

In [34]:
# List the files to analyze
input_folder = '/dodrio/scratch/projects/2024_300/training/n2v/'
output_folder = '/dodrio/scractch/projects/2024_300/<YOUR_NAME>/nv2' #TO CHANGE
model_folder = os.path.join(output_folder, 'model')
file_extension = 'tif'

In [None]:
# A previously trained model is loaded by creating a new N2V-object without providing a 'config'.  
model_name = 'n2v_2D_sem'
#basedir = 'models'
model = N2V(config=None, name=model_name, basedir=model_folder)

In [None]:
# In case you do not want to load the weights that lead to lowest validation loss during 
# training but the latest computed weights, you can execute the following line:

# model.load_weights('weights_last.h5')

## Prediction
Here we will simply use the same data as during training and denoise it using our network.

In [None]:
files = glob.glob(input_folder+'/*.'+file_extension)
files.sort()

if os.path.isdir(output_folder)==False:
    os.makedirs(output_folder)

img=None
pred=None

stack_min= np.finfo(np.float32).max
stack_max = -1

for file in files:
    print(file)
    basename = os.path.basename(file)
    basename_witout_extension = basename[0:-(len(file_extension)+1)]
    img = imread(file)

    # Here we process the data.
    # The 'n_tiles' parameter can be used if images are too big for the GPU memory.
    # If we do not provide the 'n_tiles' parameter the system will automatically try to find an appropriate tiling.
    pred = model.predict(img, axes='YX', n_tiles=(2,1))

    min = pred.min()
    if(min<stack_min):
        stack_min=min
    max = pred.max()
    if(max>stack_max):
        stack_max=max

    save_tiff_imagej_compatible(os.path.join(output_folder,basename_witout_extension+'_denoised.tif'), pred, 'YX')

In [None]:
#Convert to 16bits
print('Max:'+str(stack_max))
print('Min:'+str(stack_min))

rescale_folder = os.path.join(output_folder, '16bits')

files = glob.glob(output_folder+'/*.'+file_extension)
files.sort()

if os.path.isdir(rescale_folder)==False:
    os.makedirs(rescale_folder)

for file in files:
    print(file)
    basename = os.path.basename(file)
    basename_witout_extension = basename[0:-(len(file_extension)+1)]
    pred = imread(file)
    pred_16 = (pred / stack_max) * 65536 #2^16
    pred_16 = pred_16.astype(np.uint16) 
    save_tiff_imagej_compatible(os.path.join(rescale_folder,basename_witout_extension+'.tif'), pred_16, 'YX')

### Show results on data

In [None]:
# Let's look at the results.
# Show a 500x500 crop of the image before and after
plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(img[:500:,:500],cmap="gray")
plt.title('Input');
plt.subplot(1,2,2)
plt.imshow(pred[:500:,:500],cmap="gray")
plt.title('Prediction');