<a href="https://colab.research.google.com/github/prince-manish/C/blob/main/chi_merge.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import pandas as pd
from collections import Counter
import numpy as np

In [2]:
iris = pd.read_csv('http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data', header=None)

In [3]:
iris.columns = ['sepal_l', 'sepal_w', 'petal_l', 'petal_w', 'type']


In [4]:
iris.head()

Unnamed: 0,sepal_l,sepal_w,petal_l,petal_w,type
0,5.1,3.5,1.4,0.2,Iris-setosa
1,4.9,3.0,1.4,0.2,Iris-setosa
2,4.7,3.2,1.3,0.2,Iris-setosa
3,4.6,3.1,1.5,0.2,Iris-setosa
4,5.0,3.6,1.4,0.2,Iris-setosa


In [5]:
def chimerge(data, attr, label, max_intervals):
    distinct_vals = sorted(set(data[attr])) # Sort the distinct values
    labels = sorted(set(data[label])) # Get all possible labels
    empty_count = {l: 0 for l in labels} # A helper function for padding the Counter()
    intervals = [[distinct_vals[i], distinct_vals[i]] for i in range(len(distinct_vals))] # Initialize the intervals for each attribute
    while len(intervals) > max_intervals: # While loop
        chi = []
        for i in range(len(intervals)-1):
            # Calculate the Chi2 value
            obs0 = data[data[attr].between(intervals[i][0], intervals[i][1])]
            obs1 = data[data[attr].between(intervals[i+1][0], intervals[i+1][1])]
            total = len(obs0) + len(obs1)
            count_0 = np.array([v for i, v in {**empty_count, **Counter(obs0[label])}.items()])
            count_1 = np.array([v for i, v in {**empty_count, **Counter(obs1[label])}.items()])
            count_total = count_0 + count_1
            expected_0 = count_total*sum(count_0)/total
            expected_1 = count_total*sum(count_1)/total
            chi_ = (count_0 - expected_0)**2/expected_0 + (count_1 - expected_1)**2/expected_1
            chi_ = np.nan_to_num(chi_) # Deal with the zero counts
            chi.append(sum(chi_)) # Finally do the summation for Chi2
        min_chi = min(chi) # Find the minimal Chi2 for current iteration
        for i, v in enumerate(chi):
            if v == min_chi:
                min_chi_index = i # Find the index of the interval to be merged
                break
        new_intervals = [] # Prepare for the merged new data array
        skip = False
        done = False
        for i in range(len(intervals)):
            if skip:
                skip = False
                continue
            if i == min_chi_index and not done: # Merge the intervals
                t = intervals[i] + intervals[i+1]
                new_intervals.append([min(t), max(t)])
                skip = True
                done = True
            else:
                new_intervals.append(intervals[i])
        intervals = new_intervals
    for i in intervals:
        print('[', i[0], ',', i[1], ']', sep='')

In [6]:
for attr in ['sepal_l', 'sepal_w', 'petal_l', 'petal_w']:
    print('Interval for', attr)
    chimerge(data=iris, attr=attr, label='type', max_intervals=6)


Interval for sepal_l




[4.3,4.8]
[4.9,4.9]
[5.0,5.4]
[5.5,5.7]
[5.8,7.0]
[7.1,7.9]
Interval for sepal_w
[2.0,2.2]
[2.3,2.4]
[2.5,2.8]
[2.9,2.9]
[3.0,3.3]
[3.4,4.4]
Interval for petal_l
[1.0,1.9]
[3.0,4.4]
[4.5,4.7]
[4.8,4.9]
[5.0,5.1]
[5.2,6.9]
Interval for petal_w
[0.1,0.6]
[1.0,1.3]
[1.4,1.6]
[1.7,1.7]
[1.8,1.8]
[1.9,2.5]
