In [None]:
%matplotlib inline
import sys, time
import pyopencl as cl
import tables
import numpy as np
import scipy.misc as scp
import matplotlib.pyplot as plt

In [None]:
def init_data(population_size, resolution):
    population = np.zeros((population_size, 16), dtype=np.int32)

    max_power = 128
    max_mass = 128

    res_x = resolution[0]
    res_y = resolution[1]
    res_z = resolution[2]

    #initial position
    population[:,0] = np.random.randint(res_x, size=population_size)[:]

    population[:,1] = np.random.randint(res_y, size=population_size)[:]

    population[:,2] = np.random.randint(res_z, size=population_size)[:]
    
    #velocity + position provides vector - i.e. initial debt
    population[:,3] = 0
    population[:,4] = 0
    population[:,5] = 0

    #mass, power randomized? should be part of genome - eventually?
    population[:,6] = np.random.randint(max_power/4, max_power, size=population_size)[:]

    population[:,7] = np.random.randint(max_mass/2, max_mass, size=population_size)[:]
    
    #genomic weights to be used as bytestrings
    population[:,8:16] = 0 

    return population

In [None]:
def im(population, resolution, col=None):
    res_x = resolution[0]
    res_y = resolution[1]
    
    if col is None:
        col = np.asarray([255,255,255])
    
    flat_world = np.zeros((res_x,res_y,3), dtype=np.uint8)
    
    flat_world[population[:,0],population[:,1],0] = col[0]
    flat_world[population[:,0],population[:,1],1] = col[1]
    flat_world[population[:,0],population[:,1],2] = col[2]
    return flat_world

In [None]:
def draw(population, resolution, count):
    res_x = resolution[0]
    res_y = resolution[1]

    flat_world = np.zeros((res_x, res_y), dtype=bool)
    flat_world[population[:,0],population[:,1]] = 1
    scp.imsave("./out/image/frame_{0:05d}.png".format(count),flat_world.astype(bool))

In [None]:
def form_world(population, resolution):
    res_x = resolution[0]
    res_y = resolution[1]
    res_z = resolution[2]

    world = np.zeros((res_x, res_y, res_z), dtype=np.int32)
    world[population[:,0],population[:,1], population[:,2]] = 1
    return world

In [None]:
def save_as_hd5(population, resolution, filepath=None):
    if filepath:
      h5_out = tables.open_file(filepath, mode='w', title="Starlings")
    else:
      h5_out = tables.open_file('./starlings.h5', mode='w', title="Starlings")
    root = h5_out.root
    h5_out.create_array(root, "population", population)
    h5_out.create_array(root, "resolution", resolution)
    h5_out.close()

In [None]:
def read_from_hd5(filepath=None):
    if filepath:
      h5_in = tables.open_file(filepath, mode='r')
    else:
      h5_in = tables.open_file('./starlings.h5', mode='r')
    population = h5_in.get_node("/population").read()
    resolution = h5_in.get_node("/resolution").read()
    h5_in.close()
    return population, resolution

In [None]:
class OpenCl(object):
    def __init__(self):
        self.ctx, self.queue = self.cl_init()
        self.program = self.cl_load_program("./kernal.cl")

    def cl_init(self):
        platforms = cl.get_platforms()
        ctx = cl.create_some_context()
        queue = cl.CommandQueue(ctx)
        return ctx, queue

    def cl_load_program(self, filepath):
        f = open(filepath, 'r')
        fstr = "".join(f.readlines())
        program = cl.Program(self.ctx, fstr).build()
        return program

    def cl_load_data(self, population, world):
        mf = cl.mem_flags
        out = cl.Buffer(self.ctx, mf.WRITE_ONLY, population.nbytes)
        population_cl = cl.Buffer(self.ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=population)
        world_cl = cl.image_from_array(self.ctx, world, mode="r")
        return population_cl, world_cl, out

    def execute(self, num, population, world, inner_rad, outer_rad):
        population_cl, world_cl, out = self.cl_load_data(population, world)

        world_x = np.int32(world.shape[0])
        world_y = np.int32(world.shape[1])
        world_z = np.int32(world.shape[2])

        constants = np.asarray([0, 0], dtype=np.int32)

        ret = np.zeros_like(population)        

        global_size = ((num),)
        local_size = None

        kernalargs = (
                      population_cl,
                      out,
                      world_cl,
                      world_x, world_y, world_z,
                      inner_rad, outer_rad
                     )

        image_sequence = []

        self.program.knn(self.queue, global_size, local_size, *(kernalargs)).wait()
        cl.enqueue_copy(self.queue, ret, out)
        #cl.enqueue_read_buffer(self.queue, out, ret)


        #supersitious drivel!
        out.release()
        population_cl.release()
        world_cl.release()
        
        #self.queue.finish()
        return ret

