In [None]:
import os
import sys
current = os.path.dirname(os.path.realpath("shift_zoom_copypaste.ipynb"))
parent = os.path.dirname(current)
sys.path.append(parent)
from transformers import CLIPProcessor, CLIPModel
import torch
import torchvision
from torchvision.models import resnet50
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import clip
from PIL import Image
import requests
import torch.hub
import time
import pickle
import math
import torch.nn.functional as F
from match_utils import matching, models, stats, nethook, loading, plotting, layers



In [None]:
device = torch.device('cuda:4')

### Load Models

In [None]:
from pytorch_pretrained_biggan import (BigGAN, one_hot_from_names, truncated_noise_sample,
                                       save_as_images, display_in_terminal)

gan = BigGAN.from_pretrained('biggan-deep-256').to(device)

for p in gan.parameters(): 
    p.data = p.data.float() 
    
gan_layers = []
for name, layer in gan.named_modules():
    if "conv" in name:
        gan_layers.append(name)
        


In [None]:
table, gan_stats, dino_stats = loading.load_stats("/home/amil/Rosetta/matches", device)

### Best Buddies

In [None]:
match_scores,_ = torch.max(table,1)

In [None]:
gan_matches = torch.argmax(table,1)
dino_matches = torch.argmax(table,0)

In [None]:
perfect_matches = []
perfect_match_scores = []
dino_perfect_matches = []
num_perfect_matches = 0 
for i in range(table.shape[0]):
    gan_match = gan_matches[i].item()
    dino_match = dino_matches[gan_match].item()
    if dino_match == i:
        #print(i)
        num_perfect_matches+=1
        perfect_matches.append(i)
        dino_perfect_matches.append(gan_match)
        perfect_match_scores.append(match_scores[i])
        
print(num_perfect_matches)

In [None]:
gan = nethook.InstrumentedModel(gan)
gan.retain_layers(gan_layers, detach = False)


In [None]:
for i, unit in enumerate(perfect_matches):
    perfect_matches[i] = layers.find_act(perfect_matches[i],all_gan_layers)

In [None]:
from scipy.stats import truncnorm
def truncate_noise(size, truncation):
    '''
    Function for creating truncated noise vectors: Given the dimensions (n_samples, z_dim)
    and truncation value, creates a tensor of that shape filled with random
    numbers from the truncated normal distribution.
    Parameters:
        n_samples: the number of samples to generate, a scalar
        z_dim: the dimension of the noise vector, a scalar
        truncation: the truncation value, a non-negative scalar
    '''
    
    truncated_noise = truncnorm.rvs(-1*truncation, truncation, size=size)
    
    return torch.Tensor(truncated_noise)

### Generate Image

In [None]:
z1 = truncate_noise((1,128), 1).to(device)

In [None]:
c = torch.zeros((1,1000)).to(device)
c[:, 207] = 1

from torch.autograd import Variable
z = Variable(z1.clone(), requires_grad=True)


In [None]:
def show_gan_im(gan_im):
    im = (gan_im+1)/2
    im = torch.permute(im[0],(1,2,0)).detach().cpu()
    plt.imshow(im)
    plt.show()

show_gan_im(gan(z,c,1))

In [None]:
def shift_activ(input, shift_w, shift_h):
    shifted = torch.nn.functional.pad(input[np.newaxis,:,:,:], pad=(shift_h, -shift_h, shift_w, -shift_w))
    return shifted[0]

### Collect GAN Activations

In [None]:
gan_activs0 = matching.store_activs(gan, gan_layers)
gan_perfect_activs = []
for idx in perfect_matches:
    gan_perfect_activs.append(gan_activs0[idx[0]][:,idx[1],:,:])

### Shift GAN Rosetta Neuron Activations

In [None]:
refs = []
for idx in perfect_matches:
    ref = gan_activs0[idx[0]][:,idx[1],:,:].clone().double().unsqueeze(0).detach()
    ref = shift_activ(ref, 0, int(0.25*ref.shape[2]))
    refs.append(ref)


### Optimize for Shift

In [None]:
num_steps=500
lr_rampdown_length = 0.25
lr_rampup_length = 0.05
initial_learning_rate = 0.001

In [None]:
optimizer = torch.optim.Adam([z], betas=(0.9, 0.999), lr=initial_learning_rate)  

