In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import backend as K
from tqdm.notebook import tqdm

from fastmri_recon.models.functional_models.unet import unet
from fastmri_recon.models.subclassed_models.denoisers.dncnn import DnCNN
from fastmri_recon.models.subclassed_models.denoisers.focnet import FocNet, DEFAULT_COMMUNICATION_BETWEEN_SCALES
from fastmri_recon.models.subclassed_models.denoisers.focnet import DEFAULT_N_CONVS_PER_SCALE as default_n_convs_focnet
from fastmri_recon.models.subclassed_models.denoisers.mwcnn import MWCNN, DEFAULT_N_FILTERS_PER_SCALE
from fastmri_recon.models.subclassed_models.denoisers.mwcnn import DEFAULT_N_CONVS_PER_SCALE as default_n_convs_mwcnn
from fastmri_recon.models.subclassed_models.xpdnet import XPDNet
from fastmri_recon.models.training.compile import default_model_compile

In [2]:
n_primal = 5
test_memory_fit = False
write_to_csv = True

if write_to_csv:
    df_params = pd.DataFrame(columns=['model_name', 'model_size', 'n_params'])

In [3]:
def test_works_in_xpdnet_train(model, n_scales, res):
    run_params = {
        'n_primal': n_primal,
        'multicoil': False,
        'n_scales': n_scales,
        'n_iter': 10,
        'refine_smaps': False,
        'res': res,
    }
    model = XPDNet(model, **run_params)
    default_model_compile(model, lr=1e-3, loss='mae')
    model.fit(
        x=[
            tf.zeros([1, 640, 640, 1], dtype=tf.complex64),
            tf.zeros([1, 640, 640], dtype=tf.complex64),
        ],
        y=tf.zeros([1, 320, 320, 1]),
        epochs=1,
    )

In [4]:
#dncnn
params = {}
params['big'] = dict(
    n_convs=20,
    n_filters=64,
)
params['medium'] = dict(
    n_convs=10,
    n_filters=32,
)
params['small'] = dict(
    n_convs=5,
    n_filters=16,
)

for param_name, param_values in tqdm(params.items(), 'Dncnn'):
    print('DnCNN', param_name)
    model = DnCNN(n_outputs=2*n_primal, res=False, **param_values)
    model(tf.zeros([1, 32, 32, 2*(n_primal + 1)]))
    trainable_count = np.sum([K.count_params(w) for w in model.trainable_weights])
    print(trainable_count)
    if test_memory_fit:
        try:
            test_works_in_xpdnet_train(model, n_scales=0, res=True)
        except:
            print('Does not fit in memory for xpdnet')
    if write_to_csv:
        df_params = df_params.append(dict(
            model_name='DnCNN',
            model_size=param_name,
            n_params=trainable_count,
        ), ignore_index=True)

