In [1]:
import os
import sys
import git
import pathlib

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

PROJ_ROOT_PATH = pathlib.Path(git.Repo('.', search_parent_directories=True).working_tree_dir)
PROJ_ROOT =  str(PROJ_ROOT_PATH)
if PROJ_ROOT not in sys.path:
    sys.path.append(PROJ_ROOT)

print(f"Project Root Directory: {PROJ_ROOT}")

Project Root Directory: /repos/drl_csense


In [2]:
import numpy as np
import matplotlib.pyplot as plt

In [3]:
import scipy.fftpack as spfft
from sklearn.linear_model import Lasso

In [4]:
import gymnasium as gym

In [5]:
from lib.folder_paths import makeget_logging_dir
from lib.env_utils import AtariWrapper_NoisyFrame, AtariWrapper_Compressed, make_atari_env_Custom_VecFrameStack

In [6]:
from lib.folder_paths import get_exp_name_tag, deconstruct_exp_name

In [7]:
import imageio
import ipyplot

In [8]:
def idct2(x):
        return spfft.idct(spfft.idct(x.T, norm='ortho', axis=0).T, norm='ortho', axis=0)
nx = 84
ny = 84
A = np.kron(
    spfft.idct(np.identity(nx), norm='ortho', axis=0),
    spfft.idct(np.identity(ny), norm='ortho', axis=0)
    )

In [9]:
def reconstruct_frame(frame, alpha, max_iter):
    flat_frame = frame.T.reshape(-1,1)
    idx_nonzero = np.where(flat_frame != 0)[0]

    # compressed sample
    b = flat_frame[idx_nonzero]
    # compression matrix @ transform matrix
    Ac = A[idx_nonzero,:]

    # LASSO optimization
    lasso = Lasso(alpha=alpha, max_iter=int(max_iter), warm_start=True, selection="random")
    lasso.fit(Ac, b)

    # Reconstruct frame
    Xat = np.array(lasso.coef_).reshape(nx, ny).T # stack columns

    # Get the reconstructed frame
    Xa = idct2(Xat)

    return Xa

In [10]:
env_id = "BreakoutNoFrameskip-v4"
exp_param_type = "noisy"
exp_param_value = 0.0
run_no = 0
model_type = "best"
eval_param_type = exp_param_type

In [11]:
# eval_param_value = 0.3

# # Get names and tags of experiment
# exp_name, exp_metaname, exp_tag = get_exp_name_tag(env_id, exp_param_type, exp_param_value)

# # Get directories
# models_dir, log_dir, gif_dir, image_dir = makeget_logging_dir(exp_name)

# # Load gif file
# obs_gif_file = f"{gif_dir}/{exp_name}-run_{run_no}--eval_{model_type}-{eval_param_type}_{eval_param_value}--obs.gif"

# obs_gif_frames = imageio.mimread(obs_gif_file)

# frame_no = 12
# gif_frame = obs_gif_frames[frame_no]

# frame = gif_frame[:,:,0]/255

# plt.imshow(frame, cmap="gray")
# plt.axis("off")

# for i,alpha in enumerate([1E-5, 1E-6, 1E-7]):
#     reconstructed_frame = reconstruct_frame(frame, alpha=alpha, max_iter=1E4)
#     plt.figure(i)
#     plt.title(f"{alpha=}")
#     plt.axis("off")
#     plt.imshow(reconstructed_frame,cmap="gray")

In [14]:
# eval_param_value = 0.3
for eval_param_value in reversed([0.0,0.1,0.2, 0.3, 0.4]):
    alpha = 1E-5
    print()
    print(f"{eval_param_value=}")
    # Get names and tags of experiment
    exp_name, exp_metaname, exp_tag = get_exp_name_tag(env_id, exp_param_type, exp_param_value)
    
    # Get directories
    models_dir, log_dir, gif_dir, image_dir = makeget_logging_dir(exp_name)
    
    # Load gif file
    obs_gif_file = f"{gif_dir}/{exp_name}-run_{run_no}--eval_{model_type}-{eval_param_type}_{eval_param_value}--obs.gif"
    
    obs_gif_frames = imageio.mimread(obs_gif_file)
    
    reconstructed_frame_list = []
    for i,gif_frame in enumerate(obs_gif_frames[:50]):
        print(f"Processing Frame: {i}", end='\x1b[1K\r')
        # Take only one of the layers and normalize
        frame = gif_frame[:,:,0]/255
        reconstructed_frame = reconstruct_frame(frame, alpha=alpha, max_iter=1E4)
        reconstructed_frame = 255 * (reconstructed_frame - np.min(reconstructed_frame))/(np.max(reconstructed_frame) - np.min(reconstructed_frame))
        reconstructed_frame_list.append(reconstructed_frame)
        # reconstructed_frame_list.append(frame*255)
        
    
    # Convert obss to animation
    reconstructed_gif_file = f"{eval_param_type}_{eval_param_value}--reconstructed-alpha_1e-5.gif"
    imageio.mimsave(reconstructed_gif_file, 
                    [np.array(reconstructed_frame) for i, reconstructed_frame in enumerate(reconstructed_frame_list)], duration=100)


eval_param_value=0.4
Processing Frame: 49[1K
eval_param_value=0.3
Processing Frame: 49[1K
eval_param_value=0.2
Processing Frame: 49[1K
eval_param_value=0.1
Processing Frame: 49[1K
eval_param_value=0.0
Processing Frame: 49[1K