Skip to content

Code accompanying the paper "Predictive World Models from Real-World Partial Observations" (MOST 2023)

License

Notifications You must be signed in to change notification settings

robin-karlsson0/predictive-world-models

Repository files navigation

Predictive World Models from Real-World Partial Observations

Code accompanying the paper "Predictive World Models from Real-World Partial Observations" (IEEE MOST 2023) 🎉 Best paper award 🎉

Paper link: Predictive World Models from Real-World Partial Observations

Video presentation link: TODO

Shared public data (incl. pretrained models): Google Drive directory

Predictive world model inference

Installation

Download all submodules

git submodule update --init --recursive

The submodules are used for the following tasks

  1. pc-accumulation-lib: Semantic point cloud accumulation library for generating partial observation BEV representations.
  2. lat_var_bev_pred_model: Code for generating complete pseudo-GT representations for training the predictive world model.
  3. vdvae: Code for implementing the predictive world model. Fork of the original VDVAE repository modified to a dual encoder posterior matching HVAE model.

The paper can be reproduced by generating data and training models using the code provided within this repository including all submodules.

Install dependencies

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install matplotlib
pip install opencv-contrib-python
pip install pytorch-lightning
pip install scipy

Install VDVAE dependencies by following the README instructions inside the VDVAE directory.

export PYTHONPATH=/PATH/predictive-world-models/vdvae

while replacing PATH with the absolute path to the predictive-world-models directory.

Run model

NuScenes example

Download trained model and test sample by running the download script or manually from the project Google Drive directory.

sh download_model_nuscenes.sh

https://drive.google.com/drive/folders/1bU6W0yeEz7TttEhS3Y3oDJgvtTd_9Oqu?usp=sharing

The following files will be placed in the project root directory

predictive-world-models/pred_wm_model_ema_nuscenes.th
predictive-world-models/pred_wm_model_nuscenes.th
predictive-world-models/test_sample_nuscenes.pkl.gz
predictive-world-models/test_sample_nuscenes.png

Sample 36 plausible worlds based on the partially observable world represented by the test sample by running the script.

sh sample_worlds_nuscenes.sh

A set of sampled plausible worlds will be visualized in the output directory (out_dir by default).

Predictive world model output on NuScenes

File structure

predictive-world-models
|
└───lat_var_bev_pred_model/     # Pseudo-GT sample generation
|   └─── ...
|
└───pc-accumulation-lib/        # Observation accumulation framework
|   └─── ...
|
└───vdvae/                      # Predictive world model implementation
|   └─── ...
|
|   datamodule.py               # Reads and pre-processes input samples
|   download_model_nuscenes.sh  # Downloads model and test sample files
|   sample_worlds_nuscenes.py   # Runs the world model and save visualizations to disk
|   sample_worlds_nuscenes.sh   # Script including required environment variables
|   world_model.py              # World model inference interface

About

Code accompanying the paper "Predictive World Models from Real-World Partial Observations" (MOST 2023)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published