Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reworked Merger that supports multiclass #21

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

Conversation

the-lay
Copy link
Owner

@the-lay the-lay commented Jun 12, 2022

Resolves #20

I have pushed a pretty big rework of Merger and it has three new/updated keywords now (ignore_channels: bool = False, logits_n: Optional[int] = None, logits_dim: int = 0).

Here's an example of how I imagine it all can be used. It's significantly more flexible, but maybe the API is a bit too complex now.

@jordancaraballo, please take a look, what do you think? Am I missing anything in your opinion? Otherwise in the next commits I will fix tests and make sure I didn't break anything else.

import numpy as np
from tiler import Tiler, Merger


# Let's say you have an image of size 5000x3000 pixels and 4 channels in the last dimension
image_shape = (5000, 3000, 4)
image_channel_dimension = -1
# and you want to tile them into tiles of 256x256 pixels and 4 channels in the last dimension
tile_shape = (256, 256, 4)
tile_overlap = 0.5
# to feed into a segmentation network with 10 output classes (in the last dimension) and batches of 128 tiles
# (so the network output has shape of (128, 256, 256, 10))
output_classes = 10
output_classes_dim = -1
batch_size = 128

image = np.random.rand(*image_shape)
tiler = Tiler(
    data_shape=image_shape,
    tile_shape=tile_shape,
    channel_dimension=image_channel_dimension,
)
merger = Merger(
    tiler,
    ignore_channels=True,  # this allows to "turn off" channels from Tiler
    logits_n=output_classes,  # this specifies how many logits/segmentation classes there will be
    logits_dim=output_classes_dim,  # and in which dimension
)

print("Processing batches...")
for batch_id, batch in tiler(image, batch_size=batch_size):
    print(f"\tBatch: #{batch_id}, with data of shape {batch.shape}")

    # simulating network output of shape (128, 256, 256, 10)
    output = np.random.rand(batch_size, *tile_shape[:-1], output_classes)
    print(f"\tWe simulate NN output with shape of {output.shape} and add it to Merger")

    # adding output into Merger
    merger.add_batch(batch_id, batch_size, output)

print("Processing finished.")

raw_merge_result = merger.merge(argmax=None, unpad=False)
print(f"Shape of the raw merge result: {raw_merge_result.shape}")  # (5120, 3072, 10)

unpad_merge_result = merger.merge(argmax=None, unpad=True)
print(f"Shape of the unpad merge result: {unpad_merge_result.shape}")  # (5000, 3000, 10)

argmaxed_merge_result = merger.merge(argmax=output_classes_dim, unpad=True)
print(f"Shape of the argmaxed merge result: {argmaxed_merge_result.shape}")  # (5000, 3000)

@the-lay the-lay added bug Something isn't working enhancement New feature or request labels Jun 12, 2022
@the-lay the-lay self-assigned this Jun 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[IDEA] Figuring out how to use this library with TensorFlow multi-class and binary classification
1 participant