In [1]:
# default_exp utils

# Utilities

> Helper functions used throughout the library not related to timeseries data.

In [2]:
#hide
from nbdev.showdoc import *
from IPython.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

In [3]:
#export
from timeseries.imports import *

In [4]:
#export
def ToTensor(o):
    if isinstance(o, torch.Tensor): return o
    if isinstance(o, np.ndarray):  return torch.from_numpy(o)
    else: print(f"Can't convert {type(o)} to torch.Tensor")


def ToArray(o):
    if isinstance(o, np.ndarray): return o
    if isinstance(o, torch.Tensor): o = o.cpu().numpy()
    else: print(f"Can't convert {type(o)} to np.array")


def To3DTensor(o):
    o = ToTensor(o)
    if o.ndim == 1: o = o[None, None]
    elif o.ndim == 2: o = o[:, None]
    assert o.ndim == 3, f'Please, review input dimensions {o.ndim}'
    return o


def To2DTensor(o):
    o = ToTensor(o)
    if o.ndim == 1: o = o[None]
    elif o.ndim == 3: o = torch.squeeze(o, 0)
    assert o.ndim == 2, f'Please, review input dimensions {o.ndim}'
    return o


def To1DTensor(o):
    o = ToTensor(o)
    if o.ndim == 3: o = torch.squeeze(o, 1)
    if o.ndim == 2: o = torch.squeeze(o, 0)
    assert o.ndim == 1, f'Please, review input dimensions {o.ndim}'
    return o


def To3DArray(o):
    o = ToArray(o)
    if o.ndim == 1: o = o[None, None]
    elif o.ndim == 2: o = o[:, None]
    elif o.ndim == 4: o = o[0]
    assert o.ndim == 3, f'Please, review input dimensions {o.ndim}'
    return o


def To2DArray(o):
    o = ToArray(o)
    if o.ndim == 1: o = o[None]
    if o.ndim == 3: o = np.squeeze(o, 0)
    assert o.ndim == 2, f'Please, review input dimensions {o.ndim}'
    return o


def To1DArray(o):
    o = ToArray(o)
    if o.ndim == 3: o = np.squeeze(o, 1)
    if o.ndim == 2: o = np.squeeze(o, 0)
    assert o.ndim == 1, f'Please, review input dimensions {o.ndim}'
    return o
    
    
def To3D(o):
    if isinstance(o, np.ndarray): return To3DArray(o)
    if isinstance(o, torch.Tensor): return To3DTensor(o)
    
    
def To2D(o):
    if isinstance(o, np.ndarray): return To2DArray(o)
    if isinstance(o, torch.Tensor): return To2DTensor(o)
    
    
def To1D(o):
    if isinstance(o, np.ndarray): return To1DArray(o)
    if isinstance(o, torch.Tensor): return To1DTensor(o)
    
    
def To2DPlus(o):
    if o.ndim >= 2: return o
    if isinstance(o, np.ndarray): return To2DArray(o)
    elif isinstance(o, torch.Tensor): return To2DTensor(o)
    
    
def To3DPlus(o):
    if o.ndim >= 3: return o
    if isinstance(o, np.ndarray): return To3DArray(o)
    elif isinstance(o, torch.Tensor): return To3DTensor(o)
    
    
def To2DPlusTensor(o):
    return To2DPlus(ToTensor(o))


def To2DPlusArray(o):
    return To2DPlus(ToArray(o))


def To3DPlusTensor(o):
    return To3DPlus(ToTensor(o))


def To3DPlusArray(o):
    return To3DPlus(ToArray(o))


def ToType(dtype):
    def _to_type(o, dtype=dtype):
        if isinstance(o, np.ndarray) and o.dtype != dtype: o = o.astype(dtype)
        elif isinstance(o, torch.Tensor) and o.dtype != dtype: o = o.to(dtype=dtype)
        return o
    return _to_type

In [5]:
a = np.random.rand(100).astype(np.float32)
b = torch.from_numpy(a).float()
test_eq(ToTensor(a), b)
test_eq(a, ToArray(b))
test_eq(To3DTensor(a).ndim, 3)
test_eq(To2DTensor(a).ndim, 2)
test_eq(To1DTensor(a).ndim, 1)
test_eq(To3DArray(b).ndim, 3)
test_eq(To2DArray(b).ndim, 2)
test_eq(To1DArray(b).ndim, 1)

AssertionError: ==:
[0.72938675 0.15663046 0.5438222  0.22088297 0.46454814 0.5141533
 0.16714936 0.5135949  0.3351119  0.11289433 0.4153638  0.978432
 0.8224443  0.2878643  0.2742256  0.6790875  0.9918491  0.68498325
 0.02218606 0.95189553 0.8369518  0.7353786  0.8527941  0.3206877
 0.799821   0.9442978  0.25944415 0.73609793 0.89060056 0.40679598
 0.43710482 0.4497905  0.92010623 0.74176794 0.88327956 0.28482413
 0.04484292 0.87331396 0.39690018 0.2659735  0.00203312 0.6913985
 0.5547523  0.03704533 0.52407575 0.40796095 0.8088609  0.03368099
 0.75177205 0.8601545  0.939189   0.66926676 0.82130593 0.44646013
 0.09590904 0.2727818  0.05220867 0.66724277 0.06346308 0.3227858
 0.66742307 0.7041728  0.9012571  0.01001183 0.18869549 0.29003328
 0.7351315  0.018137   0.9058066  0.32126606 0.41499957 0.57107013
 0.50783235 0.7069229  0.4474577  0.9663113  0.21725594 0.68476665
 0.15556687 0.9888876  0.28927672 0.4848473  0.01050607 0.85284954
 0.20404243 0.5904083  0.15032993 0.3306254  0.4182569  0.35691574
 0.90094036 0.48238122 0.18380497 0.28157395 0.69944143 0.51608443
 0.5514344  0.14634123 0.69277966 0.17949983]
None

In [None]:
#export
import math
def bytes2size(size_bytes):
    if size_bytes == 0: return "0B"
    size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
    i = int(math.floor(math.log(size_bytes, 1024)))
    p = math.pow(1024, i)
    s = round(size_bytes / p, 2)
    return "%s %s" % (s, size_name[i])

def bytes2GB(bytes):
    return round(bytes / math.pow(1024, 3), 2)

In [None]:
#export
def delete_all_in_dir(tgt_dir:str, exception:Union[str, tuple]=None):
    if exception is not None and len(L(exception)) > 1: exception = tuple(exception)
    for file in os.listdir(tgt_dir):
        if exception is not None and file.endswith(exception): continue
        file_path = os.path.join(tgt_dir, file)
        if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path)
        elif os.path.isdir(file_path): shutil.rmtree(file_path)

In [None]:
#export
def reverse_dict(dictionary): 
    return {v: k for k, v in dictionary.items()}

In [None]:
#export
def is_tuple(o): return isinstance(o, tuple)

In [None]:
#export
def itemify(*o): return L(*o).zip()

In [None]:
#export
def ifnotnone(a, b):
    "`a` if `a` is None else `b`"
    return a if a is None else b

In [None]:
#hide
test_eq(ifnotnone(None, 2), None)
test_eq(ifnotnone(1, 2), 2)

In [None]:
#export
def ifnoneelse(a, b, c):
    "`b` if `a` is None else `c`"
    return b if a is None else c

In [None]:
#hide
test_eq(ifnoneelse(None, 1, 2), 1)
test_eq(ifnoneelse(1, 2, 3), 3)

In [None]:
#hide
from save_nb import *
from nbdev.export import notebook2script
save_nb()
notebook2script()
test_eq(last_saved() < 10, True)