In [None]:
# | default_exp utils

# Utils

> Utility functions

In [None]:
# | hide
from nbdev.showdoc import *

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import yaml
from typing import Dict

from tqdm import tqdm
from Bio import SeqIO

In [None]:
#| export
def yaml2dict(path: str) -> dict:
    with open(path, 'r') as f:
        data = yaml.safe_load(f)
    return data

In [None]:
#| export
def fasta2dict(path: str) -> Dict:
    data = []
    for record in tqdm(SeqIO.parse(path, "fasta")):
        item = {
            "id": record.id,
            "desc": record.description,
            "seq": str(record.seq)
        }
        data.append(item)
    
    return data

### Checkpoint

In [None]:
#| export
class ModelCheckpoint:
    SMALL = "https://storage.googleapis.com/sfr-progen-research/checkpoints/progen2-small.tar.gz"
    MEDIUM = "https://storage.googleapis.com/sfr-progen-research/checkpoints/progen2-medium.tar.gz"
    LARGE = "https://storage.googleapis.com/sfr-progen-research/checkpoints/progen2-large.tar.gz"

In [None]:
#| export
def download_checkpoint(checkpoint_url: str, path: str = "../data") -> None:
    """Download a checkpoint from a URL to a local path.

    Args:
        checkpoint_url (str): URL of the checkpoint.
        path (str): Local path to save the checkpoint.
    """
    import requests

    response = requests.get(checkpoint_url, stream=True)
    total_size_in_bytes = int(response.headers.get("content-length", 0))
    block_size = 1024  # 1 Kibibyte
    progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
    with open(path, "wb") as file:
        for data in response.iter_content(block_size):
            progress_bar.update(len(data))
            file.write(data)
    progress_bar.close()
    if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
        print("ERROR, something went wrong")