Skip to content

Commit

Permalink
* Added option to specify .npy file with dlatents to be used instead …
Browse files Browse the repository at this point in the history
…of dlatent_avg for truncation
  • Loading branch information
oneiroid committed Jun 11, 2019
1 parent c6bd591 commit 1d429da
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
3 changes: 3 additions & 0 deletions encode_images.py
Expand Up @@ -23,6 +23,7 @@ def main():
parser.add_argument('--data_dir', default='data', help='Directory for storing optional models')
parser.add_argument('--mask_dir', default='masks', help='Directory for storing optional masks')
parser.add_argument('--load_last', default='', help='Start with embeddings from directory')
parser.add_argument('--dlatent_avg', default='', help='Use dlatent from file specified here for truncation instead of dlatent_avg from Gs')
parser.add_argument('--model_url', default='https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ', help='Fetch a StyleGAN model to train on from this URL') # karras2019stylegan-ffhq-1024x1024.pkl
parser.add_argument('--model_res', default=1024, help='The dimension of images in the StyleGAN model', type=int)
parser.add_argument('--batch_size', default=1, help='Batch size for generator and perceptual model', type=int)
Expand Down Expand Up @@ -90,6 +91,8 @@ def main():
generator_network, discriminator_network, Gs_network = pickle.load(f)

generator = Generator(Gs_network, args.batch_size, clipping_threshold=args.clipping_threshold, tiled_dlatent=args.tile_dlatents, model_res=args.model_res, randomize_noise=args.randomize_noise)
if (args.dlatent_avg != ''):
generator.set_dlatent_avg(np.load(args.dlatent_avg))

perc_model = None
if (args.use_lpips_loss > 0.00000001):
Expand Down
9 changes: 8 additions & 1 deletion encoder/generator_model.py
Expand Up @@ -44,7 +44,8 @@ def __init__(self, model, batch_size, clipping_threshold=2, tiled_dlatent=False,
partial(create_stub, batch_size=batch_size)],
structure='fixed')

self.dlatent_avg = model.get_var('dlatent_avg')
self.dlatent_avg_def = model.get_var('dlatent_avg')
self.reset_dlatent_avg()
self.sess = tf.get_default_session()
self.graph = tf.get_default_graph()

Expand Down Expand Up @@ -93,6 +94,12 @@ def get_dlatents(self):
def get_dlatent_avg(self):
return self.dlatent_avg

def set_dlatent_avg(self, dlatent_avg):
self.dlatent_avg = dlatent_avg

def reset_dlatent_avg(self):
self.dlatent_avg = self.dlatent_avg_def

def generate_images(self, dlatents=None):
if dlatents:
self.set_dlatents(dlatents)
Expand Down

0 comments on commit 1d429da

Please sign in to comment.