In [1]:
import os
from random import shuffle

import tensorflow as tf

from src.models.models import BispectUnetLight, UnetLight
from src.data.drive import (get_dataset, tf_random_crop, tf_random_rotate,
                            tf_random_flip)
from src.models.loss import dice_coe_loss, dice_coe_metric

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
image_ids = [i for i in range(21, 41)]
shuffle(image_ids)
image_ids_val = image_ids[:5]
image_ids_train = image_ids[5:]

ds_train = get_dataset(id_list=image_ids_train)
ds_train = ds_train.cache().repeat(10).map(tf_random_rotate).map(
    tf_random_flip).map(tf_random_crop).map(lambda x, y, z: (x, z)).batch(2)

ds_val = get_dataset(id_list=image_ids_val)
f = lambda x: tf.image.resize_with_crop_or_pad(x, 592, 592)
ds_val = ds_val.map(lambda x, y, z: (f(x), f(z))).cache().batch(1)


In [11]:
model = UnetLight(
    output_channels=1,
    # n_harmonics=4,
    # radial_profile_type="disks",
)


In [12]:
model.compile(
    loss=[dice_coe_loss],
    optimizer=tf.keras.optimizers.Adam(1e-3),
    metrics=[dice_coe_metric, tf.keras.metrics.AUC()],
    run_eagerly=False,
)


In [13]:
x, y= next(ds_train.as_numpy_iterator())

In [14]:
y_pred = model(x)

In [15]:
model.summary()

Model: "unet_light"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sequential_10 (Sequential)   (2, 256, 256, 16)         5280      
_________________________________________________________________
sequential_11 (Sequential)   (2, 128, 128, 32)         16000     
_________________________________________________________________
sequential_12 (Sequential)   (2, 64, 64, 64)           58304     
_________________________________________________________________
sequential_13 (Sequential)   (2, 32, 32, 128)          248320    
_________________________________________________________________
sequential_14 (Sequential)   (2, 16, 16, 256)          988160    
_________________________________________________________________
up_block_light (UpBlockLight multiple                  737664    
_________________________________________________________________
up_block_light_1 (UpBlockLig multiple                  1