In [None]:
all_images = []
for step in range(num_steps):
    # Learning rate schedule.
    t = step / num_steps
    lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
    lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
    lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
    lr = initial_learning_rate * lr_ramp
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # Synth images from opt_w.
    synth_images = gan(z,c,1)


    # track images
    synth_images = (synth_images + 1) * (255/2)
    synth_images_np = synth_images.clone().detach().permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
    all_images.append(synth_images_np)

    # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
    if synth_images.shape[2] > 256:
        synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')


    gan_activs1 = matching.store_activs(gan, gan_layers)
    #normalize all activations
    eps = 0.00001
    for i,_ in enumerate(gan_activs1):
        gan_activs1[i] = (gan_activs1[i]-gan_stats[i][0])/(gan_stats[i][1]+eps)

    gan_perfect_activs1 = []
    for idx in perfect_matches:
        gan_perfect_activs1.append(gan_activs1[idx[0]][:,idx[1],:,:])



    #pearson correlation
    a_loss = 0
    for i in range(len(perfect_matches)):
        map_size = gan_perfect_activs1[i].shape[1] #max((gan_perfect_activs[i].shape[1], ref.shape[1]))
        gan_activ_new = torch.nn.Upsample(size=(map_size,map_size), mode='bilinear')(gan_perfect_activs1[i].unsqueeze(0)).double()
        prod = torch.einsum('aixy,ajxy->ij', gan_activ_new, refs[i])
        div1 = torch.sum(gan_activ_new**2)
        div2 = torch.sum(ref**2)
        corr = prod/torch.sqrt(div1*div2)
        a_loss += corr


    a_loss *= -1 
    l_reg = torch.mean((z - z1)**2)
    # Features for synth images.
    coeff = 0.5 #10
    loss = a_loss #+ coeff * l_reg
    # Step
    optimizer.zero_grad(set_to_none=True)

    loss.backward()
    optimizer.step()
    msg  = f'[ step {step+1:>4d}/{num_steps}] '
    msg += f'[ a_loss: {float(a_loss):5.2f} loss_reg: {coeff * float(l_reg):5.2f}] '
    print(msg)
    if step % 10 == 0:
        plt.imshow(synth_images_np)
        plt.show()


### Shift Other Way and Optimize

In [None]:
refs2 = []
for idx in perfect_matches:
    ref = gan_activs0[idx[0]][:,idx[1],:,:].clone().double().unsqueeze(0).detach()
    ref = shift_activ(ref, 0, -int(0.25*ref.shape[2]))
    refs2.append(ref)


In [None]:
z = Variable(z1.clone(), requires_grad=True)
optimizer = torch.optim.Adam([z], betas=(0.9, 0.999), lr=initial_learning_rate)

In [None]:
all_images = []
for step in range(num_steps):
    # Learning rate schedule.
    t = step / num_steps
    lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
    lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
    lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
    lr = initial_learning_rate * lr_ramp
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # Synth images from opt_w.
    synth_images = gan(z,c,1)


    # track images
    synth_images = (synth_images + 1) * (255/2)
    synth_images_np = synth_images.clone().detach().permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
    all_images.append(synth_images_np)

    # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
    if synth_images.shape[2] > 256:
        synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')


    gan_activs2 = matching.store_activs(gan, gan_layers)
    #normalize all activations
    eps = 0.00001
    for i,_ in enumerate(gan_activs2):
        gan_activs2[i] = (gan_activs2[i]-gan_stats[i][0])/(gan_stats[i][1]+eps)

    gan_perfect_activs2 = []
    for idx in perfect_matches:
        gan_perfect_activs2.append(gan_activs2[idx[0]][:,idx[1],:,:])



    #pearson correlation
    a_loss = 0
    for i in range(len(perfect_matches)):
        map_size = gan_perfect_activs2[i].shape[1] #max((gan_perfect_activs[i].shape[1], ref.shape[1]))
        gan_activ_new = torch.nn.Upsample(size=(map_size,map_size), mode='bilinear')(gan_perfect_activs2[i].unsqueeze(0)).double()
        prod = torch.einsum('aixy,ajxy->ij', gan_activ_new, refs2[i])
        div1 = torch.sum(gan_activ_new**2)
        div2 = torch.sum(ref**2)
        corr = prod/torch.sqrt(div1*div2)
        a_loss += corr


    a_loss *= -1 
    l_reg = torch.mean((z - z1)**2)
    # Features for synth images.
    coeff = 0.5 #10
    loss = a_loss #+ coeff * l_reg
    # Step
    optimizer.zero_grad(set_to_none=True)

    loss.backward()
    optimizer.step()
    msg  = f'[ step {step+1:>4d}/{num_steps}] '
    msg += f'[ a_loss: {float(a_loss):5.2f} loss_reg: {coeff * float(l_reg):5.2f}] '
    print(msg)
    if step % 10 == 0:
        plt.imshow(synth_images_np)
        plt.show()


### Merge Activations to Copy and Paste

In [None]:
refs3 = []
for r, l in zip(gan_perfect_activs1, gan_perfect_activs2):
    x = torch.zeros_like(l)#*torch.min(l).detach()
    x[:,:,:2*int(x.shape[-1])//4] = l[:,:,:2*int(x.shape[-1])//4].detach().clone()
    x[:,:,2*int(x.shape[-1])//4:] = r[:,:,2*int(x.shape[-1])//4:].detach().clone()
    refs3.append(x)

In [None]:
z = Variable(truncate_noise((1,128), 1).to(device), requires_grad=True)
initial_learning_rate = 0.01
optimizer = torch.optim.Adam([z], betas=(0.9, 0.999), lr=initial_learning_rate)
c = torch.zeros((1,1000)).to(device)
c[:, 207] = 1

In [None]:
all_images = []
num_steps =1000
for step in range(num_steps):
    # Learning rate schedule.
    t = step / num_steps
    lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
    lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
    lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
    lr = initial_learning_rate * lr_ramp
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # Synth images from opt_w.
    synth_images = gan(z,c,1)


    # track images
    synth_images = (synth_images + 1) * (255/2)
    synth_images_np = synth_images.clone().detach().permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
    all_images.append(synth_images_np)

    # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
    if synth_images.shape[2] > 256:
        synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')


    gan_activs3 = matching.store_activs(gan, gan_layers)
    #normalize all activations
    eps = 0.00001
    for i,_ in enumerate(gan_activs3):
        gan_activs3[i] = (gan_activs3[i]-gan_stats[i][0])/(gan_stats[i][1]+eps)

    gan_perfect_activs3 = []
    for idx in perfect_matches:
        gan_perfect_activs3.append(gan_activs3[idx[0]][:,idx[1],:,:])



    #pearson correlation
    a_loss = 0
    for i in range(len(perfect_matches)):
        map_size = gan_perfect_activs3[i].shape[1] #max((gan_perfect_activs[i].shape[1], ref.shape[1]))
        gan_activ_new = torch.nn.Upsample(size=(map_size,map_size), mode='bilinear')(gan_perfect_activs3[i].unsqueeze(0)).double()
        prod = torch.einsum('aixy,ajxy->ij', gan_activ_new, refs3[i].unsqueeze(0))
        div1 = torch.sum(gan_activ_new**2)
        div2 = torch.sum(ref**2)
        corr = prod/torch.sqrt(div1*div2)
        a_loss += corr


    a_loss *= -1 
    l_reg = torch.mean((z - z1)**2)
    # Features for synth images.
    coeff = 0.5 #10
    loss = a_loss# + coeff * l_reg
    # Step
    optimizer.zero_grad(set_to_none=True)

    loss.backward()
    optimizer.step()
    msg  = f'[ step {step+1:>4d}/{num_steps}] '
    msg += f'[ a_loss: {float(a_loss):5.2f} loss_reg: {coeff * float(l_reg):5.2f}] '
    print(msg)
    if step % 10 == 0:
        plt.imshow(synth_images_np)
        plt.show()


### Zoom

In [None]:
def zoom(input, scale):
    activ_res = input.shape[-1]
    zoomed = F.interpolate(input[np.newaxis,:,:,:], scale_factor=scale)
    tmp_res = zoomed.shape[-1]
    pad = (tmp_res - activ_res) // 2
    zoomed = zoomed[:, :, pad:pad+activ_res, pad:pad+activ_res]
    return zoomed[0]

In [None]:
refs4 = []
for idx in perfect_matches:
    ref = gan_activs0[idx[0]][:,idx[1],:,:].clone().double().detach()
    ref = zoom(ref,2)
    refs4.append(ref)


In [None]:
initial_learning_rate = 0.001
z = Variable(z1.clone(), requires_grad=True)
optimizer = torch.optim.Adam([z], betas=(0.9, 0.999), lr=initial_learning_rate)

In [None]:
all_images = []
num_steps =500
for step in range(num_steps):
    # Learning rate schedule.
    t = step / num_steps
    lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
    lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
    lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
    lr = initial_learning_rate * lr_ramp
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # Synth images from opt_w.
    synth_images = gan(z,c,1)


    # track images
    synth_images = (synth_images + 1) * (255/2)
    synth_images_np = synth_images.clone().detach().permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
    all_images.append(synth_images_np)

    # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
    if synth_images.shape[2] > 256:
        synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')


    gan_activs4 = matching.store_activs(gan, gan_layers)
    #normalize all activations
    eps = 0.00001
    for i,_ in enumerate(gan_activs4):
        gan_activs4[i] = (gan_activs4[i]-gan_stats[i][0])/(gan_stats[i][1]+eps)

    gan_perfect_activs4 = []
    for idx in perfect_matches:
        gan_perfect_activs4.append(gan_activs4[idx[0]][:,idx[1],:,:])



    #pearson correlation
    a_loss = 0
    for i in range(len(perfect_matches)):
        map_size = gan_perfect_activs4[i].shape[1] #max((gan_perfect_activs[i].shape[1], ref.shape[1]))
        gan_activ_new = torch.nn.Upsample(size=(map_size,map_size), mode='bilinear')(gan_perfect_activs4[i].unsqueeze(0)).double()
        prod = torch.einsum('aixy,ajxy->ij', gan_activ_new, refs4[i].unsqueeze(0))
        div1 = torch.sum(gan_activ_new**2)
        div2 = torch.sum(ref**2)
        corr = prod/torch.sqrt(div1*div2)
        a_loss += corr


    a_loss *= -1 
    l_reg = torch.mean((z - z1)**2)
    # Features for synth images.
    coeff = 10 #10
    loss = a_loss + coeff * l_reg
    # Step
    optimizer.zero_grad(set_to_none=True)

    loss.backward()
    optimizer.step()
    msg  = f'[ step {step+1:>4d}/{num_steps}] '
    msg += f'[ a_loss: {float(a_loss):5.2f} loss_reg: {coeff * float(l_reg):5.2f}] '
    print(msg)
    if step % 10 == 0:
        plt.imshow(synth_images_np)
        plt.show()
