In [None]:
'''
generate sbatch scripts for mayo testing
'''

In [2]:
import glob
import os
import copy
import pathlib
from datetime import datetime
import numpy as np

In [3]:
job_name = 'mayo2d'
devices = ['2', '3']
nprocesses = len(devices)
log_name = './outputs/mayo2d'
slurm_header = """#!/bin/bash
#SBATCH --partition=defq
#SBATCH --job-name=%s
#SBATCH --nodelist=gpu-node008
#SBATCH --cpus-per-task=16
#SBATCH --time=0
"""%job_name

cmds = []

In [4]:
input_dir = '/home/dwu/data/lowdoseCTsets/'
output_dir = '/home/dwu/trainData/deep_denoiser_ensemble/test/mayo_2d_3_layer_mean/'
prj_names = ['L291', 'L143', 'L067']
dose_rates = [2, 4, 6, 8]

# prj_names = ['L291']
# dose_rates = [6]

base_args = {
    'geometry': '/home/dwu/trainData/deep_denoiser_ensemble/data/mayo/geometry.cfg',
    'train_dir': '/home/dwu/trainData/deep_denoiser_ensemble/train/mayo_2d_3_layer_mean/',
    
    'nslice_mean': '3', 
    'img_norm': '0.019', 
    'islices': ['0', '-1'],
    
    'filter': 'hann',
    'margin': '96',
    'vmin': '-0.16', 
    'vmax': '0.24',
} 

In [5]:
# ensemble with real dose data
additional_args = {
    'py': 'test2d_ensemble_mayo',
    'N0': '-1',
    'dose_rate': '4', 
    'checkpoint': '25', 
    'tags': 'l2_depth_3/dose_rate_2,l2_depth_3/dose_rate_4,l2_depth_3/dose_rate_8,l2_depth_3/dose_rate_16',
    'checkpoint_smooth': 'l2_depth_3/dose_rate_16/25.h5',
    'N0_ref': '2.5e4',
}

args = copy.deepcopy(base_args)
args.update(additional_args)

for name in prj_names:
    cmd = copy.deepcopy(args)
    cmd['prj'] = os.path.join(input_dir, '%s_quarter_sino.mat'%name)
    cmd['output'] = os.path.join(output_dir, 'quarter/ensemble/%s'%name)
    cmd['device'] = devices[len(cmds) % len(devices)]
    
    cmds.append(cmd)

In [6]:
# ensemble with simulated dose
additional_args = {
    'py': 'test2d_ensemble_mayo',
    'N0': '1e5',
    'checkpoint': '25', 
    'tags': 'l2_depth_3/dose_rate_2,l2_depth_3/dose_rate_4,l2_depth_3/dose_rate_8,l2_depth_3/dose_rate_16',
    'checkpoint_smooth': 'l2_depth_3/dose_rate_16/25.h5',
    'N0_ref': '2.5e4',
}

args = copy.deepcopy(base_args)
args.update(additional_args)

for dose in dose_rates:
    for name in prj_names:
        cmd = copy.deepcopy(args)
        cmd['prj'] = os.path.join(input_dir, '%s_full_sino.mat'%name)
        cmd['output'] = os.path.join(output_dir, 'dose_rate_%d/ensemble/%s'%(dose, name))
        cmd['dose_rate'] = '%d'%dose 
        cmd['device'] = devices[len(cmds) % len(devices)]

        cmds.append(cmd)

In [7]:
# fbp 
additional_args = {
    'py': 'test2d_fbp_mayo',
    'N0': '-1',
    'dose_rate': '4', 
}

args = copy.deepcopy(base_args)
args.pop('train_dir', None)
args.update(additional_args)

for dose in ['full', 'quarter']:
    for name in prj_names:
        cmd = copy.deepcopy(args)
        cmd['prj'] = os.path.join(input_dir, '%s_%s_sino.mat'%(name, dose))
        cmd['output'] = os.path.join(output_dir, '%s/fbp/%s'%(dose, name))
        cmd['device'] = devices[len(cmds) % len(devices)]

        cmds.append(cmd)

In [8]:
# fbp for simulated dose levels
additional_args = {
    'py': 'test2d_fbp_mayo',
    'N0': '1e5',
}

args = copy.deepcopy(base_args)
args.pop('train_dir', None)
args.update(additional_args)

for dose in dose_rates:
    for name in prj_names:
        cmd = copy.deepcopy(args)
        cmd['prj'] = os.path.join(input_dir, '%s_full_sino.mat'%name)
        cmd['output'] = os.path.join(output_dir, 'dose_rate_%d/fbp/%s'%(dose, name))
        cmd['dose_rate'] = '%d'%dose 
        cmd['device'] = devices[len(cmds) % len(devices)]

        cmds.append(cmd)

In [9]:
# single network with real dose data
additional_args = {
    'py': 'test2d_mayo',
    'N0': '-1',
    'dose_rate': '4', 
}

args = copy.deepcopy(base_args)
args.update(additional_args)