In [None]:
def plot_(starlings, resolution, count):
    plt.figure(figsize = (20,20))
    
    plt.imshow(im(starlings,resolution), interpolation='nearest')
    
    ##last position
    a_pos = np.asarray([starlings[1066,0] - starlings[1066,3],
                        starlings[1066,1] - starlings[1066,4]])
    
    b_pos = np.asarray([starlings[1067,0] - starlings[1067,3],
                        starlings[1067,1] - starlings[1067,4]])
    
    c_pos = np.asarray([starlings[1068,0] - starlings[1068,3],
                        starlings[1068,1] - starlings[1068,4]])
    
    ##delta pos
    plt.plot([starlings[1066,1],a_pos[1]], 
             [starlings[1066,0],a_pos[0]], 
             lw="3", c='red')
    plt.plot([starlings[1067,1],b_pos[1]], 
             [starlings[1067,0],b_pos[0]], 
             lw="3", c='green')
    plt.plot([starlings[1068,1],c_pos[1]], 
             [starlings[1068,0],c_pos[0]]
             , lw="3", c='magenta')
    
    ##last position
    plt.scatter(a_pos[1], a_pos[0], 
                s=600, c='red', alpha=.5, marker='d')
    plt.scatter(b_pos[1], b_pos[0], 
                s=600, c='green', alpha=.5, marker='d')
    plt.scatter(c_pos[1], c_pos[0], 
                s=600, c='magenta', alpha=.5, marker='d')
    
    ##intended line
#     plt.plot([a_pos[1], a_pos[1] + starlings[1066, 11]], 
#              [a_pos[0], a_pos[0] + starlings[1066, 10]], 
#              lw="3", c='red')
#     plt.plot([b_pos[1], b_pos[1] + starlings[1067, 11]], 
#              [b_pos[0], b_pos[0] + starlings[1067, 10]], 
#              lw="3", c='green')
#     plt.plot([c_pos[1], c_pos[1] + starlings[1068, 11]], 
#              [c_pos[0], c_pos[0] + starlings[1068, 10]],
#              lw="3", c='magenta')
    
    ##intended position
#     plt.scatter(a_pos[1] + starlings[1066, 11], a_pos[0] + starlings[1066, 10], 
#                 s=600, c='red', alpha=.5, marker='s')
#     plt.scatter(b_pos[1] + starlings[1067, 11], b_pos[0] + starlings[1067, 10], 
#                 s=600, c='green', alpha=.5, marker='s')
#     plt.scatter(c_pos[1] + starlings[1068, 11], c_pos[0] + starlings[1068, 10], 
#                 s=600, c='magenta', alpha=.5, marker='s')
    
    plt.scatter(starlings[1066,1], starlings[1066,0], s=600, c='red', )
    plt.scatter(starlings[1067,1], starlings[1067,0], s=600, c='green')
    plt.scatter(starlings[1068,1], starlings[1068,0], s=600, c='magenta')
    
    plt.title(str(count), size=64)
    
    return plt

In [None]:
def filter_velocity(starlings, limit):
    s_ind = np.where(
        (starlings[:,3] > (-1 * fi)) & 
        (starlings[:,3] < (fi)) & 
        (starlings[:,4] > (-1 * fi)) & 
        (starlings[:,4] < (fi)))[0]
    return starlings[s_ind]

