In [12]:
%load_ext autoreload
%autoreload 2

import os 
from pathlib import Path

# os.environ['CUDA_VISIBLE_DEVICES'] = '0'

def mkdir(path: Path) -> Path:
    path = Path(path)
    if path.suffix != '':
        path = path.parent
    if path.exists():
        print('path exists, leaving alone')
    else:
        path.mkdir(parents=True)
    return path

this_dir = Path('').parent
TMP = mkdir('./tmp/out')

import paramiko
import sys
import subprocess
import wandb
from time import sleep
from functools import partial, reduce
from itertools import product
from simple_slurm import Slurm
import random
from typing import Any, Iterable
import re
from ast import literal_eval

def get_cartesian_product(*args):
    """ Cartesian product is the ordered set of all combinations of n sets """
    return list(product(*args))

def zip_in_n_chunks(arg: Iterable[Any], n: int) -> zip:   
    return zip(*([iter(arg)]*n))

def flat_list(lst_of_lst):
    return [lst for sublst in lst_of_lst for lst in sublst]

def gen_alphanum(n: int = 7, test=False):
    from string import ascii_lowercase, ascii_uppercase
    random.seed(test if test else None)
    numbers = ''.join([str(i) for i in range(10)])
    characters = ascii_uppercase + ascii_lowercase + numbers
    name = ''.join([random.choice(characters) for _ in range(n)])
    return name

def add_to_Path(path: Path, string: str | Path):
        return Path(str(path) + str(string))

def iterate_folder(folder: Path, iter_exp_dir):
    if iter_exp_dir and folder.exists():
        for i in range(100):
            _folder = add_to_Path(folder, f'-{i}')
            if not re.search(_folder.name, f'-[0-9]*'):
                folder = _folder
                break
        else:
            folder = add_to_Path(folder, f'-0')
    return folder

class Sub:
    
    def __init__(_i, parent=None):
        _i.parent = parent
        _i.__safe_init__()

    def __safe_init__(_i):
        cls_d = dict(_i.__class__.__dict__)
        sub_cls = {k:v for k,v in cls_d.items() if isinstance(v, type)}
        [cls_d.pop(k) for k in sub_cls.keys()]
        for k,v in cls_d.items():
            # print(k, type(v), callable(v))
            if k.startswith('__') or k in Pyfig._ignore_attr:
                continue # this was around because 'dicts are unhashable types'
            if callable(v) or isinstance(v, property):
                continue
            setattr(_i, k, v)
        [setattr(_i, k, v(parent=_i)) for k,v in sub_cls.items()]

    @property
    def dict(_i,):
        d = cls_to_dict(_i, Pyfig._ignore_attr)
        for k,v in d.items():
            if issubclass(type(v), Sub):
                d[k] = cls_to_dict(v, Pyfig._ignore_attr)
        return d

def cls_to_dict(cls, ignore:list)->dict:
    d = {}
    for k,v in cls.__dict__.items():
        if k.startswith('_') or k in ignore:
            continue
        if callable(v):
            continue
        if issubclass(type(v), Sub):
            d[k] = cls_to_dict(v, ignore)
            continue
        d[k] = getattr(cls,k)
    return d

def cmd_to_dict(cmd:str|list,ref:dict,_d={},delim:str=' --'):
    """
    fmt: [--flag, arg, --true_flag, --flag, arg1]
    # all flags double dash because of negative numbers duh """
    booleans = ['True', 'true', 't', 'False', 'false', 'f']
    
    cmd = ' '.join(cmd) if isinstance(cmd, list) else cmd
    cmd = [x.lstrip().lstrip('--').rstrip() for x in cmd.split(delim)]
    cmd = [x.split(' ', maxsplit=1) for x in cmd if ' ' in x]
    [x.append('True') for x in cmd if len(x) == 1]
    cmd = flat_list(cmd)
    cmd = iter([x.strip() for x in cmd])

    for k,v in zip(cmd, cmd):
        if v in booleans: 
            v=booleans.index(v)<3  # 0-2 True 3-5 False
        if k in ref:
            _d[k] = type(ref[k])(v)
        else:
            try:
                _d[k] = literal_eval(v)
            except:
                _d[k] = str(v)
            print(f'Guessing type: {k} as {type(v)}')
    return _d

def update_cls_with_dict(cls: Any, d:dict):
    cls_all = [v for v in cls.__dict__.values() if issubclass(type(v), Sub)]
    cls_all.extend([cls])
    n_remain = len(d)
    for k,v in d.items():
        for _cls_assign in cls_all:            
            if not hasattr(_cls_assign, k):
                continue
            else:
                if isinstance(cls.__class__.__dict__[k], property):
                    print('Tried to assign property, consider your life choices')
                    continue
                v = type(cls.__dict__)(v)
                setattr(_cls_assign, k, v)
                n_remain -= 1
    return n_remain

def cls_to_dict(cls, ignore:list)->dict:
    return {k: getattr(cls, k) for k in dir(cls) if not (k.startswith('_') or k in ignore)}

