<a href="https://colab.research.google.com/github/wandb/nb_helpers/blob/main/nbs/02_utils.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


In [None]:
#| default_exp utils

# Utilities
> A bunch of helper functions

In [None]:
#| export
import io, json, sys, re, csv, logging
import git
from types import SimpleNamespace
from logging import Formatter
from logging.handlers import RotatingFileHandler
from fastcore.foundation import L
from datetime import datetime
from rich import box
from rich.table import Table
from rich.console import Console
from rich.logging import RichHandler
from fastcore.basics import ifnone, listify, store_attr
from fastcore.xtras import run
from pathlib import Path
from execnb.nbio import read_nb

In [None]:
this_nb = Path("02_utils.ipynb")
notebook = read_nb(this_nb)

## Logger
We will create a logger based on `rich.Table`, this way we get nice summaries at the end of execution

In [None]:
#| export
LOGFORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"

In [None]:
#| export
LOGFORMAT_RICH = "%(message)s"

In [None]:
#| export
def create_table(columns=["Notebook Path", "Status", "Run Time", "Colab"], xtra_cols=None) -> Table:
    table = Table(show_header=True, header_style="bold magenta")
    table.box = box.SQUARE

    table.border_style = "bright_yellow"
    table.row_styles = ["none", "dim"]
    for col in columns + listify(xtra_cols):
        table.add_column(col)
    table.columns[1].style = "cyan"
    return table

In [None]:
t = create_table(["a", "b"])
t

In [None]:
#| export
def remove_rich_format(text):
    "Remove rich fancy coloring"
    text = str(text)
    res = re.search(r"\](.*?)\[", text)
    if res is None:
        return text
    else:
        return res.group(1)

In [None]:
s = "[green]Ok[/green]:heavy_check_mark:"
assert remove_rich_format(s) == "Ok"

In [None]:
#| export
def _csv_to_md(csv_file_path, delimiter=";"):
    
    
    csv_dict = csv.DictReader(open(csv_file_path, encoding="UTF-8"), delimiter=delimiter)
    list_of_rows = [dict_row for dict_row in csv_dict]
    headers = list(list_of_rows[0].keys())
    md_string = " | "
    for header in headers:
        md_string += header + " |"

    md_string += "\n |"
    for i in range(len(headers)):
        md_string += "--- | "

    md_string += "\n"
    for row in list_of_rows:
        md_string += " | "
        for header in headers:
            md_string += row[header] + " | "
        md_string += "\n"
    return md_string

def csv_to_md(csv_file_path, delimiter=";"):
    "From csv file to markdown table, useful for github posting"
    md_string = _csv_to_md(csv_file_path, delimiter)
    output_file = Path(csv_file_path).with_suffix(".md")
    file = open(output_file, "w", encoding="UTF-8")
    file.write(md_string)
    file.close()

In [None]:
print(_csv_to_md("test_data/file.csv", delimiter=","))

 | name |number |value |
 |--- | --- | --- | 
 | a | 1 | -100 | 
 | b | 2 | 0 | 
 | c | 3 | 2 | 



some fancy rich coloring

In [None]:
#| export
STATUS = SimpleNamespace(
    ok="[green]Ok[/green]:heavy_check_mark:", fail="[red]Fail[/red]", skip="[green]Skipped[/green]:heavy_check_mark:"
)

In [None]:
#| exporti
def _format_row(fname: Path, status: str, time: str, xtra_col=None, fname_only: bool = True) -> tuple:
    "Format one row for a rich.Table"

    formatted_status = getattr(STATUS, status.lower())
    fname = fname.name if fname_only else fname
    row = (str(fname), formatted_status, f"{int(time)}s")
    if len(listify(xtra_col)) > 0:
        row += (str(xtra_col),)
    return row

In [None]:
_format_row(Path("/path/to/file"), "Ok", "1", fname_only=False)

('/path/to/file', '[green]Ok[/green]:heavy_check_mark:', '1s')

