Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RandLA-Net example #5117

Merged
merged 77 commits into from
Dec 2, 2022
Merged

RandLA-Net example #5117

merged 77 commits into from
Dec 2, 2022

Conversation

CharlesGaydon
Copy link
Contributor

@CharlesGaydon CharlesGaydon commented Aug 2, 2022

The paper: RandLA-Net: Efficient Semantic Segmentation of Large-Scale Point Clouds

Context

There lacks a good pytorch implementation of RandLa-Net that leverages pytorch geometric standards and modules.
In torch-points3d, the current modules are outdated leading to some confusion among users.

The implementation with the most stars on github is aRI0U/RandLA-Net-pytorch, which has nasty dependencies (torch_points or torch_points_kernels), makes slow back-and-forth between cpu and gpu when calling knns, and only accepts fixed size point clouds.

Proposal

I would like to implement RandLA-Net as part of pyg's examples. For now I would tackle the ModelNet classification task, and would follow the structure of other examples (pointnet2_classification in particular).

The RandLa-Net paper focuses on segmentation, but for classification I would simply add a MLP+Global Max Pooling after the first DilatedResidualBlocks.

RandLa-Net architecture is conceptually close to PointNet++, augmented with different tricks to speed things up (random sampling instead of fps), use more context (with a sort of dilated KNN), and encode local information better (by explicitly calculating positions, distances, and euclidian distance between points in a neighborhood, and by using self-attention on these features).

If I have some success, I will take on the segmentation task as well (which is what interests me anyway for my own project)

Where I am at

I have a working implementation at examples/randlanet_classification.py. I still have to review it to make sure that I am following the paper as closely as possible, but I think I am on the right track.

I would love some guidance on how to move forward. In particular:

  • Am I using MessagePassing modules correctly?
  • What should I aim for in term of accuracy on ModelNet?
  • Should I stick strictly to the paper? Or adapt the architecture to ModelNet.

Indeed the hyperparameters were not chosen by the author for small objects but rather for large scale Lidar data, which could make convergence way longer that needed.

With 4 DilatedResidualBlocks (like in the paper), we reach ~57% accuracy at epoch 200.

With 3 DilatedResidualBlocks, we reach up to 75% accuracy at the 20th epoch

With only 2 DilatedResidualBlocks, we reach 90% accuracy at the 81st epoch, getting closer to the leaderboard for the ModelNet10 challenge.

@CharlesGaydon CharlesGaydon changed the title [WIP] Implementation of RandLa-Net in pytorch geometric's examples [WIP] RandLa-Net in pytorch geometric's examples Aug 2, 2022
@codecov
Copy link

codecov bot commented Aug 2, 2022

Codecov Report

Merging #5117 (545b2cb) into master (07ba384) will decrease coverage by 1.86%.
The diff coverage is 100.00%.

❗ Current head 545b2cb differs from pull request most recent head cb66f4b. Consider uploading reports for the commit cb66f4b to get more accurate results

@@            Coverage Diff             @@
##           master    #5117      +/-   ##
==========================================
- Coverage   86.20%   84.34%   -1.87%     
==========================================
  Files         362      363       +1     
  Lines       20477    20487      +10     
==========================================
- Hits        17653    17279     -374     
- Misses       2824     3208     +384     
Impacted Files Coverage Δ
torch_geometric/nn/pool/decimation.py 100.00% <100.00%> (ø)
torch_geometric/nn/models/dimenet_utils.py 0.00% <0.00%> (-75.52%) ⬇️
torch_geometric/nn/models/dimenet.py 14.90% <0.00%> (-52.76%) ⬇️
torch_geometric/profile/profile.py 36.27% <0.00%> (-26.48%) ⬇️
torch_geometric/nn/conv/utils/typing.py 81.25% <0.00%> (-17.50%) ⬇️
torch_geometric/nn/pool/asap.py 92.10% <0.00%> (-7.90%) ⬇️
torch_geometric/nn/inits.py 67.85% <0.00%> (-7.15%) ⬇️
torch_geometric/nn/dense/linear.py 87.40% <0.00%> (-5.93%) ⬇️
torch_geometric/transforms/add_self_loops.py 94.44% <0.00%> (-5.56%) ⬇️
torch_geometric/nn/models/attentive_fp.py 95.83% <0.00%> (-4.17%) ⬇️
... and 13 more

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@CharlesGaydon
Copy link
Contributor Author

CharlesGaydon commented Aug 4, 2022

