<a href="https://colab.research.google.com/github/stalgiag/Waifu2x/blob/master/Waifu2x.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [73]:
!python --version

Python 3.6.7


Check for CUDA

In [0]:
!nvcc --version

Install Pytorch

In [0]:
# different torch install needed
# !pip3 uninstall torch
!pip3 install -U https://download.pytorch.org/whl/cu100/torch-1.0.0-cp36-cp36m-linux_x86_64.whl

Check Pytorch Install Works with Cuda

In [0]:
!python -c "import torch; print(torch.__version__); print(torch.cuda.is_available())"

In [0]:
# we will verify that GPU is enabled for this notebook
# following should print: CUDA is available!  Training on GPU ...
# 
# if it prints otherwise, then you need to enable GPU: 
# from Menu > Runtime > Change Runtime Type > Hardware Accelerator > GPU

import torch
import numpy as np

# check if CUDA is available
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')

In [0]:
# check that gcc is installed
# should print version number along bottom
!gcc -v

In [0]:
#get gpu type
!nvidia-smi

Clone my modified fork

In [0]:
#clone repo - my fork has a few simple mods that make this work with colab
!git clone https://github.com/stalgiag/Waifu2x.git

Setup

In [0]:
%%shell
cd Waifu2x
unzip model_check_points/CRAN_V2/CRAN_V2_02_28_2019.zip -d model_check_points/CRAN_V2

In [0]:
%%shell
cd ~/../content/
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .

In [21]:
%cd ~/../content/Waifu2x

/content/Waifu2x


In [0]:
##test to see what kind of changes we have
from utils.prepare_images import *
from Models import *
from torchvision.utils import save_image
model_cran_v2 = CARN_V2(color_channels=3, mid_channels=64, conv=nn.Conv2d,
                        single_conv_size=3, single_conv_group=1,
                        scale=2, activation=nn.LeakyReLU(0.1),
                        SEBlock=True, repeat_blocks=3, atrous=(1, 1, 1))
                        
model_cran_v2 = network_to_half(model_cran_v2)
checkpoint = "model_check_points/CRAN_V2/CARN_model_checkpoint.pt"
model_cran_v2.load_state_dict(torch.load(checkpoint, 'cpu'))
# if use GPU, then comment out the next line so it can use fp16. 
model_cran_v2 = model_cran_v2.float() 

demo_img = "out.png"
img = load_single_image(demo_img,
                      up_scale=True,
                      up_scale_factor=2,
                      up_scale_method=Image.BILINEAR)
img_t = img[1]

img_b = Image.open(demo_img).convert("RGB")

img_splitter = ImageSplitter(seg_size=64, scale_factor=2, boarder_pad_size=3)
img_patches = img_splitter.split_img_tensor(img_b, scale_method=None, img_pad=0)
with torch.no_grad():
    out = [model_cran_v2(i) for i in img_patches]
img_upscale = img_splitter.merge_img_tensor(out)

final = torch.cat([img_t, img_upscale])
save_image(img_upscale, 'out.png')

In [0]:
!mkdir to_convert
!mkdir results
!mkdir intermediate

In [86]:
import os
from utils.prepare_images import *
from Models import *
from torchvision.utils import save_image
model_cran_v2 = CARN_V2(color_channels=3, mid_channels=64, conv=nn.Conv2d,
                        single_conv_size=3, single_conv_group=1,
                        scale=2, activation=nn.LeakyReLU(0.1),
                        SEBlock=True, repeat_blocks=3, atrous=(1, 1, 1))
                        
model_cran_v2 = network_to_half(model_cran_v2)
checkpoint = "model_check_points/CRAN_V2/CARN_model_checkpoint.pt"
model_cran_v2.load_state_dict(torch.load(checkpoint, 'cpu'))
# if use GPU, then comment out the next line so it can use fp16. 
model_cran_v2 = model_cran_v2.float()

# change this number to determine how many times the images are enlarged (ex: 2 = )
passes = 2

for x in range(0, passes):
  print("pass " + str(x+1))
  if x == 0:
    str_dir = "to_convert/"
  else:
    str_dir = "intermediate/"
  
  if x < passes - 1:
    str_res = "intermediate/"
  else:
    str_res = "results/"
    
  directory = os.fsencode(str_dir)

  for file in os.listdir(directory):
       filename = os.fsdecode(file)
       if filename.endswith(".png"): 
          demo_img = os.path.join(str_dir, filename)
          print(demo_img)
          img = load_single_image(demo_img,
                                up_scale=True,
                                up_scale_factor=2,
                                up_scale_method=Image.BILINEAR)
          img_t = img[1]

          img_b = Image.open(demo_img).convert("RGB")

          img_splitter = ImageSplitter(seg_size=64, scale_factor=2, boarder_pad_size=3)
          img_patches = img_splitter.split_img_tensor(img_b, scale_method=None, img_pad=0)
          with torch.no_grad():
              out = [model_cran_v2(i) for i in img_patches]
          img_upscale = img_splitter.merge_img_tensor(out)

          final = torch.cat([img_t, img_upscale])
          save_image(img_upscale, os.path.join(str_res, filename))
          print(" out = > " + str(os.path.join(str_res, filename)))
          continue

pass 1
to_convert/frame07.png
 out = > intermediate/frame07.png
to_convert/frame08.png
 out = > intermediate/frame08.png
pass 2
intermediate/frame07.png
 out = > results/frame07.png
intermediate/frame08.png
 out = > results/frame08.png
