In [1]:
import os
import math

import torch
import numpy as np
from PIL import Image, ImageDraw

from scorefield.models.ddpm.denoising_diffusion import Unet
from scorefield.models.heat.heat_diffusion import HeatDiffusion
from scorefield.utils.rl_utils import load_config
from scorefield.utils.utils import (
    gen_goals, overlay_goal, overlay_multiple, combine_objects, overlay_images,
    overlay_goal_agent, overlay_goals_agent, log_num_check,
    draw_obstacles_pil, convert_to_obstacle_masks,
    randgen_obstacle_masks, draw_obstacles_pixel,
    vector_field, clip_vectors
)

import matplotlib.pyplot as plt

import imageio
import plotly.graph_objs as go
from plotly.subplots import make_subplots
import plotly.figure_factory as ff


# Args
config_dir = "./scorefield/configs/heat_diffusion.yaml"
args = load_config(config_dir)
device = args['device']

bg = Image.open('assets/toy_exp/background0.png')
wastes = []
wastes.append(Image.open('assets/toy_exp/waste0.png'))
# wastes.append(Image.open('assets/toy_exp/waste4.png'))
# wastes.append(Image.open('assets/toy_exp/waste5.png'))


img_size = args['image_size']
goal_bounds = args['goal_bounds']
goal_num = len(wastes)
agent_bounds = args['agent_bounds']
obstacle_pos = args['obstacles']

model_path = os.path.join(args['log_path'], args['model_path'])

u0 = args['u0']
heat_steps = args['heat_steps']
noise_steps = args['noise_steps']
sample_num = args['sample_num']
time_type = args['time_type']

epochs = args['epochs']
train_lr = args['train_lr']
iterations = args['iterations']
random_goals = args['random_goals']
batch_size = noise_steps * sample_num #args['batch_size']


diffusion = HeatDiffusion(
    image_size=img_size,
    u0 = u0,
    noise_steps=noise_steps,
    heat_steps=heat_steps,
    time_type=time_type,
    device=device,
)

obstacle_masks = convert_to_obstacle_masks(noise_steps, bg.size, img_size, obstacle_pos)
background = draw_obstacles_pixel(bg, obstacle_masks)
# background = [bg]
goal = torch.tensor([[[-0.7,-0.]]]*noise_steps, device=device)
obs = overlay_goal(background, img_size, wastes, goal)
t = torch.tensor([i for i in range(1, noise_steps+1, sample_num)], device=device)
heat, score, score_field, xt = diffusion.forward_diffusion(t, goal, sample_num, obstacle_masks)


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Heat visualization
heat_fig = go.Figure()

heat_fig.add_trace(
    go.Heatmap(z=heat[0].cpu().numpy(), colorscale='Viridis')
)
frames = [go.Frame(data=[go.Heatmap(z=batch.cpu().numpy(), colorscale='Viridis')],
                   name=str(i)) for i, batch in enumerate(heat)]
heat_fig.frames = frames

heat_fig.update_layout(
    width=600,
    height=600,
    margin=dict(t=40, b=40, l=40, r=40),
    updatemenus=[{
        'buttons': [
            {
                'args': [None, {'frame': {'duration': 500, 'redraw': True}, 'fromcurrent': True}],
                'label': 'Play',
                'method': 'animate'
            },
            {
                'args': [[None], {'frame': {'duration': 0, 'redraw': True}, 'mode': 'immediate', 'transition': {'duration': 0}}],
                'label': 'Pause',
                'method': 'animate'
            }
        ],
        'direction': 'left',
        'pad': {'r': 10, 't': 87},
        'showactive': False,
        'type': 'buttons',
        'x': 0.1,
        'xanchor': 'right',
        'y': 0,
        'yanchor': 'top'
    }],
    sliders=[{
        'active': 0,
        'yanchor': 'top',
        'xanchor': 'left',
        'currentvalue': {
            'font': {'size': 20},
            'prefix': 'Batch:',
            'visible': True,
            'xanchor': 'right'
        },
        'transition': {'duration': 300, 'easing': 'cubic-in-out'},
        'pad': {'b': 10, 't': 50},
        'len': 0.9,
        'x': 0.1,
        'y': 0,
        'steps': [
            {
                'args': [
                    [frame['name']],
                    {
                        'frame': {'duration': 300, 'redraw': True},
                        'mode': 'immediate'
                    }
                ],
                'label': frame['name'],
                'method': 'animate'
            } for frame in frames
        ]
    }]
)


heat_fig.show()

In [3]:
scale = 0.5

B = heat.shape[0]

score_fields_png = [] 
for i in range(B):
    data = score_field[i]
    V = data[...,0]
    U = data[...,1]

    # V_clip, U_clip = clip_vectors(V, U, 0.01)
    V_clip, U_clip = V, U
    x, y = np.meshgrid(np.linspace(0, img_size-1, img_size), np.linspace(0,img_size-1, img_size))

    plt.figure(figsize=(10, 10))
    plt.quiver(x, y, U_clip.cpu().numpy(), V_clip.cpu().numpy(), angles='xy', scale_units='xy', scale=scale)
    plt.gca().invert_yaxis()
    plt.grid(False)
    image_path = f'./logs/visualize/quiver_{i}.png'
    plt.savefig(image_path)
    plt.close()
    score_fields_png.append(image_path)
    
fig = make_subplots(rows=1, cols=2, subplot_titles=("Heat Distribution", "Score Field"))
fig.add_trace(go.Heatmap(z=heat[0].cpu().numpy(), colorscale='Viridis', showscale=False), row=1, col=1)
fig.add_trace(go.Image(z=imageio.imread(score_fields_png[0])), row=1, col=2)
    
frames=[]
for k in range(B):
    frame_data = [go.Heatmap(z=heat[k].cpu().numpy(), colorscale='Viridis', showscale=False), 
                  go.Image(z=imageio.imread(score_fields_png[k]))]
    
    frame = go.Frame(data=frame_data, name=str(k))
    frames.append(frame)

fig.frames = frames

fig.update_layout(
    updatemenus=[{
        'buttons': [
            {
                'args': [None, {'frame': {'duration': 500, 'redraw': True}, 'fromcurrent': True}],
                'label': 'Play',
                'method': 'animate'
            },
            {
                'args': [[None], {'frame': {'duration': 0, 'redraw': True}, 'mode': 'immediate', 'transition': {'duration': 0}}],
                'label': 'Pause',
                'method': 'animate'
            }
        ],
        'direction': 'left',
        'pad': {'r': 10, 't': 87},
        'showactive': False,
        'type': 'buttons',
        'x': 0.1,
        'xanchor': 'right',
        'y': 0,
        'yanchor': 'top'
    }],
    sliders=[{
        'active': 0,
        'yanchor': 'top',
        'xanchor': 'left',
        'currentvalue': {
            'font': {'size': 20},
            'prefix': 'Batch:',
            'visible': True,
            'xanchor': 'right'
        },
        'transition': {'duration': 300, 'easing': 'cubic-in-out'},
        'pad': {'b': 10, 't': 50},
        'len': 0.9,
        'x': 0.1,
        'y': 0,
        'steps': [
            {
                'args': [
                    [frame['name']],
                    {
                        'frame': {'duration': 300, 'redraw': True},
                        'mode': 'immediate'
                    }
                ],
                'label': frame['name'],
                'method': 'animate'
            } for frame in frames
        ]
    }]
)

fig.update_layout(
    width=1600,
    height=800,
)

fig.update_yaxes(scaleanchor="x", scaleratio=1, row=1, col=1)
fig.update_yaxes(autorange="reversed", row=1, col=1)

fig.show()






