In [1]:
'''
generate sbatch scripts for dect testing
'''

'\ngenerate sbatch scripts for dect testing\n'

In [2]:
import glob
import os
import copy
import pathlib
from datetime import datetime
import numpy as np

In [3]:
job_name = 'dect2d'
devices = ['2', '3']
nprocesses = len(devices)
log_name = './outputs/dect2d'
slurm_header = """#!/bin/bash
#SBATCH --partition=defq
#SBATCH --job-name=%s
#SBATCH --nodelist=gpu-node008
#SBATCH --cpus-per-task=16
#SBATCH --time=0
"""%job_name

cmds = []

In [4]:
input_dir = '/home/dwu/data/DECT/sinogram/'
output_dir = '/home/dwu/trainData/deep_denoiser_ensemble/test/dect_2d_3_layer_mean/'
# prj_names = ['35', '54', '56']
prj_names = ['35']

base_args = {
    'geometry': '/home/dwu/trainData/deep_denoiser_ensemble/data/dect_2d_3_layer_mean/geometry.cfg',
    'train_dir': '/home/dwu/trainData/deep_denoiser_ensemble/train/mayo_2d_3_layer_mean/',
    
    'nslice_mean': '3', 
    'img_norm': '0.019', 
    'islices': ['0', '-1'],
    
    'filter': 'hann',
    'margin': '96',
    'vmin': '-0.16', 
    'vmax': '0.24',
} 

In [5]:
# ensemble with real dose data
additional_args = {
    'py': 'test2d_ensemble_dect',
    'checkpoint': '25', 
    'tags': 'l2_depth_3/dose_rate_2,l2_depth_3/dose_rate_4,l2_depth_3/dose_rate_8,l2_depth_3/dose_rate_16',
    'checkpoint_smooth': 'l2_depth_3/dose_rate_16/25.h5',
    'N0_ref': '1.25e4',
}

args = copy.deepcopy(base_args)
args.update(additional_args)

for name in prj_names:
    for energy in ['a', 'b']:
        cmd = copy.deepcopy(args)
        cmd['prj'] = os.path.join(input_dir, 'sino_%s_2.mat'%name)
        cmd['output'] = os.path.join(output_dir, 'half/ensemble/%s'%name)
        cmd['energy'] = energy
        cmd['device'] = devices[len(cmds) % len(devices)]

        cmds.append(cmd)

In [6]:
# fbp 
additional_args = {
    'py': 'test2d_fbp_dect',
}

args = copy.deepcopy(base_args)
args.pop('train_dir', None)
args.update(additional_args)

for dose in ['1', '2']:
    for name in prj_names:
        for energy in ['a', 'b']:
            cmd = copy.deepcopy(args)
            cmd['prj'] = os.path.join(input_dir, 'sino_%s_%s.mat'%(name, dose))
            if dose == '1':
                cmd['output'] = os.path.join(output_dir, 'full/fbp/%s'%name)
            else:
                cmd['output'] = os.path.join(output_dir, 'half/fbp/%s'%name)
            cmd['device'] = devices[len(cmds) % len(devices)]
            cmd['energy'] = energy

            cmds.append(cmd)

In [7]:
# single network with real dose data
additional_args = {
    'py': 'test2d_dect',
}

args = copy.deepcopy(base_args)
args.update(additional_args)

for checkpoint in ['l2_depth_3/all/5.h5', 'l2_depth_3/dose_rate_4/25.h5']:
    for name in prj_names:
        for energy in ['a', 'b']:
            cmd = copy.deepcopy(args)
            cmd['prj'] = os.path.join(input_dir, 'sino_%s_2.mat'%name)
            if 'all' in checkpoint:
                cmd['output'] = os.path.join(output_dir, 'half/l2/%s_all'%name)
            else:
                cmd['output'] = os.path.join(output_dir, 'half/l2/%s'%name)
            cmd['checkpoint'] = checkpoint
            cmd['device'] = devices[len(cmds) % len(devices)]
            cmd['energy'] = energy

            cmds.append(cmd)

In [8]:
# single wgan with real dose data
additional_args = {
    'py': 'test2d_dect',
}

args = copy.deepcopy(base_args)
args.update(additional_args)

for checkpoint in ['l2_depth_3_wgan/all/5.h5', 'l2_depth_3_wgan/dose_rate_4/25.h5']:
    for name in prj_names:
        for energy in ['a', 'b']:
            cmd = copy.deepcopy(args)
            cmd['prj'] = os.path.join(input_dir, 'sino_%s_2.mat'%name)
            if 'all' in checkpoint:
                cmd['output'] = os.path.join(output_dir, 'half/wgan/%s_all'%name)
            else:
                cmd['output'] = os.path.join(output_dir, 'half/wgan/%s'%name)
            cmd['checkpoint'] = checkpoint
            cmd['device'] = devices[len(cmds) % len(devices)]
            cmd['energy'] = energy

            cmds.append(cmd)

In [9]:
with open('%s.sh'%job_name, 'w') as f:
    # slurm
    f.write(slurm_header + '\n\n')
    f.write('cd ..\n')

    for k, cmd in enumerate(cmds):
        # arguments
        argstrs = []
        for name in cmd:
            if name == 'py':
                continue
            if isinstance(cmd[name], list):
                argstrs.append(' '.join(['--%s'%name] + ['"%s"'%s for s in cmd[name]]))
            else:
                argstrs.append('--%s "%s"'%(name, cmd[name]))
        argstr = ' '.join(argstrs)
        
        logstr = '&>> %s_%d.log'%(log_name, k%nprocesses)
        
        f.write('python3 %s.py '%cmd['py'] + argstr + ' ' + logstr + ' &\n')
        if (k+1)%nprocesses == 0:
            f.write('wait\n')
            f.write('echo "%d/%d"\n'%(k+1, len(cmds)))
    f.write('wait\n')
    
    # cat logs together
    f.write('cat ' + ' '.join(['%s_%d.log'%(log_name, k) for k in range(nprocesses)]) + ' > ' + log_name + '.log\n')
    for k in range(nprocesses):
        f.write('rm %s_%d.log\n'%(log_name, k))