In [None]:
from typing import Dict
import tempfile
from pathlib import Path
import numpy as np
from mlflow.tracking import MlflowClient

FPS = 10
B, T = 5, 50#We take B dreams and put them together, We take T images of each dream

def download_artifact_npz(run_id, artifact_path) -> Dict[str, np.ndarray]:
    client = MlflowClient()
    with tempfile.TemporaryDirectory() as tmpdir:
        path = client.download_artifacts(run_id, artifact_path, tmpdir)
        with Path(path).open('rb') as f:
            data = np.load(f)
            return {k: data[k] for k in data.keys()}  # type: ignore

def encode_gif(frames, fps):
    # Copyright Danijar
    from subprocess import Popen, PIPE
    h, w, c = frames[0].shape
    pxfmt = {1: 'gray', 3: 'rgb24'}[c]
    cmd = ' '.join([
        'ffmpeg -y -f rawvideo -vcodec rawvideo',
        f'-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex',
        '[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse',
        f'-r {fps:.02f} -f gif -'])
    proc = Popen(cmd.split(' '), stdin=PIPE, stdout=PIPE, stderr=PIPE)
    for image in frames:
        proc.stdin.write(image.tobytes())  # type: ignore
    out, err = proc.communicate()
    if proc.returncode:
        raise IOError('\n'.join([' '.join(cmd), err.decode('utf8')]))
    del proc
    return out

def make_gif(env_name, run_id, step, fps=FPS):
    dest_path = f'figures/dream_{env_name}_{step}.gif'
    artifact = f'd2_wm_dream/{step}.npz'
    data = download_artifact_npz(run_id,artifact) 
    print(data.keys())
    img = data['image_pred']
    print("Img shape")
    print(img.shape)
    print(f"other shape {img[:B, :T].reshape(-1, 64, 64, 3).shape}")
    img = img[:B, :T].reshape((-1, 64, 64, 3))
    gif = encode_gif(img, fps)
    with Path(dest_path).open('wb') as f:
        f.write(gif)


def make_gif_episode(env_name, run_id, step, fps=FPS):
    dest_path = f'figures/episode_{env_name}_{step}.gif'
    artifact = f'episodes/0/{step}.npz'
    data = download_artifact_npz(run_id,artifact) 
    print(data.keys())
    img = data['image_t'].transpose(3, 0, 1, 2)  # HWCT => THWC
    print("Img shape")
    print(img.shape)
    img = img.reshape(-1, 64, 64, 3,order='F')
    # print(f"other shape {img[:B, :T].reshape(-1, 64, 64, 3).shape}")
    # img = img[:B, :T].reshape((-1, 64, 64, 3))
    gif = encode_gif(img, fps)
    with Path(dest_path).open('wb') as f:
        f.write(gif)

def make_gif_episode_eval(env_name, run_id, step, fps=FPS):
    dest_path = f'figures/episode_{env_name}_{step}.gif'
    artifact = f'episodes_eval/0/{step}.npz'
    data = download_artifact_npz(run_id,artifact) 
    print(data.keys())
    img = data['image_t'].transpose(3, 0, 1, 2)  # HWCT => THWC
    print("Img shape")
    print(img.shape)
    img = img.reshape(-1, 64, 64, 3,order='F')
    # print(f"other shape {img[:B, :T].reshape(-1, 64, 64, 3).shape}")
    # img = img[:B, :T].reshape((-1, 64, 64, 3))
    gif = encode_gif(img, fps)
    with Path(dest_path).open('wb') as f:
        f.write(gif)



