# Inference (after training a model)
- Use this notebook to predict flourescence images from brightfield for a certain organelle
- you need a trained model and brightfield patches folder
- insert the path prefix for the github directory in "main_path"
- choose the organelle to predict
- choose how many patches from the BF folder you want to predict


In [4]:
main_path = "/.../" ## change to main directory of github project
organelle = 'NucEnv' # # NucEnv , Nuclioli , DNAmito , ER , AF , Mito , Membrane , Micro , TJ
Nimgs = 5 ## how many patches to load

In [5]:
import skimage
# from skimage import measure
from skimage.feature import hog
import sklearn
import scipy
import scipy.stats as stats
from scipy.spatial.distance import jensenshannon
from glrlm import GLRLM
import torch
from torch.optim import AdamW
import numpy as np
import matplotlib.pyplot as plt
import tifffile
import requests
import cv2 
import random
import sys

sys.path.append(main_path)
from generative.metrics import FIDMetric, MMDMetric, MultiScaleSSIMMetric, SSIMMetric
from generative.networks.nets.diffusion_model_unet import DiffusionModelUNet # Adapted from https://github.com/huggingface/diffusers
from generative.networks.schedulers.ddpm import DDPMScheduler
from generative.networks.schedulers import DDIMScheduler
from torch.cuda.amp import GradScaler, autocast

sys.path.append(main_path+'src')
from src.ProcessingFunctions import segmentation_pipeline, minmax_norm
from src.Params import SegmentationParams
from src.DisplayFunctions import display_images, volumetric2sequence
from src.LoadSaveFunctions import load_patches, LoadModel, save_patches
from src.PredictFL import predict_FL

In [None]:
# initialize parameters and gpu

seg_params = SegmentationParams(organelle)
device = torch.device("cuda") 

In [None]:
## load BF patches
BF_images = load_patches(main_path, organelle, 'BF', Nimgs) # (5, 16, 64, 64) 0-1
BF_images = torch.tensor(BF_images, dtype=torch.float32).unsqueeze(1).permute(0,1,3,4,2)


In [None]:
# load model

DiffModel = LoadModel(main_path, organelle, load_model=1, timesteps=1000)

In [None]:
### run predictions - outputs DDPM, DDPMavg, var and all intermidiate predictions between 2 timesteps

t_low, t_high = 200 , 800
pred, pred_avg, pred_var, all_x0_pred = predict_FL(DiffModel, BF_images, t_low, t_high, seed=42) # (5, 1, 64, 64, 16) (5, 1, 64, 64, 16) (5, 1, 64, 64, 16) (5, 600, 1, 64, 64, 16)

### create std image and seg std image from the same inference process
stds = torch.randn_like(BF_images)[0:1]
stds_segs = torch.randn_like(BF_images)[0:1]

for i in range(len(pred)):
    std_ = all_x0_pred[i].std(axis=0)[0] > 0.06
    std_th, std_seg_ = segmentation_pipeline(std_ , filter_type=seg_params.filter_type, k1=seg_params.k1, k2=seg_params.k2, k3=seg_params.k3, filter_kernel=seg_params.filter_kernel, sigma=seg_params.sigma, organelle_th=seg_params.organelle_th, do_erode_dilate=seg_params.do_erode_dilate, do_remove_small_objects=seg_params.do_remove_small_objects, do_fill_holes=seg_params.do_fill_holes, do_fill_holes_boarders=seg_params.do_fill_holes_boarders)  
    stds      = torch.cat(( stds      , torch.tensor(std_).unsqueeze(0).unsqueeze(0) ) , dim=0 )
    stds_segs = torch.cat(( stds_segs , torch.tensor(std_seg_).unsqueeze(0).unsqueeze(0) ) , dim=0 )
stds = stds[1:].numpy()
stds_segs = stds_segs[1:].numpy()

In [None]:
### create std image and seg std image from different seeds

t_low, t_high = 200 , 800

stds_seeds = torch.randn_like(BF_images)[0:1]
stds_segs_seeds = torch.randn_like(BF_images)[0:1]

k_pred = []
k_pred_avg = []
Nseeds = 5
for i in range(Nseeds):
    pred, pred_avg, pred_var, all_x0_pred = predict_FL(DiffModel, BF_images,t_low, t_high, seed=i) # (5, 1, 64, 64, 16) (5, 1, 64, 64, 16) (5, 1, 64, 64, 16) (5, 600, 1, 64, 64, 16)
    k_pred.append(pred)
    k_pred_avg.append(pred_avg)
k_pred = np.array(k_pred)         # (5seeds, 30images, 1, 64, 64, 16) X0 at t=0
k_pred_avg = np.array(k_pred_avg) # (5seeds, 30images, 1, 64, 64, 16) X0avg



In [None]:
for i in range(len(pred)):
    std_ = k_pred_avg[:,i,...].std(axis=0)[0].transpose(2,0,1) > 0.1
    std_th, std_seg_ = segmentation_pipeline(std_ , filter_type=seg_params.filter_type, k1=seg_params.k1, k2=seg_params.k2, k3=seg_params.k3, filter_kernel=seg_params.filter_kernel, sigma=seg_params.sigma, organelle_th=seg_params.organelle_th, do_erode_dilate=seg_params.do_erode_dilate, do_remove_small_objects=seg_params.do_remove_small_objects, do_fill_holes=seg_params.do_fill_holes, do_fill_holes_boarders=seg_params.do_fill_holes_boarders)  
    stds_seeds = torch.cat(( stds_seeds , torch.tensor(std_).permute(1,2,0).unsqueeze(0).unsqueeze(0) ) , dim=0 )
    stds_segs_seeds = torch.cat(( stds_segs_seeds , torch.tensor(std_seg_).permute(1,2,0).unsqueeze(0).unsqueeze(0) ) , dim=0 )
stds_seeds = stds_seeds[1:].numpy() # (5, 1, 64, 64, 16)
stds_segs_seeds = stds_segs_seeds[1:].numpy() # (5, 1, 64, 64, 16)


In [None]:
### saving FL patches

save_patches(main_path, organelle, pred, 'FL_pred')
save_patches(main_path, organelle, pred_avg, 'FLavg_pred')
save_patches(main_path, organelle, stds, 'FLavg_std')
save_patches(main_path, organelle, stds_segs, 'FLavg_std_seg')
save_patches(main_path, organelle, stds_seeds, 'FLavg_std_seeds')
save_patches(main_path, organelle, stds_segs_seeds, 'FLavg_std_seg_seeds')
