<a href="https://colab.research.google.com/github/patty1997/Deep_Learning_Projects/blob/GANs_Pytorch_Implementation/Tomato_late_blight_projected_gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Training Projected GAN
This is a self-contained notebook for training Projected GAN.

## Setup

Make sure you're running a GPU runtime; if not, select "GPU" as the hardware accelerator in Runtime > Change Runtime Type in the menu. 

Now, get the repo and install missing dependencies.

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%%capture
%%bash
# clone repo
git clone https://github.com/autonomousvision/projected_gan
pip install timm dill

In [3]:
%cd projected_gan

/content/projected_gan


## Data Preparation
We need to download and prepare the data. In this example, we use the few-shot datasets of the [FastGAN repo](https://github.com/odegeasslbc/FastGAN-pytorch).

In [4]:
# !gdown https://drive.google.com/u/0/uc?id=1aAJCZbXNHyraJ6Mi13dSbe7pTyfPXha0&export=download

In [5]:
# %%capture
# !unzip few-shot-image-datasets.zip
# !mv few-shot-images data

In [6]:
%%bash
python dataset_tool.py --source=/content/drive/MyDrive/Plant_Dataset/Tomato_Late_blight --dest=/content/tomato_late_blight256.zip --resolution=256x256

  0%|          | 0/1932 [00:00<?, ?it/s]  0%|          | 1/1932 [00:00<06:28,  4.98it/s]  0%|          | 2/1932 [00:00<07:20,  4.38it/s]  0%|          | 3/1932 [00:00<07:17,  4.41it/s]  0%|          | 4/1932 [00:01<09:16,  3.47it/s]  0%|          | 5/1932 [00:01<08:09,  3.93it/s]  0%|          | 6/1932 [00:01<07:59,  4.02it/s]  0%|          | 7/1932 [00:01<07:53,  4.07it/s]  0%|          | 8/1932 [00:01<07:44,  4.14it/s]  0%|          | 9/1932 [00:02<09:10,  3.50it/s]  1%|          | 10/1932 [00:02<09:00,  3.55it/s]  1%|          | 11/1932 [00:02<08:14,  3.88it/s]  1%|          | 12/1932 [00:03<08:17,  3.86it/s]  1%|          | 13/1932 [00:03<07:48,  4.10it/s]  1%|          | 14/1932 [00:03<07:51,  4.07it/s]  1%|          | 15/1932 [00:03<07:43,  4.14it/s]  1%|          | 16/1932 [00:03<07:30,  4.25it/s]  1%|          | 17/1932 [00:04<07:07,  4.48it/s]  1%|          | 18/1932 [00:04<06:56,  4.60it/s]  1%|          | 19/1932 [00:04<06:56,  4.59it/s]  1%|          | 

## Training

Now that the data is prepared, we can start training!  The training loop tracks FID, but the computations seems to lead to problems in colab. Hence, it is disable by default (```metrics=[]```). The loop also generates fixed noise samples after a defined amount of ticks, eg. below ```--snap=1```.

In [None]:
import os
import json
import re
import dnnlib

from training import training_loop
from torch_utils import training_stats
from train import init_dataset_kwargs
from metrics import metric_main

In [None]:
def launch_training(c, desc, outdir, rank=0):
    # Pick output directory.
    prev_run_dirs = []
    if os.path.isdir(outdir):
        prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))]

    matching_dirs = [re.fullmatch(r'\d{5}' + f'-{desc}', x) for x in prev_run_dirs if re.fullmatch(r'\d{5}' + f'-{desc}', x) is not None]
    if c.restart_every > 0 and len(matching_dirs) > 0:  # expect unique desc, continue in this directory
        assert len(matching_dirs) == 1, f'Multiple directories found for resuming: {matching_dirs}'
        c.run_dir = os.path.join(outdir, matching_dirs[0].group())
    else:                     # fallback to standard
        prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
        prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
        cur_run_id = max(prev_run_ids, default=-1) + 1
        c.run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{desc}')
        assert not os.path.exists(c.run_dir)


    # Print options.
    print()
    print('Training options:')
    print(json.dumps(c, indent=2))
    print()
    print(f'Output directory:    {c.run_dir}')
    print(f'Number of GPUs:      {c.num_gpus}')
    print(f'Batch size:          {c.batch_size} images')
    print(f'Training duration:   {c.total_kimg} kimg')
    print(f'Dataset path:        {c.training_set_kwargs.path}')
    print(f'Dataset size:        {c.training_set_kwargs.max_size} images')
    print(f'Dataset resolution:  {c.training_set_kwargs.resolution}')
    print(f'Dataset labels:      {c.training_set_kwargs.use_labels}')
    print(f'Dataset x-flips:     {c.training_set_kwargs.xflip}')
    print()

    # Create output directory.
    print('Creating output directory...')
    os.makedirs(c.run_dir, exist_ok=c.restart_every > 0)
    with open(os.path.join(c.run_dir, 'training_options.json'), 'wt+') as f:
        json.dump(c, f, indent=2)

    # Start training
    dnnlib.util.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=False)
    sync_device = torch.device('cuda', rank) if c.num_gpus > 1 else None
    training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
    training_loop.training_loop(rank=rank, **c)