for checkpoint in ['l2_depth_3/all/5.h5', 'l2_depth_3/dose_rate_4/25.h5']:
    for name in prj_names:
        cmd = copy.deepcopy(args)
        cmd['prj'] = os.path.join(input_dir, '%s_quarter_sino.mat'%name)
        if 'all' in checkpoint:
            cmd['output'] = os.path.join(output_dir, 'quarter/l2/%s_all'%name)
        else:
            cmd['output'] = os.path.join(output_dir, 'quarter/l2/%s'%name)
        cmd['checkpoint'] = checkpoint
        cmd['device'] = devices[len(cmds) % len(devices)]

        cmds.append(cmd)

In [10]:
# single network with simulated dose data
additional_args = {
    'py': 'test2d_mayo',
    'N0': '1e5',
}

args = copy.deepcopy(base_args)
args.update(additional_args)

for dose in dose_rates:
    # find the closest matching dose rate
    dose_rate_avail = np.array([2,4,8,12,16], int)
    ind = np.abs(dose_rate_avail - dose).argmin()
    dose_select = dose_rate_avail[ind]
    checkpoints = ['l2_depth_3/all/5.h5', 'l2_depth_3/dose_rate_%d/25.h5'%dose_select]
    
    for checkpoint in checkpoints:
        for name in prj_names:
            cmd = copy.deepcopy(args)
            cmd['prj'] = os.path.join(input_dir, '%s_full_sino.mat'%name)
            if 'all' in checkpoint:
                cmd['output'] = os.path.join(output_dir, 'dose_rate_%d/l2/%s_all'%(dose, name))
            else:
                cmd['output'] = os.path.join(output_dir, 'dose_rate_%d/l2/%s'%(dose, name))
            cmd['checkpoint'] = checkpoint
            cmd['dose_rate'] = '%d'%dose 
            cmd['device'] = devices[len(cmds) % len(devices)]

            cmds.append(cmd)

In [11]:
# single wgan with real dose data
additional_args = {
    'py': 'test2d_mayo',
    'N0': '-1',
    'dose_rate': '4', 
}

args = copy.deepcopy(base_args)
args.update(additional_args)

for checkpoint in ['l2_depth_3_wgan/all/5.h5', 'l2_depth_3_wgan/dose_rate_4/25.h5']:
    for name in prj_names:
        cmd = copy.deepcopy(args)
        cmd['prj'] = os.path.join(input_dir, '%s_quarter_sino.mat'%name)
        if 'all' in checkpoint:
            cmd['output'] = os.path.join(output_dir, 'quarter/wgan/%s_all'%name)
        else:
            cmd['output'] = os.path.join(output_dir, 'quarter/wgan/%s'%name)
        cmd['checkpoint'] = checkpoint
        cmd['device'] = devices[len(cmds) % len(devices)]

        cmds.append(cmd)

In [12]:
# single wgan with simulated dose data
additional_args = {
    'py': 'test2d_mayo',
    'N0': '1e5',
}

args = copy.deepcopy(base_args)
args.update(additional_args)

for dose in dose_rates:
    # find the closest matching dose rate
    dose_rate_avail = np.array([2,4,8,12,16], int)
    ind = np.abs(dose_rate_avail - dose).argmin()
    dose_select = dose_rate_avail[ind]
    checkpoints = ['l2_depth_3_wgan/all/5.h5', 'l2_depth_3_wgan/dose_rate_%d/25.h5'%dose_select]
    
    for checkpoint in checkpoints:
        for name in prj_names:
            cmd = copy.deepcopy(args)
            cmd['prj'] = os.path.join(input_dir, '%s_full_sino.mat'%name)
            if 'all' in checkpoint:
                cmd['output'] = os.path.join(output_dir, 'dose_rate_%d/wgan/%s_all'%(dose, name))
            else:
                cmd['output'] = os.path.join(output_dir, 'dose_rate_%d/wgan/%s'%(dose, name))
            cmd['checkpoint'] = checkpoint
            cmd['dose_rate'] = '%d'%dose 
            cmd['device'] = devices[len(cmds) % len(devices)]

            cmds.append(cmd)

In [13]:
with open('%s.sh'%job_name, 'w') as f:
    # slurm
    f.write(slurm_header + '\n\n')
    f.write('cd ..\n')

    for k, cmd in enumerate(cmds):
        # arguments
        argstrs = []
        for name in cmd:
            if name == 'py':
                continue
            if isinstance(cmd[name], list):
                argstrs.append(' '.join(['--%s'%name] + ['"%s"'%s for s in cmd[name]]))
            else:
                argstrs.append('--%s "%s"'%(name, cmd[name]))
        argstr = ' '.join(argstrs)
        
        logstr = '&>> %s_%d.log'%(log_name, k%nprocesses)
        
        f.write('python3 %s.py '%cmd['py'] + argstr + ' ' + logstr + ' &\n')
        if (k+1)%nprocesses == 0:
            f.write('wait\n')
            f.write('echo "%d/%d"\n'%(k+1, len(cmds)))
    f.write('wait\n')
    
    # cat logs together
    f.write('cat ' + ' '.join(['%s_%d.log'%(log_name, k) for k in range(nprocesses)]) + ' > ' + log_name + '.log\n')
    for k in range(nprocesses):
        f.write('rm %s_%d.log\n'%(log_name, k))