def flat_dict(d:dict,items:list=[]):
    for k,v in d.items():
        if isinstance(v, dict):
            items.extend(flat_dict(v, items=items).items())
        else:
            items.append((k, v))
    return dict(items)  

def dict_to_wandb(d:dict,parent_key:str='',sep:str ='.',items:list=[])->dict:
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
        if isinstance(v, dict): 
            items.extend(dict_to_wandb(v,new_key,items=items).items())
        else:
            if isinstance(v, Path):  v=str(v)
            items.append((new_key, v))
    return dict(items)

def count_gpu() -> int: # output = run_cmd('echo $CUDA_VISIBLE_DEVICES', cwd='.')
    return sum(c.isdigit() for c in os.environ.get('CUDA_VISIBLE_DEVICES'))

def run_cmds(cmd:str|list,cwd:str|Path=None,input_req:str=None):
    _out = []
    for cmd_1 in (cmd if isinstance(cmd, list) else [cmd]): 
        cmd_1 = [c.strip() for c in cmd_1.split(' ')]
        _out += [subprocess.run(
            cmd_1,cwd=cwd,input=input_req, capture_output=True)]
        sleep(0.1)
    return _out

def run_cmds_server(server:str,user:str,cmd:str|list,cwd=str|Path):
    _out = []
    client = paramiko.SSHClient()
    client.set_missing_host_key_policy(paramiko.AutoAddPolicy())  # if not known host
    client.connect(hostname=server, username=user)
    client.exec_command(f'cd {cwd}')
    with client as _r:
        for cmd_1 in (cmd if isinstance(cmd, list) else [cmd]):
            _out += [_r.exec_command(f'{cmd_1}')] # in, out, err
            sleep(0.1)
    return _out

class Pyfig(Sub):

    seed:               int     = 808017424 # grr
    project_root:       str     = Path().home()/'projects'

    project:            str     = 'hwat'
    project_path:       Path    = property(lambda _: _.project_root / _.project)
    server_project_path:Path    = property(lambda _: _.project_path)
    n_device:           int     = property(lambda _: count_gpu())

    exp_name:           str     = 'junk'
    run_path:           Path    = property(lambda _: _.project_path / 'run.py')
    data_dir:           Path    = project_root / 'data'
    
    half_precision:     bool    = True
    dtype:              str     = 'f32'
    n_step:             int     = 1000

    af:                 str     = 'tanh' # activation function
    n_layer:            int     = 3
    
    class data(Sub):
        dataset     = 'fashion_mnist'
        b_size      = 16
        cache       = False
        image_size  = 28
        channels    = 1

    class model(Sub):
        dim         = 64
        dim_mults   = (1, 2, 4)

    class opt(Sub):
        optimizer   = 'Adam'
        beta1       = 0.9
        beta2       = 0.99
        eps         = 1e-8
        lr          = 0.001
        loss        = 'l1'  # change this to loss table load? 

    class sweep(Sub):
        method      = 'random'
        name        = 'sweep'
        metrics = dict(
            goal    = 'minimize',
            name    = 'validation_loss',
        )
        parameters = dict(
            batch_size  = {'values' : [16, 32, 64]},
            epoch       = {'values' : [5, 10, 15]},
            lr          = {'max'    : 0.1, 'min': 0.0001},
        )
        n_sweep = run_cap = reduce(
            lambda i,j:i*j,[len(v['values']) for k,v in parameters.items() if 'values' in v])+1
        sweep_id = ''

    class wandb(Sub):
        job_type:       str     = 'training'
        entity:         str     = 'xmax1'

    log_sample_step:    int     = 5
    log_metric_step:    int     = 5
    log_state_step:     int     = 10         # wandb entity
    n_epoch:            int     = 20

    class slurm(Sub):
        output          = TMP/'o-%j.out'
        error           = TMP/'e-%j.err'
        mail_type       = 'FAIL'
        partition       ='sm3090'
        nodes           = 1                # n_node
        ntasks          = 8                # n_cpu
        cpus_per_task   = 1     
        time            = '0-12:00:00'     # D-HH:MM:SS
        gres            = 'gpu:RTX3090:1'
        job_name        = property(lambda _: _.parent.exp_name)  # this does not call the instance it is in
        sbatch          = property(lambda _: f""" 
            module purge 
            source ~/.bashrc 
            module load GCC 
            module load CUDA/11.4.1 
            module load cuDNN/8.2.2.26-CUDA-11.4.1 
            conda activate {_.parent.env} 
            export MKL_NUM_THREADS=1 
            export NUMEXPR_NUM_THREADS=1 
            export OMP_NUM_THREADS=1 
            export OPENBLAS_NUM_THREADS=1
            pwd
            nvidia-smi
            mv_cmd = f'mv {TMP}/o-$SLURM_JOB_ID.out {TMP}/e-$SLURM_JOB_ID.err $out_dir' 
    """)

    exp_id:             str     = gen_alphanum(n=7)
    
    iter_exp_dir:       bool    = True
    project_exp_dir:    Path    = property(lambda _: _.project_path / 'exp')
    project_cfg_dir:    Path    = property(lambda _: _.project_path / 'cfg')
    exp_path:           Path    = property(lambda _: iterate_folder(_.project_exp_dir/_.exp_name,_.iter_exp_dir)/_.exp_id)

    server:             str     = 'svol.fysik.dtu.dk'   # SERVER
    user:               str     = 'amawi'     # SERVER
    entity:             str     = 'xmax1'       # WANDB entity
    git_remote:         str     = 'origin'      
    git_branch:         str     = 'main'        
    env:                str     = 'dex'            # CONDA ENV
    commit_id:          str     = property(lambda _: _.get_commit_id())
    
    _sys_arg: list = sys.argv[1:]
    _submit_state:     int     = -1
    _ignore_attr = ['parent','protected','dict','cmd']

    def __init__(_i,args:dict={},cap=40,wandb_mode='online',notebook=False):
        super().__init__()
        _i.__safe_init__()

        update_cls_with_dict(_i,args)
        if not notebook:
            update_cls_with_dict(cmd_to_dict(sys.argv[1:],_i.dict))

        wandb.init(
            job_type    = _i.wandb.job_type,
            entity      = _i.wandb.entity,
            project     = _i.project,
            dir         = _i.exp_path,
            config      = dict_to_wandb(_i.dict),
            mode        = wandb_mode,
            settings=wandb.Settings(start_method='fork'), # idk y this is issue, don't change
        )

        if _i._submit_state > 0:
            n_job_running = run_cmds([f'squeue -u {_i.user} -h -t pending,running -r | wc -l'])
            if n_job_running > cap:
                exit(f'There are {n_job_running} on the submit cap is {cap}')

            _slurm = Slurm(**_i.slurm.dict)

            n_run, _i._submit_state = _i._submit_state, 0            
            for _ in range(n_run):
                _slurm.sbatch(_i.slurm.sbatch 
                + f'out_dir={(mkdir(_i.exp_path/"out"))} {_i.cmd} | tee $out_dir/py.out date "+%B %V %T.%3N" ')

    @property
    def cmd(_i,):
        d = flat_dict(_i.dict)
        return ' '.join([f' --{k}  {str(v)} ' for k,v in d.items()])

    @property
    def commit_id(_i,)->str:
        process = run_cmds(['git log --pretty=format:%h -n 1'], cwd=_i.project_path)[0]
        return process.stdout.decode('utf-8') 

    def submit(_i, sweep=False, commit_msg=None, cap=40):
        commit_msg = commit_msg or _i.exp_id
        _i._submit_state *= -1
        if _i._submit_state > 0:
            if sweep:
                _i.sweep_id = wandb.sweep(
                    env     = f'conda activate {_i.env};',
                    sweep   = _i.sweep.dict, 
                    program = _i.run_path,
                    project = _i.project,
                    name    = _i.exp_name,
                    run_cap = _i.sweep.n_sweep
                )
                _i._submit_state *= _i.sweep.n_sweep
            local_out = run_cmds(['git add .', f'git commit -m {commit_msg}', 'git push'], cwd=_i.project_path)
            server_out = run_cmds_server(_i.server, _i.user, f'python -u {_i.run_path} ' +_i.cmd, cwd=_i.server_project_path)

