In [9]:
'''
generate sbatch scripts for mayo testing
Additional averaging test
'''

'\ngenerate sbatch scripts for mayo testing\nAdditional averaging test\n'

In [10]:
import glob
import os
import copy
import pathlib
from datetime import datetime
import numpy as np

In [11]:
job_name = 'mayo2d_avg'
devices = ['0']
nprocesses = len(devices)
log_name = './outputs/mayo2d_avg'
slurm_header = """#!/bin/bash
#SBATCH --partition=defq
#SBATCH --job-name=%s
#SBATCH --nodelist=gpu-node008
#SBATCH --cpus-per-task=16
#SBATCH --time=0
"""%job_name

cmds = []

In [12]:
input_dir = '/home/local/PARTNERS/dw640/mnt/women_health_internal/dufan.wu/lowdoseCTsets/'
output_dir = '/home/local/PARTNERS/dw640/mnt/women_health_internal/dufan.wu/deep_denoiser_ensemble/test/mayo_2d_3_layer_mean/'
prj_names = ['L291', 'L143', 'L067']
dose_rates = [2, 4, 6, 8]
# dose_rates = [6]

# prj_names = ['L291']
# dose_rates = [6]

base_args = {
    'geometry': '../geometry_mayo.cfg',
    'train_dir': '/home/local/PARTNERS/dw640/mnt/women_health_internal/dufan.wu/deep_denoiser_ensemble/train/mayo_2d_3_layer_mean/',
    
    'nslice_mean': '3', 
    'img_norm': '0.019', 
    'islices': ['0', '-1'],
    
    'filter': 'hann',
    'margin': '96',
    'vmin': '-0.16', 
    'vmax': '0.24',
} 

In [13]:
# ensemble with real dose data
additional_args = {
    'py': 'test2d_average_mayo',
    'N0': '-1',
    'dose_rate': '4', 
    'checkpoint': '25', 
    'tags': 'l2_depth_3/dose_rate_2,l2_depth_3/dose_rate_4,l2_depth_3/dose_rate_8,l2_depth_3/dose_rate_16',
    'checkpoint_smooth': 'l2_depth_3/dose_rate_16/25.h5',
    'N0_ref': '2.5e4',
}

args = copy.deepcopy(base_args)
args.update(additional_args)

for name in prj_names:
    cmd = copy.deepcopy(args)
    cmd['prj'] = os.path.join(input_dir, '%s_quarter_sino.mat'%name)
    cmd['output'] = os.path.join(output_dir, 'quarter/average/%s'%name)
    cmd['device'] = devices[len(cmds) % len(devices)]
    
    cmds.append(cmd)

In [14]:
# ensemble with simulated dose
additional_args = {
    'py': 'test2d_average_mayo',
    'N0': '1e5',
    'checkpoint': '25', 
    'tags': 'l2_depth_3/dose_rate_2,l2_depth_3/dose_rate_4,l2_depth_3/dose_rate_8,l2_depth_3/dose_rate_16',
    'checkpoint_smooth': 'l2_depth_3/dose_rate_16/25.h5',
    'N0_ref': '2.5e4',
}

args = copy.deepcopy(base_args)
args.update(additional_args)

for dose in dose_rates:
    for name in prj_names:
        cmd = copy.deepcopy(args)
        cmd['prj'] = os.path.join(input_dir, '%s_full_sino.mat'%name)
        cmd['output'] = os.path.join(output_dir, 'dose_rate_%d/average/%s'%(dose, name))
        cmd['dose_rate'] = '%d'%dose 
        cmd['device'] = devices[len(cmds) % len(devices)]

        cmds.append(cmd)

In [15]:
with open('%s.sh'%job_name, 'w') as f:
    # slurm
    f.write(slurm_header + '\n\n')
    f.write('cd ..\n')

    for k, cmd in enumerate(cmds):
        # arguments
        argstrs = []
        for name in cmd:
            if name == 'py':
                continue
            if isinstance(cmd[name], list):
                argstrs.append(' '.join(['--%s'%name] + ['"%s"'%s for s in cmd[name]]))
            else:
                argstrs.append('--%s "%s"'%(name, cmd[name]))
        argstr = ' '.join(argstrs)
        
        logstr = '&>> %s_%d.log'%(log_name, k%nprocesses)
        
        f.write('python3 %s.py '%cmd['py'] + argstr + ' ' + logstr + ' &\n')
        if (k+1)%nprocesses == 0:
            f.write('wait\n')
            f.write('echo "%d/%d"\n'%(k+1, len(cmds)))
    f.write('wait\n')
    
    # cat logs together
    f.write('cat ' + ' '.join(['%s_%d.log'%(log_name, k) for k in range(nprocesses)]) + ' > ' + log_name + '.log\n')
    for k in range(nprocesses):
        f.write('rm %s_%d.log\n'%(log_name, k))