In [5]:
import time
import os
import sys
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import torchvision
from torchvision import transforms
from PIL import Image
from collections import OrderedDict
from matplotlib.pyplot import imshow
import numpy as np
from scipy import interpolate
from scipy import misc
import glob
from shutil import copyfile

model_dir = os.getcwd() + '/Models/'
content_dir = './content/' 
style_dir =  './50styles/'


#vgg definition that conveniently let's you grab the outputs from any layer
class VGG(nn.Module):
    def __init__(self, pool='max'):
        super(VGG, self).__init__()
        #vgg modules
        self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        if pool == 'max':
            self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
        elif pool == 'avg':
            self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2)
            
    def forward(self, x, out_keys):
        out = {}
        out['r11'] = F.relu(self.conv1_1(x))
        out['r12'] = F.relu(self.conv1_2(out['r11']))
        out['p1'] = self.pool1(out['r12'])
        out['r21'] = F.relu(self.conv2_1(out['p1']))
        out['r22'] = F.relu(self.conv2_2(out['r21']))
        out['p2'] = self.pool2(out['r22'])
        out['r31'] = F.relu(self.conv3_1(out['p2']))
        out['r32'] = F.relu(self.conv3_2(out['r31']))
        out['r33'] = F.relu(self.conv3_3(out['r32']))
        out['r34'] = F.relu(self.conv3_4(out['r33']))
        out['p3'] = self.pool3(out['r34'])
        out['r41'] = F.relu(self.conv4_1(out['p3']))
        out['r42'] = F.relu(self.conv4_2(out['r41']))
        out['r43'] = F.relu(self.conv4_3(out['r42']))
        out['r44'] = F.relu(self.conv4_4(out['r43']))
        out['p4'] = self.pool4(out['r44'])
        out['r51'] = F.relu(self.conv5_1(out['p4']))
        out['r52'] = F.relu(self.conv5_2(out['r51']))
        out['r53'] = F.relu(self.conv5_3(out['r52']))
        out['r54'] = F.relu(self.conv5_4(out['r53']))
        out['p5'] = self.pool5(out['r54'])
        return [out[key] for key in out_keys]
    
    
    

class Cov_Mean(nn.Module):
    def forward(self, input):
        b,c,h,w = input.size()
        F = input.view(b, c, h*w)
        mean_ = F.mean( dim=2, keepdim=True).detach()
        mean = torch.cat(h*w*[mean_], 2)
        F = F-mean.detach()
        G = torch.bmm(F, F.transpose(1,2)) 
        G.div_(h*w)
        return G.squeeze(0).data, mean_.squeeze().data 



    
# pre and post processing for images
img_size = 512 
prep = transforms.Compose([transforms.Scale(img_size),
                           transforms.ToTensor(),
                           transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])]), #turn to BGR
                           transforms.Normalize(mean=[0.40760392, 0.45795686, 0.48501961], #subtract imagenet mean
                                                std=[1,1,1]),
                           transforms.Lambda(lambda x: x.mul_(255)),
                          ])
postpa = transforms.Compose([transforms.Lambda(lambda x: x.mul_(1./255)),
                           transforms.Normalize(mean=[-0.40760392, -0.45795686, -0.48501961], #add imagenet mean
                                                std=[1,1,1]),
                           transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])]), #turn to RGB
                           ])
postpb = transforms.Compose([transforms.ToPILImage()])
def postp(tensor): # to clip results in the range [0,1]
    t = postpa(tensor)
    t[t>1] = 1    
    t[t<0] = 0
    img = postpb(t)
    return img

#get network
vgg = VGG()
vgg.load_state_dict(torch.load(model_dir + 'vgg_conv.pth'))
for param in vgg.parameters():
    param.requires_grad = False
if torch.cuda.is_available():
    vgg.cuda()

references=[]
Reference_dir = './content/'
references = glob.glob(Reference_dir+"*.jpg")                      

Total_Covs=[0,0,0,0,0]
style_layers = ['r11','r21','r31','r41', 'r51'] 

