# HiRo training
This is the main notebook for ML based super sampling. The basic assumptions is to test cloud based training providers such as google colab, therefore all executions require to depend on databases.

## Utilities

In [15]:
def get_secret(secret_name, secret_dir='./secrets/'):
    """
    Read secret from specified location as a string
    :param secret_name: name of a secret to be read
    :param secret_dir: folder where secrets are kept
    :return: first line of a secret file content
    """
    with open(f'{secret_dir}{secret_name}', 'r', encoding="utf8") as secret_file:
        return secret_file.readline().strip()

In [16]:
import jax.numpy as jnp

def deserialize_jarray_from_mongo(serialized_jarray: dict) -> jnp.ndarray:
    """
    Deserialize mongodb object into a jax array restoring its content, dtype and shape
    :param serialized_jarray: an object returned from find or findOne mongo collection
    :return: deserialized jax array
    """
    return jnp.frombuffer(serialized_jarray['content'],
                          dtype=serialized_jarray['dtype']).reshape(serialized_jarray['shape'])

Load preprocessed dataset from mongodb database. Preprocessing is implemented separately. In the database are either full train and valid datasets or prepared batches.

In [17]:
from pymongo import MongoClient

DATASET_VERSION = 1
mongo = MongoClient(get_secret('mongo_connection_string'))

train_imgs = deserialize_jarray_from_mongo(
    mongo.hiro.preprocessed_imgs.find_one({'_id': f'train_{DATASET_VERSION}'}))
valid_imgs = deserialize_jarray_from_mongo(
    mongo.hiro.preprocessed_imgs.find_one({'_id': f'valid_{DATASET_VERSION}'}))