# core

> Fill in a module description here

In [None]:
#| default_exp core

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import os
import pandas as pd
import numpy as np
import plotly.io as pio
from PIL import Image
import io
import imageio
from IPython.display import display
import plotly.express as px
import plotly.graph_objects as go



class GIF_PLOT(object):
    """
    base class for gif data should be time series in this case
    this class expects columns to be integer type.
    """
    def __init__(self, data, analysis_column, ideal_column, group_column, time_column, label_column):
        
        self.data = data
        self.analysis_column = analysis_column
        self.ideal_column = ideal_column
        self.group_column = group_column
        self.time_column = time_column
        self.label_column = label_column
        
    def mad(self, df_, factor = 1):
        if df_.ndim == 1:
            return factor * np.nanmedian(abs(df_ - np.nanmedian(df_)))
        else:
            return factor * np.nanmedian(abs(df_ - np.nanmedian(df_, axis=0)), axis=0)
    
    def add_batch_time(self):
        
        self.data["batch_time"] = (
            self.data.groupby(self.group_column)[self.time_column]
            .transform(lambda x: (x.index - x.index.min()))
            .apply(pd.Timedelta.total_seconds)
        )
        return self.data
    
    # Define a function to update the visibility of the traces based on the frame index
    def update_visibility(self, index, fig, start_index=1):
        for i in range(len(fig.data)):
            if i>start_index:
                fig.data[i].visible = i == index
        
    def get_median_profile(self, df):
        fig = go.Figure()
        colors = px.colors.qualitative.Dark24
        color_map = {0:colors[0], 1: colors[1]}
        for j, (i, group) in enumerate(df.groupby(self.label_column)):
            fig.add_trace(go.Scatter(x = group.index, y = group['median'], line_color = color_map[i], name = 'median profiles of good parts'))
#         fig.update_layout(fig, "time", col, "y_title2")
        return fig
    
    
    def get_update_layout(self, fig, x_title, y_title1):
        fig.update_layout(width=900, height=550, plot_bgcolor='rgb(256,256,256)') # or simply 'white
        fig.update_xaxes(title_text=x_title, showline=True, linewidth=0.2, linecolor='rgb(100,100,100)', ticks="outside")
        fig.update_yaxes(showline=True, linewidth=0.2, linecolor='rgb(100,100,100)', ticks="inside")
        fig.update_layout(xaxis=dict(domain=[0, 1]),
                    yaxis=dict(title= y_title1,titlefont=dict(color="rgb(100,100,100)"),tickfont=dict(color="rgb(100,100,100)")),
#                     yaxis2=dict(title= y_title2,titlefont=dict(color="rgb(100,100,100)"),tickfont=dict(color="rgb(100,100,100)"),
#                         anchor="x",overlaying="y",side="right",showline=True,linewidth=0.2,position=0.95,linecolor='rgb(100,100,100)',ticks="inside"),
#                           legend=dict(yanchor="top",y=0.99,xanchor="left",x=0.1),
                )
        return fig
    
    def get_gif_plot(self, save_address, plot_type = 'lines+markers'):
        self.data[self.label_column] = self.data[self.label_column].astype(int)
        df = self.data.groupby([self.label_column, 'batch_time'])[self.ideal_column].agg([self.mad, 'median']).reset_index(level=0)
        fig = self.get_median_profile(df)

        # title = fig.layout.title.text
        fig = self.get_update_layout(fig, 'time', self.analysis_column)
    
        for id_, batch in self.data.groupby(self.group_column):
            batch = batch.reset_index()
            fig.add_trace(go.Scatter(y=batch[self.analysis_column], 
                                     mode=plot_type,
                                     marker=dict(color="black"),
                                     showlegend=True,
                                     name=f"Batch_ids:{id_}",
                                     visible=False))

        images = []
        for i in range(len(fig.data)):
            self.update_visibility(i, fig)  # Set the visibility for the current frame
            fig_bytes = pio.to_image(fig, format="png")  # Render the frame as a PNG image
            img = Image.open(io.BytesIO(fig_bytes))
            images.append(img)
            
        save_gif = os.path.join(save_address, 'fluctions_gif_plots')
        os.makedirs(save_gif, exist_ok=True)
        imageio.mimsave(os.path.join(save_gif, f"animated_{self.analysis_column}.gif"), images, fps=20) 
            


In [None]:
# file = '../../data/sample_data_for_gif.feather'
# df = pd.read_feather(file)

In [None]:
# gif = GIF_PLOT(data=df, 
#                analysis_column='temperature_metal_temp_furnace_temp_1_pv',
#                ideal_column='temperature_metal_temp_furnace_temp_1_pv',
#                group_column='id_wheel', 
#                time_column='timestamp',
#                label_column='label')
# gif.get_gif_plot('../../data/')