-
Notifications
You must be signed in to change notification settings - Fork 105
/
file_samplers.py
66 lines (56 loc) · 2.2 KB
/
file_samplers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Copyright 2021 The TensorFlow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING, TypeVar
import tensorflow as tf
from .memory_samplers import MultiShotMemorySampler
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from ..types import FloatTensor, IntTensor
from .samplers import Augmenter
T = TypeVar("T", FloatTensor, IntTensor)
def load_image(path: str, target_size: tuple[int, int] | None = None) -> T:
image_string = tf.io.read_file(path)
image: T = tf.image.decode_jpeg(image_string, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
if target_size:
image = tf.image.resize(image, target_size, method=tf.image.ResizeMethod.LANCZOS3)
image = tf.clip_by_value(image, 0.0, 1.0)
return image
class MultiShotFileSampler(MultiShotMemorySampler):
def __init__(
self,
x,
y,
load_example_fn: Callable = load_image,
classes_per_batch: int = 2,
examples_per_class_per_batch: int = 2,
steps_per_epoch: int = 1000,
class_list: Sequence[int] | None = None,
total_examples_per_class: int | None = None,
augmenter: Augmenter | None = None,
warmup: int = -1,
):
super().__init__(
x,
y,
load_example_fn=load_example_fn,
classes_per_batch=classes_per_batch,
examples_per_class_per_batch=examples_per_class_per_batch,
steps_per_epoch=steps_per_epoch,
class_list=class_list,
total_examples_per_class=total_examples_per_class,
augmenter=augmenter,
warmup=warmup,
)