In [None]:
#| export
class RichLogger:
    "A simple logger that logs to a file and the rich console"

    def __init__(self, columns, out_file="summary_table.csv", width=180):
        store_attr()
        self.data = []
        self.links = []
        self.console = Console(width=width, record=True)
        rh = RichHandler(console=self.console)
        rh.setFormatter(Formatter(LOGFORMAT_RICH))
        logging.basicConfig(
            level=logging.ERROR,
            format=LOGFORMAT,
            handlers=[
                rh,
                RotatingFileHandler("log.txt", maxBytes=1024 * 1024 * 10, backupCount=10),  # 10Mb
            ],
        )
        self.logger = logging.getLogger("rich")
        self.info(f"CONSOLE.is_terminal(): {self.console.is_terminal}")
        self.info(f"Writing output to {out_file}")

    def writerow(self, row, colab_link=None):
        self.data.append(row)
        self.links.append(colab_link)

    def writerow_incolor(self, fname, status, time, colab_link):
        "Same as write row, but color status"
        row = _format_row(fname, status, time)
        self.writerow(row, colab_link)

    def to_csv(self, out_file, delimiter=";", format_link=False):
        self.csv_file = open(out_file, "w", newline="")
        self.csv_writer = csv.writer(self.csv_file, delimiter=delimiter)
        # write header
        self.csv_writer.writerow(self.columns)
        for row, link in zip(self.data, self.links):
            if format_link:
                fname = self._format_colab_link_md(link, row[0])
            else:
                fname = row[0]
            self.csv_writer.writerow([fname] + [remove_rich_format(e) for e in row[1:]])
        self.csv_file.close()

    def to_table(self, enum=True):
        columns = (["#"] + self.columns) if enum else self.columns
        table = create_table(columns=columns)
        for i, (row, link) in enumerate(zip(self.data, self.links)):
            fname = self._format_colab_link(link, row[0])
            table.add_row(f"{i}", fname, *row[1:])
        self.console.print(table)

    def to_md(self, out_file):
        csv_file = Path(out_file).with_suffix(".csv")
        self.to_csv(csv_file)
        csv_to_md(csv_file)
        self.info(f"Output table saved to [red]{out_file}[/red]")

    @property
    def info(self):
        return self.logger.info

    @property
    def warning(self):
        return self.logger.warning

    @property
    def exception(self):
        return self.logger.exception

    @property
    def error(self):
        return self.logger.error

    @staticmethod
    def _format_colab_link(colab_link, fname):
        return f"[link={colab_link}]{fname}[link]"

    @staticmethod
    def _format_colab_link_md(colab_link, fname):
        return f"[{fname}]({colab_link})"

In [None]:
l = RichLogger(["a", "b"])

In [None]:
l.error("An Error!")

## Functions to make my life easier!
> A bunch of random functions to deal with notebooks, git, colab, wandb, etc...

In [None]:
#| export
def is_nb(fname: Path):
    "filter files that are jupyter notebooks"
    return (fname.suffix == ".ipynb") and (not fname.name.startswith("_")) and (not "checkpoint" in str(fname))

In [None]:
assert is_nb(Path("02_utils.ipynb"))
assert not is_nb(Path("file.csv"))

In [None]:
#| export
def find_nbs(path: Path):
    "Get all nbs on path recursively"
    path = Path(path).resolve()
    if is_nb(path):
        return [path.resolve()]
    return L([nb.resolve() for nb in path.rglob("*.ipynb") if is_nb(nb)]).sorted()

there should be 7 notebooks here

In [None]:
assert len(find_nbs(Path("test_data"))) == 2

In [None]:
#| export
def print_output(notebook):  # pragma: no cover
    "Print `notebook` in stdout for git things"
    output_stream = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
    x = json.dumps(notebook, sort_keys=True, indent=1, ensure_ascii=False)
    output_stream.write(x)
    output_stream.write("\n")
    output_stream.flush()

In [None]:
# only works in terminal
# print_output(notebook)

In [None]:
#| export
CellType = SimpleNamespace(code="code", md="markdown")

In [None]:
#| export
def search_cell(cell, string) -> bool:
    "Search string in cell source, can be a list"
    source = listify(cell["source"])
    source = "".join(source)
    if string in source:
        return True
    return False

In [None]:
cells = notebook["cells"]

In [None]:
#check that we import rich
imports_cell = cells[3]
assert search_cell(imports_cell, "rich")

In [None]:
#| export
def search_cells(nb, string: str = None, cell_type=CellType.code):
    "Get cells containing string, you can pass comma separated strings"
    strings = ifnone(string, "").replace(" ", "").split(",")
    cells = []
    for cell in nb["cells"]:
        if cell["cell_type"] == cell_type:
            if any([search_cell(cell, string) for string in strings]):
                cells.append(cell["source"])
    return cells

return the cells that contain the string in question

In [None]:
search_cells(notebook, "search_cells")

