# Utils

> Utilities to avoid re-inventing the wheel

In [None]:
#| default_exp utils

In [None]:
#| hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *

## Device

In [None]:
#| export

import torch
import lightning
import logging
from rich.logging import RichHandler
from time import time, sleep
from functools import wraps

In [None]:
#| export
FORMAT = "[%(asctime)s] %(levelname)s - %(message)s"
logging.basicConfig(
    level=logging.INFO, format=FORMAT, datefmt='%H:%M:%S',
    # handlers=[RichHandler(rich_tracebacks=True)]#, tracebacks_suppress=[matplotlib, L])]
)
logger = logging.getLogger(__name__)

In [None]:
#| export

def get_device():
    device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
    logger.info(f"Using device: {device}")
    return device

### Usage

In [None]:
device = get_device()
print(device)
print(f"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")

mps
Is MPS (Metal Performance Shader) built? True


## Seeding

In [None]:
#| export
def set_seed(seed: int = 42) -> None:
    # np.random.seed(seed)
    # random.seed(seed)
    # torch.manual_seed(seed)
    # torch.cuda.manual_seed(seed)
    # # When running on the CuDNN backend, two further options must be set
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False
    # # Set a fixed value for the hash seed
    # os.environ["PYTHONHASHSEED"] = str(seed)
    # print(f"Random seed set as {seed}")
    lightning.seed_everything(seed, workers=True)

### Usage

In [None]:
set_seed()

Seed set to 42


## Timing

In [None]:
#| export
def time_it(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time()
        result = func(*args, **kwargs)
        end = time()
        print(f"{func.__name__}: {end-start:.3f} seconds")
        return result
    return wrapper



### Usage

In [None]:
@time_it
def hello(name):
    """Says hello to someone"""
    sleep(0.001)
    return f"Hello {name}!"

print(hello.__name__)  # Prints: "hello"
print(hello.__doc__)   # Prints: "Says hello to someone"
print(hello('sylain'))

hello
Says hello to someone
hello: 0.105 seconds
Hello sylain!


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()