<a href="https://colab.research.google.com/github/un1tz3r0/stylegan3/blob/main/stylegan3_training_and_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# <big><big><big><big><big>Training StyleGAN3</big></big></big></big></big>

This is my setup for training NVidia's [StyleGAN3](https://github.com/NVlabs/stylegan3) \(aka [Alias-Free GAN](https://nvlabs.github.io/stylegan3/)\) on Google's [Colab Pro+ GPUs](https://colab.research.google.com/). For a measly $50/month you can train from scratch or from a pretrained network as a starting point your very own translation/rotation-invariant alias-free GAN for all kinds of fun applications. All you need is some images to train with.

This work was inspired by many projects, most directly the work of [Max Braun](https://braun.design/), whose [eBoyGAN](https://onezero.medium.com/how-i-accidentally-created-an-infinite-pixel-hellscape-fe070551365f) I have attempted here to replicate using the more recent alias-free StyleGAN3. This notebook contains everything you will need to reproduce my results, and hopefully will act as a starting point for similar projects.

One of the biggest motivations for this work and for publishing this notebook and the github repos below was the lack of pretrained models available for experimenting with inference, projection and semantic editing and other applications that manipulate images in the GAN latent space. There are many examples that use older models such as StyleGAN2, but they all suffer from poor image quality and artifacts due to the issues addressed by StyleGAN3.

\- [Victor Condino](https://twitter.com/un1tz3r0/)

## Github repos used by this notebook:

- [un1tz3r0/stylegan3](https://github.com/un1tz3r0/stylegan3.git)
- [un1tz3r0/pixelscapes-dataset](https://github.com/un1tz3r0/pixelscapes-dataset.git)


<big><big>*If you wish to support my work, please donate generously to:*

BTC: [3MzZseGSqXFo6GmJxthVRvdCre8CR2F1QJ](https://www.blockchain.com/btc/address/3MzZseGSqXFo6GmJxthVRvdCre8CR2F1QJ)

Ethereum/Polygon/ERC20 Tokens: [0x0480409E69c4c89EeB4cDb84111B63976E56c389](https://www.blockchain.com/eth/address/0x0480409E69c4c89EeB4cDb84111B63976E56c389)





---



In [None]:
#@title <big><big><big><big>Notebook Mode</big></big></big></big> { display-mode: "form" }
#@markdown This notebook has some common setup cells followed by a section containing cells specific to training and one specific to inference. This disables either the training section or inference section, so that you can set it and do "Run All Cells"

notebook_mode = "Inference" #@param ["Training", "Inference"]




# <big><big><big><big>Setup</big></big></big></big>


In [None]:
#@title <big><big><big>Connect with Google Drive</big></big></big> { vertical-output: true, display-mode: "form" }
#@markdown Run this cell to authorize the runtime instance to access files on your Google Drive.
#@markdown
#@markdown The files will be placed in folder called 'stylegan3-training' in the top-level [My Drive] folder, which is created if it does not aleady exist. You can specify a different folder to use below:

folderprefix = "/stylegan3-training" #@param {type:"string"}

default_rclone_config_path = "/content/rclone.config"

from shutil import make_archive

def configure_rclone(configfile=default_rclone_config_path, folderprefix=folderprefix, overwrite=True, clearauth=True):
  import pathlib
  if (not pathlib.Path(configfile).exists()) or overwrite:
    # Import PyDrive and associated libraries.
    # This only needs to be done once per notebook.
    from pydrive.auth import GoogleAuth
    from pydrive.drive import GoogleDrive
    from google.colab import auth
    from oauth2client.client import GoogleCredentials
    import json, datetime, tzlocal
    import httplib2
    import google.colab.drive
    from pytz import timezone

    if clearauth:
      google.colab.drive.flush_and_unmount()

    # Authenticate and create the PyDrive client.
    # This only needs to be done once per notebook.
    auth.authenticate_user()
    gauth = GoogleAuth()
    gauth.credentials = GoogleCredentials.get_application_default()
    drive = GoogleDrive(gauth)
    #print(drive.auth.credentials.get_access_token())

    # we have authenticated, write the credentials to rclone.config
    rcloneconfig = "\n".join([
      '[driveapi]',
      'type = alias',
      f'remote = driveroot:{folderprefix}',
      '',
      '[driveroot]',
      'type = drive',
      f'client_id = {drive.auth.credentials.client_id}',
      f'client_secret = {drive.auth.credentials.client_secret}',
      'scope = drive.file',
      'token = {}'.format(json.dumps({
        "access_token":drive.auth.credentials.get_access_token().access_token,
        "token_type":"Bearer",
        "refresh_token":drive.auth.credentials.refresh_token,
        "expiry":drive.auth.credentials.token_expiry.astimezone(tzlocal.get_localzone()).astimezone(datetime.timezone.utc).isoformat().replace("Z", "+")
      }))
    ])

    with open(configfile, "wt") as fout:
      print(f"Writing rclone remote configuration with Google Drive auth token to {configfile}...")
      wrsz = fout.write(rcloneconfig)
      print(f"... wrote {wrsz} bytes")

# ------------------------------------------------------------------------------
# install rclone
# ------------------------------------------------------------------------------

def install_rclone_from_github():
  #!sudo apt install golang
  %cd /content
  !wget https://go.dev/dl/go1.17.5.linux-amd64.tar.gz
  !rm -rf /usr/local/go && tar -C /usr/local -xzf go1.17.5.linux-amd64.tar.gz
  
  def extend_path(searchdir):
    
    # add the directory to the python interpreter's PATH env var (avoiding duplicates), 
    # so it takes effect immediately
    import os, shlex
    os.environ['PATH']=":".join([*[p for p in os.environ['PATH'].split(":") if p != "/usr/local/go/bin"], "/usr/local/go/bin"])
    
    # add the directory to the current user's .profile, which is sourced by shells
    #  on startup. also avoiding duplicates
    lines = []
    with open(os.path.expanduser("~/.profile"), "rt") as fin:
      lines = fin.readlines()
    foundline = False
    addline = 'export PATH="${PATH}:"' + shlex.quote(searchdir)
    for line in lines:
      if line.strip() == addline:
        foundline = True
        break
    if not foundline:
      lines.append(addline)
      with open(os.path.expanduser("~/.profile"), "wt") as fout:
        fout.write("\n".join(lines))
  

  #!echo 'export PATH="${PATH}:/usr/local/go/bin"' >> ~/.profile
  
  extend_path("/usr/local/go/bin")
  !go get github.com/rclone/rclone
  # %cd /content
  # !git clone https://github.com/rclone/rclone.git
  # %cd rclone
  # !make
  # !sudo make install


def install_rclone_via_shell(quiet=True):
  import pathlib, shutil
  from google.colab import output
  import json, binascii

  # check if rclone is in $PATH
  if shutil.which("rclone") == None:
    # no rclone, install it!
    print("Downloading and running rclone install.sh...")
    if not quiet:
      !bash -c 'cd /content; curl https://rclone.org/install.sh | sudo bash'
    else: 
      !bash -c 'cd /content; curl https://rclone.org/install.sh 2>/dev/null | sudo bash >/dev/null 2>&1'
    assert(shutil.which("rclone") != None)
  else:
    if not quiet:
      print("It appears rclone is already installed!")

print("Authorizing notebook to use google drive...")
configure_rclone(default_rclone_config_path)

print("Downloading and installing rclone...")
try:
  import shutil
  install_rclone_via_shell(quiet=True)
  assert(shutil.which("rclone") != None)
except:
  install_rclone_from_github()
  assert(shutil.which("rclone") != None)

print("Testing rclone google drive remote...")
!rclone --config=$default_rclone_config_path touch driveapi:.timestamp
print("... success!")

# -------------------------------------------------------------------
# Python subprocess wrappers for programmatically using rclone cli
# -------------------------------------------------------------------

import os, pathlib
import subprocess, json

if 'default_rclone_config_path' not in vars().keys():
  default_rclone_config_path = "/content/rclone.config"

def rclone(*args, output="pass", check=True, config=None):
  global default_rclone_config_path
  if config == None:
    config = default_rclone_config_path
  if config != None:
    args = list([f"--config={config}", *args])
  
  if output == "pass":
    p = subprocess.run(["rclone", *args], check=check, stderr=subprocess.STDOUT)
    if not check:
      return p.returncode
  else:
    p = subprocess.run(["rclone", *args], check=check, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
    if output == "raw":
      if not check:
        return p.stdout, p.returncode
      else:
        return p.stdout
    elif output == "json":
      try:
        jsonout = json.loads(p.stdout)
      except json.JSONDecodeError as err:
        jsonout = None
      if not check:
        return jsonout, p.returncode
      else:
        return jsonout

def rclonels(*args, max_depth = None):
  files, retcode = rclone("lsjson", *args, *([f"--max-depth={max_depth}"] if max_depth != None else []), output="json", check=False)
  if retcode != 0:
    return []
  if files == None:
    return []
  return files

def syncnewestmatchingfile(searchdirs, pattern, max_depth = None, download_to = None):
  import fnmatch
  allfiles = []
  for searchdir in searchdirs:
    dirfiles = rclonels(searchdir, max_depth=max_depth)
    allfiles = list(allfiles) + list([(pathlib.Path(searchdir)/df['Path'], df['ModTime']) for df in dirfiles if fnmatch.fnmatch(df['Name'], pattern)])
  if len(allfiles) < 1:
    return None # no matching files were found
  newestfile = str(list([g[0] for g in list(sorted(allfiles, key=lambda j: j[1]))])[-1])
  if download_to != None:
    downloadedfile = str(pathlib.Path(download_to) / pathlib.Path(newestfile).name)
    if str(pathlib.Path(newestfile)).startswith(str(pathlib.Path(download_to))):
      return str(pathlib.Path(newestfile))
    print(f"Downloading newest file matching pattern {repr(pattern)} in search dirs {searchdirs} to {download_to}: {newestfile}")
    rclone("copy", "--progress", newestfile, download_to, output="pass", check=False)
    print(f"... done downloading {downloadedfile}")
    return downloadedfile
  else:
    return str(newestfile)

def localnewestmatchingfile(dirpath, pattern):
  import pathlib
  sortedmatches = list([str(f) for f in sorted(pathlib.Path(dirpath).glob(pattern), key=lambda f: f.stat().st_mtime)])
  if len(sortedmatches) > 0:
    return sortedmatches[-1]
  else:
    return None


In [None]:
#@title <big><big><big>Install **StyleGAN3** Fork</big></big></big>
#@markdown And various python dependencies it needs to run. [un1tz3r0/stylegan3](https://github.com/un1tz3r0/stylegan3.git)

!pip install einops ninja gdown aiohttp

import os
%cd /content/

#!rm -rf /content/stylegan3
if not os.path.isdir('/content/stylegan3/'):
  !git clone https://github.com/un1tz3r0/stylegan3.git /content/stylegan3/
else:
  %cd /content/stylegan3/
  !git pull


# <big><big><big>Prepare Model and Data</big></big></big>
If this is your first time running this notebook, the first two cells of this section won't do much, because you don't have any model snapshots on your Google Drive yet. The dropdown that appears after running the first cell, *Select a Model Snapshot to Resume Training*, will only contain the values `none` and the url(s) of the pretrained models distributed by the NVidia team with the official StyleGAN3 Alias-Free GANs paper. Once you have successfully started training a model, the snapshots produced will appear in this dropdown as starting points for future training runs.

In [None]:
#@title <big><big><big>Select a **Model Snapshot**</big></big></big> { vertical-output: true, display-mode: "form" }

#@markdown Run this cell to display a list of all `network-snapshot-*.pkl` files on your Google drive, sorted in ascending chronological order by modification time.
#@markdown 
#@markdown *the newest model snapshot found on google drive will be automatically selected*
#@markdown
#@markdown When you select a model, if there is a matching fakes*.png in the same folder it will be downloaded and shown below the dropdown to preview the selected model's generator output.

import ipywidgets as widgets
from IPython import display
from fnmatch import fnmatch
import pathlib

#if "modelselect" not in globals().keys():
modelselect = None
#if "modellayout" not in globals().keys():
modellayout = None
#if "modelpreview" not in globals().keys():
modelpreview = None
nodownloadpreview = False

def modelselectchanged(evt):
  global modelselection
  global modelselect
  global modelpreview
  global nodownloadpreview

  newvalue = evt['new']
  if isinstance(newvalue, int):
    newvalue = modelselect.options[evt['new']]
  modelselection = newvalue

  if modelpreview != None and not nodownloadpreview:
    pngpath = pathlib.Path(str(newvalue.replace('network-snapshot-', 'fakes').replace('.pkl', '.png')))
    if str(pngpath).startswith("driveapi:"):
      localpngpath = pathlib.Path("/content") / pngpath.name
      if not localpngpath.exists():
        print(f"Downloading preview image {pngpath} for model {newvalue} to {localpngpath}...")
        output, resultcode = rclone("copyto", "-P", str(pngpath), str(localpngpath), check=False, output="raw")
        print(f"download {str(pngpath)} to {str(localpngpath)} finished with result code {resultcode}: ouput={repr(output)}")
      pngpath = localpngpath
    if pngpath.exists():
      print(f"found preview fakes png {repr(str(pngpath))} for model snapshot {repr(evt['new'])}, loading image in browser...")
      modelpreview.set_value_from_file(str(pngpath))
      print(f"loaded preview fakes png {repr(str(pngpath))}!")
    else:
      print(f"missing preview fakes png {repr(str(pngpath))} for model snapshot {repr(evt['new'])}")
      modelpreview.value = bytes()
  elif nodownloadpreview:
    print("Not downloading preview (choose a model from the dropdown to show the associated fakesN.png preview...)")
    nodownloadpreview = False

def updatemodelselect():
  global modelselect
  global modelpreview
  global modellayout
  global nodownloadpreview
  global modelselection

  nodownloadpreview = True
  if modelpreview == None:
    modelpreview = widgets.Image()
  print("Getting list of models from google drive...")
  models = [str(pathlib.Path("driveapi:") / f['Path']) for f in sorted(rclonels("driveapi:", max_depth=3), key=lambda f: f['ModTime']) if fnmatch(f['Name'], "network-snapshot-*.pkl")]
  models = ['none', 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhqu-256x256.pkl'] + models
  if 'modelselect' not in vars().keys() or modelselect == None:
    modelselect = widgets.Dropdown(options=models)
    modelselection = models[-1]
  else:
    modelselection = modelselect.value
    modelselect.options = models
  modelindex = list([-1, *[n for n, m in enumerate(models) if modelselection == m]])[-1]
  modelselect.unobserve_all()
  modelselect.observe(modelselectchanged, ['value', 'index'])
  modelselect.index = modelindex
  if modellayout == None:
    modellayout = widgets.VBox(children=(modelselect, modelpreview), layout=widgets.Layout(height='auto', width='auto'))

nodownloadpreview = True
updatemodelselect()
display.display(modellayout)
#nodownloadpreview = False




In [None]:
#@title <big><big><big>**Download** the selected **Model Snapshot**</big></big></big> { vertical-output: true, display-mode: "form" }

#@markdown Run this cell to download the model snapshot selected in the previous cell and set the downloaded file to be used for training and inference in the following sections.

#latestpkl = syncnewestmatchingfile(["driveapi:/"], pattern="network-snapshot-*.pkl", download_to="/content")
modelpath = modelselection
if modelpath.startswith("driveapi:"):
    newmodelpath = str(pathlib.Path("/content") / pathlib.Path(modelpath).name)
    print(f"Downloading model {repr(modelpath)} to local file {repr(newmodelpath)}...")
    out, code = rclone("copyto", modelpath, newmodelpath, check=False, output="raw")
    print(f"rclone_exit_code={repr(code)}, rclone_output={repr(out)}")
    modelpath = newmodelpath
print(f"\n\nUsing network snapshot: {repr(modelpath)}...")

In [None]:
#@title <big><big><big>Prepare **Dataset**</big></big></big> { vertical-output: true, display-mode: "form" }

#@markdown Prepare the training image set by downloading it or optionally generating new images from the github repo of large source images. First check google drive for dataset.zip, and use that if found. If not, optionally generate a new dataset to use and upload it for future runs. 

#@markdown Options:
generate_missing_dataset = False #@param {type:"boolean"}
#@markdown > Check this box to enable generating a new dataset from the images and script in my [pixelscapes-dataset repo](https://github.com/un1tz3r0/pixelscapes-dataset.git)
force_regenerate_dataset = False #@param {type:"boolean"}
#@markdown > Check this box to skip checking google drive for dataset.zip, and rebuild a new dataset from the pixelscapes-dataset repo's source images and randomcrops.py script. When done, the new dataset will be uploaded via rclone to Google Drive (as dataset.zip, an existing dataset.zip, if present, will be backed up)
generate_dataset_count = 200000 #@param {type:"integer"}
#@markdown > Size of the dataset to generate, in number of training images. These will be random crops from the source images, weighted by relative size so all pixels contribute equally to the training. When generating a new dataset from source images. output this many randomly cropped squares
upscale_factor =  2.0#@param {type:"number", min:1.0, max:4.0}
#@markdown > Zoom original images using a pretrained superresolution model with RealESRGAN by this factor before randomly cropping.
weighting_amount =  0.25 #@param {type:"number", min:0.0, max:1.0}
#@markdown > Amount of weighting based on source image size to use when sampling source images. 1.0=probability is proportional to ${width} \times {height}$, 0.0 = even probability
unzip_dataset = True #@param {type: "boolean"}
#@markdown > Extract the dataset.zip to /content/dataset (needs patched StyleGAN3 train.py, which is used by this notebook already.)

# -----------------------------------------------------------------------------------------------
if 0: #notebook_mode != "Training":
  # never get here, we use the training images for projection targets in the inference section below, so prepare the dataset either way!
  print("No action taken, training is not enabled for this notebook_mode.")
else:
  # ---------------------------------------------------------------------------------------------

  datasetpath = None #"/content/drive/dataset.zip"

  if force_regenerate_dataset:
    print("Forcing regeneration of dataset.zip, will back up existing dataset.zip on drive first...")
    if len(rclonels("driveapi:/dataset.zip")) > 0:
      generate_missing_dataset = True
      import fnmatch, re
      datasets = [f['Name'] for f in rclonels("driveapi:/") if fnmatch.fnmatch(f['Name'], 'dataset-*.zip')]
      datasetnumbers = [int(m.group(1)) for m in [re.match("^dataset-([0-9]+)\.zip$", f) for f in datasets] if m != None]
      if len(datasetnumbers) < 1:
        nextdatasetbackupnumber = 1
      else:
        nextdatasetbackupnumber = max(datasetnumbers) + 1
      print(f"Renaming existing dataset.zip on drive to dataset-{nextdatasetbackupnumber}.zip before generating new dataset...")
      rclone("moveto", "driveapi:/dataset.zip", f"driveapi:/dataset-{nextdatasetbackupnumber}.zip", check=True)
    if pathlib.Path("/content/dataset.zip").exists():
      print("Removing existing dataset.zip since we are about to generate a new one!")
      pathlib.Path("/content/dataset.zip").unlink()
      print("... removed /content/dataset.zip!")
  elif not os.path.exists("/content/dataset.zip"):
    # okay, first we check if there is a dataset.zip on google drive
    print("Checking drive for existing dataset.zip to download...")
    if len(rclonels("driveapi:/dataset.zip")) > 0:
      # if it exists, download dataset from drive to local storage for speed
      print("Copying dataset.zip from drive to /content...")
      resultcode = rclone("copyto", "--progress", "--stats=2s", "driveapi:/dataset.zip", "/content/dataset.zip", output="pass", check=False)
      datasetpath = "/content/dataset.zip"
      datasetbytes = pathlib.Path(datasetpath).stat().st_size
      print(f"... downloaded {(datasetbytes//1024//1024//10.24)/100} GB {datasetpath} from drive.")
    else:
      # looks like the dataset.zip is missing from google drive, oops
      print("No dataset.zip found on Google Drive! Nothing synced.")

  # check if we need to generate a dataset from sources
  if ((not os.path.exists("/content/dataset.zip")) and \
      generate_missing_dataset) or force_regenerate_dataset:
    print("Generating missing dataset.zip from github repo now!")
    # create a new dataset.zip from the source images and randomcrop script in our github repo
    %cd '/content/'

    def install_realesrgan():
      %cd '/content/'
      # Clone Real-ESRGAN and enter the Real-ESRGAN
      !git clone https://github.com/xinntao/Real-ESRGAN.git
      %cd Real-ESRGAN
      !pip3 install -r requirements.txt
      !python3 setup.py develop --user
      # Download the pre-trained model
      !wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models

    def upscale_all(inputdir, outputdir, factor = 2.0):
      import pathlib
      outpath = pathlib.Path(outputdir).absolute()
      if not (outpath.exists() and outpath.is_dir()):
          outpath.mkdir()
      infiles = [f.absolute() for f in pathlib.Path(inputdir).glob("*.png")]
      for f in infiles:
        fin = str(f)
        fo = outpath / f.name
        if fo.exists():
            print(f"Skipping existing output file: {fo}")
            continue
            #fo = outpath / f"{fo.stem}-out.{fo.suffix}"
        fout = str(fo.parent.absolute())
        print(f"Upscaling x{factor}: {fin} -> {fout}")
        %cd /content/Real-ESRGAN
        !python3 inference_realesrgan.py  -i $fin --outscale $factor -o $fout -n RealESRGAN_x4plus

    import os, pathlib
    if not pathlib.Path("pixelscapes-dataset").exists():
      !git clone https://github.com/un1tz3r0/pixelscapes-dataset.git
    else:
      %cd /content/pixelscapes-dataset
      !git diff --no-ext-diff --quiet --exit-code || rm -Rf cropped ../dataset.zip
      !git pull
    
    %cd /content
    if not os.path.exists("/content/pixelscapes-dataset/scaled/"):
    #if True:
      print(">>> Installing upscaler network to zoom 2x source images")
      install_realesrgan()
      print(">>> Upscaling raw dataset images...")
      upscale_all("/content/pixelscapes-dataset/pixelscapes/", \
                  "/content/pixelscapes-dataset/scaled", upscale_factor)
      print(">>> Done, now cropping from upscaled images...")
    
    %cd /content
    if not os.path.exists("/content/pixelscapes-dataset/cropped"):
      !python3 pixelscapes-dataset/randomcrops.py \
        pixelscapes-dataset/scaled \
        pixelscapes-dataset/cropped \
        --jsonout pixelscapes-dataset/cropped/dataset.json \
        --count $generate_dataset_count \
        --size 256 --weighting $weighting_amount

    !python3 /content/stylegan3/dataset_tool.py \
      --source=pixelscapes-dataset/cropped \
      --dest=dataset.zip \
      --resolution='256x256'

    datasetpath = "/content/dataset.zip"

    # upload the newly created dataset
    print("Syncing dataset.zip to drive...")
    resultcode = rclone("syncto", "--progress", "/content/dataset.zip", "driveapi:/", output="pass", check=False)
    if resultcode != 0:
      print(f"... not synced, result code is {resultcode}")
    else:
      print("ok")

  !rm -rf /content/dataset
  # unzip the dataset.zip if needed and we have one
  if unzip_dataset and pathlib.Path("/content/dataset.zip").exists() and not (pathlib.Path("/content/dataset").exists() and pathlib.Path("/content/dataset").is_dir()):
    print("Unzipping dataset.zip to /content/dataset/...")
    import ipywidgets
    from IPython import display
    outw = widgets.Output()
    display.display(outw)
    import subprocess
    p = subprocess.Popen(["unzip", "-d/content/dataset", "/content/dataset.zip"], stdin=subprocess.PIPE, stdout=subprocess.PIPE)
    buf = bytes()
    lineno = 0
    while p.returncode is None:
      o, e = p.communicate(bytes())
      buf = buf + o
      lines = buf.splitlines()
      buf = lines[-1]
      lines = lines[0:-1]
      for line in lines:
        lineno = lineno + 1
        if lineno > 100:
          lineno = 0
          outw.clear_output(wait=True)
          with outw:
            print(line.strip().decode("utf8"), flush=True)
    print("done!")
    datasetpath = "/content/dataset"


# <big><big><big><big><big>**Training**</big></big></big></big></big>
- From scratch
- Resume training a model snapshot

In [None]:
#@title # <big><big><big>Run train.py</big></big></big> { vertical-output: true, display-mode: "form" }
#@markdown ### `train.py` Options (see `train.py --help`):


upload_to_subdir = "fromscratchconditionalone" #@param {type: "string"}
#@markdown > when copying output files to google drive, place them in this subdirectory, so as not to overwrite other training runs. allows substitutions of the form *\{varname\}*, where *varname* is one of:
#@markdown > - `{gamma}` the R1 regularisation parameter, $\gamma$

cfg = "stylegan3-t" #@param ["stylegan2", "stylegan3-t", "stylegan3-r"]

kimg =    2000#@param {type:"integer"}
#@markdown > Fine-tune the pre-trained model for an additional ${kimg} \times 10^3$ iterations.

tick =  5#@param {type:"integer"}
#@markdown > Log status after every ${tick}$ kimgs during training

snap =    10 #@param {type:"integer"}
#@markdown > Save a model snapshot each ${snapshot\_ticks}$ during training

img_snap = 1 #@param {type:"integer"}
#@markdown > ${tick}$ per fakes png saved during training

gamma_values = "6" #@param {type: "string"}
#@markdown > critical hyperparameter ${\gamma}$ is the R1 regularisation rate factor for the mapping network in the discriminator. -1 to usethe default heuristic based on $gpus$, 
#@markdown > 
#@markdown > **for hyperparameter searching** give multiple, space separated ${\gamma}$ values, and use `{gamma}` in your `upload_to_subdir` folder name, and training will be run with each ${\gamma}$ value and the results uploaded to separate folders for analysis.

ema_factor_values = "1" #@param {type: "string"}
#@markdown > critical hyperparameter ${ema}_{factor}$ scales the default heuristic value for the generator weights exponential moving average, which smooths large gradients during the training process and stabilizes things significantly. <1 

freezed =   -1#@param {type: "integer", min: 0}
#@markdown > freeze first ${freezed}$ layers of generator network (mostly useful for transfer learning)
batch =  32#@param {type: "integer"}
#@markdown > ${batch}$ is the training batch size 

batch_gpu =  16 #@param {type: "integer"}
#@markdown > ${batch}_{GPU}$ is the per-gpu training batch size. if this is less than ${batch} \div {gpus}$

mbstd_group =  -1#@param {type: "integer"}
#@markdown > ${mbstd}_{group}$ is the mini-batch size. reduce it along with ${batch}_{GPU}$ if you run out of GPU memory.

half_cbase = True #@param {type: "boolean"}
#@markdown > Set ${cbase}$ to half the usual channel count (affects the size of the neural network layers), to reduce GPU memory use.

mirror = False #@param {type: "boolean"}
#@markdown > Augment the dataset by randomly flipping the images about the y axis centerline.

image_snap_res = "4k" #@param ["1080p", "4k", "8k"]
#@markdown > the size of the fakes.png image grids written every ${image\_ticks} \times {tick\_kimg}$ during training

aug = "ada" #@param ["noaug", "ada", "fixed"]

augpipe = "blit" #@param ["none", "bgc", "bc", "blit"] {allow-input: true}

metrics = "none" #@param ["none", "fid50k_full"]


seed =      1983#@param {type: "integer"}

#@markdown > seed for random values, use consistent seed for deterministic, replicatable training runs
#@markdown ---

# -----------------------------------------------------------------------------------------------
if notebook_mode != "Training":
  print("No action taken, training is not enabled for this notebook_mode.")
else:
  # ---------------------------------------------------------------------------------------------

  # if gamma <= 0:
  #   gamma = None
  # if mbstd_group <= 0:
  #   mbstd_group = None
  # if batch_gpu <= 0:
  #   batch_gpu = None
  # if batch <= 0:
  #   batch = None

  import sys, os, pathlib, glob, re

  # resume training with the network-snapshot-######.pkl we downloaded above

  if modelpath != None and modelpath != 'none':
    resume = modelpath
    # determine the kimg that the model we are fine-tuning has already been trained for
    # by extracting it from the filename. TODO: get this and other parameters from the
    # training_options.json
    try:
      resume_kimg = int(re.match(".*-(\d+)\..*?$", pathlib.Path(resume).name).group(1))
    except:
      resume_kimg = 500
  else:
    resume = -1 #"https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhqu-256x256.pkl"
    resume_kimg = -1 #500

  # hpar fast orig
  if (not isinstance(resume, int)) or resume != -1:
    resume_kimg = int(((resume_kimg + tick-1) // tick) * tick)
    kimg = int((int(resume_kimg + kimg + tick-1)//tick)*tick) # 500

    print(f'resuming from network snapshot pkl: {resume}')
    print(f'resuming from kimg count: {resume_kimg}')
  else:
    print(f'training from scratch!')
  print(f"running until kimg: {kimg}")

  %cd /content

  if False:
    %cd /content/stylegan3

    if 'train' in sys.modules.keys():
      del sys.modules['train']
    if 'train' in locals().keys():
      del train
    import train

  def runtraining(outdirname="training-runs", gamma=-1, ema_factor=-1, **overrides):
    import shlex


    def kwargstodict(*dicts, **kwargs):
      import shlex
      outdict = {}
      for d in dicts:
        for k,v in d.items():
          outdict[k] = v
      for k,v in kwargs.items():
        if callable(v):
          if k in outdict.keys():
            outdict[k] = v(outdict[k])
          else:
            outdict[k] = v()
        else:
          outdict[k] = v
      return outdict
    
    def quoteargs(args):
      outargs = []
      for k,v in args.items():
        if v == None or v == '' or (isinstance(v, (int, float)) and v < 0):
          print(f"quoteargs(): ignoring deleted or empty or negative-integer long-option {repr(k)} with value {repr(v)}!")
          continue
        outargs.append(f"--{k.replace('_','-')}={shlex.quote(str(v))}")
      return " ".join(outargs)

    args=kwargstodict(
        outdir='/content/training-runs', 
        data="/content/dataset", 
        resume=resume,
        resume_kimg=resume_kimg, 
        cfg=cfg, 
        cond=True,
        gpus=1, 
        workers=1,
        kimg=kimg, 
        gamma=gamma,
        ema_factor=ema_factor,
        batch=batch, 
        batch_gpu=batch_gpu, 
        mbstd_group=mbstd_group,
        aug=aug,
        augpipe=augpipe, 
        mirror=mirror, 
        tick=tick, 
        snap=snap, 
        img_snap=img_snap, 
        snap_res=image_snap_res, 
        freezed=freezed if freezed > 0 else None,
        cbase=16384 if half_cbase else None,
        seed=seed,
        metrics=metrics
    )

    # apply any overrides over defaults
    args = kwargstodict(args, **overrides)  
    
    # add args for uploading output files to google drive using outdirname
    args=kwargstodict(
        args,
        img_cmd='rclone --config=/content/rclone.config copy "$1" driveapi:/'+shlex.quote(outdirname)+'/training-run-"$DESC"/; echo "$1"',
        snap_cmd='rclone --config=/content/rclone.config copy "$1" driveapi:/'+shlex.quote(outdirname)+'/training-run-"$DESC"/; echo "$1"'
    )
    
    # produce a shell-quoted string with long-option-style args from the dict
    qargs = quoteargs(args)

    print(f"/content/stylegan3/train.py {repr(qargs)}")  
    
    # Fine-tune StyleGAN3-T for pixelscapes-256-50k using 1 GPU, starting from the pre-trained FFHQ-U pickle.
    from IPython.display import display
    from ipywidgets.widgets import Box, Output, Image
    oimage = Image(value=bytes(),layout={'border': '1px solid black'})
    #olog = Output(layout={'border': '1px solid black'})
    #obox = Box(children=(oimage, olog))
    #display(olog)
    display(oimage)

    import subprocess, select, re
    p = subprocess.Popen(f"python3 /content/stylegan3/train.py {qargs}", shell=True, stdin=subprocess.DEVNULL, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
    done = False
    
    #with olog:
    if True:
      print("running train.py script...")

    outdir = None
    lastpng = None
    pngname = None

    try:
      while not done:
        rfds, wfds, xfds = select.select([p.stdout.fileno()], [], [])
        if p.poll() != None:
          # process 
          done=True
        else:
        #with olog:
          #if out == '' and err == '':
          #  resultcode = p.wait()
          #  done = True
          #  continue
          if p.stdout.fileno() in rfds:
            line = p.stdout.readline().strip()
            m = re.match("^(/content/.*\.png);.*$", line)
            if m != None:
              pngname = m.group(1)
              oimg = Image(value=bytes())
              oimg.set_value_from_file(str(pngname))
              display(oimg)
            #if lastpng != pngname:
            #  lastpng = pngname
            #  pngpath = pathlib.Path(lastpng)
            #  oimage.set_value_from_file(str(pngpath))
            #  display(oimage)
            elif m == None:
              print(line)

          #if p.stderr.fileno() in rfds:
          #  print(f"{repr(p.stderr.read(1))}", end='')
        if len(rfds) == 0 or p.poll() != None:
          done=True
    except KeyboardInterrupt:
      print("*** keyboardinterrupt in runtraining()")
      pass
    except Exception as err:
      import traceback as tb
      print(f"exception thrown in runtraining(): {err}")
      tb.print_exc()
    finally:
      try:
        if p.poll() == None:
          p.kill()
        else:
          p.wait(timeout=2.0)
      except:
        pass

  try:
    import re
    gammas = [float(word) for word in re.split("[, \t;]+", gamma_values)]
  except Exception as err:
    print(f"Error parsing list of gamma values to run training on:  {err}")
    exit(1)

  try:
    import re
    ema_factors = [float(word) for word in re.split("[, \t;]+", ema_factor_values)]
  except Exception as err:
    print(f"Error parsing list of ema_factor values to run training on:  {err}")
    exit(1)

  try:
    for ema_factor in ema_factors:
      for gamma in gammas:
        print(f"*** training run with gamma={gamma} and ema_factor={ema_factor}***")
        runtraining(
          outdirname = upload_to_subdir,
          gamma = float(gamma),
          ema_factor = float(ema_factor)
          #tick=1, kimg=2, gamma=lambda x: x*2
        )
        print("*** training run ended ***")
  except Exception as err:
    import traceback as tb
    print(f"exception thrown in training main hyperparameter search loop: {err}")
    tb.print_exc()

# <big><big><big><big><big>Inference</big></big></big></big></big>


These cells are a WIP interactive latent space explorer, allowing interpolation between two seed/class latent space points, and projection of real samples from the dataset or uploaded images, and interpolation between the projected latent space coordinates.

In [None]:
#@title Real-time Inference { vertical-output: true, display-mode: "form" }

%cd /content/stylegan3/

import os
from typing import List, Optional, Union, Tuple
import click

import dnnlib
from torch_utils import gen_utils

import scipy
import numpy as np
import PIL.Image
import torch

import legacy
import projector
from training.dataset import ImageFolderDataset

dataset = ImageFolderDataset(datasetpath)

def init_model(network_pkl):
	device = torch.device('cuda')
	with dnnlib.util.open_url(network_pkl) as f:
		G = legacy.load_network_pkl(f)['G_ema'].to(device)  # type: ignore
	gen_utils.anchor_latent_space(G)
	return device, G

def pil_to_jpeg(im):
  from io import BytesIO
  with BytesIO() as o:
    im.save(o, "JPEG")
    return o.getvalue()

def latent_from_seed(device, G, seed, class_idx):
	if G.c_dim != 0:
		c = np.zeros([G.c_dim])
		c[class_idx] = 1
	else:
		c = None
	z = np.random.RandomState(seed).randn(G.z_dim)
	return z, c	

def latent_from_image(device, G, img, class_idx):
	return projector.project(device, G, img)[-1]

def interpolate_z(device, G, z0, z1, interp):
	return gen_utils.slerp(interp, z0, z1)

def interpolate_c(device, G, c0, c1, interp):
	if c0 is not None and c1 is not None:
		return c1 * interp + c0 * (1.0-interp)
	return None

def render_image(device, G, 
		seeds,
		class_idxes,
		truncation_psi = 1.0,
		noise_mode = 'const',
		grid_rows = 1,
		grid_cols = 1,
		interpz = 0.0,
		interpl = 0.0
		):
	try:
		
		if len(seeds) == 1:
			seed0 = seeds[0]
			seed1 = seeds[0]
		else:
			seed0 = seeds[0]
			seed1 = seeds[1]

		if class_idxes != None:
			if G.c_dim == 0:
				raise RuntimeError("Error, cannot specify class for unconditional network")
			if len(class_idxes) == 1:
				class_idx0 = class_idxes[0]
				class_idx1 = class_idxes[0]
			else:
				class_idx0 = class_idxes[0]
				class_idx1 = class_idxes[1]
		else:
			if G.c_dim != 0:
				raise ValueError("Error, must specify class for conditional network")
			class_idx0, class_idx1 = None, None

		# generate latent z interpolation
		z0, c0 = latent_from_seed(device, G, seed0, class_idx0)
		z1, c1 = latent_from_seed(device, G, seed1, class_idx1)

		if G.c_dim != 0:
			cs = np.zeros([grid_rows*grid_cols, G.c_dim])
			for col in range(0,grid_cols):
				a = (col/max(grid_cols-1, 1)) + interpl
				ci = interpolate_c(device, G, c0, c1, interpl)
				for row in range(0,grid_rows):
					cs[row*grid_cols+col, :] = ci
		else:
			cs = None
		
		c = None
		if cs is not None:
			c = torch.from_numpy(cs).to(device)

		zs = np.zeros([grid_rows * grid_cols, G.z_dim])
		for row in range(0, grid_cols):
			a = row/max(1,(grid_rows-1)) + interpz
			zi = interpolate_z(device, G, z0, z1, a)
			for col in range(0, grid_rows):
				zs[row * grid_cols + col, :] = zi
		z = torch.from_numpy(zs).to(device)
		
		imgs = gen_utils.z_to_img(G, z, c, truncation_psi, noise_mode)
		img = gen_utils.create_image_grid(imgs, (grid_rows, grid_cols))
		im = PIL.Image.fromarray(img, 'RGB')
		# from io import BytesIO
		# with BytesIO() as stream:
		# 	im.save(stream, "JPEG")
		# 	return stream.getvalue()
		return im
	except Exception as err:
		raise err


import IPython
import ipywidgets as widgets

def label_on_left(label="", widget=None, parent=None):
	child = widgets.HBox(children=[widgets.Label(value=label), widget])
	if parent != None:
		parent.children.append(child)
	else:
		return child
	return widget

seed_a_input = widgets.IntText(value=0, label="Seed A")
seed_b_input = widgets.IntText(value=0, label="Seed B")
class_a_input = widgets.IntSlider(value=0, min=0, max=70, label="Class A")
class_b_input = widgets.IntSlider(value=0, min=0, max=70, label="Class B")

interp_seed_input = widgets.FloatSlider(value=0, min=0, max=1, label="Seed")
interp_class_input = widgets.FloatSlider(value=0, min=0, max=1, label="Class")
inputs_a = widgets.VBox(children=[
	  label_on_left("Seed", seed_a_input), 
		label_on_left("Class", class_a_input)
	], layout=widgets.Layout(border="1px solid white"))

inputs_b = widgets.VBox(children=[
	  label_on_left("Seed", seed_b_input), 
		label_on_left("Class", class_b_input)
	], layout=widgets.Layout(border="1px solid white"))

inputs_i = widgets.VBox(children=[interp_seed_input, interp_class_input], layout=widgets.Layout(border="1px solid white"))

inputs = widgets.VBox(children=[
	  label_on_left("Input A", inputs_a), 
		label_on_left("Fade A->B", inputs_i), 
		label_on_left("Input B", inputs_b)
])
outputim = widgets.Image(value=bytes(), format="jpeg")

projectedim0 = widgets.Image(value=bytes(), format="jpeg")
realindex0 = widgets.IntText(value=0)
realclass0 = widgets.Label(value="0")
uploadim0 = widgets.FileUpload(multiple=False, accepts="image/*")
projected0box = widgets.VBox(children=[projectedim0, uploadim0, realindex0, realclass0], layout=widgets.Layout(border="1px solid white"))
def onrealindex0changed(evt):
	global dataset
	global realindex0
	global projectedim0
	singim, cls = dataset[realindex0.value]
	im = np.zeros([1, *singim.shape])
	im[0] = singim
	img = singim
	pilim = PIL.Image.fromarray(img, 'RGB')
	projectedim0.value = pil_to_jpeg(pilim)
	realclass0.value = str(cls)
realindex0.observe(onrealindex0changed, ['value'])

projectedim1 = widgets.Image(value=bytes(), format="jpeg")
realindex1 = widgets.IntText(value=0)
realclass1 = widgets.Label(value="0")
uploadim1 = widgets.FileUpload(multiple=False, accepts="image/*")
projected1box = widgets.VBox(children=[projectedim1, uploadim1, realindex1, realclass1], layout=widgets.Layout(border="1px solid white"))
def onrealindex1changed(evt):
	global dataset
	global realindex1
	global projectedim1
	singim, cls = dataset[realindex1.value]
	im = np.zeros([1, *singim.shape])
	im[0] = singim
	img = gen_utils.create_image_grid(im[..., :, :, :], (1,1))
	pilim = PIL.Image.fromarray(img, 'RGB')
	projectedim1.value = pil_to_jpeg(pilim)
	realclass1.value = str(cls)
realindex1.observe(onrealindex1changed, ['value'])

mainproj = widgets.HBox(children=[projected0box, projected1box])
mainio = widgets.HBox(children=[inputs, outputim])
main = widgets.VBox(children=[mainio, mainproj])

print("loading model...")
device, G = init_model(modelpath)


def refresh(evt,*args, **kwargs):
  global outputim
  result = render_image(device, G, [seed_a_input.value, seed_b_input.value],
                 [class_a_input.value, class_b_input.value],
                 interpz = interp_seed_input.value,
                 interpl = interp_class_input.value)
  #print(result)
  outputim.value=pil_to_jpeg(result)
  #main.children = [inputs, output]
  

def hookupinputs(widget):
  if 'children' in widget.keys:
    for child in widget.children:
      hookupinputs(child)
  else:
    widget.unobserve_all()
    widget.observe(refresh, ['value'])

hookupinputs(inputs)
IPython.display.display(main)
refresh(None)

In [None]:
#!curl -X GET --unix-socket /content/server.sock http://lodqlhost/
#!curl -X GET --unix-socket /content/server.sock http://lodqlhost/load?network_pkl=$modelpath

In [None]:
'''
# -----------------------------------------------------------------------------------------------
if notebook_mode != "Inference":
  print("No action taken, inference is not enabled for this notebook_mode.")
else:
  # ---------------------------------------------------------------------------------------------
  
  #!python3 /content/stylegan3/generate.py random-video -anchor --seeds=1 \
  #  --network=$modelpath --class-interp=30 --class-seed=1
  #   --out=lerpout --trunc=1 --seeds=0-31 --grid=4x2 \
  #   --network=$modelpath --anchor-latent-space --class=1
  %pip install aiohttp
  %cd /content/stylegan3
  import subprocess
  if 'server_proc' in globals().keys():
    server_proc.kill()
    !rm /content/server.sock
  #server_proc = subprocess.Popen(["python3", "-m", "aiohttp.web", "-U", "/content/server.sock", "gen_server:init_app"])
  server_proc = subprocess.Popen(["python3", "gen_server.py"])
  print(server_proc.pid)
  #!python3 -m aiohttp.web -U /content/server.sock gen_server:init_app &

  #!python3 /content/stylegan3/gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --labels=0 \
  #      --network=$modelpath --stabilize-video
'''