['#| export\ndef search_cells(nb, string: str = None, cell_type=CellType.code):\n    "Get cells containing string, you can pass comma separated strings"\n    strings = ifnone(string, "").replace(" ", "").split(",")\n    cells = []\n    for cell in nb["cells"]:\n        if cell["cell_type"] == cell_type:\n            if any([search_cell(cell, string) for string in strings]):\n                cells.append(cell["source"])\n    return cells',
 'search_cells(notebook, "search_cells")',
 '#| export\ndef search_string_in_nb(nb, string: str = None, cell_type=CellType.code):\n    "Check if string is present in notebook cells, you can pass comma separated strings"\n    return len(search_cells(nb, string, cell_type)) > 0',
 'assert search_string_in_nb(notebook, "search_cells")',
 '#| export\ndef detect_imported_libs(notebook):\n    "Guess imported libs from notebook"\n    text_list = L(search_cells(notebook, "import,from")).concat()\n\n    # format lines\n    text_list = L([x.split("\\n") for x i

it is useful to filter out notebooks based on libraries or functions

In [None]:
#| export
def search_string_in_nb(nb, string: str = None, cell_type=CellType.code):
    "Check if string is present in notebook cells, you can pass comma separated strings"
    return len(search_cells(nb, string, cell_type)) > 0

In [None]:
assert search_string_in_nb(notebook, "search_cells")

Used imports and libraries

In [None]:
#| export
def extract_libs(strings):
    "Automatically detect libraries imported in `strings`"

    after_import_regex = re.compile(r"^import\s([^\.]*)", re.VERBOSE)
    before_as_regex = re.compile(r"([^\s]*?)\sas\s", re.VERBOSE)
    between_from_import_regex = re.compile(r"^from\s(.*?)\simport", re.VERBOSE)

    def _search_with_regex(regex, string):
        res = regex.search(string)
        if res is not None:
            return res.group(1)
        else:
            return ""

    libs = []
    for string in strings:
        if "from" in string:
            string = _search_with_regex(between_from_import_regex, string).split(".")[0]
        else:
            string = _search_with_regex(after_import_regex, string)
            if "as" in string:
                string = _search_with_regex(before_as_regex, string)
        if string:
            libs.append(string.replace(" ", "").split(","))
    return L(libs).concat().unique()

In [None]:
libs = ["import io, json, sys, re, csv, logging",
        "import git",
        "from types import SimpleNamespace",
        "from logging import Formatter"]

In [None]:
libs

['import io, json, sys, re, csv, logging',
 'import git',
 'from types import SimpleNamespace',
 'from logging import Formatter']

In [None]:
extract_libs(libs)

(#8) ['io','json','sys','re','csv','logging','git','types']

In [None]:
#| export
def detect_imported_libs(notebook):
    "Guess imported libs from notebook"
    text_list = L(search_cells(notebook, "import,from")).concat()

    # format lines
    text_list = L([x.split("\n") for x in text_list]).concat()
    text_list = [line for line in text_list if (("from" in line) or ("import" in line))]

    return extract_libs(text_list)

In [None]:
text_list = search_cells(notebook, "import,from")

In [None]:
detect_imported_libs(notebook)

(#13) ['io','json','sys','re','csv','logging','git','types','fastcore','datetime'...]

## Git stuff
We deal with a bunch of guessing on the repo we are at, so we use pygit for this sorcery

In [None]:
#| export
def get_repo(fname) -> git.Repo:
    try:
        repo = git.Repo(fname, search_parent_directories=True)
        return repo
    except Exception as e:
        raise Exception(f"Probably not in a git repo: {e}")

In [None]:
repo = get_repo(this_nb)
assert type(repo) == git.Repo

In [None]:
#| export
def git_current_branch(fname) -> str:
    "Get current git branch"
    repo = get_repo(fname)
    try:
        return repo.active_branch.name
    except Exception as e:
        return "master"

In [None]:
git_current_branch(this_nb)

'fix-ci'

In [None]:
#| export
def git_main_name(fname) -> str:
    "Get the name of master/main branch"
    repo = get_repo(fname)
    branches = [b.name for b in repo.branches]
    return "main" if "main" in branches else "master"

In [None]:
assert git_main_name(this_nb) == 'main'

In [None]:
#| export
def git_origin_repo(fname):
    "Get github repo name from `fname`"
    repo = get_repo(fname)
    repo_url = repo.remote().url

    # check if ssh or html
    if repo_url != "":
        if "git@" in repo_url:
            github_repo = re.search(r".com:(.*).git", repo_url).group(1)
        else:
            github_repo = re.search(r".com/(.*).git", repo_url).group(1)
        return github_repo
    else:
        raise Exception(f"Not in a valid github repo: {fname=}")

In [None]:
repo_name = git_origin_repo(this_nb)
assert repo_name == 'wandb/nb_helpers', f"The repo name is {repo_name}"

In [None]:
#| export
def git_local_repo(fname):
    "Get local github repo path"
    repo = get_repo(fname)
    return Path(repo.git_dir).parent.resolve()

In [None]:
git_local_repo(this_nb)

Path('/Users/tcapelle/wandb/nb_helpers')

In [None]:
#| export
def git_last_commit(fname):
    "Gets the last commit on fname"
    repo = get_repo(fname)
    return repo.commit().hexsha

In [None]:
git_last_commit(this_nb)

'2016c290419b94ac4afb4c7711c87c64ce992da6'

## Other random stuff

In [None]:
#| export
def today():
    "datetime object containing current date and time"
    now = datetime.now()

    # dd/mm/YY H:M:S
    dt_string = now.strftime("%d/%m/%Y %H:%M:%S")
    return dt_string

get todays date in nice format

In [None]:
today()

'12/09/2022 15:20:22'