Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions PanoramAI/VAEorama.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,36 +23,34 @@ def __init__(self, dataset,
latent_dim = 100):
super().__init__(dataset, BATCH_SIZE, test_size, latent_dim)

def _generate_random_vector(self, n_samples):
self.n_samples_to_generate = n_samples
return tf.random.normal(
shape=[n_samples, self.latent_dim])

def generate_samples(self, n_samples):
#if n_samples > self.n_samples_to_generate:
# print("Regenerating sample vector.")
return self.model.sample(self._generate_random_vector(n_samples))

self.n_samples_to_generate = n_samples
return self.model.sample(
tf.random.normal(shape=[n_samples, self.latent_dim]))
def create_model(self):
M, N = self.dimensions
self.model = _CVAE(M, N, self.latent_dim)
return

def log_normal_pdf(self, sample, mean, logvar, raxis=1):
log2pi = tf.math.log(2. * np.pi)
return tf.reduce_sum(
-.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
axis=raxis)

@tf.function
def compute_loss(self, x):
mean, logvar = self.model.encode(x)
z = self.model.reparameterize(mean, logvar)
x_predicted = self.model.decode(z)
MSE = tf.losses.MSE(x, x_predicted)
logpx_z = -tf.reduce_sum(MSE)
logpz = self.log_normal_pdf(z, 0., 0.)
logqz_x = self.log_normal_pdf(z, mean, logvar)
#Log-normal distributions for z
#The prior has mean 0 with no variance
#The variational distribution has a mean and logvar
log2pi = tf.math.log(2. * np.pi)
logpz = tf.reduce_sum(
-0.5 * (z ** 2. + log2pi), axis=1)
logqz_x = tf.reduce_sum(
-.5 * ((z - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
axis=1)
#logpz = self.log_normal_pdf(z, 0., 0.)
#logqz_x = self.log_normal_pdf(z, mean, logvar)
return -tf.reduce_mean(logpx_z + logpz - logqz_x)

@tf.function
Expand Down
28 changes: 27 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,30 @@ These can all be installed together using the `requirements.txt` file as shown a

## Usage

This section is still being finalized.
PanoramAI contains models to generate panoramic images. Models provided in this package are untrained. This section documents how to instantiate and train a model, and then how to generate new panoramic images.

First, import and create a model. As an example here, we consider a convolutional variational autoencoder (`VAEorama`).

```python
import PanoramAI

model = PanoramAI.VAEorama(dataset)
```

In this example, `dataset` is a `numpy.ndarray` containing the input panoramic images for training. It should have `N` RGB images in total, meaning its shape must be (`N`,`Height`,`Width`,`3`) where the height and width are in pixels.

Once created, we train the model for some number of epochs.
```python
epochs = 100
model.train(epochs)
```

Predictions can be made in batches
```python
#Create 10 sample panoramas
samples - model.generate_samples(10)
```

### Models

At present, PanoramAI includes two fully generative models: a DCGAN and a convolutional variational autoencoder (VAE). A conditional VAE is in development.