In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import backend as K

from fastmri_recon.models.subclassed_models.denoisers.proposed_params import get_models
from fastmri_recon.models.subclassed_models.xpdnet import XPDNet
from fastmri_recon.models.training.compile import default_model_compile

In [2]:
n_primal = 5
test_memory_fit = False
write_to_csv = True

if write_to_csv:
    df_params = pd.DataFrame(columns=['model_name', 'model_size', 'n_params'])

In [3]:
def test_works_in_xpdnet_train(model, n_scales, res):
    run_params = {
        'n_primal': n_primal,
        'multicoil': False,
        'n_scales': n_scales,
        'n_iter': 10,
        'refine_smaps': False,
        'res': res,
    }
    model = XPDNet(model, **run_params)
    default_model_compile(model, lr=1e-3, loss='mae')
    model.fit(
        x=[
            tf.zeros([1, 640, 640, 1], dtype=tf.complex64),
            tf.zeros([1, 640, 640], dtype=tf.complex64),
        ],
        y=tf.zeros([1, 320, 320, 1]),
        epochs=1,
    )

In [4]:
for model_name, model_size, model, n_scales, res in get_models(n_primal):
    trainable_count = np.sum([K.count_params(w) for w in model.trainable_weights])
    print(trainable_count)
    if test_memory_fit:
        try:
            test_works_in_xpdnet_train(model, n_scales=n_scales, res=res)
        except tf.errors.ResourceExhaustedError:
            print('Does not fit in memory for xpdnet')
    if write_to_csv:
        df_params = df_params.append(dict(
            model_name=model_name,
            model_size=model_size,
            n_params=int(trainable_count),
        ), ignore_index=True)

Models:   0%|          | 0/4 [00:00<?, ?it/s]
DnCNN:   0%|          | 0/3 [00:00<?, ?it/s][A

DnCNN big



DnCNN: 100%|██████████| 3/3 [00:00<00:00, 10.73it/s][A
Models:  25%|██▌       | 1/4 [00:00<00:00,  3.55it/s]
FocNet:   0%|          | 0/2 [00:00<?, ?it/s][A

677450
DnCNN medium
80362
DnCNN small
10154
FocNet medium



FocNet:  50%|█████     | 1/2 [00:00<00:00,  4.14it/s][A
FocNet: 100%|██████████| 2/2 [00:00<00:00,  4.89it/s][A
Models:  50%|█████     | 2/4 [00:00<00:00,  3.12it/s]
MWCNN:   0%|          | 0/3 [00:00<?, ?it/s][A

621698.0
FocNet small
455674.0
MWCNN big



MWCNN:  33%|███▎      | 1/3 [00:00<00:00,  4.66it/s][A
MWCNN:  67%|██████▋   | 2/3 [00:00<00:00,  5.09it/s][A

24932746
MWCNN medium
6217930
MWCNN small


MWCNN: 100%|██████████| 3/3 [00:00<00:00,  6.63it/s]
Models:  75%|███████▌  | 3/4 [00:01<00:00,  2.77it/s]
U-net:   0%|          | 0/3 [00:00<?, ?it/s][A
U-net:  33%|███▎      | 1/3 [00:00<00:00,  5.21it/s][A

338122
U-net big
1928600
U-net medium



U-net: 100%|██████████| 3/3 [00:00<00:00,  6.61it/s][A
Models: 100%|██████████| 4/4 [00:01<00:00,  2.49it/s]

483592
U-net small
58536





In [5]:
df_params

Unnamed: 0,model_name,model_size,n_params
0,DnCNN,big,677450
1,DnCNN,medium,80362
2,DnCNN,small,10154
3,FocNet,medium,621698
4,FocNet,small,455674
5,MWCNN,big,24932746
6,MWCNN,medium,6217930
7,MWCNN,small,338122
8,U-net,big,1928600
9,U-net,medium,483592


In [7]:
df_params.to_csv('n_params_model.csv')