<a href="https://colab.research.google.com/github/saotomryo/Use_MMGeneration/blob/main/Use_MMGeneration_Train_CycleGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
from glob import glob
print(torch.__version__)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device = ", device)

## 環境準備 MMCV MMGenerationのインストール

In [None]:
# MMCVのインストール
!pip install -U openmim
!mim install mmcv-full

In [None]:
!git clone https://github.com/open-mmlab/mmgeneration.git
%cd /content/mmgeneration
!pip install -v -e .  # or "python setup.py develop"

# コンフィグファイルの編集

CycleGANでは、変換を行いたい2種類の画像を使って学習を行うことになりますが、コンフィグファイルを確認しても、2種類の画像のフォルダパスを直接指定することは出来なさそうでした。
取りあえずは、サンプルのテストデータと同じフォルダ構成にする必要がありそうです。

画像のルートフォルダ  
|  
|  -  trainA  （グループAの学習用の画像フォルダ）  
|  
|  -  trainB　（グループBの学習用の画像フォルダ）  
|  
|  -  testA　 （グループAのテスト用の画像フォルダ）  
|  
|  -  testB　 （グループBのテスト用の画像フォルダ）  


In [None]:
from mmcv import Config
cfg = Config.fromfile('./configs/cyclegan/cyclegan_lsgan_id0_resnet_in_facades_b1x1_80k.py')

In [None]:
# データのパス
cfg.data.train.dataroot = # 画像のルートフォルダ
cfg.data.test.dataroot = # 画像のルートフォルダ
cfg.data.val.dataroot = # 画像のルートフォルダ
cfg.gpu_ids = range(0, 1)
cfg.seed = 123


print(f'Config:\n{cfg.pretty_text}')

# 学習の実施

In [None]:
import argparse
import copy
import multiprocessing as mp
import os
import os.path as osp
import platform
import time
import warnings

import cv2
import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import get_git_hash

from mmgen import __version__
from mmgen.apis import set_random_seed, train_model
from mmgen.datasets import build_dataset
from mmgen.models import build_model
from mmgen.utils import collect_env, get_root_logger

In [None]:

model = build_model(
    cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)

datasets = [build_dataset(cfg.data.train)]

timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())

meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'

meta['env_info'] = env_info
meta['config'] = cfg.pretty_text

train_model(
    model,
    datasets,
    cfg,
    distributed=False,
    timestamp=timestamp,
    meta=meta)

  'Unnecessary conv bias before batch/instance norm')
  cpuset_checked))
2022-07-25 08:26:50,685 - mmgen - INFO - Start running, host: root@9d67c8fab200, work_dir: /content/mmgeneration/work_dirs/experiments/cyclegan_facades_id0
2022-07-25 08:26:50,687 - mmgen - INFO - Hooks will be executed in the following order:
before_run:
(VERY_HIGH   ) LinearLrUpdaterHook                
(NORMAL      ) CheckpointHook                     
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
before_train_epoch:
(VERY_HIGH   ) LinearLrUpdaterHook                
(LOW         ) IterTimerHook                      
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
before_train_iter:
(VERY_HIGH   ) LinearLrUpdaterHook                
(LOW         ) IterTimerHook                      
 -------------------- 
after_train_iter:
(NORMAL      ) CheckpointHook                     
(NORMAL      ) VisualizationHook                  
(LOW         ) IterTimerHook       

# 学習結果の確認
MMGenerationの「sample_img2img_model」ではうまく動作しなかったため「sample_img2img_model2」を作成して動作させています。  
おそらくモデルの中にコンフィグを格納できていないことが原因だと思われます。

In [None]:
cfg2 = cfg

from mmgen.datasets.pipelines import Compose
from mmgen.models import BaseTranslationModel

from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
from mmcv.utils import is_list_of

def sample_img2img_model2(model, image_path, target_domain=None, **kwargs):
    """Sampling from translation models.

    Args:
        model (nn.Module): The loaded model.
        image_path (str): File path of input image.
        style (str): Target style of output image.
    Returns:
        Tensor: Translated image tensor.
    """
    assert isinstance(model, BaseTranslationModel)

    # get source domain and target domain
    if target_domain is None:
        target_domain = model._default_domain
    source_domain = model.get_other_domains(target_domain)[0]

    #cfg = model._cfg
    cfg = cfg2
    device = next(model.parameters()).device  # model device
    # build the data pipeline
    test_pipeline = Compose(cfg.test_pipeline)

    # prepare data
    data = dict()
    # dirty code to deal with test data pipeline
    data['pair_path'] = image_path
    data[f'img_{source_domain}_path'] = image_path
    data[f'img_{target_domain}_path'] = image_path

    data = test_pipeline(data)
    if device.type == 'cpu':
        data = collate([data], samples_per_gpu=1)
        data['meta'] = []
    else:
        data = scatter(collate([data], samples_per_gpu=1), [device])[0]

    source_image = data[f'img_{source_domain}']
    # forward the model
    with torch.no_grad():
        results = model(
            source_image,
            test_mode=True,
            target_domain=target_domain,
            **kwargs)
    output = results['target']
    return output

In [None]:
from mmgen.apis import sample_img2img_model

test_folder = "/content/drive/MyDrive/kaggle/monet/testB"

test_images = glob(test_folder + "/*.jpg")

m = len(test_images)

plt.figure(figsize=(24,120))

for i,image_path in enumerate(test_images):
    original_image = cv2.imread(image_path)
    original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
    plt.subplot(m,2,(i * 2) + 1)
    plt.imshow(original_image)

    # データの変換（trainAのmask画像のスタイルに画像を変換します。）
    translated_image = sample_img2img_model2(model, image_path, target_domain='mask')
    translate_image = translated_image.cpu().numpy()[0]
    translate_image = translate_image.transpose(1,2,0)
    plt.subplot(m,2,(i * 2) + 2)
    plt.imshow(translate_image)

plt.show()