In [172]:
import pandas as pd
import numpy as np
from collections import Counter
from scipy.stats import chisquare

In [173]:
dataset = pd.read_csv('iris.data')
dataset

Unnamed: 0,5.1,3.5,1.4,0.2,Iris-setosa
0,4.9,3.0,1.4,0.2,Iris-setosa
1,4.7,3.2,1.3,0.2,Iris-setosa
2,4.6,3.1,1.5,0.2,Iris-setosa
3,5.0,3.6,1.4,0.2,Iris-setosa
4,5.4,3.9,1.7,0.4,Iris-setosa
...,...,...,...,...,...
144,6.7,3.0,5.2,2.3,Iris-virginica
145,6.3,2.5,5.0,1.9,Iris-virginica
146,6.5,3.0,5.2,2.0,Iris-virginica
147,6.2,3.4,5.4,2.3,Iris-virginica


In [174]:
dataset.columns = ['sepal_l', 'sepal_w', 'petal_l', 'petal_w', 'type']

In [175]:
# Create the contingency tables based on the column and data.
def contingencyTable(obs0, obs1):
    #Count the values for each label for given attribute
    
    count_0 = np.array([(obs0['type']=='Iris-setosa').sum(),     (obs0['type']=='Iris-versicolor').sum(),
             (obs0['type']=='Iris-virginica').sum()]);

    count_1 = np.array([(obs1['type']=='Iris-setosa').sum(),     (obs1['type']=='Iris-versicolor').sum(),
             (obs1['type']=='Iris-virginica').sum()]);

    #Calculate expected values
    total = len(obs0) + len(obs1)
    count_total = count_0 + count_1
    expected_0 = count_total*sum(count_0)/total
    expected_1 = count_total*sum(count_1)/total
    
    return count_0, count_1, expected_0, expected_1

# Calculate the chi square values based on the observed and expected values
def calculatechisquare(count_0, count_1, expected_0, expected_1):
    chi_ = (count_0 - expected_0)**2/expected_0 + (count_1 - expected_1)**2/expected_1
    chi_ = np.nan_to_num(chi_) # Deal with the zero counts
    return sum(chi_)

# Merge the rows with the lowest chi square values
def mergerows(intervals, min_chi_index):
    # Prepare for the merged array
    new_intervals = [] 
    skip = False
    done = False

    #Merge the intervals found at min_chi_index with next interval
    for i in range(len(intervals)):
        if skip:
            skip = False
            continue
        if i == min_chi_index and not done: # Merge the intervals
            t = intervals[i] + intervals[i+1]
            new_intervals.append([min(t), max(t)])
            skip = True
            done = True
        else:
            new_intervals.append(intervals[i])
    return new_intervals

# Print the values of intervals and split points
def printintervalsandsplitpoints(intervals, attr):
    print('\nIntervals for',attr)
    print(intervals)

    print('\nSplit points for',attr)
    for i in intervals:
        print(i[0])
    print('***********************')
    

In [176]:
# Chimerge function

def chimerge(data, attr, label, max_intervals):
    # Sort the values in the column
    distinct_vals = sorted(set(data[attr])) 
    # define initial intervals
    intervals = [[distinct_vals[i], distinct_vals[i]] for i in range(len(distinct_vals))] 
    
    # Keep applying chimerge process as long as we reach the max_intervals condition
    while len(intervals) > 6:

        #Array to hold the chi values for this iteration
        chi = []

        #Calculate chi values for each consecutive intervals in this iteration
        for i in range(len(intervals)-1):

            # Get the observations of the attribute that falls between interval 
            obs0 = data[data[attr].between(intervals[i][0], intervals[i][1])]
            obs1 = data[data[attr].between(intervals[i+1][0], intervals[i+1][1])]
            
            count_0, count_1, expected_0, expected_1 = contingencyTable(obs0, obs1)
            chi_ = calculatechisquare(count_0, count_1, expected_0, expected_1)

            # Finally create a list of chi square values
            chi.append(chi_) 

        #Find the minimum chi for the current iteration
        min_chi = min(chi)  

        #Find the first index with minumum chi
        for i, v in enumerate(chi):
            if v == min_chi:
                min_chi_index = i # Find the index of the interval to be merged
                break
                
        new_intervals = mergerows(intervals, min_chi_index)

        #Start the chimerge with new set of merged intervals
        intervals = new_intervals
    return intervals

In [177]:
label='type'

for attr in ['sepal_l','sepal_w', 'petal_l', 'petal_w']:
    inter1 = chimerge(dataset, attr, label, 6)
    printintervalsandsplitpoints(inter1, attr)


Intervals for sepal_l
[[4.3, 4.8], [4.9, 5.4], [5.5, 5.7], [5.8, 6.2], [6.3, 7.0], [7.1, 7.9]]

Split points for sepal_l
4.3
4.9
5.5
5.8
6.3
7.1
***********************

Intervals for sepal_w
[[2.0, 2.2], [2.3, 2.4], [2.5, 2.8], [2.9, 2.9], [3.0, 3.3], [3.4, 4.4]]

Split points for sepal_w
2.0
2.3
2.5
2.9
3.0
3.4
***********************

Intervals for petal_l
[[1.0, 1.9], [3.0, 4.4], [4.5, 4.7], [4.8, 4.9], [5.0, 5.1], [5.2, 6.9]]

Split points for petal_l
1.0
3.0
4.5
4.8
5.0
5.2
***********************

Intervals for petal_w
[[0.1, 0.6], [1.0, 1.3], [1.4, 1.6], [1.7, 1.7], [1.8, 1.8], [1.9, 2.5]]

Split points for petal_w
0.1
1.0
1.4
1.7
1.8
1.9
***********************