I implemented RandLa-Net for segmentation as well, and made some small refactor.
The model seems to learn quite well, and reaches 70% accuracy after 3 epochs. It takes ~1s to run on CPU.

Unfortunately, I am not able to fully test it out on ShapeNet's Airplane task due to pytorch conflicts that prevent me to use CUDA :/

I need to install the master branch of pyg to run segmentation training. When I follow instructions to build pytorch geometric from source, I see that I am working on a machine with CUDA 11.4, and that I therefore have to build pytorch against CUDA 11.4 before installing dependencies (torch_scatter, etc), and then pytorch_geometric from master branch directly. However, it seems that pytorch + CUDA 11.4 is not really supported - I could not find how to build it from source, and using cudatoolkit=11.4 in pytorch's conda install does not work.

Maybe I am missing something here... Any help would be appreciated :)

@rusty1s
Copy link
Member

rusty1s commented Aug 5, 2022

I suggest to simply install the wheels with CUDA 11.3, this should work even for CUDA 11.4.

@CharlesGaydon
Copy link
Contributor Author

Thank you, this worked like a charm. I think this is ready for review. :)

Right now both scripts follow the paper's architecture (in terms of hyperparameters, depth and number of channels in MLPs). Those were chosen by authors for large scale aerial lidar, not ModelNet and ShapeNet. For ModelNet, removing a few layer enables to reach good accuracy. I was not able to replicate this for ShapeNet, which quickly plateaus around 70% train accuracy (vs. 90% train accuracy / 79% test IoU for PointNet++).

I think it is cleaner to keep everything as it is to follow the paper. We could also change the benchmark for this model (with e.g. S3DIS), but I am not really sure that this is worth the extra work for an example.

@CharlesGaydon CharlesGaydon marked this pull request as ready for review August 6, 2022 10:46
@CharlesGaydon CharlesGaydon marked this pull request as draft August 25, 2022 16:31
@CharlesGaydon
Copy link
Contributor Author

I identified some differences between my implementation and the original paper (in particular in terms of batch norms, activations, and number of channels). Will come back with fixes and modifications!

@rusty1s
Copy link
Member

rusty1s commented Aug 26, 2022

Thanks! Sorry for the delay in review. Please keep me posted.

@CharlesGaydon
Copy link
Contributor Author

@rusty1s Hi!
Back to the office! I moved the function to get decimation indices to its own decimation module under torch.nn.pool, along with a few simple tests.
On a side note, I always wonder what ptr stands for (maybe pointer?). Feel free to edit its docstring if needed (currenly: ptr (LongTensor): indices of samples in the batch.).

@CharlesGaydon
Copy link
Contributor Author

@rusty1s @saedrna The gentlest bump on this :)

@rusty1s
Copy link
Member

rusty1s commented Oct 21, 2022

Yes, I will merge this over the weekend. Sorry for the delay.

CharlesGaydon added a commit to IGNF/myria3d that referenced this pull request Oct 21, 2022
* Update with pyg-team/pytorch_geometric#5117

* Bump minor version to indicate no-model-compatibility

* Update signature for pyg randlanet

* Fix old randlanet signature

* Get rid of legacy implementation of RandLA-Net

* Disable example run until release of a model that is compatible

* Fix misleading batch_size indication for multi-GPUs setting.

* Pyg RandLaNet XP with min/max num_nodes and gradient accumulation.

* Rename XP.

* NoRS XP inherits from base XP.

* 5 epochs of cooldown before reducing lr

* 20 epochs of patience  before reducing lr

* Flake8 corrections.

* Correct model version name in CICD workflow

* Correct config  name in CICD workflow
@CharlesGaydon
Copy link
Contributor Author

@rusty1s Resolved the conflict in Changelog :)

@CharlesGaydon
Copy link
Contributor Author

@rusty1s Small bump :)

Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks all good to me, just the remaining two comments.

examples/randlanet_classification.py Show resolved Hide resolved
examples/randlanet_classification.py Show resolved Hide resolved
CharlesGaydon added a commit to IGNF/myria3d that referenced this pull request Nov 28, 2022
* Update with pyg-team/pytorch_geometric#5117

* Bump minor version to indicate no-model-compatibility

* Update signature for pyg randlanet

* Fix old randlanet signature

* Get rid of legacy implementation of RandLA-Net

* Disable example run until release of a model that is compatible

* Fix misleading batch_size indication for multi-GPUs setting.

* Pyg RandLaNet XP with min/max num_nodes and gradient accumulation.

* Rename XP.

* NoRS XP inherits from base XP.

