In [None]:
import torch
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
from matplotlib.animation import FuncAnimation
from functools import reduce
from PIL import Image
import clip
import h5py
import io
import argparse
from os.path import isfile, join
from os import listdir
import os
import json
from copy import deepcopy


# Kalman Filter Implementation
class KalmanFilter:
    def __init__(self, num_classes, process_noise=1e-5, measurement_noise=1e-3):
        self.num_classes = num_classes
        
        # State vector (class probabilities)
        self.x = np.zeros(num_classes)  # Initial state
        self.P = np.eye(num_classes)   # Initial uncertainty
        
        # Transition matrix (Identity, assuming probabilities are stable)
        self.F = np.eye(num_classes)
        self.Q = np.eye(num_classes) * process_noise  # Process noise
        
        # Measurement matrix (Identity)
        self.H = np.eye(num_classes)
        self.R = np.eye(num_classes) * measurement_noise  # Measurement noise

    def predict(self):
        # Predict state and uncertainty
        self.x = self.F @ self.x
        self.P = self.F @ self.P @ self.F.T + self.Q

    def update(self, z):
        # Kalman Gain
        S = self.H @ self.P @ self.H.T + self.R
        K = self.P @ self.H.T @ np.linalg.inv(S)
        
        # Update state and uncertainty
        self.x = self.x + K @ (z - self.H @ self.x)
        self.P = (np.eye(self.num_classes) - K @ self.H) @ self.P


def mkdir(path):
    if not os.path.exists(path):
        os.mkdir(path)
    return path

def get_mask_area(seg_img, colors):
    arrs = []
    for color in colors:
        arr = seg_img == color
        arr = arr.min(-1).astype("float32")
        arr = arr.reshape((arr.shape[-1], arr.shape[-1])).astype(bool)
        arrs.append(arr)
    return reduce(np.logical_or, arrs)

def blackout_image(depth_map, area):
    zero_depth_map = np.full(depth_map.shape, 255)
    zero_depth_map[area] = depth_map[area]
    return zero_depth_map

def color_image(depth_map, area, color):
    _depth_map = deepcopy(depth_map)
    _depth_map[area] = color
    return _depth_map

