In [None]:
from pathlib import Path
import os
import datetime
import logging
import pprint

import click
import yaml
import h5py
from dotenv import find_dotenv, load_dotenv
import tensorflow as tf

from src.models.models import get_compiled_model
from src.data.tf_data import TFDataCreator
from src.models.utils import config_gpu
from src.data.utils import get_split

log_fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
logging.basicConfig(level=logging.INFO, format=log_fmt)
logger = logging.getLogger(__name__)

project_dir = Path("../").resolve()

config_path = project_dir / "configs/config.yaml"

gpu_id = "0"
split_id = 0

In [None]:
with open(config_path) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

In [None]:
config_gpu(gpu_id, memory_limit=60)

In [None]:
tf.keras.mixed_precision.set_global_policy('mixed_float16')

In [None]:
file = h5py.File(os.environ["DATAPATH"], 'r')
ids_train = get_split(split_id, os.environ["SPLITPATH"])["training"]
ids_val = get_split(split_id, os.environ["SPLITPATH"])["validation"]
tf_data_creator = TFDataCreator.get("Task04")(
    file,
    image_ids=ids_train,
    # patch_size=config["data"]["patch_size"],
    num_parallel_calls=tf.data.AUTOTUNE,
    params_augmentation=config["data"]["augmentation"],
)
ds_train = tf_data_creator.get_tf_data(
    ids_train,
    data_augmentation=False,
).batch(1)

ds_val = tf_data_creator.get_tf_data(
    ids_val,
    data_augmentation=False,
).batch(1)

In [None]:
model = get_compiled_model(config["model"], run_eagerly=False)

In [None]:
x, y = next(iter(ds_train))

In [None]:
x.shape

In [None]:
y_pred = model(x)

In [None]:
y_pred.shape

In [None]:
y_pred.numpy().max()