# conseption et realisation d'un dashbord et d'un modele de detection de fraude des données de la Direction des Grandes Entreprise

## Preparation des données

### Importation des modules 

In [38]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

### parametrages des modules 

In [2]:
sns.set_theme()
#pd.set_option('display.max_rows', None)  # Show all rows
pd.set_option('display.max_columns', None)  # Show all columns

### selection, affichage et filtrage des fichiers

In [3]:
TVA = pd.read_excel('VraiTVA.xlsx')
#TVA

In [4]:
ERA = pd.read_excel('VraiERA.xlsx')
#ERA

In [5]:
jointure = pd.merge(ERA, TVA, on='BP', how='inner')
#jointure = jointure[['BP', 'Wilaya', 'Code CNRC', 'Code ONS', 'Chiffre d’affaire (C.A)', 'Total TVA anuelle']]  

In [6]:
jointure.rename(columns={'Chiffre d’affaire (C.A)': 'ChAff'}, inplace=True)
jointure.rename(columns={'Total TVA anuelle': 'Total TVA anunelle'}, inplace=True)
#jointure.head()

### créer la colonne feature et cible


In [7]:
jointure['feature'] =  jointure['ChAff']  - jointure['Total TVA anunelle'] 

In [8]:
def estime(row):
    if row['ChAff'] == 0 :
        return 'sus'
    elif abs(row['feature']) > 0.2 * row['ChAff']:
        return 'fraude'
    else:
        return 'bon'

jointure['cible'] = jointure.apply(estime, axis=1)

In [9]:
jointure.head()

Unnamed: 0,BP,Wilaya,Code CNRC,Code ONS,ChAff,Total TVA Janvier,Total TVA Février,Total TVA Mars,Total TVA Avril,Total TVA Mai,Total TVA Juin,Total TVA Juillet,Total TVA Août,Total TVA Septembre,Total TVA Octobre,Total TVA Novembre,Total TVA Décembre,Total TVA anunelle,feature,cible
0,2000000147,DIW Alger est,607047,/,0,4381490000.0,4710011000.0,5358351000.0,4811742000.0,4835182000.0,4467203000.0,3692317000.0,3713033000.0,5274394000.0,6056911000.0,4706919000.0,5852010000.0,57859560000.0,-57859560000.0,sus
1,2000000200,DIW Alger est,608001,/,93115698,644181.0,1478702.0,2304256.0,1226741.0,21477610.0,33439220.0,9184847.0,4261077.0,3033041.0,4966790.0,5425418.0,5673810.0,93115700.0,0.0,bon
2,2000007178,DIW Bordj Bou Arréridj,110202,/,155423274,4872747.0,4275042.0,3918504.0,6082352.0,7774514.0,27067300.0,22407960.0,8738202.0,5113770.0,5666626.0,26437940.0,33068320.0,155423300.0,0.0,bon
3,2000009570,DIW Alger est,405105,/,5798168362,423208600.0,396558100.0,411129300.0,628866000.0,313280200.0,410063900.0,333267400.0,561521000.0,713319000.0,495072400.0,596409800.0,527251500.0,5809947000.0,-11778630.0,bon
4,2000011509,DIW Alger est,409001,/,2062602782,160768700.0,293782400.0,332108000.0,123468900.0,105158600.0,233051700.0,130518600.0,172088600.0,159896900.0,143759600.0,74881360.0,132886200.0,2062370000.0,233138.0,bon


## Over sample de la donnée fraude dans le data set

In [10]:

fraud_count = jointure['cible'].value_counts()
print(fraud_count)

#definir la donnée maximal et minimal
minority_cible = fraud_count.idxmin()
majority_cible = fraud_count.idxmax()

# Calcule du nombre de rows à add
oversample_amount = fraud_count[majority_cible] - fraud_count[minority_cible]



cible
bon       669
fraude    371
sus       230
Name: count, dtype: int64


In [11]:
# Filter the minority class
minority_data = jointure[jointure['cible'] == minority_cible]

