In [27]:
import pandas as pd
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go

class DataDistributionVisualizer:
    def __init__(self, data):
        self.data = data
        self.column_types = {}
        
        # Classify columns while maintaining original order
        for col in self.data.columns:
            if self.data[col].dtype in ['int64', 'float64']:
                unique_vals = set(self.data[col].dropna().unique())
                if unique_vals == {0, 1} or unique_vals == {0.0, 1.0}:
                    self.column_types[col] = 'binary'
                else:
                    self.column_types[col] = 'numerical'
            else:
                self.column_types[col] = 'categorical'

    def distribution_grid(self):
        num_cols = len(self.data.columns)
        
        # Create specs array - one row, many columns
        specs = [[{"type": "xy" if self.column_types[col] == 'numerical' else "domain"} 
                 for col in self.data.columns]]
        
        fig = make_subplots(
            rows=1,
            cols=num_cols,
            subplot_titles=self.data.columns,
            specs=specs
        )

        # Create visualizations in original column order
        for i, column in enumerate(self.data.columns):
            col = i + 1  # plotly uses 1-based indexing
            
            if self.column_types[column] == 'numerical':
                # Histogram for numerical columns
                histogram_fig = px.histogram(self.data, x=column, 
                                          color_discrete_sequence=["skyblue"],  
                                          opacity=0.7)
                for trace in histogram_fig.data:
                    fig.add_trace(trace, row=1, col=col)
                fig.update_yaxes(title_text='Frequency', row=1, col=col)
            
            else:  # binary or categorical
                category_counts = self.data[column].value_counts()
                
                if self.column_types[column] == 'binary':
                    # For binary columns, create custom labels
                    labels = ['False (0)', 'True (1)'] if len(category_counts) == 2 else \
                            ['True (1)' if 1 in category_counts.index else 'False (0)']
                    values = [category_counts.get(0, 0), category_counts.get(1, 0)] if len(category_counts) == 2 else \
                            [category_counts.iloc[0]]
                else:
                    # For regular categorical columns
                    labels = category_counts.index
                    values = category_counts.values

                if len(category_counts) <= 20:
                    try:
                        fig.add_trace(go.Pie(
                            labels=labels,
                            values=values,
                            hole=0.3,
                            textinfo='label+percent',
                            insidetextorientation='radial'
                        ), row=1, col=col)
                    except ValueError as e:
                        fig.add_annotation(
                            xref='paper', yref='paper',
                            x=(col - 0.5) / num_cols,
                            y=0.5,
                            text=f'{column}\nUnique Values: {len(category_counts)}',
                            showarrow=False,
                            font=dict(size=12)
                        )
                else:
                    fig.add_annotation(
                        xref='paper', yref='paper',
                        x=(col - 0.5) / num_cols,
                        y=0.5,
                        text=f'{column}\nUnique Values: {len(category_counts)}',
                        showarrow=False,
                        font=dict(size=12)
                    )

        # Update layout
        fig.update_layout(
            title_text='Data Distribution',
            title_x=0.5,
            height=400,  # Fixed height since we only have one row
            width=300 * num_cols,  # Scale width based on number of columns
            showlegend=False
        )
        
        return fig

# Example usage
df = pd.read_csv('models/titanic.csv')  # Adjust the path to your dataset
visualizer = DataDistributionVisualizer(df)
fig = visualizer.distribution_grid()
fig.show()