HBox(children=(FloatProgress(value=0.0, description='Dncnn', max=1.0, style=ProgressStyle(description_width='i…

DnCNN small
10154



In [5]:
#unet
params = {}
# params['big'] = dict(
#     n_layers=4,
#     layers_n_channels=[32, 64, 128, 256],
#     layers_n_non_lins=2,
# )
# params['medium'] = dict(
#     n_layers=4,
#     layers_n_channels=[16, 32, 64, 128],
#     layers_n_non_lins=2,
# )
params['small'] = dict(
    n_layers=3,
    layers_n_channels=[16, 32, 64],
    layers_n_non_lins=1,
)
for param_name, param_values in tqdm(params.items(), 'unet'):
    print('U-net', param_name)
    model = unet(
        input_size=(None, None, 2*(n_primal + 1)), 
        compile=False, 
        res=False, 
        n_output_channels=2*n_primal,
        **param_values,
    )
    model(tf.zeros([1, 32, 32, 2*(n_primal + 1)]))
    trainable_count = np.sum([K.count_params(w) for w in model.trainable_weights])
    print(trainable_count)
    if test_memory_fit:
        try:
            test_works_in_xpdnet_train(model, n_scales=param_values['n_layers'], res=True)
        except:
            print('Does not fit in memory for xpdnet')
    if write_to_csv:
        df_params = df_params.append(dict(
            model_name='U-net',
            model_size=param_name,
            n_params=trainable_count,
        ), ignore_index=True)

HBox(children=(FloatProgress(value=0.0, description='unet', max=1.0, style=ProgressStyle(description_width='in…

U-net small
58536



In [6]:
#mwcnn
params = {}
# params['big'] = dict(
#     n_scales=3,
#     n_filters_per_scale=DEFAULT_N_FILTERS_PER_SCALE,
#     n_convs_per_scale=default_n_convs_mwcnn,
#     n_first_convs=3,
#     first_conv_n_filters=64,
# )
# params['medium'] = dict(
#     n_scales=3,
#     n_filters_per_scale=[64, 128, 256],
#     n_convs_per_scale=default_n_convs_mwcnn,
#     n_first_convs=2,
#     first_conv_n_filters=32,
# )
params['small'] = dict(
    n_scales=2,
    n_filters_per_scale=[32, 64],
    n_convs_per_scale=[2, 2],
    n_first_convs=2,
    first_conv_n_filters=32,
)
for param_name, param_values in tqdm(params.items(), 'mwcnn'):
    print('MWCNN', param_name)
    model = MWCNN(res=False, n_outputs=2*n_primal, **param_values)
    model(tf.zeros([1, 32, 32, 2*(n_primal + 1)]))
    trainable_count = np.sum([K.count_params(w) for w in model.trainable_weights])
    print(trainable_count)
    if test_memory_fit:
        try:
            test_works_in_xpdnet_train(model, n_scales=param_values['n_scales'], res=True)
        except:
            print('Does not fit in memory for xpdnet')
    if write_to_csv:
        df_params = df_params.append(dict(
            model_name='MWCNN',
            model_size=param_name,
            n_params=trainable_count,
        ), ignore_index=True)

HBox(children=(FloatProgress(value=0.0, description='mwcnn', max=1.0, style=ProgressStyle(description_width='i…

MWCNN small
338122



In [7]:
#focnet
params = {}
# params['big'] = dict(
#     n_scales=4,
#     n_filters=128,
#     n_convs_per_scale=default_n_convs_focnet,
#     communications_between_scales=DEFAULT_COMMUNICATION_BETWEEN_SCALES,
# )
# params['medium'] = dict(
#     n_scales=4,
#     n_filters=64,
#     n_convs_per_scale=default_n_convs_focnet,
#     communications_between_scales=DEFAULT_COMMUNICATION_BETWEEN_SCALES,
# )
params['small'] = dict(
    n_scales=3,
    n_filters=32,
    n_convs_per_scale=default_n_convs_focnet[:-1],
    communications_between_scales=DEFAULT_COMMUNICATION_BETWEEN_SCALES[:-1],
)
for param_name, param_values in tqdm(params.items(), 'focnet'):
    print('FocNet', param_name)
    model = FocNet(n_outputs=2*n_primal, **param_values)
    model(tf.zeros([1, 32, 32, 2*(n_primal + 1)]))
    trainable_count = np.sum([K.count_params(w) for w in model.trainable_weights])
    print(trainable_count)
    if test_memory_fit:
        try:
            test_works_in_xpdnet_train(model, n_scales=param_values['n_scales'], res=False)
        except:
            print('Does not fit in memory for xpdnet')
    if write_to_csv:
        df_params = df_params.append(dict(
            model_name='FocNet',
            model_size=param_name,
            n_params=trainable_count,
        ), ignore_index=True)

HBox(children=(FloatProgress(value=0.0, description='focnet', max=1.0, style=ProgressStyle(description_width='…

FocNet small
457402.0



In [8]:
df_params

Unnamed: 0,model_name,model_size,n_params
0,DnCNN,small,10154
1,U-net,small,58536
2,MWCNN,small,338122
3,FocNet,small,457402
