New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trying to create a UNet with descent, but a node is not connected #2
Comments
Hi, happy to see that the code is useful for experimentation! I've not tried debugging the above, but I think the Using instances of a
As such I'd expect the implementation of a
There is a bit more of an explanation of this API at https://github.com/sjb3d/descent/tree/main/examples/array_api, and hopefully the existing implementations of e.g. |
Ah, that makes more sense than what I did! I've rewritten cropping and upsampling. In cropping, for the reverse path, we need to pad the incoming gradient with zeroes to make it fit with the existing one. pub fn crop(self, left: usize, top: usize, right: usize, bottom: usize) -> Self{
let (a, da) = self.into_inner();
let input_shape = a.shape();
assert_eq!(input_shape.len(), 4);
//Compute crop shape
let mut input_offsets: TinyVec<[isize; MAX_DIM]> = std::iter::repeat(0).take(input_shape.len()).collect();
input_offsets[1] = top as isize;
input_offsets[2] = left as isize;
let mut output_shape = input_shape;
output_shape[1] -= top + bottom;
output_shape[2] -= left + right;
//Crop input to shape
let view = View{
input_shape: a.shape(),
input_offsets,
output_mapping: (0..input_shape.len()).map(|i| input_shape.identity_mapping(Axis::from_index(i))).collect(),
output_shape
};
let (b, db) = a.view(view).with_empty_grad();
//We also need to pad the gradient by the crop we just did
let padded = db.pad(1, top, bottom).pad(2, left, right);
da.accumulate(padded);
(b, db).into()
} In upsampling, we need to shrink the gradient in X and Y by just summing the incoming window of the same size as our upsampling factor. We could do this through a convolution with the channels set exactly, but I found the pub fn upsample(self, x_grow_factor: usize, y_grow_factor: usize) -> Self{
let (a, da) = self.into_inner();
let input_shape = a.shape();
assert_eq!(input_shape.len(), 4);
let a_reshaped = a.reshape([input_shape[0], input_shape[1], 1, input_shape[2], 1, input_shape[3]]);
let a_broadcasted = a_reshaped.broadcast([input_shape[0], input_shape[1], y_grow_factor, input_shape[2], x_grow_factor, input_shape[3]]);
let mut output_shape = input_shape;
output_shape[2] *= x_grow_factor;
output_shape[1] *= y_grow_factor;
let a_backshaped = a_broadcasted.reshape(output_shape);
let (b, db) = a_backshaped.with_empty_grad();
//We need to add all pixels we upsampled into the pixel they came from
//We can do this through sum-pooling with stride
//Following code basically copied from the max-pooling implementation
let windows = db.image_to_windows((y_grow_factor, x_grow_factor), (y_grow_factor, x_grow_factor), 1);
let [m, output_h, output_w, groups, filter_h, filter_w, group_nc]: [usize; 7] =
windows.shape().try_into().unwrap();
let summed = windows
.reshape([
m * output_h * output_w * groups,
filter_h * filter_w,
group_nc,
])
.reduce_sum(1, true)
.reshape([m, output_h, output_w, groups * group_nc]);
da.accumulate(summed);
(b, db).into()
} I think the gradients look better now. |
Nice, these updated functions look good to me, glad that the resulting graph compiles to working compute shaders! Good luck for the rest of your project! :) |
Thanks so much for your help! |
There was still one problem with the gradient: I didn't look into whether the existing padding was same padding and that caused problems when the gradient got small. I wrote my own, probably horribly inefficient zero padding method and it seems to be more stable: apexys@4073bc6 At the same time, I'm getting these weird spikes when plotting loss over training generation, so maybe something is still broken? |
Hi, the existing padding code is implementing same padding. This is really just for simplicity, since it means just clamping the input indices when reading values from arrays). For same padding the gradient of each input value will accumulate gradients for all the outputs that is gets broadcast to (since same padding causes edge values to get broadcast to multiple outputs). There are a few tests (search for "#[test]") that check values and gradients for simple examples. That loss graph looks a bit odd for sure! This stuff is unfortunately quite tricky to debug, personally I found just testing each |
Hi,
I've been trying to extend descent to be able to create a U-Net (as in https://arxiv.org/abs/1505.04597) with it.
A U-Net ist basically a bunch of convolutional layers that get concatenated with previous versions of their input.
I implemented the following two new operations in array.rs:
Upsampling: https://github.com/apexys/descent/blob/61415476ffaeda734b841e11b533092691c569b4/src/array.rs#L833-L848
Cropping: https://github.com/apexys/descent/blob/61415476ffaeda734b841e11b533092691c569b4/src/array.rs#L850-L874
The U-Net is then created as a recursive struct, where each layer just applies two Conv-Operations and if it's not the innermost layer, also a cropped and upsampled version of the inner layer, which is then concatenated to the output and goes through another double convolution.
Here's just the execution part, full code is at https://github.com/apexys/descent_unet_example/blob/b828ce8c4a034d9288f316c108a72229b163f493/src/main.rs
When I create the graph as in I run into the following problem:
Somewhere in the creation of the graph, a Mov-Operation is created, but no inputs are connected to it.
I've added logging after every stage of the optimization process, but the error seems to be there from the start (see the upper right Node n295 in https://github.com/apexys/descent_unet_example/blob/main/svgs/after_build_clusters.svg).
I've also tried isolating the problem and it seems that my upsample and crop operations are "fine", at least I couldn't get the error to show up with just them, but they might still cause problems further down the line.
Do you have an idea of what might cause this?
Thank you for your help!
The text was updated successfully, but these errors were encountered: