Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
simonmeister committed May 7, 2018
0 parents commit 235a420
Show file tree
Hide file tree
Showing 24 changed files with 2,288 additions and 0 deletions.
8 changes: 8 additions & 0 deletions .gitignore
@@ -0,0 +1,8 @@
/datasets
/log
/checkpoints
/downloads
*.pyc
__pycache__
octave-workspace
/laina_models
30 changes: 30 additions & 0 deletions README.md
@@ -0,0 +1,30 @@
# Depth Prediction #

## Setup (Python 3)
* install pytorch (see pytorch.org)
* install tensorflow (no gpu support needed)
* install python packages: `scipy matplotlib h5py`

### Prepare datasets
* `python nyud_test_to_npy.py` (modify the paths in that file to point to correct dirs)
* download NYU Depth v2 raw dataset (~400GB) toolbox
* generate training dataset with matlab - see process_raw.m
* `python nyud_raw_train_to_npy.py` (modify the paths in that file to point to correct dirs)
* modify raw_root in train.py and test.py to point to correct dir


## Usage examples

### Train and view results
* `python train.py --ex my_test`
* `tensorboard logdir=log/my_test`
* open 'localhost:6006' in browser

### Continue training from checkpoint
Checkpoints are stored after each epoch.

* `python train.py --ex my_test --epochs 80 --lr 0.01`
* `python train.py --ex my_test --epochs 50 --lr 0.003`

### View all training options
* `python train.py --help`
Empty file added dense_estimation/__init__.py
Empty file.
23 changes: 23 additions & 0 deletions dense_estimation/app/experiment.py
@@ -0,0 +1,23 @@
import os
import shutil


def get_experiment(name, overwrite, epoch=None):
log_dir = os.path.join('./log', name)
save_dir = os.path.join('/media/data/depth-prediction/checkpoints', name)
if overwrite: # or (os.path.isdir(log_dir) and not os.path.isdir(save_dir)):
shutil.rmtree(log_dir, ignore_errors=True)
shutil.rmtree(save_dir, ignore_errors=True)
if not os.path.isdir(save_dir):
os.makedirs(log_dir)
os.makedirs(save_dir)
save_paths = sorted(os.listdir(save_dir),
key=lambda s: int(s.split('.')[0].split('_')[1]))
if len(save_paths) > 0:
save_path = save_paths[-1] if epoch is None else 'model_{}.pth'.format(epoch)
restore_path = os.path.join(save_dir, save_path)
starting_epoch = int(save_path.split('.')[0].split('_')[1]) + 1
else:
restore_path = None
starting_epoch = 0
return log_dir, save_dir, restore_path, starting_epoch
85 changes: 85 additions & 0 deletions dense_estimation/app/gui.py
@@ -0,0 +1,85 @@
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button


def display(results, image_names, title="Flow eval"):
image_grids = []
num_images = len(results[0])
num_rows = len(results[0][0])
num_cols = len(results)

image_grids = []
for i in range(num_images):
image_grid = []
for image_lists in results:
image_grid.append(image_lists[i])
image_grids.append(image_grid)

fig = plt.figure(facecolor='grey')
fig.suptitle(title)
mng = plt.get_current_fig_manager()
mng.resize(*mng.window.maxsize())
imshow_images = []
plt.subplots_adjust(wspace=0, hspace=0.0)

imshow_image_lists = []
for j, image_col in enumerate(image_grids[0]):
imshow_images = []
for i, t in enumerate(zip(image_names, image_col)):
title, image = t
ax = fig.add_subplot(num_rows, num_cols, i * num_cols + j + 1)
if j == 0:
ax.set_ylabel(title)
ax.set_yticks([])
ax.set_xticks([])
if np.size(image, 3) == 1:
imshow_images.append(ax.imshow(image[0, :, :, 0], "gray"))
else:
imshow_images.append(ax.imshow(image[0, :, :, :]))
imshow_image_lists.append(imshow_images)

def display_example(index):
for j, image_col in enumerate(image_grids[int(index)]):
imshow_images = imshow_image_lists[j]
for im, image in zip(imshow_images, image_col):
if np.size(image, 3) == 1:
im.set_data(image[0, :, :, 0])
else:
im.set_data(image[0, :, :, :])
plt.draw()

current_index = 0

next_button_ax = fig.add_axes([0.8, 0.025, 0.1, 0.04])
next_button = Button(next_button_ax, 'next')
prev_button_ax = fig.add_axes([0.7, 0.025, 0.1, 0.04])
prev_button = Button(prev_button_ax, 'previous')
slider_ax = fig.add_axes([0.1, 0.025, 0.55, 0.04])
slider = Slider(slider_ax, 'Page', 0, num_images - 1,
valinit=1, valfmt='%0.0f')

def next_button_on_clicked(mouse_event):
nonlocal current_index
if current_index < num_images - 1:
current_index += 1
slider.set_val(current_index)

def prev_button_on_clicked(mouse_event):
nonlocal current_index
if current_index > 0:
current_index -= 1
slider.set_val(current_index)

def sliders_on_changed(val):
nonlocal current_index
current_index = val
display_example(val)

slider.on_changed(sliders_on_changed)
prev_button.on_clicked(prev_button_on_clicked)
next_button.on_clicked(next_button_on_clicked)

plt.draw()
plt.show()
16 changes: 16 additions & 0 deletions dense_estimation/data.py
@@ -0,0 +1,16 @@
from torch.utils.data import DataLoader


def get_training_loader(dset_class, root, batch_size, out_size,
num_threads=1, limit=None, debug=False, shuffle=True):
dset = dset_class(root, split='train', transform=dset_class.get_transform(True, size=out_size),
limit=limit, debug=debug)
return DataLoader(dset, shuffle=shuffle, batch_size=batch_size, pin_memory=True,
num_workers=num_threads)


def get_testing_loader(dset_class, root, batch_size, out_size,
num_threads=1, limit=None, debug=False, training=False, shuffle=False):
dset = dset_class(root, split='test', transform=dset_class.get_transform(training, out_size),
limit=limit, debug=debug)
return DataLoader(dset, shuffle=shuffle, batch_size=batch_size, num_workers=num_threads)
Empty file.

0 comments on commit 235a420

Please sign in to comment.