def print_pyfig(d:dict):
    for k,v in d.items():
        if isinstance(v, dict):
            print(f'{k}: ')
            print_pyfig(v)
        else:
            print(k,v)

c = Pyfig(wandb_mode='disabled', notebook=True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
path exists, leaving alone


In [13]:
from pprint import pprint

pprint(c.dict)



{'af': 'tanh',
 'commit_id': 'b78c668',
 'data': {'b_size': 16,
          'cache': False,
          'channels': 1,
          'dataset': 'fashion_mnist',
          'image_size': 28},
 'data_dir': PosixPath('/home/amawi/projects/data'),
 'dtype': 'f32',
 'entity': 'xmax1',
 'env': 'dex',
 'exp_id': 'NCscEnx',
 'exp_name': 'junk',
 'exp_path': PosixPath('/home/amawi/projects/hwat/exp/junk/NCscEnx'),
 'git_branch': 'main',
 'git_remote': 'origin',
 'half_precision': True,
 'iter_exp_dir': True,
 'log_metric_step': 5,
 'log_sample_step': 5,
 'log_state_step': 10,
 'model': {'dim': 64, 'dim_mults': (1, 2, 4)},
 'n_device': 0,
 'n_epoch': 20,
 'n_layer': 3,
 'n_step': 1000,
 'opt': {'beta1': 0.9,
         'beta2': 0.99,
         'eps': 1e-08,
         'loss': 'l1',
         'lr': 0.001,
         'optimizer': 'Adam'},
 'project': 'hwat',
 'project_cfg_dir': PosixPath('/home/amawi/projects/hwat/cfg'),
 'project_exp_dir': PosixPath('/home/amawi/projects/hwat/exp'),
 'project_path': PosixPath('/h