In [None]:
import os
import re
import sys
import torch
import scipy.io
import mat73
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as T
import torch.nn as nn
import torch.fft as fft
import torch.nn.functional as F
import cv2 as cv
from unet import UNet
from mpl_toolkits.axes_grid1 import make_axes_locatable
%matplotlib inline

In [None]:
print("tested on Torch version:",torch.__version__)

### load PSFs

In [None]:
data_dict_psf = mat73.loadmat('data/psfs_save_magfs.mat')
psfs = data_dict_psf['psfs'][:,:,:,-25:][::2,::2]


### real results

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
device = 'cuda:0'
drng = torch.from_numpy(data_dict_psf['drng'][-25:])
model = UNet(in_channels=25*3,
              out_channels=4,
              in_layer='filter',
              device = device,
              batch_size = 7,
              n_blocks=5,
              start_filts = 64,
              attention = True,
              activation=nn.ELU(),
              normalization='batch',
              conv_mode='same',
              out_layer='linear',
              dim=2).to(device)

real_data_path = 'data/avgCap30_1.mat'

#if i==0:
#data_dict = scipy.io.loadmat(real_data_path) #only for avgCap10_4.mat
data_dict = mat73.loadmat(real_data_path)
PATH = "checkpoint/model_2dunet_50dB_reg_48.pt"
state_dict = torch.load(PATH,map_location='cpu')
model.load_state_dict(state_dict)
model.to(device)
#else:
#    data_dict = mat73.loadmat(data_dir+'raw_data/'+real_data[i-1])

img = data_dict['avgCap']
left = (img.shape[0]//2-128)//2
top = (img.shape[1]//2-128)//2
measurement = np.zeros([456, 684, 3])
measurement[:, :, 0] = img[::2, ::2, 0]
measurement[:, :, 1] = img[::2, ::2, 1]
measurement[:, :, 2] = img[::2, ::2, -1]
measurement /=np.max(measurement)
x = torch.from_numpy((255*measurement)).unsqueeze(0).permute(0, 3, 1, 2).to(device).type(torch.float32)
model.eval()

with torch.no_grad():
    out_d = 1/model(x)[0, 3:,...].cpu()[0]
    out_aif = model(x)[0, :3,...].cpu().numpy()
    out_aif -= np.min(out_aif)
    out_aif /= np.max(out_aif)
    out_d[np.where(out_aif[1,:,:] < 0.15 )] = np.nan
    out_d[0,0]=20
    out_d[0,1]=3.6
    out_d = np.clip(out_d,3.6,20)

fig, ax = plt.subplots(1, 3, gridspec_kw={'width_ratios': [1.5, 1,1.08]},figsize=(10,5))
fig.tight_layout()

ax[0].imshow(measurement)
ax[0].set_title('measurement')
ax[1].imshow(out_aif.transpose(1, 2, 0))
ax[1].set_title('aif')
divider = make_axes_locatable(ax[2])
cax = divider.append_axes('right', size='5%', pad=0.05)
c = ax[2].imshow(out_d,cmap='jet')
ax[2].set_title('depth')
fig.colorbar(c,cax)
plt.show()