def make_gif_minigrid(env_name, run_id, step, fps=FPS):
    dest_path = f'figures/dream_{env_name}_{step}.gif'
    artifact = f'd2_wm_dream/{step}.npz'
    data = download_artifact_npz(run_id, artifact)#Minigrid
    print(data.keys())
    img = data['image_pred']
    print("Img shape")
    print(img.shape)
    # print(f"other shape {img[:B, :T].reshape(-1, 64, 64, 3).shape}")
    img = img[:B, :T,:,:,:4].reshape((-1, 7, 7, 4))
    l = len(img[:,0,0,0])
    new_shape = (img.shape[0],img.shape[1],img.shape[2],3)
    print(new_shape)
    rgb = np.empty(new_shape)
    # for i in range(l):
    #     rgb[i] = rgba2rgb(img[i,:,:,:])
    print("new image shape")
    print(rgb.shape)
    gif = encode_gif(rgb, fps)
    with Path(dest_path).open('wb') as f:
        f.write(gif)


def rgba2rgb( rgba, background=(255,255,255) ):
    row, col, ch = rgba.shape#Minigrid
    if ch == 3:
        return rgba

    assert ch == 4, 'RGBA image has 4 channels.'

    rgb = np.zeros( (row, col, 3), dtype='float32' )
    r, g, b, a = rgba[:,:,0], rgba[:,:,1], rgba[:,:,2], rgba[:,:,3]

    a = np.asarray( a, dtype='float32' ) / 255.0

    R, G, B = background

    rgb[:,:,0] = r * a + (1.0 - a) * R
    rgb[:,:,1] = g * a + (1.0 - a) * G
    rgb[:,:,2] = b * a + (1.0 - a) * B

    return np.asarray( rgb, dtype='uint8' )

In [None]:
#Minigrid

# make_gif_minigrid('minigrid', '2fdd91da643b4b20a6f06d398f5c554f', '0001001')
# make_gif_minigrid('minigrid', '342ccaea0b0b4812929cb5433bac3510', '0002001')


In [None]:
# Adventure 
# make_gif('adventure', '261d3a26b2b842ec990a8d0a5d6111ac', '0000001')
# make_gif('adventure', '261d3a26b2b842ec990a8d0a5d6111ac', '0001001')
# make_gif('adventure', '261d3a26b2b842ec990a8d0a5d6111ac', '0086001')
# make_gif('adventure', '261d3a26b2b842ec990a8d0a5d6111ac', '0150001')
# make_gif('adventure', '261d3a26b2b842ec990a8d0a5d6111ac', '0201001')
# make_gif('adventure', '261d3a26b2b842ec990a8d0a5d6111ac', '0250001')
# make_gif('adventure', '261d3a26b2b842ec990a8d0a5d6111ac', '0300001')
# make_gif('adventure', '261d3a26b2b842ec990a8d0a5d6111ac', '0314001')
# make_gif('adventure', '261d3a26b2b842ec990a8d0a5d6111ac', '0308001')
make_gif_episode('adventure', '261d3a26b2b842ec990a8d0a5d6111ac', 'ep000257_000257-0-r0-1001')
make_gif_episode_eval('adventure', '261d3a26b2b842ec990a8d0a5d6111ac', 'ep000283_000283-5-r0-1142')



In [None]:
# Montezuma

make_gif('montezuma', '599e69d178ca4f65a10423d272f9f45d', '0500001')

In [None]:
# Breakout

make_gif('breakout', '83e5def4975242ccbf16a3ca8f62a674', '0500001')

In [None]:
# Space invaders

make_gif('invaders', '6d57d49ab844475cbb83b606816b01fe', '0500001')

In [None]:
# DMC quadruped

make_gif('quadruped', 'ff6cb24c04de4e6b821bb811c855d207', '0300001')

In [None]:
# DMLab goals small

make_gif('dmlab', '6f78cce067464e8aa4bcb6f35a1a4386', '0161001', fps=8)

In [None]:
# MiniWorld ScavengerHunt

make_gif('scavenger', '123b575400874f5db75ac7887f4e61c0', '0900001')

In [None]:
# make_gif('pong', '6e7cd15f26854e42a458c358d21b65c9', '0000001')
make_gif('pong', 'a4efeae409604aa4a0f8455488dae462', '0002001')
make_gif('pong', 'a4efeae409604aa4a0f8455488dae462', '0004001')
make_gif('pong', 'a4efeae409604aa4a0f8455488dae462', '0006001')