In [None]:
num = 640 * 480
resolution = [480, 640, 1]

# num = 1920 * 1080
# resolution = [1080, 1920, 1]

starlings = init_data(num, resolution)
world = form_world(starlings, resolution)

opcl = OpenCl()

count = 0

inner_rad = np.int32(16)
outer_rad = np.int32(32)

stop = 256

frames = []

fi = 10

while count < stop:
    
    frames.append(starlings)

    draw(starlings, resolution, count)
      
#     print count
    plot_(starlings, resolution, count).show()
    
    try:
        starlings = opcl.execute(num, starlings, world, inner_rad, outer_rad)
        #starlings = filter_velocity(starlings, fi)
    except cl.RuntimeError, e:
        print str(e)
        np.set_printoptions(threshold=np.nan, linewidth=512)
        print starlings[:,0:15]
        if filepath and filepath[0]:
            save_as_hd5(starlings, resolution, filepath[0])
        else:
            save_as_hd5(starlings, resolution)
        raise
    world = form_world(starlings, resolution)

    count += 1
    

In [None]:
frames = np.asarray(frames)

def plot_frame(i):
#     valid_units = np.where(
#         (frames[i,:,3] > -50) & 
#         (frames[i,:,3] < 50) & 
#         (frames[i,:,4] > -50) & 
#         (frames[i,:,4] < 50))[0]
    
#     H, xedges, yedges = np.histogram2d(
#         frames[i,valid_units,4], 
#         frames[i,valid_units,3], 
#         bins=141)
    
    H, xedges, yedges = np.histogram2d(
        frames[i,:,4], 
        frames[i,:,3], 
        bins=141)

    H = H.T

    Hmasked = np.ma.masked_where(H==0,H) # Mask pixels with a value of zero

    plt.figure(figsize = (20,20))

    # Plot 2D histogram using pcolor
    plt.pcolormesh(xedges,yedges,Hmasked, antialiased=True)
    plt.xlabel('x')
    plt.ylabel('y')
    cbar = plt.colorbar()
    cbar.ax.set_ylabel('Counts')

    plt.title(str(i), size=64)
    #plt.axis([-50, 50, -50 ,50])
    plt.show()
    
for i in np.arange(frames[:,0,0].shape[0]):
    plot_frame(i)

In [None]:
fi = 60
frames = np.asarray(frames)

for i in np.arange(frames[:,0,0].shape[0]):
    s_ind = np.where(
        (frames[i,:,3] > (-1 * fi)) & 
        (frames[i,:,3] < (fi)) & 
        (frames[i,:,4] > (-1 * fi)) & 
        (frames[i,:,4] < (fi)))[0]

    f_ind = np.where(
        (frames[i,:,3] < (-1 * fi)) | 
        (frames[i,:,3] > (fi)) |
        (frames[i,:,4] < (-1 * fi)) | 
        (frames[i,:,4] > (fi)))[0]
    
    slow_star = frames[i,s_ind,:]
    
    slow_c = np.asarray([200,200,0])
    plt.imshow(im(slow_star,resolution,col=slow_c), interpolation='nearest')
    
    plt.show()

    fast_star = frames[i,f_ind,:]

    fast_c = np.asarray([0,200,200])
    plt.imshow(im(fast_star,resolution,col=fast_c), interpolation='nearest')
    
    plt.show()

##TODO - debugging & making a better separation/cohesion algo
1. use a filter to keep speeds down
   1. want to avoid velocities greater than the frame of reference
   1. i.e. cohesion reduces to smallest possible localities
1. get cohesion and separation working separately
   1. cohesion already works - with a speed filter.
   1. separation does NOT work - it's really just a flavor of cohesion that rewards distance. Not all the way there yet
1. translate the discrete algorithms into a a continuity that may look something like:
$$\sum_{0, size.retina} reaction.given(distance.to.observation, relation.to.velocity.vector)$$