for sample in references:    

    img_torch = prep(Image.open(sample)) 
    
    if torch.cuda.is_available():
        img_torch = Variable(img_torch.unsqueeze(0).cuda()) 
    else:
        img_torch = Variable(img_torch.unsqueeze(0))
        sample
    Covs_Means = [Cov_Mean()(A) for A in vgg(img_torch, style_layers[:])]
    Total_Covs= [x+y[0] for x,y in zip(Total_Covs,Covs_Means)] 



In [6]:
AVG_Covs = [x/len(references) for x in Total_Covs]

def PCA(A):
        U,S,V = torch.svd(A)
        return U,S,V 
    
PCA_basis = [PCA(data) for data in AVG_Covs]


In [7]:

Ks = [ 32,48,128,256,256]

file1 = open('sample.txt', 'r') 
file2= open('EValue%s.txt'%Ks, 'w')
columns = ["style", "content","weight","E1","E2","E3","E4","E5","\n"]
name = '\t'.join(columns) 
file2.write(name)


for line in file1.readlines()[:]:
    
    filename  = line[:-1] 
    sp =line[:].split('_')
    style =int(sp[0][5:]) 
    content = int(sp[1][7:])
    print(style,content, filename)
 
    source_dir = './sample/'
    img_dirs = [style_dir, content_dir, source_dir]
    img_names = ['styles - %s.jpg'%style, '%s.jpg'%content, filename]
    imgs = [Image.open(img_dirs[i] + name) for i,name in enumerate(img_names)]
    imgs_torch = [prep(img) for img in imgs]
    if torch.cuda.is_available():
        imgs_torch = [Variable(img.unsqueeze(0).cuda()) for img in imgs_torch]
    else:
        imgs_torch = [Variable(img.unsqueeze(0)) for img in imgs_torch]
    style_image, content_image, syn_image= imgs_torch


    style_layers = ['r11','r21','r31','r41', 'r51'] 
    content_layers = ['r42']  

    style_targets = [Cov_Mean()(A) for A in vgg(style_image, style_layers[:])]
    syn_results = [Cov_Mean()(A) for A in vgg(syn_image, style_layers[:])]


    def PCA_Proj(A,P,k):
        return torch.mm(torch.mm(P[0][:,:k].t(),A[0]),P[2][:,:k]), torch.mm( A[1].unsqueeze(0), P[0][:,:k] ) 
    
    PCA_targets = [PCA_Proj(data,P ,k) for data,P,k in zip(style_targets,PCA_basis,Ks)]
    PCA_syn_results = [PCA_Proj(data,P ,k) for data,P,k in zip(syn_results,PCA_basis,Ks)]     
  
    def Det(A,B):
        _,S,_ = torch.svd(A)
        _,S1,_ = torch.svd(B)
        temp =torch.log(S1/S)
        u=0
        for a in temp:
            u +=a
        return u
    
    LogDet_AoverB = [ Det(syn[0],tar[0]) for syn,tar,k in zip(PCA_syn_results,PCA_targets,Ks)]

    KLs = []

    KL_parts  = [ (torch.trace(torch.mm( y[0].inverse(), x[0])), torch.mm( torch.mm((y[1] -x[1]),  y[0].inverse()), (y[1]-x[1]).t() ).squeeze()[0] ,-k, logD) for x,y,logD, k in zip(PCA_syn_results,PCA_targets,LogDet_AoverB,Ks)]

    KLs.append(np.sum(x) for x in KL_parts )  # np.sum(x) gives the 2 KL divergence
    KL_array = np.array(KLs)
    Es = [ str(-np.log(x)+ np.log(2)) for x in KLs[0]] # E value is -log(KL)
    new = sp[:3]+Es +["\n"]  
 
    name = '\t'.join(new) 
    file2.write(name)
file1.close()
file2.close()




6 344010 style6_content344010_weight11.0_weight100.0_iteration799.png
6 344010 style6_content344010_weight11.0_weight150.0_iteration799.png
10 141012 style10_content141012_weight11.0_weight0.1_iteration699.png
10 141012 style10_content141012_weight11.0_weight0.5_iteration799.png


In [13]:
file1 = open('sample.txt', 'r') 

for line in file1.readlines()[:]:
    sp =line[:].split('_')
    style =int(sp[0][5:]) 

    content = int(sp[1][7:])
    print(style,content)

6 344010
6 344010
10 141012
10 141012
