Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 235a420
Showing
24 changed files
with
2,288 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
/datasets | ||
/log | ||
/checkpoints | ||
/downloads | ||
*.pyc | ||
__pycache__ | ||
octave-workspace | ||
/laina_models |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Depth Prediction # | ||
|
||
## Setup (Python 3) | ||
* install pytorch (see pytorch.org) | ||
* install tensorflow (no gpu support needed) | ||
* install python packages: `scipy matplotlib h5py` | ||
|
||
### Prepare datasets | ||
* `python nyud_test_to_npy.py` (modify the paths in that file to point to correct dirs) | ||
* download NYU Depth v2 raw dataset (~400GB) toolbox | ||
* generate training dataset with matlab - see process_raw.m | ||
* `python nyud_raw_train_to_npy.py` (modify the paths in that file to point to correct dirs) | ||
* modify raw_root in train.py and test.py to point to correct dir | ||
|
||
|
||
## Usage examples | ||
|
||
### Train and view results | ||
* `python train.py --ex my_test` | ||
* `tensorboard logdir=log/my_test` | ||
* open 'localhost:6006' in browser | ||
|
||
### Continue training from checkpoint | ||
Checkpoints are stored after each epoch. | ||
|
||
* `python train.py --ex my_test --epochs 80 --lr 0.01` | ||
* `python train.py --ex my_test --epochs 50 --lr 0.003` | ||
|
||
### View all training options | ||
* `python train.py --help` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import os | ||
import shutil | ||
|
||
|
||
def get_experiment(name, overwrite, epoch=None): | ||
log_dir = os.path.join('./log', name) | ||
save_dir = os.path.join('/media/data/depth-prediction/checkpoints', name) | ||
if overwrite: # or (os.path.isdir(log_dir) and not os.path.isdir(save_dir)): | ||
shutil.rmtree(log_dir, ignore_errors=True) | ||
shutil.rmtree(save_dir, ignore_errors=True) | ||
if not os.path.isdir(save_dir): | ||
os.makedirs(log_dir) | ||
os.makedirs(save_dir) | ||
save_paths = sorted(os.listdir(save_dir), | ||
key=lambda s: int(s.split('.')[0].split('_')[1])) | ||
if len(save_paths) > 0: | ||
save_path = save_paths[-1] if epoch is None else 'model_{}.pth'.format(epoch) | ||
restore_path = os.path.join(save_dir, save_path) | ||
starting_epoch = int(save_path.split('.')[0].split('_')[1]) + 1 | ||
else: | ||
restore_path = None | ||
starting_epoch = 0 | ||
return log_dir, save_dir, restore_path, starting_epoch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import numpy as np | ||
import matplotlib | ||
import matplotlib.pyplot as plt | ||
from matplotlib.widgets import Slider, Button | ||
|
||
|
||
def display(results, image_names, title="Flow eval"): | ||
image_grids = [] | ||
num_images = len(results[0]) | ||
num_rows = len(results[0][0]) | ||
num_cols = len(results) | ||
|
||
image_grids = [] | ||
for i in range(num_images): | ||
image_grid = [] | ||
for image_lists in results: | ||
image_grid.append(image_lists[i]) | ||
image_grids.append(image_grid) | ||
|
||
fig = plt.figure(facecolor='grey') | ||
fig.suptitle(title) | ||
mng = plt.get_current_fig_manager() | ||
mng.resize(*mng.window.maxsize()) | ||
imshow_images = [] | ||
plt.subplots_adjust(wspace=0, hspace=0.0) | ||
|
||
imshow_image_lists = [] | ||
for j, image_col in enumerate(image_grids[0]): | ||
imshow_images = [] | ||
for i, t in enumerate(zip(image_names, image_col)): | ||
title, image = t | ||
ax = fig.add_subplot(num_rows, num_cols, i * num_cols + j + 1) | ||
if j == 0: | ||
ax.set_ylabel(title) | ||
ax.set_yticks([]) | ||
ax.set_xticks([]) | ||
if np.size(image, 3) == 1: | ||
imshow_images.append(ax.imshow(image[0, :, :, 0], "gray")) | ||
else: | ||
imshow_images.append(ax.imshow(image[0, :, :, :])) | ||
imshow_image_lists.append(imshow_images) | ||
|
||
def display_example(index): | ||
for j, image_col in enumerate(image_grids[int(index)]): | ||
imshow_images = imshow_image_lists[j] | ||
for im, image in zip(imshow_images, image_col): | ||
if np.size(image, 3) == 1: | ||
im.set_data(image[0, :, :, 0]) | ||
else: | ||
im.set_data(image[0, :, :, :]) | ||
plt.draw() | ||
|
||
current_index = 0 | ||
|
||
next_button_ax = fig.add_axes([0.8, 0.025, 0.1, 0.04]) | ||
next_button = Button(next_button_ax, 'next') | ||
prev_button_ax = fig.add_axes([0.7, 0.025, 0.1, 0.04]) | ||
prev_button = Button(prev_button_ax, 'previous') | ||
slider_ax = fig.add_axes([0.1, 0.025, 0.55, 0.04]) | ||
slider = Slider(slider_ax, 'Page', 0, num_images - 1, | ||
valinit=1, valfmt='%0.0f') | ||
|
||
def next_button_on_clicked(mouse_event): | ||
nonlocal current_index | ||
if current_index < num_images - 1: | ||
current_index += 1 | ||
slider.set_val(current_index) | ||
|
||
def prev_button_on_clicked(mouse_event): | ||
nonlocal current_index | ||
if current_index > 0: | ||
current_index -= 1 | ||
slider.set_val(current_index) | ||
|
||
def sliders_on_changed(val): | ||
nonlocal current_index | ||
current_index = val | ||
display_example(val) | ||
|
||
slider.on_changed(sliders_on_changed) | ||
prev_button.on_clicked(prev_button_on_clicked) | ||
next_button.on_clicked(next_button_on_clicked) | ||
|
||
plt.draw() | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from torch.utils.data import DataLoader | ||
|
||
|
||
def get_training_loader(dset_class, root, batch_size, out_size, | ||
num_threads=1, limit=None, debug=False, shuffle=True): | ||
dset = dset_class(root, split='train', transform=dset_class.get_transform(True, size=out_size), | ||
limit=limit, debug=debug) | ||
return DataLoader(dset, shuffle=shuffle, batch_size=batch_size, pin_memory=True, | ||
num_workers=num_threads) | ||
|
||
|
||
def get_testing_loader(dset_class, root, batch_size, out_size, | ||
num_threads=1, limit=None, debug=False, training=False, shuffle=False): | ||
dset = dset_class(root, split='test', transform=dset_class.get_transform(training, out_size), | ||
limit=limit, debug=debug) | ||
return DataLoader(dset, shuffle=shuffle, batch_size=batch_size, num_workers=num_threads) |
Empty file.
Oops, something went wrong.