In [1]:
import os
import sys
import git
import pathlib
import cv2 
import sympy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

PROJ_ROOT_PATH = pathlib.Path(git.Repo('.', search_parent_directories=True).working_tree_dir)
PROJ_ROOT =  str(PROJ_ROOT_PATH)
if PROJ_ROOT not in sys.path:
    sys.path.append(PROJ_ROOT)

print(f"Project Root Directory: {PROJ_ROOT}")

Project Root Directory: /repos/drl_csense


In [2]:
import numpy as np
import matplotlib.pyplot as plt

In [3]:
from scipy.linalg import hadamard
from sklearn.linear_model import Lasso

In [4]:
import imageio
import ipyplot

In [5]:
alpha=1E-4
max_iter=1E3

In [6]:
def add_square_borders(image, new_dim):
    org_height, org_width = image.shape
    
    delta_h = new_dim - org_height
    delta_w = new_dim - org_width
    
    color = [0, 0, 0]
    top, bottom = delta_h//2, delta_h-(delta_h//2)
    left, right = delta_w//2, delta_w-(delta_w//2)
    square_image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT,value=color)

    return square_image
    

In [7]:
def reconstruct_frame_hadamard(noisyframe, lasso):
       
    ny, nx = noisyframe.shape
    A = hadamard(nx*ny)
    flat_frame = noisyframe.T.reshape(-1,1)
    
    idx_nonzero = np.where(flat_frame != 0)[0]
    
    # compressed sample
    b = flat_frame[idx_nonzero]
    # compression matrix @ transform matrix
    Ac = A[idx_nonzero,:]
    
    # LASSO optimization
    lasso.fit(Ac, b)
    
    # Reconstruct frame
    coeff = np.array(lasso.coef_)
    reconstructed_image = np.array(sympy.fwht(coeff)).astype(np.float32).reshape(width,height).T

    return reconstructed_image

In [8]:
# env_id = "BreakoutNoFrameskip-v4"
# noise = 0.4

ENV_LIST = ["BreakoutNoFrameskip-v4", "BankHeistNoFrameskip-v4", "WizardOfWorNoFrameskip-v4"]

NOISE_LIST = [0.0, 0.1, 0.2, 0.3, 0.4]
width = 128
height = 128

In [9]:
for env_id in ["BankHeistNoFrameskip-v4"]:#ENV_LIST:
    for noise in [0.3, 0.2, 0.1]:#NOISE_LIST:
        lasso = Lasso(alpha=alpha, max_iter=int(max_iter), warm_start=True, selection="random")
        print("--"*10)
        print(f"{env_id=}; {noise=}")    
        obs_gif_file = f"./gifs/{env_id}--noise_{noise}--obs.gif"
        obs_gif_frames = imageio.mimread(obs_gif_file)
        
        reconstructed_frame_list = []
        for i,gif_frame in enumerate(obs_gif_frames[:50]):
            print(f"Processing Frame: {i}", end='\x1b[1K\r')
            # Take only one of the layers and normalize
            frame = gif_frame[:,:,0]/255
            frame = add_square_borders(frame, 128)
            reconstructed_frame = reconstruct_frame_hadamard(frame, lasso)
            # Denormalize
            reconstructed_frame = 255 * (reconstructed_frame - np.min(reconstructed_frame))/(np.max(reconstructed_frame) - np.min(reconstructed_frame))
            reconstructed_frame_list.append(reconstructed_frame)
        # Convert obss to animation
        reconstructed_gif_file = f"./gifs/reconstructed/{env_id}--noise_{noise}--obs--reconstructed_hadamard.gif"
        imageio.mimsave(reconstructed_gif_file, 
                        [np.array(reconstructed_frame) for i, reconstructed_frame in enumerate(reconstructed_frame_list)], duration=100)

--------------------
env_id='BankHeistNoFrameskip-v4'; noise=0.3
Processing Frame: 0[1K

  model = cd_fast.enet_coordinate_descent(


Processing Frame: 5[1K

  model = cd_fast.enet_coordinate_descent(


Processing Frame: 7[1K

  model = cd_fast.enet_coordinate_descent(


Processing Frame: 8[1K

  model = cd_fast.enet_coordinate_descent(


Processing Frame: 13[1K

  model = cd_fast.enet_coordinate_descent(


Processing Frame: 14[1K

  model = cd_fast.enet_coordinate_descent(


Processing Frame: 18[1K

  model = cd_fast.enet_coordinate_descent(


Processing Frame: 19[1K

  model = cd_fast.enet_coordinate_descent(


Processing Frame: 20[1K

  model = cd_fast.enet_coordinate_descent(


Processing Frame: 22[1K

  model = cd_fast.enet_coordinate_descent(


Processing Frame: 28[1K

  model = cd_fast.enet_coordinate_descent(


--------------------[1K
env_id='BankHeistNoFrameskip-v4'; noise=0.2
Processing Frame: 0[1K

  model = cd_fast.enet_coordinate_descent(


Processing Frame: 5[1K

  model = cd_fast.enet_coordinate_descent(


Processing Frame: 14[1K

  model = cd_fast.enet_coordinate_descent(


Processing Frame: 18[1K

  model = cd_fast.enet_coordinate_descent(


Processing Frame: 19[1K

  model = cd_fast.enet_coordinate_descent(


Processing Frame: 22[1K

  model = cd_fast.enet_coordinate_descent(


Processing Frame: 24[1K

  model = cd_fast.enet_coordinate_descent(


Processing Frame: 39[1K

  model = cd_fast.enet_coordinate_descent(


--------------------[1K
env_id='BankHeistNoFrameskip-v4'; noise=0.1
Processing Frame: 0[1K

  model = cd_fast.enet_coordinate_descent(


Processing Frame: 5[1K

  model = cd_fast.enet_coordinate_descent(


Processing Frame: 18[1K

  model = cd_fast.enet_coordinate_descent(


Processing Frame: 20[1K

  model = cd_fast.enet_coordinate_descent(


Processing Frame: 24[1K

  model = cd_fast.enet_coordinate_descent(


Processing Frame: 49[1K