<a href="https://colab.research.google.com/github/reveondivad/ExData_Plotting1/blob/master/Copy_of_StyleGAN_Explorer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Q: How to do?

A:

1. Runtime -> Change runtime type -> Hardware accelerator -> GPU.
2. Runtime -> Factory reset runtime
3. "Reconnect" in the top right corner
3. Runtime -> Run all
4. Scroll down and wait until you see the GUI.

Pro tip: change the 3rd cell before pressing "run all" to load a different model.

# Troubleshooting
If you get "Unpickling error" or "403 Forbidden" it means that the model file was downloaded too many times. Every time you start the runtime, you are forced by Colab to re-download the model. This creates a lot of pressure on those who store it, because distributing such big files to a lot of people costs money.

**Consider downloading the model to your Google Drive before it happens to you. Use "Get shareable link" to get ID of your file from Google Drive and then replace the ID in the** `entity_to_url` **dictionary with yours.**

In [None]:
%cd /content
!pip install typeguard;
!pip install psutil
!pip install humanize
!pip install tqdm
!rm -rf stylegan && git clone https://github.com/lucidrains/stylegan.git;
%cd /content/stylegan

from IPython.display import Image
from google.colab import files
import sys
import pickle
import numpy as np
import PIL
import psutil
import humanize
import os
import time
from tqdm import tqdm

from scipy import ndimage

%tensorflow_version 1.x
sys.path.append('/content/stylegan/dnnlib')
import dnnlib
import dnnlib.tflib as tflib
dnnlib.tflib.init_tf()

entity_to_url = {
  'faces': 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ',
  'celebs': 'https://drive.google.com/uc?id=1MGqJl28pN4t7SAtSrPdSRJSQJqahkzUf',
  'bedrooms': 'https://drive.google.com/uc?id=1MOSKeGF0FJcivpBI7s63V9YHloUTORiF',
  'cars': 'https://drive.google.com/uc?id=1MJ6iCfNtMIRicihwRorsM3b7mmtmK9c3',
  'cats': 'https://drive.google.com/uc?id=1MQywl0FNt6lHu8E_EUqnRbviagS7fbiJ',
  'anime': 'https://drive.google.com/uc?id=1z8N_-xZW9AU45rHYGj1_tDHkIkbnMW-R',
}

model_cache = {}
synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), minibatch_size=20)

def gen_pil_image(latents, zoom=1, psi=0.7):
    fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    image = Gs.run(latents, None, randomize_noise=True, output_transform=fmt, truncation_psi=psi)
    if zoom == 1:
        return PIL.Image.fromarray(image[0])
    else:
        print(image[0].shape)
        return PIL.Image.fromarray(ndimage.zoom(image[0],(zoom,zoom,1)))

import google.colab.output
import random
import io
import base64

def gen(l=None, psi=1):
    if l is None:
        l = [random.random()*2-1 for x in range(512)]
    pimg = gen_pil_image(np.array(l).reshape(1,512), psi=psi)
    bio = io.BytesIO()
    pimg.save(bio, "PNG")
    b = bio.getvalue()
    return 'data:image/png;base64,'+str(base64.b64encode(b),encoding='utf-8')

google.colab.output.register_callback('gen', gen)

In [None]:
def fetch_model(name):
  if model_cache.get(name):
    return model_cache[name]
  url = entity_to_url[name]
  with dnnlib.util.open_url(url, cache_dir='cache') as f:
    _G, _D, Gs = pickle.load(f)
    model_cache[name] = Gs
  return model_cache[name]

def fetch_file(filename):
  with open(filename,'rb') as f:
    return pickle.load(f)

In [None]:
#choose model here. default is ffhq

curr_model = "faces" # can be faces, celebs, bedrooms, cars, cats, anime
Gs = fetch_model(curr_model) # if you uploaded your own file, use fetch_file('path/to/file.pkl')

In [None]:
from IPython.display import HTML

def get_latent_html(i):
    return """<div class="pure-control-group">
            <label for="l%i">L%03i:</label>
            <input type="number" min="-999.99" max="999.99" step="0.01" id="l%i" value="%.2f" style="background-color: white;">
    </div>""" % (i, i, i, (random.random()*2-1))

def get_latents_html():
    return '\n'.join([get_latent_html(i) for i in range(512)])