In [None]:
def train(**kwargs):
    # Initialize config.
    opts = dnnlib.EasyDict(kwargs) # Command line arguments.
    c = dnnlib.EasyDict() # Main config dict.
    c.G_kwargs = dnnlib.EasyDict(class_name=None, z_dim=64, w_dim=128, mapping_kwargs=dnnlib.EasyDict())
    c.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8)
    c.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8)
    c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, prefetch_factor=2)

    # Training set.
    c.training_set_kwargs, dataset_name = init_dataset_kwargs(data=opts.data)
    if opts.cond and not c.training_set_kwargs.use_labels:
        raise ValueError('--cond=True requires labels specified in dataset.json')
    c.training_set_kwargs.use_labels = opts.cond
    c.training_set_kwargs.xflip = opts.mirror

    # Hyperparameters & settings.
    c.num_gpus = opts.gpus
    c.batch_size = opts.batch
    c.batch_gpu = opts.batch_gpu or opts.batch // opts.gpus
    c.G_kwargs.channel_base = opts.cbase
    c.G_kwargs.channel_max = opts.cmax
    c.G_kwargs.mapping_kwargs.num_layers = 2
    c.G_opt_kwargs.lr = (0.002 if opts.cfg == 'stylegan2' else 0.0025) if opts.glr is None else opts.glr
    c.D_opt_kwargs.lr = opts.dlr
    c.metrics = opts.metrics
    c.total_kimg = opts.kimg
    c.kimg_per_tick = opts.tick
    c.image_snapshot_ticks = c.network_snapshot_ticks = opts.snap
    c.random_seed = c.training_set_kwargs.random_seed = opts.seed
    c.data_loader_kwargs.num_workers = opts.workers

    # Sanity checks.
    if c.batch_size % c.num_gpus != 0:
        raise ValueError('--batch must be a multiple of --gpus')
    if c.batch_size % (c.num_gpus * c.batch_gpu) != 0:
        raise ValueError('--batch must be a multiple of --gpus times --batch-gpu')
    if any(not metric_main.is_valid_metric(metric) for metric in c.metrics):
        raise ValueError('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))

    # Base configuration.
    c.ema_kimg = c.batch_size * 10 / 32
    if opts.cfg == 'stylegan2':
        c.G_kwargs.class_name = 'pg_modules.networks_stylegan2.Generator'
        c.G_kwargs.fused_modconv_default = 'inference_only' # Speed up training by using regular convolutions instead of grouped convolutions.
        use_separable_discs = True

    elif opts.cfg == 'fastgan':
        c.G_kwargs = dnnlib.EasyDict(class_name='pg_modules.networks_fastgan.Generator', cond=opts.cond)
        c.G_opt_kwargs.lr = c.D_opt_kwargs.lr = 0.0002
        use_separable_discs = False

    # Restart.
    c.restart_every = opts.restart_every

    # Description string.
    desc = f'{opts.cfg:s}-{dataset_name:s}-gpus{c.num_gpus:d}-batch{c.batch_size:d}'
    if opts.desc is not None:
        desc += f'-{opts.desc}'

    # Projected and Multi-Scale Discriminators
    c.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.ProjectedGANLoss')
    c.D_kwargs = dnnlib.EasyDict(
        class_name='pg_modules.discriminator.ProjectedDiscriminator',
        diffaug=True,
        interp224=(c.training_set_kwargs.resolution < 224),
        backbone_kwargs=dnnlib.EasyDict(),
    )

    c.D_kwargs.backbone_kwargs.cout = 64
    c.D_kwargs.backbone_kwargs.expand = True
    c.D_kwargs.backbone_kwargs.proj_type = 2
    c.D_kwargs.backbone_kwargs.num_discs = 4
    c.D_kwargs.backbone_kwargs.separable = use_separable_discs
    c.D_kwargs.backbone_kwargs.cond = opts.cond

    # Launch.
    launch_training(c=c, desc=desc, outdir=opts.outdir)

In [None]:
# start training!

train(
    outdir='/content/drive/MyDrive/ProjectedGANResults', 
    cfg='fastgan',
    data='/content/tomato_late_blight256.zip', 
    gpus=1, 
    batch=64, 
    cond=False, 
    mirror=1, 
    batch_gpu=8, 
    cbase=32768, 
    cmax=512, 
    glr=None, 
    dlr=0.002, 
    desc='', 
    metrics=[],
    kimg=10000, 
    tick=4, 
    snap=1, 
    seed=0, 
    workers=0,
    restart_every=999999,
)


Training options:
{
  "G_kwargs": {
    "class_name": "pg_modules.networks_fastgan.Generator",
    "cond": false
  },
  "G_opt_kwargs": {
    "class_name": "torch.optim.Adam",
    "betas": [
      0,
      0.99
    ],
    "eps": 1e-08,
    "lr": 0.0002
  },
  "D_opt_kwargs": {
    "class_name": "torch.optim.Adam",
    "betas": [
      0,
      0.99
    ],
    "eps": 1e-08,
    "lr": 0.0002
  },
  "data_loader_kwargs": {
    "pin_memory": true,
    "prefetch_factor": 2,
    "num_workers": 0
  },
  "training_set_kwargs": {
    "class_name": "training.dataset.ImageFolderDataset",
    "path": "/content/tomato_late_blight256.zip",
    "use_labels": false,
    "max_size": 1932,
    "xflip": 1,
    "resolution": 256,
    "random_seed": 0
  },
  "num_gpus": 1,
  "batch_size": 64,
  "batch_gpu": 8,
  "metrics": [],
  "total_kimg": 10000,
  "kimg_per_tick": 4,
  "image_snapshot_ticks": 1,
  "network_snapshot_ticks": 1,
  "random_seed": 0,
  "ema_kimg": 20.0,
  "restart_every": 999999,
  "loss_k

To inspect the samples, click on the folder symbol on the left and navigate to 

```projected_gan/training-runs/YOUR_RUN```

The files ```fakesXXXXXX.png``` are the samples for a fixed noise vector at point.

#Results

In [11]:
%%bash
python gen_video.py --output=lerp.mp4 --trunc=1.0 --seeds=0-31 --grid=4x2 --network=/content/drive/MyDrive/ProjectedGANResults/00002-fastgan-tomato_late_blight256-gpus1-batch64-/network-snapshot.pkl

Loading networks from "/content/drive/MyDrive/ProjectedGANResults/00002-fastgan-tomato_late_blight256-gpus1-batch64-/network-snapshot.pkl"...


  0%|          | 0/480 [00:00<?, ?it/s]  0%|          | 1/480 [00:00<03:55,  2.04it/s]  0%|          | 2/480 [00:00<02:54,  2.74it/s]  1%|          | 3/480 [00:01<02:34,  3.08it/s]  1%|          | 4/480 [00:01<02:25,  3.28it/s]  1%|          | 5/480 [00:01<02:20,  3.37it/s]  1%|▏         | 6/480 [00:01<02:16,  3.47it/s]  1%|▏         | 7/480 [00:02<02:19,  3.38it/s]  2%|▏         | 8/480 [00:02<02:16,  3.45it/s]  2%|▏         | 9/480 [00:02<02:14,  3.50it/s]  2%|▏         | 10/480 [00:03<02:13,  3.53it/s]  2%|▏         | 11/480 [00:03<02:12,  3.55it/s]  2%|▎         | 12/480 [00:03<02:10,  3.58it/s]  3%|▎         | 13/480 [00:03<02:10,  3.58it/s]  3%|▎         | 14/480 [00:04<02:10,  3.58it/s]  3%|▎         | 15/480 [00:04<02:09,  3.59it/s]  3%|▎         | 16/480 [00:04<02:08,  3.60it/s]  4%|▎         | 17/480 [00:04<02:08,  3.60it/s]  4%|▍         | 18/480 [00:05<02:08,  3.59it/s]  4%|▍         | 19/480 [00:05<02:07,  3.60it/s]  4%|▍         | 20/480 [00:05<02:07,

In [None]:
%%bash
python calc_metrics.py --metrics=fid50k_full --network=/content/drive/MyDrive/ProjectedGANResults/00002-fastgan-tomato_late_blight256-gpus1-batch64-/network-snapshot.pkl

In [None]:
!python calc_metrics.py --metrics=kid50k_full --network=/content/drive/MyDrive/ProjectedGANResults/00002-fastgan-tomato_late_blight256-gpus1-batch64-/network-snapshot.pkl

In [None]:
!python calc_metrics.py --metrics=is50k --network=/content/drive/MyDrive/ProjectedGANResults/00002-fastgan-tomato_late_blight256-gpus1-batch64-/network-snapshot.pkl

In [None]:
!python calc_metrics.py --metrics=pr50k3_full --network=/content/drive/MyDrive/ProjectedGANResults/00000-fastgan-powdery_mildew256-gpus1-batch64-/network-snapshot.pkl

Loading network from "/content/drive/MyDrive/ProjectedGANResults/00000-fastgan-powdery_mildew256-gpus1-batch64-/network-snapshot.pkl"...
Dataset options:
{
  "class_name": "training.dataset.ImageFolderDataset",
  "path": "/content/powdery_mildew256.zip",
  "use_labels": false,
  "max_size": 997,
  "xflip": 1,
  "resolution": 256,
  "random_seed": 0
}
Launching processes...

Generator              Parameters  Buffers  Output shape        Datatype
---                    ---         ---      ---                 ---     
mapping                -           -        [1, 1, 256]         float32 
synthesis.init.init    16785408    16385    [1, 2048, 4, 4]     float32 
synthesis.feat_8.0     -           -        [1, 2048, 8, 8]     float32 
synthesis.feat_8.1     37748736    20480    [1, 2048, 8, 8]     float32 
synthesis.feat_8.2     1           -        [1, 2048, 8, 8]     float32 
synthesis.feat_8.3     4096        4097     [1, 2048, 8, 8]     float32 
synthesis.feat_8.4     -           -   