<a href="https://colab.research.google.com/github/vijishmadhavan/Chehara-GAN/blob/master/%E0%A4%9A%E0%A5%87%E0%A4%B9%E0%A4%B0%E0%A4%BE_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**चेहरा-GAN**

In [None]:
#@title Install requirements
%%capture

!git clone https://github.com/vijishmadhavan/Chehara-GAN.git Chehara-GAN
%cd Chehara-GAN
!pip install -r requirements.txt

In [None]:
#@title Run only once!

from fastai.vision import *
from fastai.callbacks import *
from fastai.vision.gan import *
from fastai.utils.mem import *
import dlib
import os
import cv2
import numpy as np 
from skimage import transform as trans
import urllib.request




def get_points(img, detector, shape_predictor, size_threshold=999):
    dets = detector(img, 1)
    if len(dets) == 0:
        return None
    
    all_points = []
    for det in dets:
        if isinstance(detector, dlib.cnn_face_detection_model_v1):
            rec = det.rect # for cnn detector
        else:
            rec = det
        if rec.width() > size_threshold or rec.height() > size_threshold: 
            break
        shape = shape_predictor(img, rec) 
        single_points = []
        for i in range(5):
            single_points.append([shape.part(i).x, shape.part(i).y])
        all_points.append(np.array(single_points))
    if len(all_points) <= 0:
        return None
    else:
        return all_points

def align_and_save(img, save_path, src_points, template_path, template_scale=1):
    out_size = (512, 512)
    reference = np.load(template_path) / template_scale

    ext = os.path.splitext(save_path)
    for idx, spoint in enumerate(src_points):
        tform = trans.SimilarityTransform()
        tform.estimate(spoint, reference)
        M = tform.params[0:2,:]

        crop_img = cv2.warpAffine(img, M, out_size)
        if len(src_points) > 1:
            save_path = ext[0] + '_{}'.format(idx) + ext[1]
        dlib.save_image(crop_img.astype(np.uint8), save_path)
        print('Saving image', save_path)

def align_and_save_dir(src_dir, save_dir, template_path='/content/Chehara-GAN/pretrain_models/FFHQ_template.npy', template_scale=2, use_cnn_detector=True):
    out_size = (512, 512)    
    if use_cnn_detector:
        detector = dlib.cnn_face_detection_model_v1('/content/Chehara-GAN/pretrain_models/mmod_human_face_detector.dat')
    else:
        detector = dlib.get_frontal_face_detector()
    sp = dlib.shape_predictor('/content/Chehara-GAN/pretrain_models/shape_predictor_5_face_landmarks.dat')

    for name in os.listdir(src_dir):
        img_path = os.path.join(src_dir, name)
        img = dlib.load_rgb_image(img_path)

        points = get_points(img, detector, sp)
        if points is not None:
            save_path = os.path.join(save_dir, name)
            align_and_save(img, save_path, points, template_path, template_scale)
        else:
            print('No face detected in', img_path)

class FeatureLoss(nn.Module):
    def __init__(self, m_feat, layer_ids, layer_wgts):
        super().__init__()
        self.m_feat = m_feat
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts
        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
              ] + [f'gram_{i}' for i in range(len(layer_ids))]

    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]
    
    def forward(self, input, target):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(input)
        self.feat_losses = [base_loss(input,target)]
        self.feat_losses += [base_loss(f_in, f_out)*w
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)
    
    def __del__(self): self.hooks.remove()


MODEL_URL = "https://www.dropbox.com/s/iiqvfu58as8unz1/p6500.pkl?dl=1"
urllib.request.urlretrieve(MODEL_URL, "p6500.pkl")
path = Path(".")
learn=load_learner(path, 'p6500.pkl')


#align_and_save_dir("/content/Chehara-GAN/input", "/content/Chehara-GAN/cropped", template_path='/content/Chehara-GAN/pretrain_models/FFHQ_template.npy', template_scale=2, use_cnn_detector=False)



Put images inside input folder and run the below Main-code,find enhanced images inside output folder.

Note: Might not work great with all images and poses.

# Main-code

In [None]:
rm -rf `find -type d -name .ipynb_checkpoints`


In [None]:
#@title

align_and_save_dir("/content/Chehara-GAN/input", "/content/Chehara-GAN/cropped", template_path='/content/Chehara-GAN/pretrain_models/FFHQ_template.npy', template_scale=2, use_cnn_detector=False)

f = '/content/Chehara-GAN/cropped/'
from PIL import Image

for images in os.listdir(f):
  img_path = os.path.join(f, images)
  img_fast = open_image(img_path)
  p,img_hr,b = learn.predict(img_fast)
  x = np.minimum(np.maximum(image2np(img_hr.data*255), 0), 255).astype(np.uint8)
  PIL_image = Image.fromarray(np.uint8(x)).convert('RGB')
  size = img_fast.size
  im1 = PIL_image.resize(size)
  im1.save('/content/Chehara-GAN/output/' + str(images), 'JPEG')



Saving image /content/Chehara-GAN/cropped/1_0.jpg
Saving image /content/Chehara-GAN/cropped/1_1.jpg
Saving image /content/Chehara-GAN/cropped/1_2.jpg
Saving image /content/Chehara-GAN/cropped/1_3.jpg
Saving image /content/Chehara-GAN/cropped/1_4.jpg
Saving image /content/Chehara-GAN/cropped/1_5.jpg
Saving image /content/Chehara-GAN/cropped/1_6.jpg
Saving image /content/Chehara-GAN/cropped/1_7.jpg
Saving image /content/Chehara-GAN/cropped/1_8.jpg
Saving image /content/Chehara-GAN/cropped/1_9.jpg
Saving image /content/Chehara-GAN/cropped/1_10.jpg