input_form = """
<link rel="stylesheet" href="https://necolas.github.io/normalize.css/8.0.1/normalize.css">
<link rel="stylesheet" href="https://unpkg.com/purecss@1.0.1/build/pure-min.css" integrity="sha384-oAOxQR6DkCoMliIh8yFnu25d7Eq/PHS21PClpwjOTeU2jRSq11vu66rf90/cZr47" crossorigin="anonymous">

<div style="background-color:white; border:solid #ccc; width:1200px; padding:20px; color: black;">
<p>You have currently loaded %s model</p>
  <div class="pure-g" style="width:1200px; margin-bottom: 25px;">
    <div class="pure-u-2-3">
      <img id="stylegan" src="" style="height:512px; width:512px;">
    </div>
    <div class="pure-u-1-3">
      <div style="overflow-y:scroll; height:512px; width:300px" class="pure-form pure-form-aligned">
        %s
      </div>
    </div>
  </div>

  <div class="pure-g">
    <div class="pure-u-1-6">
      <button class="pure-button pure-button-primary" onclick="generate();">Generate from latents</button>
    </div>
    <div class="pure-u-1-6 pure-form">
      <div class="pure-control-group">
            <label for="psi">psi:</label>
            <input type="number" min="0" max="999.99" step="0.01" id="psi" value="0.7" style="background-color: white;">
      </div>
    </div>
    <div class="pure-u-1-6">
      <button class="pure-button pure-button-primary" onclick="mutate();">Mutate randomly</button>
    </div>
    <div class="pure-u-1-6 pure-form">
      <div class="pure-control-group">
            <label for="mut_str">Mutation strength:</label>
            <input type="number" min="0" max="999.99" step="0.01" id="mut_str" value="0.2" style="background-color: white;">
      </div>
    </div>
    <div class="pure-u-1-6">
      <button class="pure-button pure-button-primary" onclick="randomize();">Random image</button>
    </div>
    <div class="pure-u-1-6">
      <button class="pure-button pure-button-primary" onclick="nnormalize();">Normalize latents</button>
    </div>
  </div>

  <div class="pure-g">
    <div class="pure-u-1-3">
      <button class="pure-button pure-button-primary" onclick="save();">Save latents</button>
      <button class="pure-button pure-button-primary" onclick="load();">Load latents</button>
    </div>
     <div class="pure-u-2-3 pure-form">
      <div class="pure-control-group">
            <input type="text" id="save-input" style="width:100%%; background-color: white;" placeholder="Saved latents will show up here...">
      </div>
    </div>
  </div>

</div>
""" % (curr_model, get_latents_html())

javascript = """
<img src onerror='generate()'>
<script type="text/Javascript">
    function desanitize(text) {
        return text.slice(1,-1).replace(/\\\\n/g, "\\n").replace(/\\\\'/g, "'");
    };

    function set_img(text) {
        document.getElementById('stylegan').src = text;
    };

    function generate() {
        var kernel = google.colab.kernel;
        var latents = [];
        var psi = parseFloat(document.getElementById('psi').value);
        for (var i=0;i<512;i++) {
            latents[i] = parseFloat(document.getElementById('l'+i).value);
            //console.log(i);
        };
        console.log(latents);
        var resultPromise = kernel.invokeFunction("gen", [latents, psi]);
        resultPromise.then(
            function(value) {
              console.log(value.data);
              set_img(desanitize(value.data["text/plain"]));
              //document.getElementById('spinner').style = "visibility: hidden;";
        });
    };

    function mutate() {
        var kernel = google.colab.kernel;
        var latents = [];
        var psi = parseFloat(document.getElementById('psi').value);
        var mutationStrength = parseFloat(document.getElementById('mut_str').value)
        for (var i=0;i<512;i++) {
            latents[i] = parseFloat(document.getElementById('l'+i).value);
            latents[i] += (Math.random()*2-1) * mutationStrength;
            document.getElementById('l'+i).value = latents[i];
        };
        console.log(latents);
        var resultPromise = kernel.invokeFunction("gen", [latents, psi]);
        resultPromise.then(
            function(value) {
              console.log(value.data);
              set_img(desanitize(value.data["text/plain"]));
              //document.getElementById('spinner').style = "visibility: hidden;";
        });
    };

    function nnormalize() {
        console.log('Any normies?');
        var latents = [];
        for (var i=0;i<512;i++) {
            latents[i] = parseFloat(document.getElementById('l'+i).value);
        };
        var sum = latents.reduce((a, b) => a + b, 0);
        for (var i=0;i<512;i++) {
            latents[i] = latents[i]/sum;
            document.getElementById('l'+i).value = latents[i];
        };
        console.log(latents);
        generate();
    };

    function randomize() {
        var latents = [];
        for (var i=0;i<512;i++) {
            latents[i] = Math.random()*2-1;
            document.getElementById('l'+i).value = latents[i];
        };
        generate();
    };

    function save() {
        var latents = [];
        for (var i=0;i<512;i++) {
            latents[i] = parseFloat(document.getElementById('l'+i).value);
        };
        document.getElementById('save-input').value = JSON.stringify(latents);
    }

    function load() {
        var latents = JSON.parse(document.getElementById('save-input').value);
        for (var i=0;i<512;i++) {
            document.getElementById('l'+i).value = latents[i];
        };
        generate();
    }

</script>
"""

HTML(input_form + javascript)