In [1]:
import matplotlib.pyplot as plt
import numpy as np
import numpy.linalg as la

from src.structured_random_features.src.models.weights import V1_weights

# Packages for fft and fitting data
from scipy import fftpack as fft
from sklearn.linear_model import Lasso

# Package for importing image representation
from PIL import Image, ImageOps

from src.V1_Compress import generate_Y, compress
import pandas as pd
import itertools
import dask
from dask.distributed import Client, progress
import seaborn as sns
import time
import os.path

In [2]:
def opt_hyperparams(data): 
    # Try to use pd group_by to group repetition and get avg
    
    # Among those average, get the lowest error hyperparam
    ordered_data = pd.DataFrame(data).sort_values(by = 'error', ascending = True)
    print(ordered_data.head(5))
    
    return ordered_data.head(1)

In [3]:
def run_sim(rep, alp, num, sz, freq, img_arr):
    num = int(num)
    img_arr = np.array([img_arr]).squeeze()
    dim = img_arr.shape
    n, m = dim

    # Generate V1 weight with y
    W = V1_weights(num, dim, sz, freq) 
    y = generate_Y(W, img_arr)
    W_model = W.reshape(num, n, m)
    
    # Call function and calculate error
    theta, reform, s = compress(W_model, y, alp)
    error = np.linalg.norm(img - reform, 'fro') / np.sqrt(m*n)
    
    return error, theta, reform, s

In [5]:
image_path_list = ['image/city_part2.png', 'image/city_part3.png']
for image_path in image_path_list:
    #DF version after looking at Desk method

    # Set up hyperparameters that would affect results
    delay_list = []
    params = []
    alpha = np.logspace(-3, 3, 7)
    rep = np.arange(10)
    num_cell = [100, 200, 500]
    cell_sz = [2, 5, 7]
    sparse_freq = [1, 2, 5]

    # Load Image
#     image_path = 'image/city_part2.png'
    image_nm = image_path.split('/')[1].split('.')[0]
    img = Image.open(image_path)
    img = ImageOps.grayscale(img)
    img_arr = np.asarray(img)

    save_path = os.path.join("result/{img_nm}/V1".format(img_nm = image_nm))



    search_list = [rep, alpha, num_cell, cell_sz, sparse_freq]

    # All combinations of hyperparameter to try
    search = list(itertools.product(*search_list))             
    search_df = pd.DataFrame(search, columns= [ 'rep', 'alp', 'num_cell', 'cell_sz', 'sparse_freq'])
    print(search_df.head())

    # Call dask
    client = Client()
    client

    # counter = 0; # Keep track of number of iteration. Debugging method
    for p in search_df.values:
        delay = dask.delayed(run_sim)(*p, img_arr)
        delay_list.append(delay)

    print('running dask completed')

    futures = dask.persist(*delay_list)
    print('futures completed')
    progress(futures)
    print('progressing futures')

    # Compute the result
    results = dask.compute(*futures)
    print('result computed')
    results_df = pd.DataFrame(results, columns=['error', 'theta', 'reform', 's'])

    # Add error onto parameter
    params_result_df = search_df.join(results_df['error'])

    # save parameter_error data with error_results data
    params_result_df.to_csv(os.path.join(save_path, "param_" + "_".join(str.split(time.ctime().replace(":", "_"))) + ".csv"))
    results_df.to_csv(os.path.join(save_path, "result_" + "_".join(str.split(time.ctime().replace(":", "_"))) + ".csv"))




   rep    alp  num_cell  cell_sz  sparse_freq
0    0  0.001       100        2            1
1    0  0.001       100        2            2
2    0  0.001       100        2            5
3    0  0.001       100        5            1
4    0  0.001       100        5            2


Perhaps you already have a cluster running?
Hosting the HTTP server on port 52027 instead


running dask completed
futures completed
progressing futures
result computed
   rep    alp  num_cell  cell_sz  sparse_freq
0    0  0.001       100        2            1
1    0  0.001       100        2            2
2    0  0.001       100        2            5
3    0  0.001       100        5            1
4    0  0.001       100        5            2


Perhaps you already have a cluster running?
Hosting the HTTP server on port 52115 instead


running dask completed
futures completed
progressing futures
result computed


In [None]:
results_df = pd.DataFrame(results, columns=['error', 'theta', 'reform', 's'])
results_df

In [None]:
temp = search_df.join(results_df['error'])
opt_hyperparams(temp)

In [None]:
image_path = 'image/tree_part1.jpg'