In [14]:
#import necessary libraries
import pandas as pd
from collections import Counter
import numpy as np


In [15]:
#Read the iris dataset into pandas dataframe with headers off as this data file has no headers
iris = pd.read_csv('http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data', header=None)

#Assign names to features
iris.columns = ['sepal_l', 'sepal_w', 'petal_l', 'petal_w', 'type']

#Ensure data is loaded properly
iris.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 150 entries, 0 to 149
Data columns (total 5 columns):
 #   Column   Non-Null Count  Dtype  
---  ------   --------------  -----  
 0   sepal_l  150 non-null    float64
 1   sepal_w  150 non-null    float64
 2   petal_l  150 non-null    float64
 3   petal_w  150 non-null    float64
 4   type     150 non-null    object 
dtypes: float64(4), object(1)
memory usage: 6.0+ KB


In [16]:
#Actual chimerge algorithm
def chimerge(data, attr, label, max_intervals):
    #In case of division by zero, ignore the error
    np.seterr(divide='ignore', invalid='ignore')
    
    #Get the distinct sorted values of the given feature/attribute
    distinct_vals = sorted(set(data[attr]))  
    
    #Get the distinct sorted labels
    labels = sorted(set(data[label]))  
    
    #A dictionary of counts for each label
    empty_count = {l: 0 for l in labels}  
    
    # Initialize the intervals for the given attribute; to start with each row is taken as an interval
    intervals = [[distinct_vals[i], distinct_vals[i]] for i in range(len(distinct_vals))] 
 
    # Keep applying chimerge process as long as we reach the max_intervals condition
    while len(intervals) > max_intervals:
        
        #Array to hold the chi values for this iteration
        chi = []
        
        #Calculate chi values for each consecutive intervals in this iteration
        for i in range(len(intervals)-1):
            
            # Indexes of the attribute that falls between given interval 
            obs0 = data[data[attr].between(intervals[i][0], intervals[i][1])]
            obs1 = data[data[attr].between(intervals[i+1][0], intervals[i+1][1])]
            total = len(obs0) + len(obs1)
            
            #Count the values for each label for given attribute
            count_0 = np.array([v for i, v in {**empty_count, **Counter(obs0[label])}.items()])
            count_1 = np.array([v for i, v in {**empty_count, **Counter(obs1[label])}.items()])
            count_total = count_0 + count_1
            
            #Caclculate expected values
            expected_0 = count_total*sum(count_0)/total
            expected_1 = count_total*sum(count_1)/total
  
            # Calculate the Chi2 value
            chi_ = (count_0 - expected_0)**2/expected_0 + (count_1 - expected_1)**2/expected_1
            chi_ = np.nan_to_num(chi_) # Deal with the zero counts
            
            # Finally do the summation for Chi2 and append it to list of chi values
            chi.append(sum(chi_)) 
            
        
        #Find the minimum chi for the current iteration
        min_chi = min(chi)  
 
        #Find the first index with minumum chi
        for i, v in enumerate(chi):
            if v == min_chi:
                min_chi_index = i # Find the index of the interval to be merged
                break
                
        
        # Prepare for the merged array
        new_intervals = [] 
        skip = False
        done = False
        
        #Merge the intervals found at min_chi_index with next interval
        for i in range(len(intervals)):
            if skip:
                skip = False
                continue
            if i == min_chi_index and not done: # Merge the intervals
                t = intervals[i] + intervals[i+1]
                new_intervals.append([min(t), max(t)])
                skip = True
                done = True
            else:
                new_intervals.append(intervals[i])
        
        #Start the chimerge with new set of merged intervals
        intervals = new_intervals
    
    #Print split points for the given attribute
    print('\nSplit points for',attr)
    for i in intervals:
        print(i[0])
        
    #print intervals for the given attribute
    print('Intervals for', attr)
    for i in intervals:
        print('[', i[0], ',', i[1], ']', sep='')
        

In [17]:
#Perform chimerge on each feature/attribute with stopping criteria as maximum 6 intervals
for attr in ['sepal_l','sepal_w', 'petal_l', 'petal_w']:
    chimerge(data=iris, attr=attr, label='type', max_intervals=6)


Split points for sepal_l
4.3
4.9
5.0
5.5
5.8
7.1
Intervals for sepal_l
[4.3,4.8]
[4.9,4.9]
[5.0,5.4]
[5.5,5.7]
[5.8,7.0]
[7.1,7.9]

Split points for sepal_w
2.0
2.3
2.5
2.9
3.0
3.4
Intervals for sepal_w
[2.0,2.2]
[2.3,2.4]
[2.5,2.8]
[2.9,2.9]
[3.0,3.3]
[3.4,4.4]

Split points for petal_l
1.0
3.0
4.5
4.8
5.0
5.2
Intervals for petal_l
[1.0,1.9]
[3.0,4.4]
[4.5,4.7]
[4.8,4.9]
[5.0,5.1]
[5.2,6.9]

Split points for petal_w
0.1
1.0
1.4
1.7
1.8
1.9
Intervals for petal_w
[0.1,0.6]
[1.0,1.3]
[1.4,1.6]
[1.7,1.7]
[1.8,1.8]
[1.9,2.5]
