In [1]:
# SLURM class
class SLURM:
    def __init__(self):
        import subprocess
        from collections import defaultdict
        from time import sleep
        
        self.job_list = defaultdict(list)

    def start(self,ncpus=2,time_str='06:00:00',
              mempercpu='2G',jname='ipengines',name='slurm'):
        global job_list
        import subprocess
        result = subprocess.check_output(
            '''echo "#!/bin/sh
#SBATCH -J %s
#SBATCH -n %d
#SBATCH --time=%s
#SBATCH --mem-per-cpu=%s
#SBATCH --partition=defq

ipcluster start --ip='*' --profile='%s'"|sbatch''' % (
                jname, ncpus, time_str, mempercpu, name
            ),
            shell=True,
            stderr=subprocess.STDOUT

        )
        print(result)
        jnum = int(result.split(b' ')[-1])
        self.job_list[jname].append(jnum)
        return jnum

    def cancel(self,jid):
        import subprocess
        result = subprocess.check_output('scancel %d'%jid,
                                         shell=True,
                                         stderr=subprocess.STDOUT
                                        )
        print(result)
        return

    def show(self):
        import subprocess
        result = subprocess.check_output(
            'squeue -o "%.8i %.5P %.8j %.8u %.2t %.8M %.3D %.3c %.5m %R" -u ksung',
            shell=True
        )
        print(result.decode('utf-8'))
        return

    def killname(self,jname='ipengines'):
        for j in self.job_list[jname]:
            self.cancel(j)
        self.job_list[jname] = []