# Video processing
def process_video(hdf5_file_path, labels, mapping):
    thr = 5
    trial_name = hdf5_file_path.split('/')[-1][:-5]
    text = clip.tokenize(labels).to(device)

    with h5py.File(hdf5_file_path, "r") as f:
        object_ids = np.array(f['static']['object_ids'])
        object_segmentation_colors = np.array(f['static']['object_segmentation_colors'])
        model_names = np.array(f["static"]["model_names"])

        fixed_joints = []
        if "base_id" in np.array(f['static']) and "attachment_id" in np.array(f['static']):
            base_id = np.array(f['static']['base_id'])
            attachment_id = np.array(f['static']['attachment_id'])
            use_cap = np.array(f['static']['use_cap'])
            assert attachment_id.size==1
            assert base_id.size==1
            attachment_id = attachment_id.item()
            base_id = base_id.item()
            fixed_joints.append(base_id)
            fixed_joints.append(attachment_id)
            if use_cap:
                cap_id = attachment_id+1
                fixed_joints.append(cap_id)
        fixed_joint_ids = np.concatenate([np.where(object_ids==fixed_joint)[0] for fixed_joint in fixed_joints], axis=0).tolist() if fixed_joints else []
        fixed_joint_ids.sort()

        distractors = np.array(f['static']['distractors']) if np.array(f['static']['distractors']).size != 0 else None
        occluders = np.array(f['static']['occluders']) if np.array(f['static']['occluders']).size != 0 else None
        distractor_ids = np.concatenate([np.where(model_names==distractor)[0] for distractor in distractors], axis=0).tolist() if distractors else []
        occluder_ids = np.concatenate([np.where(model_names==occluder)[0] for occluder in occluders], axis=0).tolist() if occluders else []
        excluded_model_ids = distractor_ids+occluder_ids
        included_model_ids = [idx for idx in range(len(object_ids)) if idx not in excluded_model_ids]
        
        all_data = []
        for i, o_index in enumerate(included_model_ids):
            color = object_segmentation_colors[o_index]
            print('\t\t', f"object {i}")

            ims = []
            bar_data = []
            flag = True
            # kalman_filter = KalmanFilter(num_classes=len(labels))
            for key in f['frames'].keys():
                if int(key) < thr:
                    continue

                im_seg = np.array(Image.open(io.BytesIO(f['frames'][key]['images']['_id_cam0'][:])))
                if o_index in fixed_joint_ids:
                    if fixed_joint_ids.index(o_index) == 0:
                        for j, o_id in enumerate(fixed_joint_ids):
                            this_color = object_segmentation_colors[o_id]
                            this_area = get_mask_area(im_seg, [this_color])
                            if j == 0:
                                consistent_color = this_color
                            else:
                                im_seg = color_image(im_seg, this_area, consistent_color)
                    else:
                        flag = False
                        continue
    
                # print("still?")
                image = np.array(Image.open(io.BytesIO(f["frames"][key]["images"]["_img_cam0"][:])))
                
                area = get_mask_area(im_seg, [color])
                image_masked = blackout_image(image, area)
                image_masked = Image.fromarray(np.uint8(image_masked))

                image = preprocess(image_masked).unsqueeze(0).to(device)
                with torch.no_grad():
                    logits_per_image, _ = model(image, text)
                    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

                # # Kalman Filter steps
                # kalman_filter.predict()
                # kalman_filter.update(probs[0])
                
                # # Get smoothed predictions
                # smoothed_probs = kalman_filter.x
                smoothed_probs = probs[0]

                # make the animation
                ims.append(image_masked)
                bar_data.append(smoothed_probs.tolist())
                
            if not flag:
                continue
            n_frames = len(ims)
            fig, (ax_img, ax_bar) = plt.subplots(1, 2, figsize=(10, 5))
            # Initialize the panels
            img_display = ax_img.imshow(ims[0])
            ax_img.axis('off')  # Turn off axes for the image

            # Create the bar plot
            bars = ax_bar.bar(range(len(labels)), bar_data[0], color="skyblue")
            ax_bar.set_ylim(0, 1)  # Fix the y-axis limit for consistent display
            ax_bar.set_xticks(range(len(labels)))      # Ensure correct ticks are set
            ax_bar.set_xticklabels(list(mapping.keys()), rotation=90)  # Rotate tick labels by 90 degrees
            tick_labels = ax_bar.get_xticklabels()  # Get x-axis tick labels
            if o_index in fixed_joint_ids:
                if len(fixed_joint_ids)==3:
                    model_name = "cyn_cyn_cyn"
                elif len(fixed_joint_ids)==2:
                    if "cone" in [model_names[idd].decode("utf-8") for idd in fixed_joint_ids]:
                        model_name = "cyn_cone"
                    else:
                        model_name = "cyn_cyn"
                else:
                    model_name = model_names[o_index].decode("utf-8")
            else:
                model_name = model_names[o_index].decode("utf-8")
            index = list(mapping.keys()).index(model_name)
            tick_labels[index].set_color('red')  
            title = ax_bar.set_title(f"frame {thr}")  # Initial title

            # Update function for animation
            def update(frame):
                img_display.set_data(ims[frame])  # Update the RGB image

                max_height = max(bar_data[frame])
                for bar, new_height in zip(bars, bar_data[frame]):
                    bar.set_height(new_height)
                    if new_height == max_height:
                        bar.set_color('orange')  # Color the highest bar orange
                    else:
                        bar.set_color('skyblue')  # Reset other bars to sky blue
            
                title.set_text(f"frame {thr+frame}")
                return img_display, bars, title

            # save the animation
            mkdir(f'/ccn2/u/haw027/b3d_ipe/vis_original/{scenario}')
            ani = FuncAnimation(fig, update, frames=n_frames, blit=False, interval=100)
            ani.save(f'/ccn2/u/haw027/b3d_ipe/vis_original/{scenario}/{trial_name}_{i}.mp4',writer=animation.FFMpegWriter(fps=10))
            plt.close(fig)
        
            all_data.append((model_name, bar_data))

        with open(f'/ccn2/u/haw027/b3d_ipe/vis_original/{scenario}/{trial_name}.json', "w") as f:
            json.dump(all_data, f)

# Define label list and process video
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
labels = ["a bowl", "a cone", "a cuboid", "a cylinder", "an octahedron", "a pentagonal prism", "a tube", "an icosahedron", "a pyramid", "a ramp", "a sphere", "a ring", "a triangular prism", "a combination of two cylinders stacking on top of each other", "a combination of three cylinders stacking on top of each other", "a combination of a cylinder and a cone stacking on top of each other"]
mapping = {"bowl":"a bowl", "cone":"a cone", "cube":"a cuboid", "cylinder":"a cylinder", "octahedron":"an octahedron", "pentagon":"a pentagonal prism", "pipe":"a tube", "platonic":"an icosahedron", "pyramid":"a pyramid", "ramp_with_platform_30":"a ramp", "sphere":"a sphere", "torus":"a ring", "triangular_prism":"a triangular prism", "cyn_cyn":"a combination of two cylinders stacking on top of each other", "cyn_cyn_cyn":"a combination of three cylinders stacking on top of each other", "cyn_cone":"a combination of a cylinder and a cone stacking on top of each other"}

parser = argparse.ArgumentParser(description='Download all stimuli from S3.')
parser.add_argument('--scenario', type=str, default='Dominoes', help='name of the scenarios')
args = parser.parse_args()
scenario = args.scenario
print(scenario)

source_path = '/ccn2/u/rmvenkat/data/testing_physion/regenerate_from_old_commit/test_humans_consolidated/lf_0/'
save_path = '/ccn2/u/haw027/b3d_ipe/num_obj'
scenario_path = join(source_path, scenario+'_all_movies')
onlyhdf5 = [f for f in listdir(scenario_path) if isfile(join(scenario_path, f)) and join(scenario_path, f).endswith('.hdf5')]

for hdf5_file in onlyhdf5:
    trial_name = hdf5_file[:-5]
    print('\t', trial_name)
    hdf5_file_path = join(scenario_path, hdf5_file)
    process_video(hdf5_file_path, labels, mapping)
