# AR Punks
Using an Auto Encoder to generate new punks from old

In [2]:
import sys
sys.path.append('..')
sys.path.append('/home/tnn1t1s/art/cpunks-10k')

import numpy as np
import pandas as pd
import pickle
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import matplotlib.animation as animation

%matplotlib inline
plt.style.use('default')
from matplotlib.colors import rgb2hex
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

import cpunks.cpunks10k as cpunks10k
import cpunks.utils as cputils

import os
#os.environ["CUDA_VISIBLE_DEVICES"]="1"    
import tensorflow as tf

In [4]:
from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, BatchNormalization, LeakyReLU, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam

In [5]:
cp = cpunks10k.cpunks10k()
(X_train, Y_train), (X_test, Y_test), (labels) = cp.load_data()
X = np.concatenate((X_train, X_test), axis=0)
df = cp.punks_df

### meta parameters

In [6]:
class ContextManager(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


In [7]:
def r_loss(y_true, y_pred):
            return K.mean(K.square(y_true - y_pred), axis = [1,2,3])
ctx = ContextManager({})
ctx.r_loss = r_loss
ctx.learning_rate = 0.0005
ctx.batch_size = 32
ctx.initial_epoch = 0
ctx.input_dim = (24, 24, 4)
ctx.encoder_conv_filters = [32, 64, 64, 64]
ctx.encoder_conv_kernel_size = [3,3,3,3]
ctx.encoder_conv_strides = [1,2,2,1]
ctx.decoder_conv_t_filters = [64,64,32,4]
ctx.decoder_conv_t_kernel_size = [3,3,3,3]
ctx.decoder_conv_t_strides = [1,2,2,1]
ctx.z_dim = 4
ctx.n_layers_encoder = len(ctx.encoder_conv_filters)
ctx.n_layers_decoder = len(ctx.decoder_conv_t_filters)


In [9]:
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
for device in gpu_devices:
    tf.config.experimental.set_memory_growth(device, True)

In [11]:
encoder_input = Input(shape=ctx.input_dim, 
                      name='encoder_input')

x = encoder_input

for i in range(ctx.n_layers_encoder):
    conv_layer = Conv2D(filters = ctx.encoder_conv_filters[i],
                        kernel_size = ctx.encoder_conv_kernel_size[i],
                        strides = ctx.encoder_conv_strides[i],
                        padding = 'same',
                        name = 'encoder_conv_' + str(i))
    x = conv_layer(x)
    x = LeakyReLU()(x)
    
shape_before_flattening = K.int_shape(x)[1:]
x = Flatten()(x)
encoder_output= Dense(ctx.z_dim, name='encoder_output')(x)
encoder = Model(encoder_input, encoder_output)


### The Decoder

In [12]:
decoder_input = Input(shape=(ctx.z_dim,), name='decoder_input')

x = Dense(np.prod(shape_before_flattening))(decoder_input)
x = Reshape(shape_before_flattening)(x)

for i in range(ctx.n_layers_decoder):
    conv_t_layer = Conv2DTranspose(filters = ctx.decoder_conv_t_filters[i],
                                   kernel_size = ctx.decoder_conv_t_kernel_size[i],
                                   strides = ctx.decoder_conv_t_strides[i],
                                   padding = 'same',
                                   name = 'decoder_conv_t_' + str(i))
    x = conv_t_layer(x)
    if i < ctx.n_layers_decoder - 1:
        x = LeakyReLU()(x)
        #x = BatchNormalization()(x)
        #x = Dropout(rate = 0.25)(x)
    else:
        x = Activation('sigmoid')(x)

decoder_output = x
decoder = Model(decoder_input, decoder_output)


### Combine to Build the Autoencoder

In [13]:
model_input = encoder_input
model_output = decoder(encoder_output)

model = Model(model_input, model_output)


### Compile

In [14]:
def r_loss(y_true, y_pred):
            return K.mean(K.square(y_true - y_pred), axis = [1,2,3])

In [15]:
optimizer = Adam(learning_rate=ctx.learning_rate)
model.compile(optimizer=optimizer, loss = r_loss)

### Train

In [16]:
from tensorflow.keras.callbacks import Callback, LearningRateScheduler

In [17]:
def step_decay_schedule(initial_lr, decay_factor=0.5, step_size=1):
    '''
    Wrapper function to create a LearningRateScheduler with step decay schedule.
    https://arxiv.org/abs/1908.01878
    '''
    def schedule(epoch):
        new_lr = initial_lr * (decay_factor ** np.floor(epoch/step_size))

        return new_lr

    return LearningRateScheduler(schedule)




In [33]:
epochs = 500

initial_epoch = 0
lr_decay = 1

lr_sched = step_decay_schedule(initial_lr=ctx.learning_rate,
                               decay_factor=lr_decay,
                               step_size=1)

callbacks_list = [lr_sched]

model.fit(X_train,
          X_train,
          batch_size = ctx.batch_size,
          shuffle = True,
          epochs = epochs,
          initial_epoch = initial_epoch,
          callbacks = callbacks_list)

Epoch 1/500
Epoch 2/500
Epoch 3/500
Epoch 4/500
Epoch 5/500
Epoch 6/500
Epoch 7/500
Epoch 8/500
Epoch 9/500
Epoch 10/500
Epoch 11/500
Epoch 12/500
Epoch 13/500
Epoch 14/500
Epoch 15/500
Epoch 16/500
Epoch 17/500
Epoch 18/500
Epoch 19/500
Epoch 20/500
Epoch 21/500
Epoch 22/500
Epoch 23/500
Epoch 24/500
Epoch 25/500
Epoch 26/500
Epoch 27/500
Epoch 28/500
Epoch 29/500
Epoch 30/500
Epoch 31/500
Epoch 32/500
Epoch 33/500
Epoch 34/500
Epoch 35/500
Epoch 36/500
Epoch 37/500
Epoch 38/500
Epoch 39/500
Epoch 40/500
Epoch 41/500
Epoch 42/500
Epoch 43/500
Epoch 44/500
Epoch 45/500
Epoch 46/500
Epoch 47/500
Epoch 48/500
Epoch 49/500
Epoch 50/500
Epoch 51/500
Epoch 52/500
Epoch 53/500
Epoch 54/500
Epoch 55/500
Epoch 56/500
Epoch 57/500
Epoch 58/500
Epoch 59/500
Epoch 60/500
Epoch 61/500
Epoch 62/500
Epoch 63/500
Epoch 64/500
Epoch 65/500
Epoch 66/500
Epoch 67/500
Epoch 68/500
Epoch 69/500
Epoch 70/500
Epoch 71/500
Epoch 72/500
Epoch 73/500
Epoch 74/500
Epoch 75/500
Epoch 76/500
Epoch 77/500
Epoch 78

Epoch 83/500
Epoch 84/500
Epoch 85/500
Epoch 86/500
Epoch 87/500
Epoch 88/500
Epoch 89/500
Epoch 90/500
Epoch 91/500
Epoch 92/500
Epoch 93/500
Epoch 94/500
Epoch 95/500
Epoch 96/500
Epoch 97/500
Epoch 98/500
Epoch 99/500
Epoch 100/500
Epoch 101/500
Epoch 102/500
Epoch 103/500
Epoch 104/500
Epoch 105/500
Epoch 106/500
Epoch 107/500
Epoch 108/500
Epoch 109/500
Epoch 110/500
Epoch 111/500
Epoch 112/500
Epoch 113/500
Epoch 114/500
Epoch 115/500
Epoch 116/500
Epoch 117/500
Epoch 118/500
Epoch 119/500
Epoch 120/500
Epoch 121/500
Epoch 122/500
Epoch 123/500
Epoch 124/500
Epoch 125/500
Epoch 126/500
Epoch 127/500
Epoch 128/500
Epoch 129/500
Epoch 130/500
Epoch 131/500
Epoch 132/500
Epoch 133/500
Epoch 134/500
Epoch 135/500
Epoch 136/500
Epoch 137/500
Epoch 138/500
Epoch 139/500
Epoch 140/500
Epoch 141/500
Epoch 142/500
Epoch 143/500
Epoch 144/500
Epoch 145/500
Epoch 146/500
Epoch 147/500
Epoch 148/500
Epoch 149/500
Epoch 150/500
Epoch 151/500
Epoch 152/500
Epoch 153/500
Epoch 154/500
Epoch 155

Epoch 163/500
Epoch 164/500
Epoch 165/500
Epoch 166/500
Epoch 167/500
Epoch 168/500
Epoch 169/500
Epoch 170/500
Epoch 171/500
Epoch 172/500
Epoch 173/500
Epoch 174/500
Epoch 175/500
Epoch 176/500
Epoch 177/500
Epoch 178/500
Epoch 179/500
Epoch 180/500
Epoch 181/500
Epoch 182/500
Epoch 183/500
Epoch 184/500
Epoch 185/500
Epoch 186/500
Epoch 187/500
Epoch 188/500
Epoch 189/500
Epoch 190/500
Epoch 191/500
Epoch 192/500
Epoch 193/500
Epoch 194/500
Epoch 195/500
Epoch 196/500
Epoch 197/500
Epoch 198/500
Epoch 199/500
Epoch 200/500
Epoch 201/500
Epoch 202/500
Epoch 203/500
Epoch 204/500
Epoch 205/500
Epoch 206/500
Epoch 207/500
Epoch 208/500
Epoch 209/500
Epoch 210/500
Epoch 211/500
Epoch 212/500
Epoch 213/500
Epoch 214/500
Epoch 215/500
Epoch 216/500
Epoch 217/500
Epoch 218/500
Epoch 219/500
Epoch 220/500
Epoch 221/500
Epoch 222/500
Epoch 223/500
Epoch 224/500
Epoch 225/500
Epoch 226/500
Epoch 227/500
Epoch 228/500
Epoch 229/500
Epoch 230/500
Epoch 231/500
Epoch 232/500
Epoch 233/500
Epoch 

Epoch 243/500
Epoch 244/500
Epoch 245/500
Epoch 246/500
Epoch 247/500
Epoch 248/500
Epoch 249/500
Epoch 250/500
Epoch 251/500
Epoch 252/500
Epoch 253/500
Epoch 254/500
Epoch 255/500
Epoch 256/500
Epoch 257/500
Epoch 258/500
Epoch 259/500
Epoch 260/500
Epoch 261/500
Epoch 262/500
Epoch 263/500
Epoch 264/500
Epoch 265/500
Epoch 266/500
Epoch 267/500
Epoch 268/500
Epoch 269/500
Epoch 270/500
Epoch 271/500
Epoch 272/500
Epoch 273/500
Epoch 274/500
Epoch 275/500
Epoch 276/500
Epoch 277/500
Epoch 278/500
Epoch 279/500
Epoch 280/500
Epoch 281/500
Epoch 282/500
Epoch 283/500
Epoch 284/500
Epoch 285/500
Epoch 286/500
Epoch 287/500
Epoch 288/500
Epoch 289/500
Epoch 290/500
Epoch 291/500
Epoch 292/500
Epoch 293/500
Epoch 294/500
Epoch 295/500
Epoch 296/500
Epoch 297/500
Epoch 298/500
Epoch 299/500
Epoch 300/500
Epoch 301/500
Epoch 302/500
Epoch 303/500
Epoch 304/500
Epoch 305/500
Epoch 306/500
Epoch 307/500
Epoch 308/500
Epoch 309/500
Epoch 310/500
Epoch 311/500
Epoch 312/500
Epoch 313/500
Epoch 

Epoch 323/500
Epoch 324/500
Epoch 325/500
Epoch 326/500
Epoch 327/500
Epoch 328/500
Epoch 329/500
Epoch 330/500
Epoch 331/500
Epoch 332/500
Epoch 333/500
Epoch 334/500
Epoch 335/500
Epoch 336/500
Epoch 337/500
Epoch 338/500
Epoch 339/500
Epoch 340/500
Epoch 341/500
Epoch 342/500
Epoch 343/500
Epoch 344/500
Epoch 345/500
Epoch 346/500
Epoch 347/500
Epoch 348/500
Epoch 349/500
Epoch 350/500
Epoch 351/500
Epoch 352/500
Epoch 353/500
Epoch 354/500
Epoch 355/500
Epoch 356/500
Epoch 357/500
Epoch 358/500
Epoch 359/500
Epoch 360/500
Epoch 361/500
Epoch 362/500
Epoch 363/500
Epoch 364/500
Epoch 365/500
Epoch 366/500
Epoch 367/500
Epoch 368/500
Epoch 369/500
Epoch 370/500
Epoch 371/500
Epoch 372/500
Epoch 373/500
Epoch 374/500
Epoch 375/500
Epoch 376/500
Epoch 377/500
Epoch 378/500
Epoch 379/500
Epoch 380/500
Epoch 381/500
Epoch 382/500
Epoch 383/500
Epoch 384/500
Epoch 385/500
Epoch 386/500
Epoch 387/500
Epoch 388/500
Epoch 389/500
Epoch 390/500
Epoch 391/500
Epoch 392/500
Epoch 393/500
Epoch 

Epoch 403/500
Epoch 404/500
Epoch 405/500
Epoch 406/500
Epoch 407/500
Epoch 408/500
Epoch 409/500
Epoch 410/500
Epoch 411/500
Epoch 412/500
Epoch 413/500
Epoch 414/500
Epoch 415/500
Epoch 416/500
Epoch 417/500
Epoch 418/500
Epoch 419/500
Epoch 420/500
Epoch 421/500
Epoch 422/500
Epoch 423/500
Epoch 424/500
Epoch 425/500
Epoch 426/500
Epoch 427/500
Epoch 428/500
Epoch 429/500
Epoch 430/500
Epoch 431/500
Epoch 432/500
Epoch 433/500
Epoch 434/500
Epoch 435/500
Epoch 436/500
Epoch 437/500
Epoch 438/500
Epoch 439/500
Epoch 440/500
Epoch 441/500
Epoch 442/500
Epoch 443/500
Epoch 444/500
Epoch 445/500
Epoch 446/500
Epoch 447/500
Epoch 448/500
Epoch 449/500
Epoch 450/500
Epoch 451/500
Epoch 452/500
Epoch 453/500
Epoch 454/500
Epoch 455/500
Epoch 456/500
Epoch 457/500
Epoch 458/500
Epoch 459/500
Epoch 460/500
Epoch 461/500
Epoch 462/500
Epoch 463/500
Epoch 464/500
Epoch 465/500
Epoch 466/500
Epoch 467/500
Epoch 468/500
Epoch 469/500
Epoch 470/500
Epoch 471/500
Epoch 472/500
Epoch 473/500
Epoch 

Epoch 483/500
Epoch 484/500
Epoch 485/500
Epoch 486/500
Epoch 487/500
Epoch 488/500
Epoch 489/500
Epoch 490/500
Epoch 491/500
Epoch 492/500
Epoch 493/500
Epoch 494/500
Epoch 495/500
Epoch 496/500
Epoch 497/500
Epoch 498/500
Epoch 499/500
Epoch 500/500


<keras.callbacks.History at 0x7f68f01481f0>

In [36]:
n_to_show = 10
example_idx = np.random.choice(range(len(X_test)), n_to_show)
example_images = X_test[example_idx]

z_points = encoder.predict(example_images)

reconst_images = decoder.predict(z_points)

fig = plt.figure(figsize=(15, 3))
fig.subplots_adjust(hspace=0.4, wspace=0.4)

for i in range(n_to_show):
    img = example_images[i]#.squeeze()
    ax = fig.add_subplot(2, n_to_show, i+1)
    ax.axis('off')
    ax.text(0.5, -0.35, str(np.round(z_points[i],1)), fontsize=10, ha='center', transform=ax.transAxes)   
    ax.imshow(img)

for i in range(n_to_show):
    img = reconst_images[i]#.squeeze()
    ax = fig.add_subplot(2, n_to_show, i+n_to_show+1)
    ax.axis('off')
    ax.imshow(img)

In [37]:
#reconst_images = AE.decoder.predict(z_points)
z_points

array([[-6.983255  , -0.17334947,  5.9279175 ,  2.1353912 ],
       [-7.2512293 , -0.5766503 , -1.1700253 ,  1.338737  ],
       [ 6.048606  ,  2.0829217 ,  3.240875  ,  5.7680807 ],
       [-3.9236457 ,  4.792412  ,  5.8934507 , -1.4251614 ],
       [ 5.4448347 , -0.17101339, -5.1700573 ,  2.0893226 ],
       [ 1.1456046 ,  0.89082515, -2.3790147 , -0.49088526],
       [ 2.1583476 , -6.03071   , -0.20910612,  1.7462301 ],
       [ 4.409732  , -0.7090642 , -3.4749818 , -3.2978754 ],
       [-3.2824917 , -1.7504213 , -1.6584668 ,  7.558395  ],
       [-2.1981242 , -4.5862494 , -1.2531925 ,  5.118972  ]],
      dtype=float32)

In [38]:
z_s = np.array([[  7.350558  , -13.799404  ,  -2.4654112 ,  -6.299609  ]])
reconst_images = decoder.predict(z_s)
plt.imshow(reconst_images[0])

<matplotlib.image.AxesImage at 0x7f6910250a90>

### pick two punks and do the old face morph trick 

In [39]:
punk_x_img = example_images[0]
punk_y_img = example_images[8]
punk_x_coords = [  7.350558  , -13.799404  ,  -2.4654112 ,  -6.299609  ]
punk_y_coords = [-10.365253  , -11.024603  ,  16.105639  ,   3.2864196 ]

punk_x_coords = [ 10.241212  ,   8.368749  ,   4.796343  ,  -5.5694203 ]
punk_y_coords = [ 12.844572  ,   1.8600631 ,  10.4856    ,  -6.2996225 ]

punk_x_coords = [-3.9236457 ,  4.792412  ,  5.8934507 , -1.4251614 ]
punk_y_coords = [ 6.048606  ,  2.0829217 ,  3.240875  ,  5.7680807 ]

M=20
x = punk_x_coords
y = punk_y_coords
x = zip(np.linspace(x[0],y[0],M),
        np.linspace(x[1],y[1],M),
        np.linspace(x[2],y[2],M),
        np.linspace(x[3],y[3],M),
       )
z_s = np.array(list(x))
reconst_images = decoder.predict(z_s)

fig = plt.figure(figsize=(15, 3))
fig.subplots_adjust(hspace=0.4, wspace=0.4)
for i in range(M):
    img = reconst_images[i]
    ax = fig.add_subplot(2, n_to_show, i+1)
    ax.axis('off')
    ax.imshow(img)

In [40]:
imgs_reversed = np.array([reconst_images[i] for i in range(len(reconst_images) - 1, -1, -1)])
imgs = np.concatenate((reconst_images, imgs_reversed))

In [41]:
import matplotlib.animation as animation
%matplotlib qt

imagelist=imgs 
fig = plt.figure() # make figure

# make axesimage object
# the vmin and vmax here are very important to get the color map correct
im = plt.imshow(imagelist[0]) #, cmap=plt.get_cmap('jet'), vmin=0, vmax=255)

# function to update figure
def updatefig(j):
    # set the data in the axesimage object
    im.set_array(imagelist[j])
    # return the artists set
    return [im]
# kick off the animation
ani = animation.FuncAnimation(fig, updatefig, frames=range(40), 
                             interval=48)
plt.show()

In [29]:
writer = animation.PillowWriter(fps=25)  
ani.save("../tmp/demo.gif", writer=writer) 

## Reducing Dimensions and staying true to the colorpunx

In [123]:
import pickle

class CryptoPunksColorMap:
    def __init__(self):
        with open(f"../../cpunks-10k/data/_colors_count.pickle", 'rb') as f:
            self.colors_count = pickle.load(f)
        self.colors = list(self.colors_count.keys())
        self.color_d = {}
        i=0
        for k in self.colors:
            self.color_d[k] = i
            i+=1
        self.np_colors = [np.fromstring(color[1:-1], float, sep=' ') for color in self.colors]
            
    
    def flatten(self, img):
        return np.array([[self.colors_d[str(c)] for c in row]
                                    for row in img]).astype(np.uint8)
    
    def unflatten(self,img):
        return np.array([[np.fromstring(self.colors[c][1:-1], float, sep=' ') for c in row]
                                                                       for row in img])

cpcm = CryptoPunksColorMap()

The background in the generated punkx is close to but not exactly equal to the original cryptopunks. Here, we define a distance based measure to normalize the colors to the Cryptopunks color palette after image generation. 

In [140]:
def distance(x, y):
    '''return euclidian distance between x,y in any dimension'''
    return np.sqrt(np.sum(np.square(x - y)))

def nearest_neighbor(color, colors):
     '''given a `color` in RGBT, return nearest color in `colors`'''
     return colors[np.argmin([distance(k, color) for k in colors])]

nearest_neighbor(img[12][12], cpcm.np_colors)

array([0.85882354, 0.69411767, 0.5019608 , 1.        ])

In [141]:
normalized_img = [[nearest_neighbor(pixel, cpcm.np_colors) for pixel in row] for row in img]

In [142]:
plt.imshow(normalized_img)

<matplotlib.image.AxesImage at 0x7f68bb2a1700>