In [3]:
# Swarm_Map class
class Swarm_Map:
    def __init__(self,ncpu=2,mempercpu='2G'):
        self.mempercpu = mempercpu
        self.ncpu = ncpu
    def vanilla_map(self,func,arr=None,data=None,ncpu=None):
        try:
            import time, subprocess, dill, os
            from random import random
            if not data: data = {}
            if not arr: arr = []
            if not ncpu:
                ncpu = self.ncpu
            if type(data) == list:
                data = {k: globals()[k] for k in data}
            if func.__name__ not in data:
                data[func.__name__] = func
            assert '__arr' not in data

            data['__arr'] = arr

            now = int(time.time()*1e6)
            py_file_name = 'tmp_%d.py' % now
            pkl_file_name = 'tmp_%d.pkl' % now
            res_file_name = 'res_%d.pkl' % now

            with open(pkl_file_name,'wb') as f:
                dill.dump(data,f)
            
            # write function
            
            # write to file
            with open(py_file_name,'w') as f:
                f.write('''#!/usr/bin/env python3
import dill
from multiprocessing import Pool, cpu_count
dill_data = dill.load(open('%s','rb'))\n'''%pkl_file_name
                       )
                for k in data:
                    f.write("globals()['%s']=dill_data['%s']\n"%(k,k))
                f.write("p = Pool(cpu_count())\n")
                f.write("result = list(p.starmap(%s,__arr))\n"%(func.__name__))
                f.write("dill.dump(result,open('%s','wb'))"%res_file_name)
            os.chmod(py_file_name,0o755)
            # print('srun -n 1 -c %d %s'%(ncpu,py_file_name))
            try:
                subprocess.check_output("ssh swarm2 'cd k && srun -n 1 -c %d %s'"%(ncpu,py_file_name),
                                    shell=True,
                                    stderr=subprocess.STDOUT
                                   )
            except subprocess.CalledProcessError as e:
                print (e.returncode)
                print (e.output)

            return dill.load(open(res_file_name,'rb'))
        finally:
            os.unlink(py_file_name)
            os.unlink(pkl_file_name)
            os.unlink(res_file_name)
            
    def ipengine_map(self,func,arr=None,data=None,ncpu=None):
        if not ncpu:
            ncpu = self.ncpu
        try:
            import ipyparallel as ipp
            from time import sleep,time
            slurm = SLURM()
            jnum = slurm.start(ncpu)
            slurm.show()
            c = None
            while not c or len(c.ids) < ncpu:
                if c: c.close()
                try:
                    c = ipp.Client(profile='slurm',timeout=1)
                except ipp.TimeoutError:
                    continue
                sleep(1)
            print(len(c.ids))
            sleep(1)
            dview = c[:]
            dview.push(data)
            lb = c.load_balanced_view()
            ar = lb.map_async(func,*zip(*arr)).get()
            startt = time()
            from IPython import display
            while ar.progress < len(arr):
                display.clear_output(wait=True)
                print('start:',startt)
                print('ar',ar.progress)
                print(int(time()-startt))
                sleep(1)
            print('end:',time())
            print('runtime:',time()-startt)
            return ar.get()
        finally:
            slurm.killname()
        
    def compute_available(self):
        pass
    def profile(self, func, arr, data):
        pass
        # what's the max requestable resources?
    def bundle_map(self,func,arr=None,data=None,bundle_size=None,
                  time_str='06:00:00',mempercpu=None,jname='bundle'):
        if not mempercpu:
            mempercpu = self.mempercpu
        if not bundle_size:
            bundle_size = self.ncpu
        try:
            import time, subprocess, dill, os, shutil
            from random import random
            if not data: data = {}
            if not arr: arr = []
                
            if type(data) == list:
                data = {k: globals()[k] for k in data}
            if func.__name__ not in data:
                data[func.__name__] = func
            
            assert '__arr' not in data
            
            now = int(time.time()*1e6)
            pkl_file_name = 'tmp_%d.pkl' % now
            
            py_dir = 'py_%d/'%now
            tmp_dir = 'tmp_%d/'%now
            res_dir = 'res_%d/'%now
            os.mkdir(py_dir)
            os.mkdir(tmp_dir)
            os.mkdir(res_dir)
            with open(pkl_file_name,'wb') as f:
                dill.dump(data,f)
            
            bundle_size = (bundle_size+1)//2*2
            
            pending = {}
            for bundle_i in range(0,len(arr),bundle_size):
                now_id = str(int(time.time()*1e6))+'_'+str(bundle_i)
                py_file_name = py_dir+'tmp_%s.py' % now_id
                bundle_pkl_file_name = tmp_dir+'tmp_%s.pkl' % now_id
                res_file_name = res_dir+'res_%s.pkl' % now_id
                
                with open(bundle_pkl_file_name,'wb') as f:
                    dill.dump(arr[bundle_i:bundle_i+bundle_size],f)
                
                # write to file
                with open(py_file_name,'w') as f:
                    f.write('''#!/usr/bin/env python3
import dill
from multiprocessing import Pool, cpu_count
dill_data = dill.load(open('%s','rb'))
__arr = dill.load(open('%s','rb'))\n'''%(pkl_file_name,bundle_pkl_file_name)
                       )
                    for k in data:
                        f.write("globals()['%s']=dill_data['%s']\n"%(k,k))
                    f.write("p = Pool(cpu_count())\n")
                    f.write("result = list(p.starmap(%s,__arr))\n"%(func.__name__))
                    f.write("dill.dump(result,open('%s','wb'))"%res_file_name)
                os.chmod(py_file_name,0o755)
                result = subprocess.check_output(
                '''echo "#!/bin/sh
#SBATCH -J %s
#SBATCH -n 1
#SBATCH -c %d
#SBATCH --time=%s
#SBATCH --mem-per-cpu=%s
#SBATCH --partition=defq

srun -n 1 -c %d %s"|sbatch''' % (
                                jname, bundle_size, time_str, 
                                mempercpu, bundle_size, py_file_name
                            ),
                        shell=True,
                        stderr=subprocess.STDOUT
                        )
                pending[res_file_name] = [True,bundle_i,bundle_i+bundle_size]
                time.sleep(0.1)
            results = [None for _ in range(len(arr))]
            while sum([x for x,_,_ in pending.values()]) > 0:
                # print(res_dir)
                # print(os.getcwd())
                files = [res_dir+fn for fn in os.listdir(res_dir)]
                # print(list(filter(lambda x: x[1][0],pending.items())))
                for k,v in filter(lambda x: x[1][0],pending.items()):
                    if k in files:
                        # print(k)
                        try:
                            time.sleep(0.1)
                            with open(k,'rb') as f:
                                results[v[1]:v[2]] = dill.load(f)
                        except EOFError:
                            print('eof error...retrying...')
                            time.sleep(2)
                            with open(k,'rb') as f:
                                results[v[1]:v[2]] = dill.load(f)
                            print('reattempted') # fix race condition by ignoring till next round obviously
                            continue
                        #dill.load(open(k,'rb'))
                        # print('done')
                        pending[k][0] = False
                time.sleep(1)
            return results
        finally:
            time.sleep(1)
            os.unlink(pkl_file_name)
            shutil.rmtree(py_dir)
            shutil.rmtree(tmp_dir)
            shutil.rmtree(res_dir)
            