Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IDEA] Figuring out how to use this library with TensorFlow multi-class and binary classification #20

Open
jordancaraballo opened this issue Apr 25, 2022 · 3 comments · May be fixed by #21
Assignees
Labels
bug Something isn't working

Comments

@jordancaraballo
Copy link

Is your feature request related to a problem? Please describe.

I have been trying to use this library for the inference of TensorFlow binary and multiclass segmentation models. I am able to use the tiler object to perform the predictions. I have not been able to figure out how to leverage the merger for the following cases.

data_shape = 5000 x 3000 x 4
tile_shape = (256, 256, 4)
channel_dimension = 0

The output of the model can be either a batch of (N x 256 x 256 x 1) or (N x 256 x 256 x 6); where 6 is the number of classes.

ValueError: Passed data shape ([256 256   1]) does not fit expected tile shape ((256, 256, 4)).

Describe the solution you'd like
Would be great to have additional examples regarding similar use cases performing TensorFlow or PyTorch predictions.

Here is an example of what I have been trying:

model = tf.keras.models.load_model(model.hdf5)

image = rxr.open_rasterio(filename)
image = image.transpose("y", "x", "band")
print(image.shape)

tiler = Tiler(
            data_shape=image.shape,
            tile_shape=(256, 256, 4),
            channel_dimension=2,
            #overlap=0.50
        )

# Calculate and apply extra padding, as well as adjust tiling parameters
#new_shape, padding = tiler.calculate_padding()
#tiler.recalculate(data_shape=new_shape)
#padded_image = np.pad(image, padding, mode="reflect")

merger = Merger(tiler=tiler)#, window="overlap-tile")
print(tiler)

for batch_id, batch in tiler(image, batch_size=512):
    batch = model.predict(batch)
    merger.add_batch(batch_id, 512, batch)

I am probably missing something, but would be nice to have it documented. Also, argmax option seems to be hardcoded for channel first images, which adds additional computational requirements when using channels last images. Any help would be appreciated.

@jordancaraballo jordancaraballo added the enhancement New feature or request label Apr 25, 2022
@the-lay
Copy link
Owner

the-lay commented Apr 26, 2022

Hi Jordan!

As of v0.5.7:

The current logits/argmax functionality is not flexible and can definitely be improved. Moreover, your example highlights a limitation that I overlooked completely. I also use the library to tile images, feed tiles to semantic segmentation network and merge back to full result, but in my case those images don't have channel dimension as it's always just one value per pixel, so I never used Merger's logits/argmax functionality and Tiler's channel_dimension at the same time...

Merger's add expects data with the same shape as tile_shape of the original Tiler. Similarly add_batch expects data of shape [batch, *tile_shape]. In your example batch variable is expected to be of shape (batch, 256, 256, 4).

If you specify logits for Merger, it would change the expected shape to [logits, *tile_shape] (or [batch, logits, *tile_shape]). If we specify logits in your example (e.g. merger = Merger(tiler, logits=6)), the expected data for add_batch would become (batch, 6, 256, 256, 4), which is also not something that you want to happen.

I will try to find the time to implement this soon and sorry for not supporting your usecase yet!

@the-lay the-lay added bug Something isn't working and removed enhancement New feature or request labels Apr 26, 2022
@the-lay the-lay self-assigned this Apr 26, 2022
@jordancaraballo
Copy link
Author

jordancaraballo commented Apr 26, 2022

Hi,

Thanks for your response! While this might not be ideal, I was able to work around the channel_dimension constraints by having a second Tiler object with the N channels that are supposed to be the output of the network. The following is an example of the implementation. If you are okay with this, I can create a pull request with a similar example so other users I point to this library can leverage it.

Binary segmentation problem where output is N x 256 x 256 x 1

mode = 'constant'
batch_size = 512

tiler_image = Tiler(
    data_shape=image.shape,
    tile_shape=(256, 256, 4),
    channel_dimension=2,
    overlap=0.50,
    mode=mode,
)

tiler_mask = Tiler(
    data_shape=image.shape,
    tile_shape=(256, 256, 1),
    channel_dimension=2,
    overlap=0.50,
    mode=mode,
)

new_shape, padding = tiler_image.calculate_padding()
tiler_image.recalculate(data_shape=new_shape)
tiler_mask.recalculate(data_shape=new_shape)
padded_image = np.pad(image, padding, mode=mm, constant_values=1200)

merger = Merger(tiler=tiler_mask, window="overlap-tile")

for batch_id, batch in tiler_image(padded_image, batch_size=batch_size):
    batch = model.predict(batch)
    merger.add_batch(batch_id, batch_size, batch)

prediction = merger.merge(extra_padding=padding, dtype=image.dtype)
prediction = np.squeeze(np.where(prediction > 0.5, 1, 0).astype(np.int16))
print(prediction.shape, prediction.min(), prediction.max())

The only challenge I am trying to work around now is the presence of artifact effects at the boundary level of non-uniform images (e.g. an image of size 90538x9148x4 where the tile size is 256x256 with a batch size of 512). Is this something you have worked around with this library? I can open a new issue with this topic as well. An example is illustrated below, where those vertical lines are not expected at the left border of the image.

Screen Shot 2022-04-26 at 08 48 12

@the-lay
Copy link
Owner

the-lay commented Apr 28, 2022

Nice workaround! Hopefully in the near future it will not be needed anymore!

Not sure how helpful these suggestions are, but:

  • You can using another numpy padding mode (I personally see reflect provide the best results in my data)
  • You can try using other merger windows, for example hamming, it should apply more weight to center of tile, instead of weighting all tile pixels equally
  • If you have time, maybe you can try implementing something described in Vooban/Smoothly-Blend-Image-Patches
  • If inference is cheap, you can try increasing overlap. In theory it should improve the results.

I'm curious to hear if you manage to fix this :)

@the-lay the-lay linked a pull request Jun 12, 2022 that will close this issue
@the-lay the-lay linked a pull request Jun 12, 2022 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants