In [582]:
import numpy as np
import pandas as pd
import sys
import matplotlib.pyplot as plt

In [583]:
L = 10
stride = 2
W = [[[1, -1,0],[0,0,0],[0,0,0]], # 1 => 1.0
    [[0,1, -1],[0,0,0],[0,0,0]],
    [[1,0,0],[-1,0,0],[0,0,0]],
    [[0,1,0],[0,-1,0],[0,0,0]],
    [[0,0,0],[1,-1,0],[0,0,0]],
    [[0,0,0],[0,1,-1],[0,0,0]],
    [[0,0,0],[1,0,0],[-1,0,0]],
    [[0,0,0],[0,1,0],[0,-1,0]],
    [[0,0,0],[0,0,0],[0,0,0]],
    [[0,0,0],[0,0,0],[0,0,0]]]

bias_hidden = 0     # b
bias_visible = 0   # c

V = np.matrix([[-1 for _ in range(L)] for _ in range(L)])
H = [np.matrix([[0 for _ in range(L//2)] for _ in range(L//2)])  for _ in range(len(W)) ]

num_samples = 10000
samples = []
min_energy = float('inf')
min_energy_sample = V.copy()

In [584]:
def sample_with_prob(arr):
  random = np.random.uniform(-1,1, arr.shape)
  array = np.ones(arr.shape)
  array[np.where(arr<random)] = -1
  return np.array(array)

# Convolution convolution(W[i],V,L,stride)
def convolution_v_to_h(Wi,V,L,stride):
  res = [[] for _ in range(L//2)]
  #print("W \n", Wi)
  V_pad = np.pad(V, (0,stride), 'wrap')
  #print("V_pad \n", V_pad)
  for i in range(0, L, stride): # row
    for j in range(0, L, stride): # col
      conv_res = np.multiply(np.array(Wi), V_pad[i:i+3,j:j+3])
      res[i//stride].append(conv_res.sum())
  #print("res", res)
  return np.array(res)

def convolution_h_to_v(Wi, H_zp, L):
  #print("W", Wi)
  res = [[] for _ in range(L)]
  for i in range(0, L, 1): # row
    for j in range(0, L, 1): # col
      conv_res = np.multiply(Wi, H_zp[i:i+3,j:j+3])
      res[i].append(conv_res.sum())
  return np.array(res)

def zero_pad(Hi):
  res = np.dstack((np.zeros_like(Hi),Hi)).reshape(Hi.shape[0],-1)
  for i in range(0,len(res)*2,2):
    res = np.insert(res,i,0,axis=0)
  res = np.pad(res, (1,1), 'wrap')
  return np.array(res)


# Energy Function. 2d array??
def energy_function(W,h,v,L,stride,b,c):
  return (  -sum([np.multiply(h[k].T, convolution_v_to_h(W[k],v,L, stride)).sum() for k in range(len(W))]) 
            -b*sum([h_k.sum() for h_k in h]) - c*sum([sum(v_i) for v_i in v])   )

In [585]:
print("H", len(H), len(H[0]))
for iter in range(200):
    if iter%2 == 0: # V to H
        #print("V", V)
        for i in range(len(W)):
            W_flipped = np.flip(W[i], axis=1)
            W_flipped = np.flip(W[i], axis=0)
            H_conv = convolution_v_to_h(W_flipped, V, L, stride)
            #print("W flipped \n", W_flipped)
            #print("H",i,", \n", H_conv)
            tanh_vec = np.array(np.tanh(H_conv + bias_hidden))
            H[i] = sample_with_prob(tanh_vec)
            #print("H sampled",i,", \n", H[i])
    else: # H to V
        V_conv_sum = np.zeros(V.shape)
        for i in range(len(W)):
            H_zp = zero_pad(H[i])
            #print("H",i,", \n", H_zp)

            V_conv = convolution_h_to_v(W[i] , H_zp, L)
            #print("Vconv" ,V_conv)
            
            V_conv_sum = V_conv_sum + V_conv
        #print("V conv sum\n" , V_conv_sum)
        tanh_vec = np.array(np.tanh(V_conv_sum + bias_visible))
        #print("tanh vec", tanh_vec)
        V = sample_with_prob(tanh_vec)
        #print(V)
        samples.append(V.copy())
        #istr = [str(i) for i in V]
        #samples.append(",".join(istr))
        energy = energy_function(W,H,V,L,stride,bias_visible,bias_hidden)
        if energy < min_energy:
            min_energy_sample = V.copy()
            min_energy = energy

H 10 5


In [586]:

freq = 0
for item in samples:
    if np.array_equal(item, min_energy_sample):
        freq += 1
print(freq)
print("\nMin Energy : ", min_energy, " is at \n", min_energy_sample)

1

Min Energy :  -88.0  is at 
 [[-1.  1. -1.  1. -1. -1.  1.  1. -1.  1.]
 [ 1. -1.  1. -1.  1. -1.  1. -1.  1. -1.]
 [-1.  1. -1.  1. -1.  1. -1.  1. -1.  1.]
 [ 1. -1.  1. -1.  1. -1.  1.  1. -1. -1.]
 [ 1. -1.  1.  1.  1. -1.  1. -1.  1. -1.]
 [-1.  1. -1.  1. -1.  1. -1.  1. -1.  1.]
 [ 1. -1.  1. -1.  1. -1.  1. -1.  1. -1.]
 [-1.  1. -1.  1. -1.  1. -1.  1. -1.  1.]
 [ 1. -1.  1. -1.  1. -1.  1. -1. -1.  1.]
 [ 1. -1.  1.  1. -1.  1. -1.  1.  1. -1.]]
