<a href="https://colab.research.google.com/github/cedro3/stylized-neural-painting/blob/main/image_to_paint.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Githubからコードをコピー

In [None]:
# githubのコードをコピー
!git clone https://github.com/cedro3/stylized-neural-painting.git

In [None]:
cd stylized-neural-painting

# 学習済みモデルのダウンロード

In [None]:
import requests

def download_file_from_google_drive(id, destination):
    URL = "https://docs.google.com/uc?export=download"

    session = requests.Session()

    response = session.get(URL, params = { 'id' : id }, stream = True)
    token = get_confirm_token(response)

    if token:
        params = { 'id' : id, 'confirm' : token }
        response = session.get(URL, params = params, stream = True)

    save_response_content(response, destination)    

def get_confirm_token(response):
    for key, value in response.cookies.items():
        if key.startswith('download_warning'):
            return value

    return None

def save_response_content(response, destination):
    CHUNK_SIZE = 32768

    with open(destination, "wb") as f:
        for chunk in response.iter_content(CHUNK_SIZE):
            if chunk: # filter out keep-alive new chunks
                f.write(chunk)

In [None]:
# download and unzip...
file_id = '1sqWhgBKqaBJggl2A8sD1bLSq2_B1ScMG'
destination = './checkpoints_G_oilpaintbrush.zip'
download_file_from_google_drive(file_id, destination)

In [None]:
!unzip checkpoints_G_oilpaintbrush.zip

# コード本体

In [None]:
import argparse
import torch
torch.cuda.current_device()
import torch.optim as optim
from painter import *

# Decide which device we want to run on
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# settings
parser = argparse.ArgumentParser(description='STYLIZED NEURAL PAINTING')
args = parser.parse_args(args=[])
args.img_path = './test_images/kasumi.png' # path to input photo
args.renderer = 'oilpaintbrush' # [watercolor, markerpen, oilpaintbrush, rectangle]
args.canvas_color = 'black' # [black, white]
args.canvas_size = 512 # size of the canvas for stroke rendering'
args.max_m_strokes = 500 # max number of strokes
args.max_divide = 5 # divide an image up-to max_divide x max_divide patches
args.beta_L1 = 1.0 # weight for L1 loss
args.with_ot_loss = False # set True for imporving the convergence by using optimal transportation loss, but will slow-down the speed
args.beta_ot = 0.1 # weight for optimal transportation loss
args.net_G = 'zou-fusion-net' # renderer architecture
args.renderer_checkpoint_dir = './checkpoints_G_oilpaintbrush' # dir to load the pretrained neu-renderer
args.lr = 0.005 # learning rate for stroke searching
args.output_dir = './output' # dir to save painting results


def _drawing_step_states(pt):
    acc = pt._compute_acc().item()
    print('iteration step %d, G_loss: %.5f, step_acc: %.5f, grid_scale: %d / %d, strokes: %d / %d'
          % (pt.step_id, pt.G_loss.item(), acc,
              pt.m_grid, pt.max_divide,
              pt.anchor_id, pt.m_strokes_per_block))
    vis2 = utils.patches2img(pt.G_final_pred_canvas, pt.m_grid).clip(min=0, max=1)


def optimize_x(pt):
    pt._load_checkpoint()
    pt.net_G.eval()
    print('begin drawing...')

    PARAMS = np.zeros([1, 0, pt.rderr.d], np.float32)

    if pt.rderr.canvas_color == 'white':
        CANVAS_tmp = torch.ones([1, 3, 128, 128]).to(device)
    else:
        CANVAS_tmp = torch.zeros([1, 3, 128, 128]).to(device)

    for pt.m_grid in range(1, pt.max_divide + 1):

        pt.img_batch = utils.img2patches(pt.img_, pt.m_grid).to(device)
        pt.G_final_pred_canvas = CANVAS_tmp

        pt.initialize_params()
        pt.x_ctt.requires_grad = True
        pt.x_color.requires_grad = True
        pt.x_alpha.requires_grad = True
        utils.set_requires_grad(pt.net_G, False)

        pt.optimizer_x = optim.RMSprop([pt.x_ctt, pt.x_color, pt.x_alpha], lr=pt.lr, centered=True)

        pt.step_id = 0
        for pt.anchor_id in range(0, pt.m_strokes_per_block):
            pt.stroke_sampler(pt.anchor_id)
            iters_per_stroke = 80
            for i in range(iters_per_stroke):
                pt.G_pred_canvas = CANVAS_tmp

                # update x
                pt.optimizer_x.zero_grad()

                pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1)
                pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1)
                pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1)

                pt._forward_pass()
                _drawing_step_states(pt)
                pt._backward_x()

                pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1)
                pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1)
                pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1)

                pt.optimizer_x.step()
                pt.step_id += 1

        v = pt._normalize_strokes(pt.x)
        PARAMS = np.concatenate([PARAMS, np.reshape(v, [1, -1, pt.rderr.d])], axis=1)
        CANVAS_tmp = pt._render(PARAMS)[-1]
        CANVAS_tmp = utils.img2patches(CANVAS_tmp, pt.m_grid + 1, to_tensor=True).to(device)

    pt._save_stroke_params(PARAMS)
    pt.final_rendered_images = pt._render(PARAMS)
    pt._save_rendered_images()

# レンダリング

In [None]:
  pt = ProgressivePainter(args=args)
  optimize_x(pt)

# 画像表示

In [None]:
# show picture
fig = plt.figure(figsize=(8,4))
plt.subplot(1,2,1)
plt.imshow(pt.img_), plt.title('input')
plt.subplot(1,2,2)
plt.imshow(pt.final_rendered_images[-1]), plt.title('generated')
plt.show()

In [None]:
# make animation
import matplotlib.animation as animation
from IPython.display import HTML

fig = plt.figure(figsize=(8,8))
plt.axis('off')
ims = [[plt.imshow(img, animated=True)] for img in pt.final_rendered_images[::10]]
ani = animation.ArtistAnimation(fig, ims, interval=100)

HTML(ani.to_jshtml())

In [None]:
# save animation
ani.save('anime.mp4', writer='ffmpeg')