* 5 epochs of cooldown before reducing lr

* 20 epochs of patience  before reducing lr

* Flake8 corrections.

* Correct model version name in CICD workflow

* Correct config  name in CICD workflow
CharlesGaydon added a commit to IGNF/myria3d that referenced this pull request Nov 28, 2022
* Update with pyg-team/pytorch_geometric#5117

* Bump minor version to indicate no-model-compatibility

* Update signature for pyg randlanet

* Fix old randlanet signature

* Get rid of legacy implementation of RandLA-Net

* Disable example run until release of a model that is compatible

* Fix misleading batch_size indication for multi-GPUs setting.

* Pyg RandLaNet XP with min/max num_nodes and gradient accumulation.

* Rename XP.

* NoRS XP inherits from base XP.

* 5 epochs of cooldown before reducing lr

* 20 epochs of patience  before reducing lr

* Flake8 corrections.

* Correct model version name in CICD workflow

* Correct config  name in CICD workflow
@rusty1s rusty1s changed the title RandLA-Net in pytorch geometric's examples RandLA-Net example Dec 2, 2022
@rusty1s rusty1s enabled auto-merge (squash) December 2, 2022 06:54
@rusty1s rusty1s merged commit 11c8cbd into pyg-team:master Dec 2, 2022
@CharlesGaydon CharlesGaydon deleted the randlanet branch December 23, 2022 12:43
CharlesGaydon added a commit to IGNF/myria3d that referenced this pull request Jan 23, 2023
* WIP V3.*.* with torch-geometric RandLA-Net implementation (#39)

Development of PyG-RandLA-Net

Co-authored-by: Michel Daab <michel.daab@ign.fr>

* Architecture update to latest state + Max Nodes Budgets (#43)

* Update with pyg-team/pytorch_geometric#5117

* Bump minor version to indicate no-model-compatibility

* Update signature for pyg randlanet

* Fix old randlanet signature

* Get rid of legacy implementation of RandLA-Net

* Disable example run until release of a model that is compatible

* Fix misleading batch_size indication for multi-GPUs setting.

* Pyg RandLaNet XP with min/max num_nodes and gradient accumulation.

* Rename XP.

* NoRS XP inherits from base XP.

* 5 epochs of cooldown before reducing lr

* 20 epochs of patience  before reducing lr

* Flake8 corrections.

* Correct model version name in CICD workflow

* Correct config  name in CICD workflow

* New "create_hdf5" task for data-preparation-as-a-task (#42)

* add get_las_paths_by_split_dict to utils

* taking review into account

* increase version number

* add create_hdf5 to the doc

* forgot a word

* resolve version conflict

* Isort all python files (#44)

* Bump version to V3.1.2

* put the config file and the checkpoint into docker

* code neatness patch

* change the version number

* Comet must be first import in run.py (#45)

* fix a checkpoint bug

* patch for checkpoint path

* display of checkpoint path

* test checkpoint path

* test chekcpoint path

* test checkpoint path

* add nano to docker (for debugging purpose)

* change to have only one hydra loading

* exclue a couple of preparatory functions from coverage

* exclude two method from coverage

* cheating to pass coverage step (we are in a hurry!)

* make path to checkpoint absolute

* path works

* code cleaning

* 1 forgotten line

* 1 line forgotten

* Revert "Merge branch 'main' of https://github.com/IGNF/lidar-deep-segmentation"

This reverts commit fbf5dcc, reversing
changes made to 5129be6.

* task.task_name is mandatory now, patched a test

* updated config file

* monkey patching for interpolator

* change doc and version number

* deals with "interpolator"

* little change to redo the docker image

* correct the path to ckpt

* correction for interpolator and increase the version number

* add proba_to_save to interpolator

* CICD patch for interpolator

* change to get the docker image

* patch for confidence

* little patch for confidence

* test to psuh hidinf files on nexus

* manual merge to repair history

* forgot one file

* another change to try and get "confidence" channel

* patched and verified, update of its version number

* patch for pul request reviews

* correct path to the CICD directory to match the content of those directory

* another correction for pul request reviews

* change default.env to placeholder.env

* correctly provide k_interpolation through interpolator

* change name of default files for trained model assets to display version

* another change to trained model assests default name

* setting correct "interpolator" paths into the doc

* puts interpolation_k back into interpolator

* change back a regression

* add testing the default config in the CICD

Co-authored-by: Charles Gaydon <11660435+CharlesGaydon@users.noreply.github.com>
Co-authored-by: CharlesGaydon <charles.gaydon@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants