In [0]:
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import LSUN
import os

In [0]:
!pip install https://download.pytorch.org/whl/cu100/torch-1.1.0-cp36-cp36m-linux_x86_64.whl



In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/gdrive


In [0]:
from torch import nn
class VAE(nn.Module):
  
  #define layers
  def __init__(self):
    super(VAE, self).__init__()
    self.hidden_size = 256

    #Encoder layers
    self.enc_conv1 = nn.Conv2d(3, 128, 5, stride=2, padding=2) # rgb
    self.enc_bn1 = nn.BatchNorm2d(128)
    self.enc_conv2 = nn.Conv2d(128, 256, 5, stride=2, padding=2)
    self.enc_bn2 = nn.BatchNorm2d(256)
    self.enc_conv3 = nn.Conv2d(256, 512, 5, stride=2, padding=2)
    self.enc_bn3 = nn.BatchNorm2d(512)
    self.enc_conv4 = nn.Conv2d(512, 1024, 3, stride=2, padding=1)
    self.enc_bn4 = nn.BatchNorm2d(1024)
    self.enc_fc1 = nn.Linear(4*4*1024, self.hidden_size*2)
    self.enc_dropout1 = nn.Dropout(p=.7)

    #Cond encoder layers
    self.cond_enc_conv1 = nn.Conv2d(1, 128, 5, stride=2, padding=2)
    self.cond_enc_bn1 = nn.BatchNorm2d(128)
    self.cond_enc_conv2 = nn.Conv2d(128, 256, 5, stride=2, padding=2)
    self.cond_enc_bn2 = nn.BatchNorm2d(256)
    self.cond_enc_conv3 = nn.Conv2d(256, 512, 5, stride=2, padding=2)
    self.cond_enc_bn3 = nn.BatchNorm2d(512)
    self.cond_enc_conv4 = nn.Conv2d(512, 1024, 3, stride=2, padding=1)
    self.cond_enc_bn4 = nn.BatchNorm2d(1024)

    #Decoder layers
    self.dec_upsamp1 = nn.Upsample(scale_factor=4, mode='bilinear')
    self.dec_conv1 = nn.Conv2d(1024+self.hidden_size, 512, 3, stride=1, padding=1)
    self.dec_bn1 = nn.BatchNorm2d(512)
    self.dec_upsamp2 = nn.Upsample(scale_factor=2, mode='bilinear')
    self.dec_conv2 = nn.Conv2d(512*2, 256, 5, stride=1, padding=2)
    self.dec_bn2 = nn.BatchNorm2d(256)
    self.dec_upsamp3 = nn.Upsample(scale_factor=2, mode='bilinear')
    self.dec_conv3 = nn.Conv2d(256*2, 128, 5, stride=1, padding=2)
    self.dec_bn3 = nn.BatchNorm2d(128)
    self.dec_upsamp4 = nn.Upsample(scale_factor=2, mode='bilinear')
    self.dec_conv4 = nn.Conv2d(128*2, 64, 5, stride=1, padding=2)
    self.dec_bn4 = nn.BatchNorm2d(64)
    self.dec_upsamp5 = nn.Upsample(scale_factor=2, mode='bilinear')
    self.dec_conv5 = nn.Conv2d(64, 3, 5, stride=1, padding=2) #rgb

  def encoder(self, x):
    x = F.relu(self.enc_conv1(x))
    x = self.enc_bn1(x)
    x = F.relu(self.enc_conv2(x))
    x = self.enc_bn2(x)
    x = F.relu(self.enc_conv3(x))
    x = self.enc_bn3(x)
    x = F.relu(self.enc_conv4(x))
    x = self.enc_bn4(x)
    x = x.view(-1, 4*4*1024)
    
    x = self.enc_fc1(x)
    mu = x[..., :self.hidden_size]
    logvar = x[..., self.hidden_size:]
    return mu, logvar

  def cond_encoder(self, x):
    x = F.relu(self.cond_enc_conv1(x))
    sc_feat32 = self.cond_enc_bn1(x)
    x = F.relu(self.cond_enc_conv2(sc_feat32))
    sc_feat16 = self.cond_enc_bn2(x)
    x = F.relu(self.cond_enc_conv3(sc_feat16))
    sc_feat8 = self.cond_enc_bn3(x)
    x = F.relu(self.cond_enc_conv4(sc_feat8))
    sc_feat4 = self.cond_enc_bn4(x)
    return sc_feat32, sc_feat16, sc_feat8, sc_feat4

  def decoder(self, z, sc_feat32, sc_feat16, sc_feat8, sc_feat4):
    x = z.view(-1, self.hidden_size, 1, 1)
    x = self.dec_upsamp1(x)
    x = torch.cat([x, sc_feat4], 1)
    x = F.relu(self.dec_conv1(x))
    x = self.dec_bn1(x)
    x = self.dec_upsamp2(x) 
    x = torch.cat([x, sc_feat8], 1)
    x = F.relu(self.dec_conv2(x))
    x = self.dec_bn2(x)
    x = self.dec_upsamp3(x) 
    x = torch.cat([x, sc_feat16], 1)
    x = F.relu(self.dec_conv3(x))
    x = self.dec_bn3(x)
    x = self.dec_upsamp4(x) 
    x = torch.cat([x, sc_feat32], 1)
    x = F.relu(self.dec_conv4(x))
    x = self.dec_bn4(x)
    x = self.dec_upsamp5(x) 
    x = F.tanh(self.dec_conv5(x))
    return x
      
  #define forward pass
  def forward(self, color, greylevel):
    sc_feat32, sc_feat16, sc_feat8, sc_feat4 = self.cond_encoder(greylevel)
    mu, logvar = self.encoder(color)
    
    stddev = torch.sqrt(torch.exp(logvar))
    eps = Variable(torch.randn(stddev.size()).normal_()).cuda()
    z = torch.add(mu, torch.mul(eps, stddev))
    
  
    
    
    color_out = self.decoder(z, sc_feat32, sc_feat16, sc_feat8, sc_feat4)
    return mu, logvar, color_out

In [0]:
import torch
import torchvision
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import STL10
import os
import cv2
import numpy as np

if not os.path.exists('./gdrive/My Drive/Colab Notebooks/colorize'):
    os.mkdir('./gdrive/My Drive/Colab Notebooks/colorize')

    
def rgb2gray(img):
  temp=img.numpy().copy()
  
  out=np.empty((batch_size,1,64,64))
  for i in range(batch_size):
    a=temp[i]
    c=a.transpose((1,2,0))
    b=cv2.cvtColor(c,cv2.COLOR_RGB2GRAY)
    b=np.expand_dims(b, axis=2)
    
    out[i]=b.transpose((2,0,1))
  return torch.from_numpy(out)


num_epochs = 100
batch_size = 32
learning_rate = 1e-4

img_transform = transforms.Compose([transforms.Resize(64),
    transforms.ToTensor()
    
])

dataset = STL10('./data',  split='unlabeled',transform=img_transform,download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)



model = VAE()

from collections import OrderedDict

state_dict = torch.load('./gdrive/My Drive/Colab Notebooks/vae.pth')
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k
    #name = k[7:] # remove "module."
    new_state_dict[name] = v

model.load_state_dict(new_state_dict)

if torch.cuda.is_available():
    model.cuda()

reconstruction_function = nn.L1Loss(size_average=False)


def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    BCE = reconstruction_function(recon_x, x)  # mse loss
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return BCE + KLD


optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(dataloader):
        img, _ = data
        gray=rgb2gray(img)
        gray=gray.type(torch.FloatTensor)
       # print(gray.size())
        #print(gray)
        
        img = Variable(img)
        gray=Variable(gray)
        if torch.cuda.is_available():
            img = img.cuda()
            gray=gray.cuda()
        optimizer.zero_grad()
        mu, logvar, recon_batch  = model(img,gray)
        loss = loss_function(recon_batch, img, mu, logvar)
        loss.backward()
        train_loss += loss.data.item()
        optimizer.step()
        if batch_idx % 500 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch,
                batch_idx * len(img),
                len(dataloader.dataset), 100. * batch_idx / len(dataloader),
                loss.data.item() / len(img)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(dataloader.dataset)))
    if epoch % 5 == 0:
        save=recon_batch.cpu().data
        save_image(img, './gdrive/My Drive/Colab Notebooks/colorize/GT_{}.png'.format(epoch))
        save_image(save, './gdrive/My Drive/Colab Notebooks/colorize/output_{}.png'.format(epoch))

        torch.save(model.state_dict(), './gdrive/My Drive/Colab Notebooks/vae.pth')

Files already downloaded and verified


  "See the documentation of nn.Upsample for details.".format(mode))


====> Epoch: 0 Average loss: 254.5351
====> Epoch: 1 Average loss: 253.2933
====> Epoch: 2 Average loss: 252.2893
====> Epoch: 3 Average loss: 250.9335
====> Epoch: 4 Average loss: 250.6099
====> Epoch: 5 Average loss: 251.4131
====> Epoch: 6 Average loss: 249.9644
====> Epoch: 7 Average loss: 247.8969
====> Epoch: 8 Average loss: 247.3102
====> Epoch: 9 Average loss: 246.4976
====> Epoch: 10 Average loss: 246.5582
====> Epoch: 11 Average loss: 257.1457
====> Epoch: 12 Average loss: 253.8229
====> Epoch: 13 Average loss: 244.8605
====> Epoch: 14 Average loss: 244.0773
====> Epoch: 15 Average loss: 243.3219
====> Epoch: 16 Average loss: 243.3581
====> Epoch: 17 Average loss: 241.5270
====> Epoch: 18 Average loss: 242.1476
====> Epoch: 19 Average loss: 241.6843
====> Epoch: 20 Average loss: 239.4605
====> Epoch: 21 Average loss: 238.7823
====> Epoch: 22 Average loss: 238.7773
====> Epoch: 23 Average loss: 237.7556
====> Epoch: 24 Average loss: 237.4936
====> Epoch: 25 Average loss: 237.0

In [0]:
import torch #==inference code==================================================
import torchvision
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import STL10
import os
import cv2
import numpy as np

if not os.path.exists('./gdrive/My Drive/Colab Notebooks/colorize/test'):
    os.mkdir('./gdrive/My Drive/Colab Notebooks/colorize/test')

    
def rgb2gray(img):
  temp=img.numpy().copy()
  
  out=np.empty((batch_size,1,64,64))
  for i in range(batch_size):
    a=temp[i]
    c=a.transpose((1,2,0))
    b=cv2.cvtColor(c,cv2.COLOR_RGB2GRAY)
    b=np.expand_dims(b, axis=2)
    
    out[i]=b.transpose((2,0,1))
  return torch.from_numpy(out)


num_epochs = 5
batch_size = 32
learning_rate = 1e-4

img_transform = transforms.Compose([transforms.Resize(64),
    transforms.ToTensor()
    
])

dataset = STL10('./data',  split='train',transform=img_transform,download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)



model = VAE()

from collections import OrderedDict

state_dict = torch.load('./gdrive/My Drive/Colab Notebooks/vae.pth')
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k
    #name = k[7:] # remove "module."
    new_state_dict[name] = v

model.load_state_dict(new_state_dict)

if torch.cuda.is_available():
    model.cuda()




for epoch in range(num_epochs):
    model.eval()
    z=torch.randn(batch_size,256)
    rand=DataLoader(z, batch_size=batch_size, shuffle=True)
    for data,z in zip(dataloader,rand):
        img, _ = data
        gray=rgb2gray(img)
        gray=gray.type(torch.FloatTensor)
        
        img = Variable(img)
        gray=Variable(gray)
        z=Variable(z)
      
        if torch.cuda.is_available():
            img = img.cuda()
            gray=gray.cuda()
            z=z.cuda()
        
        sc_feat32, sc_feat16, sc_feat8, sc_feat4  = model.cond_encoder(gray)
        recon_batch=model.decoder(z, sc_feat32, sc_feat16, sc_feat8, sc_feat4)
        
    if epoch % 1 == 0:
        save=recon_batch.cpu().data
        if epoch==0:
          save_image(img, './gdrive/My Drive/Colab Notebooks/colorize/test/test_GT.png')
          save_image(gray, './gdrive/My Drive/Colab Notebooks/colorize/test/test_input.png')
        save_image(save, './gdrive/My Drive/Colab Notebooks/colorize/test/test_output_{}.png'.format(epoch))


Files already downloaded and verified


  "See the documentation of nn.Upsample for details.".format(mode))