# Randomly sample from the minority data
oversampled_data = minority_data.sample(n=oversample_amount, replace=True, random_state=42)

# Append the oversampled data to the original DataFrame
jointure_oversampled = pd.concat([jointure, oversampled_data])

# Shuffle the dataset to mix the rows up
jointure_oversampled = jointure_oversampled.sample(frac=1, random_state=42).reset_index(drop=True)


### Verification du nouveau data set

In [12]:
print(jointure_oversampled['cible'].value_counts())

cible
sus       669
bon       669
fraude    371
Name: count, dtype: int64


In [13]:
print(jointure_oversampled.info())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1709 entries, 0 to 1708
Data columns (total 20 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   BP                   1709 non-null   int64  
 1   Wilaya               1709 non-null   object 
 2   Code CNRC            1709 non-null   object 
 3   Code ONS             1709 non-null   object 
 4   ChAff                1709 non-null   int64  
 5   Total TVA Janvier    1709 non-null   float64
 6   Total TVA Février    1709 non-null   float64
 7   Total TVA Mars       1709 non-null   float64
 8   Total TVA Avril      1709 non-null   float64
 9   Total TVA Mai        1709 non-null   float64
 10  Total TVA Juin       1709 non-null   float64
 11  Total TVA Juillet    1709 non-null   float64
 12  Total TVA Août       1709 non-null   float64
 13  Total TVA Septembre  1709 non-null   float64
 14  Total TVA Octobre    1709 non-null   float64
 15  Total TVA Novembre   1709 non-null   f

## Encodage categoriel & Normalisation/Standardisation 

### encodage categoriel des wilaya

In [14]:
wilaya_count = jointure_oversampled['Wilaya'].value_counts()
print(wilaya_count)

Wilaya
DIW Alger centre          729
DIW Alger est             235
DIW Alger ouest           184
DIW Oran Est               80
DIW Ouargla                68
DIW Blida                  58
DIW Constantine            44
DIW Sétif                  34
DIW Boumerdès              30
DIW Béjaïa                 29
DIW Sidi Bel Abbes         23
DIW Annaba                 20
DIW Bordj Bou Arréridj     15
DIW Skikda                 14
DIW Batna                  14
DIW Tlemcen                13
DIW Tizi Ouzou             10
DIW Mila                    9
DIW Tipaza                  9
DIW Bouira                  8
DIW Chlef                   8
DIW Relizane                8
DIW Oran Ouest              7
DIW Mostaganem              7
DIW M'Sila                  7
DIW Jijel                   6
DIW Biskra                  6
Non affecté                 6
DIW Oum el-Bouaghi          5
DIW Médéa                   3
DIW Aïn Témouchent          3
DIW Mascara                 3
DIW Tiaret                  2
DIW

In [15]:
# Clean and standardize the Wilaya names
jointure_oversampled['Standardized_Wilaya'] = jointure_oversampled['Wilaya'].str.replace('DIW ', '')

# Label Encoding
wilaya_mapping = {name: i + 1 for i, name in enumerate(jointure_oversampled['Standardized_Wilaya'].unique())}
jointure_oversampled['Wilaya_encoded'] = jointure_oversampled['Standardized_Wilaya'].map(wilaya_mapping)

# Display the DataFrame to check the new columns
print(jointure_oversampled[['Wilaya', 'Standardized_Wilaya', 'Wilaya_encoded']])

                  Wilaya Standardized_Wilaya  Wilaya_encoded
0     DIW Sidi Bel Abbes      Sidi Bel Abbes               1
1       DIW Alger centre        Alger centre               2
2        DIW Alger ouest         Alger ouest               3
3          DIW Alger est           Alger est               4
4          DIW Alger est           Alger est               4
...                  ...                 ...             ...
1704    DIW Alger centre        Alger centre               2
1705    DIW Alger centre        Alger centre               2
1706     DIW Alger ouest         Alger ouest               3
1707     DIW Alger ouest         Alger ouest               3
1708        DIW Relizane            Relizane               6

[1709 rows x 3 columns]


### encodage categoriel des Codes ONS et CNRC

In [18]:
# Label Encoding for 'code CNRC'
cnrc_unique = pd.unique(jointure_oversampled['Code CNRC'])
code_cnrc_mapping = {code: idx + 1 for idx, code in enumerate(cnrc_unique)}
jointure_oversampled['code CNRC_encoded'] = jointure_oversampled['Code CNRC'].map(code_cnrc_mapping)

# Label Encoding for 'code ONS'
ons_unique = pd.unique(jointure_oversampled['Code ONS'])
code_ons_mapping = {code: idx + 1 for idx, code in enumerate(ons_unique)}
jointure_oversampled['code ONS_encoded'] = jointure_oversampled['Code ONS'].map(code_ons_mapping)

# Display the DataFrame to check the new columns
print(jointure_oversampled)

              BP              Wilaya Code CNRC Code ONS      ChAff  \
0     2000045399  DIW Sidi Bel Abbes         /        /          0   
1     2000045867    DIW Alger centre    102102     4321          0   
2     2000046488     DIW Alger ouest         /        /          0   
3     2000044699       DIW Alger est    103103        /  838526717   
4     2000045234       DIW Alger est         /        /          0   
...          ...                 ...       ...      ...        ...   
1704  2000046633    DIW Alger centre    613125     4321   22988920   
1705  2000046541    DIW Alger centre    613125     4321          0   
1706  2000045829     DIW Alger ouest         /        /  113184009   
1707  2000046552     DIW Alger ouest    613203     4329          0   
1708  2000046628        DIW Relizane    613203     4329          0   

      Total TVA Janvier  Total TVA Février  Total TVA Mars  Total TVA Avril  \
0          1.958626e+09       1.813410e+09    1.867400e+09     2.031599e+09   
1

### encodage categoriel de target ENFIN

In [20]:
target_mapping = {
    'bon': 0,      # Typically, the "normal" class can be zero
    'fraude': 1,   # Positive class can be one, especially if it's a binary classification
    'sus': 2       # Additional class can take the next integer
}

# Apply the mapping to the target column
jointure_oversampled['target_encoded'] = jointure_oversampled['cible'].map(target_mapping)

# Display the DataFrame to check the new column
print(jointure_oversampled[['cible', 'target_encoded']])

       cible  target_encoded
0        sus               2
1        sus               2
2        sus               2
3        bon               0
4        sus               2
...      ...             ...
1704  fraude               1
1705     sus               2
1706  fraude               1
1707     sus               2
1708     sus               2

[1709 rows x 2 columns]


### verification du Data set :SOB: + supression des columns not needed 

In [22]:
jointure_oversampled

Unnamed: 0,BP,Wilaya,Code CNRC,Code ONS,ChAff,Total TVA Janvier,Total TVA Février,Total TVA Mars,Total TVA Avril,Total TVA Mai,Total TVA Juin,Total TVA Juillet,Total TVA Août,Total TVA Septembre,Total TVA Octobre,Total TVA Novembre,Total TVA Décembre,Total TVA anunelle,feature,cible,Standardized_Wilaya,Wilaya_encoded,code CNRC_encoded,code ONS_encoded,target_encoded
0,2000045399,DIW Sidi Bel Abbes,/,/,0,1.958626e+09,1.813410e+09,1.867400e+09,2.031599e+09,2.153525e+09,2.191206e+09,1.305699e+09,1.180612e+09,1.526498e+09,2.681394e+09,1.298003e+09,2.209720e+09,2.221769e+10,-2.221769e+10,sus,Sidi Bel Abbes,1,1,1,2
1,2000045867,DIW Alger centre,102102,4321,0,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,sus,Alger centre,2,2,2,2
2,2000046488,DIW Alger ouest,/,/,0,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,sus,Alger ouest,3,1,1,2
3,2000044699,DIW Alger est,103103,/,838526717,7.900464e+07,6.374161e+07,8.521569e+07,6.320110e+07,4.088590e+07,6.831483e+07,1.348191e+08,7.720119e+07,8.304980e+07,9.323792e+07,6.313974e+07,5.412466e+07,9.059362e+08,-6.740948e+07,bon,Alger est,4,3,1,0
4,2000045234,DIW Alger est,/,/,0,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,sus,Alger est,4,1,1,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1704,2000046633,DIW Alger centre,613125,4321,22988920,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,2.954700e+06,2.954700e+06,2.003422e+07,fraude,Alger centre,2,7,2,1
1705,2000046541,DIW Alger centre,613125,4321,0,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,sus,Alger centre,2,7,2,2
1706,2000045829,DIW Alger ouest,/,/,113184009,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,1.131840e+08,fraude,Alger ouest,3,1,1,1
1707,2000046552,DIW Alger ouest,613203,4329,0,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,sus,Alger ouest,3,5,3,2


In [23]:
print(jointure_oversampled.columns)

Index(['BP', 'Wilaya', 'Code CNRC', 'Code ONS', 'ChAff', 'Total TVA Janvier',
       'Total TVA Février', 'Total TVA Mars', 'Total TVA Avril',
       'Total TVA Mai', 'Total TVA Juin', 'Total TVA Juillet',
       'Total TVA Août', 'Total TVA Septembre', 'Total TVA Octobre',
       'Total TVA Novembre', 'Total TVA Décembre', 'Total TVA anunelle',
       'feature', 'cible', 'Standardized_Wilaya', 'Wilaya_encoded',
       'code CNRC_encoded', 'code ONS_encoded', 'target_encoded'],
      dtype='object')


In [24]:
jointure_oversampled = jointure_oversampled.drop(columns = ['Wilaya', 'Code CNRC', 'Code ONS','cible'])

In [26]:
jointure_oversampled = jointure_oversampled.drop(columns = ['Standardized_Wilaya'])

In [27]:
jointure_oversampled

Unnamed: 0,BP,ChAff,Total TVA Janvier,Total TVA Février,Total TVA Mars,Total TVA Avril,Total TVA Mai,Total TVA Juin,Total TVA Juillet,Total TVA Août,Total TVA Septembre,Total TVA Octobre,Total TVA Novembre,Total TVA Décembre,Total TVA anunelle,feature,Wilaya_encoded,code CNRC_encoded,code ONS_encoded,target_encoded
0,2000045399,0,1.958626e+09,1.813410e+09,1.867400e+09,2.031599e+09,2.153525e+09,2.191206e+09,1.305699e+09,1.180612e+09,1.526498e+09,2.681394e+09,1.298003e+09,2.209720e+09,2.221769e+10,-2.221769e+10,1,1,1,2
1,2000045867,0,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,2,2,2,2
2,2000046488,0,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,3,1,1,2
3,2000044699,838526717,7.900464e+07,6.374161e+07,8.521569e+07,6.320110e+07,4.088590e+07,6.831483e+07,1.348191e+08,7.720119e+07,8.304980e+07,9.323792e+07,6.313974e+07,5.412466e+07,9.059362e+08,-6.740948e+07,4,3,1,0
4,2000045234,0,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,4,1,1,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1704,2000046633,22988920,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,2.954700e+06,2.954700e+06,2.003422e+07,2,7,2,1
1705,2000046541,0,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,2,7,2,2
1706,2000045829,113184009,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,1.131840e+08,3,1,1,1
1707,2000046552,0,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,3,5,3,2


### normalisation et standardisation 
 vu que nous allons faire a random forest and/or a regression tree, il serait preferable de standardiser les données 

In [31]:
jointure_oversampled.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1709 entries, 0 to 1708
Data columns (total 20 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   BP                   1709 non-null   int64  
 1   ChAff                1709 non-null   int64  
 2   Total TVA Janvier    1709 non-null   float64
 3   Total TVA Février    1709 non-null   float64
 4   Total TVA Mars       1709 non-null   float64
 5   Total TVA Avril      1709 non-null   float64
 6   Total TVA Mai        1709 non-null   float64
 7   Total TVA Juin       1709 non-null   float64
 8   Total TVA Juillet    1709 non-null   float64
 9   Total TVA Août       1709 non-null   float64
 10  Total TVA Septembre  1709 non-null   float64
 11  Total TVA Octobre    1709 non-null   float64
 12  Total TVA Novembre   1709 non-null   float64
 13  Total TVA Décembre   1709 non-null   float64
 14  Total TVA anunelle   1709 non-null   float64
 15  feature              1709 non-null   f

In [32]:
columns_to_standardize = [
    'ChAff',
    'Total TVA Janvier', 'Total TVA Février', 'Total TVA Mars', 'Total TVA Avril',
    'Total TVA Mai', 'Total TVA Juin', 'Total TVA Juillet', 'Total TVA Août',
    'Total TVA Septembre', 'Total TVA Octobre', 'Total TVA Novembre', 'Total TVA Décembre',
    'Total TVA anunelle'
]

# Apply standardization
for column in columns_to_standardize:
    jointure_oversampled[column] = (jointure_oversampled[column] - jointure_oversampled[column].mean()) / jointure_oversampled[column].std()

# Now your columns are standardized
print(jointure_oversampled[columns_to_standardize].head())

      ChAff  Total TVA Janvier  Total TVA Février  Total TVA Mars  \
0 -0.055971           0.574705           0.645778        0.581130   
1 -0.055971          -0.168711          -0.183681       -0.193220   
2 -0.055971          -0.168711          -0.183681       -0.193220   
3 -0.052058          -0.138724          -0.154525       -0.157884   
4 -0.055971          -0.168711          -0.183681       -0.193220   

   Total TVA Avril  Total TVA Mai  Total TVA Juin  Total TVA Juillet  \
0         0.655551       0.652536        0.719614           0.390718   
1        -0.174768      -0.177589       -0.186745          -0.175976   
2        -0.174768      -0.177589       -0.186745          -0.175976   
3        -0.148938      -0.161828       -0.158487          -0.117463   
4        -0.174768      -0.177589       -0.186745          -0.175976   

   Total TVA Août  Total TVA Septembre  Total TVA Octobre  Total TVA Novembre  \
0        0.309813             0.455276           0.892869            0.

In [33]:
jointure_oversampled.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1709 entries, 0 to 1708
Data columns (total 20 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   BP                   1709 non-null   int64  
 1   ChAff                1709 non-null   float64
 2   Total TVA Janvier    1709 non-null   float64
 3   Total TVA Février    1709 non-null   float64
 4   Total TVA Mars       1709 non-null   float64
 5   Total TVA Avril      1709 non-null   float64
 6   Total TVA Mai        1709 non-null   float64
 7   Total TVA Juin       1709 non-null   float64
 8   Total TVA Juillet    1709 non-null   float64
 9   Total TVA Août       1709 non-null   float64
 10  Total TVA Septembre  1709 non-null   float64
 11  Total TVA Octobre    1709 non-null   float64
 12  Total TVA Novembre   1709 non-null   float64
 13  Total TVA Décembre   1709 non-null   float64
 14  Total TVA anunelle   1709 non-null   float64
 15  feature              1709 non-null   f

In [47]:
X = jointure_oversampled.drop(columns=['target_encoded'])
y = jointure_oversampled['target_encoded']

In [48]:
import numpy as np

class RegressionTreeNode:
    def __init__(self, feature_index=None, threshold=None, left=None, right=None, value=None):
        self.feature_index = feature_index
        self.threshold = threshold
        self.left = left
        self.right = right
        self.value = value

    def is_leaf_node(self):
        return self.value is not None

class RegressionTree:
    def __init__(self, min_samples_split=10, max_depth=5):
        self.min_samples_split = min_samples_split
        self.max_depth = max_depth
        self.root = None

    def fit(self, X, y):
        X = np.array(X)
        y = np.array(y)
        self.root = self._grow_tree(X, y)

    def _grow_tree(self, X, y, depth=0):
        num_samples, num_features = X.shape
        if depth >= self.max_depth or num_samples < self.min_samples_split or np.unique(y).size == 1:
            return RegressionTreeNode(value=np.mean(y))

        best_feature, best_threshold = self._best_split(X, y)
        if best_feature is None:
            return RegressionTreeNode(value=np.mean(y))

        left_idxs, right_idxs = self._split(X[:, best_feature], best_threshold)
        if len(left_idxs) == 0 or len(right_idxs) == 0:
            return RegressionTreeNode(value=np.mean(y))

        left = self._grow_tree(X[left_idxs, :], y[left_idxs], depth + 1)
        right = self._grow_tree(X[right_idxs, :], y[right_idxs], depth + 1)
        return RegressionTreeNode(feature_index=best_feature, threshold=best_threshold, left=left, right=right)

    def _best_split(self, X, y):
        min_mse = np.inf
        best_feature, best_threshold = None, None
        num_samples, num_features = X.shape

        for feature_index in range(num_features):
            thresholds = np.unique(X[:, feature_index])
            for threshold in thresholds:
                mse = self._calculate_mse(X[:, feature_index], y, threshold)
                if mse < min_mse:
                    min_mse = mse
                    best_feature = feature_index
                    best_threshold = threshold

        return best_feature, best_threshold

    def _calculate_mse(self, X_feature, y, threshold):
        left_mask = X_feature < threshold
        right_mask = X_feature >= threshold
        if np.sum(left_mask) == 0 or np.sum(right_mask) == 0:
            return np.inf
        left_mse = np.var(y[left_mask]) * np.sum(left_mask)
        right_mse = np.var(y[right_mask]) * np.sum(right_mask)
        mse = (left_mse + right_mse) / len(y)
        return mse
        
    def _split(self, feature, threshold):
        left = np.where(feature < threshold)[0]
        right = np.where(feature >= threshold)[0]
        if len(left) == 0 or len(right) == 0:
            return [], []  # Prevent creating nodes with no data
        return left, right

    def predict(self, X):
        return np.array([self._traverse_tree(x, self.root) for x in np.array(X)])

    def _traverse_tree(self, x, node):
        if node.is_leaf_node():
            return node.value
        if x[node.feature_index] < node.threshold:
            return self._traverse_tree(x, node.left)
        return self._traverse_tree(x, node.right)


In [50]:

# Convert DataFrame to numpy array for easier manipulation
X = jointure_oversampled.drop(columns=['target_encoded']).values
y = jointure_oversampled['target_encoded'].values

# Splitting the dataset into training and testing sets
np.random.seed(42)  # for reproducibility
indices = np.arange(X.shape[0])
np.random.shuffle(indices)

# Calculate the split index
split_idx = int(0.8 * len(X))

# Create training and testing sets
X_train, X_test = X[indices[:split_idx]], X[indices[split_idx:]]
y_train, y_test = y[indices[:split_idx]], y[indices[split_idx:]]

# Now create the Regression Tree and fit it on the training data
tree = RegressionTree(max_depth=5, min_samples_split=10)
tree.fit(X_train, y_train)

# Making predictions on the test set
predictions = tree.predict(X_test)

# Optionally, calculate and print the mean squared error on the test set
mse = np.mean((predictions - y_test)**2)
print(f"Mean Squared Error on Test Set: {mse:.2f}")


Mean Squared Error on Test Set: 0.04


In [51]:
from sklearn.metrics import mean_squared_error, r2_score

def regression_results(y_true, y_pred):
    mse = mean_squared_error(y_true, y_pred)
    rmse = np.sqrt(mse)
    r2 = r2_score(y_true, y_pred)
    
    print("Regression Model Evaluation")
    print("----------------------------")
    print(f"Mean Squared Error (MSE): {mse:.2f}")
    print(f"Root Mean Squared Error (RMSE): {rmse:.2f}")
    print(f"R-squared: {r2:.2f}")

# Assuming 'y_test' and 'predictions' are already defined as your test targets and model predictions, respectively
regression_results(y_test, predictions)


Regression Model Evaluation
----------------------------
Mean Squared Error (MSE): 0.04
Root Mean Squared Error (RMSE): 0.21
R-squared: 0.94
