In [1]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

COLS_TO_DROP = ['Index', 'First Name', 'Last Name', 'Birthday', 'Defense Against the Dark Arts', 'Arithmancy', 'Care of Magical Creatures', 'Birthday Month', 'Birthday Year', 'Birthday Weekday']
NUMERICAL_COLS = ['Astronomy', 'Herbology', 'Divination', 'Muggle Studies', 'Ancient Runes', 'Charms', 'Potions', 'Transfiguration', 'History of Magic', 'Flying']
CATEGORICAL_COLS = ['Best Hand']
LEARNING_RATE = 0.02
TARGET_CLASSES = ['Slytherin', 'Hufflepuff', 'Gryffindor', 'Ravenclaw']
EPOCHS = 1000

In [2]:
data = pd.read_csv("datasets/dataset_train.csv")
data['Birthday'] = pd.to_datetime(data['Birthday'])
data['Birthday Weekday'] = data['Birthday'].dt.dayofweek
data['Birthday Year'] = data['Birthday'].dt.year
data['Birthday Month'] = data['Birthday'].dt.month
data = data.drop(columns=COLS_TO_DROP)
data

Unnamed: 0,Hogwarts House,Best Hand,Astronomy,Herbology,Divination,Muggle Studies,Ancient Runes,History of Magic,Transfiguration,Potions,Charms,Flying
0,Ravenclaw,Left,-487.886086,5.727180,4.722,272.035831,532.484226,5.231058,1039.788281,3.790369,-232.79405,-26.89
1,Slytherin,Right,-552.060507,-5.987446,-5.612,-487.340557,367.760303,4.107170,1058.944592,7.248742,-252.18425,-113.45
2,Ravenclaw,Left,-366.076117,7.725017,6.140,664.893521,602.585284,3.555579,1088.088348,8.728531,-227.34265,30.42
3,Gryffindor,Left,697.742809,-6.497214,4.026,-537.001128,523.982133,-4.809637,920.391449,0.821911,-256.84675,200.64
4,Gryffindor,Left,436.775204,-7.820623,2.236,-444.262537,599.324514,-3.444377,937.434724,4.311066,-256.38730,157.98
...,...,...,...,...,...,...,...,...,...,...,...,...
1595,Gryffindor,Right,354.280086,-4.541837,5.702,-497.235066,618.220213,-5.231721,964.219853,3.389086,-250.39401,185.83
1596,Slytherin,Left,367.531174,6.061064,1.757,-643.271092,445.827565,2.238112,1056.147366,5.825263,-246.42719,44.80
1597,Gryffindor,Right,544.018925,-3.203269,6.065,-385.150457,635.211486,-5.984257,953.866685,1.709808,-251.63679,198.47
1598,Hufflepuff,Left,453.676219,3.442831,6.738,-831.741123,383.444937,3.813111,1087.949205,3.904100,-246.19072,-76.81


In [3]:
def ft_train_test_split(data, test_size=0.25, stratify_col=None, random_state=None):
    if not stratify_col:
        data_train = data.sample(frac=1-test_size, random_state=random_state)
    else:
        groups = data.groupby(stratify_col).groups
        data_train = pd.DataFrame()
        for group in groups.values():
            group_sample = data.iloc[group].sample(frac=1-test_size, random_state=random_state)
            data_train = pd.concat([data_train, group_sample])
    data_test = data.iloc[data.index.difference(data_train.index)]
    return (data_train, data_test)

In [4]:
data_train, data_test = ft_train_test_split(data, stratify_col='Hogwarts House')
display(data_train)
display(data_test)

Unnamed: 0,Hogwarts House,Best Hand,Astronomy,Herbology,Divination,Muggle Studies,Ancient Runes,History of Magic,Transfiguration,Potions,Charms,Flying
263,Gryffindor,Right,237.510101,-4.884012,5.614,-697.136575,607.395493,-6.969016,940.452086,0.806860,-251.18317,220.86
478,Gryffindor,Right,278.462545,-6.279350,5.109,-296.834007,563.656879,-5.136535,945.247235,-0.427582,-252.16262,142.10
297,Gryffindor,Left,447.409851,-5.586751,4.649,-596.427045,560.805587,-5.753091,939.794031,1.309745,-253.63740,202.40
28,Gryffindor,Left,236.888879,-5.077751,5.517,-544.758909,604.096378,-4.955320,977.896778,4.377129,-249.38786,173.06
1042,Gryffindor,Left,420.547643,-6.949362,4.202,-582.448533,569.204830,-4.035152,940.895902,3.690323,-253.59301,185.49
...,...,...,...,...,...,...,...,...,...,...,...,...
897,Slytherin,Right,-376.775355,-1.942163,-5.949,-504.131371,432.158807,2.606439,1042.283637,8.870901,-250.31200,6.62
1470,Slytherin,Left,-344.639322,-6.710850,-5.191,-273.396055,447.028599,3.414097,1048.417290,9.960919,-252.41067,-59.42
72,Slytherin,Right,-508.019034,-4.079019,-5.725,,407.140184,4.735857,1052.071559,10.657080,-249.20965,-39.65
13,Slytherin,Left,-544.192049,-7.308856,-6.180,-319.946875,391.652916,2.914732,1082.581409,10.948791,-251.12516,-80.42


Unnamed: 0,Hogwarts House,Best Hand,Astronomy,Herbology,Divination,Muggle Studies,Ancient Runes,History of Magic,Transfiguration,Potions,Charms,Flying
1,Slytherin,Right,-552.060507,-5.987446,-5.612,-487.340557,367.760303,4.107170,1058.944592,7.248742,-252.18425,-113.45
2,Ravenclaw,Left,-366.076117,7.725017,6.140,664.893521,602.585284,3.555579,1088.088348,8.728531,-227.34265,30.42
6,Gryffindor,Left,628.046051,-4.861976,,-926.892512,583.742442,-7.322486,923.539573,1.646666,-257.83447,261.55
14,Ravenclaw,Right,-197.527318,2.742444,6.603,527.356323,605.590600,5.480097,1063.522361,9.407484,-232.65964,-19.94
15,Ravenclaw,Left,-447.649812,4.046727,4.949,810.154483,615.531088,3.653495,1075.853850,9.622899,-229.38229,17.00
...,...,...,...,...,...,...,...,...,...,...,...,...
1589,Hufflepuff,Left,708.202206,4.850931,5.660,-504.777873,417.520448,4.568628,1046.345436,6.272776,-245.03263,28.70
1593,Ravenclaw,Left,-426.175401,5.681107,6.205,473.879478,647.238809,6.254227,1046.815627,7.206156,-230.80139,-29.82
1594,Hufflepuff,Left,599.901612,5.479485,5.543,-525.883264,467.950418,5.933211,1034.394428,9.344054,-241.96940,71.23
1595,Gryffindor,Right,354.280086,-4.541837,5.702,-497.235066,618.220213,-5.231721,964.219853,3.389086,-250.39401,185.83


In [5]:
X_train = data_train.drop(columns=['Hogwarts House'])
data_train['Slytherin'] = (data_train['Hogwarts House'] == 'Slytherin').astype(int)
data_train['Hufflepuff'] = (data_train['Hogwarts House'] == 'Hufflepuff').astype(int)
data_train['Gryffindor'] = (data_train['Hogwarts House'] == 'Gryffindor').astype(int)
data_train['Ravenclaw'] = (data_train['Hogwarts House'] == 'Ravenclaw').astype(int)
Y_train = data_train[TARGET_CLASSES]
display(X_train)
display(Y_train)

Unnamed: 0,Best Hand,Astronomy,Herbology,Divination,Muggle Studies,Ancient Runes,History of Magic,Transfiguration,Potions,Charms,Flying
263,Right,237.510101,-4.884012,5.614,-697.136575,607.395493,-6.969016,940.452086,0.806860,-251.18317,220.86
478,Right,278.462545,-6.279350,5.109,-296.834007,563.656879,-5.136535,945.247235,-0.427582,-252.16262,142.10
297,Left,447.409851,-5.586751,4.649,-596.427045,560.805587,-5.753091,939.794031,1.309745,-253.63740,202.40
28,Left,236.888879,-5.077751,5.517,-544.758909,604.096378,-4.955320,977.896778,4.377129,-249.38786,173.06
1042,Left,420.547643,-6.949362,4.202,-582.448533,569.204830,-4.035152,940.895902,3.690323,-253.59301,185.49
...,...,...,...,...,...,...,...,...,...,...,...
897,Right,-376.775355,-1.942163,-5.949,-504.131371,432.158807,2.606439,1042.283637,8.870901,-250.31200,6.62
1470,Left,-344.639322,-6.710850,-5.191,-273.396055,447.028599,3.414097,1048.417290,9.960919,-252.41067,-59.42
72,Right,-508.019034,-4.079019,-5.725,,407.140184,4.735857,1052.071559,10.657080,-249.20965,-39.65
13,Left,-544.192049,-7.308856,-6.180,-319.946875,391.652916,2.914732,1082.581409,10.948791,-251.12516,-80.42


Unnamed: 0,Slytherin,Hufflepuff,Gryffindor,Ravenclaw
263,0,0,1,0
478,0,0,1,0
297,0,0,1,0
28,0,0,1,0
1042,0,0,1,0
...,...,...,...,...
897,1,0,0,0
1470,1,0,0,0
72,1,0,0,0
13,1,0,0,0


In [6]:
X_test = data_test.drop(columns=['Hogwarts House'])
data_test['Slytherin'] = (data_test['Hogwarts House'] == 'Slytherin').astype(int)
data_test['Hufflepuff'] = (data_test['Hogwarts House'] == 'Hufflepuff').astype(int)
data_test['Gryffindor'] = (data_test['Hogwarts House'] == 'Gryffindor').astype(int)
data_test['Ravenclaw'] = (data_test['Hogwarts House'] == 'Ravenclaw').astype(int)
Y_test = data_test[TARGET_CLASSES]
display(X_test)
display(Y_test)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data_test['Slytherin'] = (data_test['Hogwarts House'] == 'Slytherin').astype(int)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data_test['Hufflepuff'] = (data_test['Hogwarts House'] == 'Hufflepuff').astype(int)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data_test['Gryffindor'] = (data_test['H

Unnamed: 0,Best Hand,Astronomy,Herbology,Divination,Muggle Studies,Ancient Runes,History of Magic,Transfiguration,Potions,Charms,Flying
1,Right,-552.060507,-5.987446,-5.612,-487.340557,367.760303,4.107170,1058.944592,7.248742,-252.18425,-113.45
2,Left,-366.076117,7.725017,6.140,664.893521,602.585284,3.555579,1088.088348,8.728531,-227.34265,30.42
6,Left,628.046051,-4.861976,,-926.892512,583.742442,-7.322486,923.539573,1.646666,-257.83447,261.55
14,Right,-197.527318,2.742444,6.603,527.356323,605.590600,5.480097,1063.522361,9.407484,-232.65964,-19.94
15,Left,-447.649812,4.046727,4.949,810.154483,615.531088,3.653495,1075.853850,9.622899,-229.38229,17.00
...,...,...,...,...,...,...,...,...,...,...,...
1589,Left,708.202206,4.850931,5.660,-504.777873,417.520448,4.568628,1046.345436,6.272776,-245.03263,28.70
1593,Left,-426.175401,5.681107,6.205,473.879478,647.238809,6.254227,1046.815627,7.206156,-230.80139,-29.82
1594,Left,599.901612,5.479485,5.543,-525.883264,467.950418,5.933211,1034.394428,9.344054,-241.96940,71.23
1595,Right,354.280086,-4.541837,5.702,-497.235066,618.220213,-5.231721,964.219853,3.389086,-250.39401,185.83


Unnamed: 0,Slytherin,Hufflepuff,Gryffindor,Ravenclaw
1,1,0,0,0
2,0,0,0,1
6,0,0,1,0
14,0,0,0,1
15,0,0,0,1
...,...,...,...,...
1589,0,1,0,0
1593,0,0,0,1
1594,0,1,0,0
1595,0,0,1,0


# Missing values

In [7]:
X = pd.concat([X_train, X_test])
Y = pd.concat([Y_train, Y_test])

In [8]:
X.isna().sum() / len(X) * 100

Best Hand           0.0000
Astronomy           2.0000
Herbology           2.0625
Divination          2.4375
Muggle Studies      2.1875
Ancient Runes       2.1875
History of Magic    2.6875
Transfiguration     2.1250
Potions             1.8750
Charms              0.0000
Flying              0.0000
dtype: float64

In [9]:
Y.isna().sum() / len(Y) * 100

Slytherin     0.0
Hufflepuff    0.0
Gryffindor    0.0
Ravenclaw     0.0
dtype: float64

In [10]:
(Y['Slytherin'] + Y['Gryffindor'] + Y['Ravenclaw'] + Y['Hufflepuff']).value_counts()

1    1600
Name: count, dtype: int64

# Preprocessings

In [11]:
from logreg_train import SimpleImputer, StandardScaler, OneHotEncoder, PreprocessorPipeline

imputer = SimpleImputer(NUMERICAL_COLS, CATEGORICAL_COLS)
scaler = StandardScaler(NUMERICAL_COLS)
ohe = OneHotEncoder(CATEGORICAL_COLS)
preprocessor = PreprocessorPipeline([imputer, scaler, ohe])
preprocessor

--- SimpleImputer ---
Means: None
Modes: None

--- StandardScaler ---
Means: None
Standard Deviations: None

--- OneHotEncoder ---
Columns mapping: {}
Drop Last: True


In [12]:
preprocessor.fit(X_train)
preprocessor

--- SimpleImputer ---
Means: {'Astronomy': 39.38310654128731, 'Herbology': 1.2042533235476416, 'Divination': 3.1663233788395906, 'Muggle Studies': -226.36888113323383, 'Ancient Runes': 496.3469566811535, 'Charms': -243.3103696833333, 'Potions': 5.971364787216205, 'Transfiguration': 1030.3923958342505, 'History of Magic': 2.9344002836301355, 'Flying': 22.782649999999997}
Modes: {'Best Hand': 'Right'}

--- StandardScaler ---
Means: {'Astronomy': 39.38310654128731, 'Herbology': 1.2042533235476416, 'Divination': 3.1663233788395906, 'Muggle Studies': -226.36888113323383, 'Ancient Runes': 496.3469566811535, 'Charms': -243.3103696833333, 'Potions': 5.971364787216205, 'Transfiguration': 1030.3923958342505, 'History of Magic': 2.9344002836301355, 'Flying': 22.782649999999997}
Standard Deviations: {'Astronomy': 518.7299496753959, 'Herbology': 5.182707092974763, 'Divination': 4.150225205421095, 'Muggle Studies': 487.72804469483935, 'Ancient Runes': 106.34795776747801, 'Charms': 8.778864852517279,

In [13]:
X_train_preprocessed = preprocessor.transform(X_train)
X_train_preprocessed

Unnamed: 0,Astronomy,Herbology,Divination,Muggle Studies,Ancient Runes,History of Magic,Transfiguration,Potions,Charms,Flying,Best Hand_Left
263,0.381946,-1.174727,0.589770,-0.965226,1.044200,-2.240281,-2.046265,-1.646574,-0.896790,2.033399,0
478,0.460894,-1.443956,0.468089,-0.144476,0.632922,-1.825750,-1.937169,-2.040146,-1.008359,1.224874,0
297,0.786588,-1.310320,0.357252,-0.758739,0.606111,-1.965223,-2.061237,-1.486242,-1.176351,1.843894,1
28,0.380749,-1.212109,0.566397,-0.652802,1.013178,-1.784757,-1.194347,-0.508283,-0.692287,1.542699,1
1042,0.734803,-1.573235,0.249547,-0.730078,0.685090,-1.576603,-2.036167,-0.727254,-1.171295,1.670302,1
...,...,...,...,...,...,...,...,...,...,...,...
897,-0.802264,-0.607099,-2.196344,-0.569503,-0.603567,-0.074189,0.270542,0.924445,-0.797555,-0.165921,0
1470,-0.740313,-1.527214,-2.013704,-0.096421,-0.463745,0.108514,0.410091,1.271971,-1.036615,-0.843866,1
72,-1.055274,-1.019404,-2.142371,0.000000,-0.838820,0.407513,0.493231,1.493924,-0.671987,-0.640914,0
13,-1.125008,-1.642599,-2.252004,-0.191865,-0.984448,-0.004449,1.187371,1.586929,-0.890182,-1.059445,1


In [14]:
from logreg_train import SortingHat

sorting_hat = SortingHat(X_train_preprocessed.shape[1], lr=LEARNING_RATE)
sorting_hat

<logreg_train.SortingHat at 0x7f0cb1af6d40>

In [15]:
sorting_hat.parameters

{'Slytherin': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 'Hufflepuff': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 'Gryffindor': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 'Ravenclaw': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}

In [16]:
X_test_preprocessed = preprocessor.transform(X_test)
X_test_preprocessed

Unnamed: 0,Astronomy,Herbology,Divination,Muggle Studies,Ancient Runes,History of Magic,Transfiguration,Potions,Charms,Flying,Best Hand_Left
1,-1.140176,-1.387634,-2.115144,-0.535076,-1.209113,0.265296,0.649601,0.407260,-1.010823,-1.398521,0
2,-0.781638,1.258177,0.716510,1.827376,0.998969,0.140519,1.312662,0.879054,1.818882,0.078403,1
6,1.134816,-1.170475,0.000000,-1.436300,0.821788,-2.320241,-2.431048,-1.378823,-1.654439,2.451109,1
14,-0.456712,0.296793,0.828070,1.545380,1.027228,0.575870,0.753752,1.095522,1.213224,-0.438577,0
15,-0.938895,0.548453,0.429537,2.125208,1.120700,0.162668,1.034310,1.164201,1.586547,-0.059363,1
...,...,...,...,...,...,...,...,...,...,...,...
1589,1.289340,0.703624,0.600853,-0.570828,-0.741213,0.369683,0.362953,0.096097,-0.196183,0.060746,1
1593,-0.897497,0.863806,0.732172,1.435735,1.418850,0.750988,0.373651,0.393682,1.424897,-0.540002,1
1594,1.080559,0.824903,0.572662,-0.614101,-0.267015,0.678370,0.091052,1.075298,0.152750,0.497345,1
1595,0.607054,-1.108704,0.610973,-0.555363,1.145986,-1.847282,-1.505516,-0.823296,-0.806897,1.673792,0


In [17]:
Y_train

Unnamed: 0,Slytherin,Hufflepuff,Gryffindor,Ravenclaw
263,0,0,1,0
478,0,0,1,0
297,0,0,1,0
28,0,0,1,0
1042,0,0,1,0
...,...,...,...,...
897,1,0,0,0
1470,1,0,0,0
72,1,0,0,0
13,1,0,0,0


In [18]:
for i in range(EPOCHS):
    sorting_hat.train_step(X_train_preprocessed, Y_train, X_test_preprocessed, Y_test)

In [19]:
data_train['pred'] = sorting_hat.predict(X_train_preprocessed)
houses = {0 : 'Slytherin', 1 : 'Hufflepuff', 2 : 'Gryffindor', 3 : 'Ravenclaw'}
data_train = data_train.replace({'pred':houses})
display(data_train)
data_train['true']=(data_train['Hogwarts House'] == data_train['pred'])
data_train['true'].sum() / len(data_train)

Unnamed: 0,Hogwarts House,Best Hand,Astronomy,Herbology,Divination,Muggle Studies,Ancient Runes,History of Magic,Transfiguration,Potions,Charms,Flying,Slytherin,Hufflepuff,Gryffindor,Ravenclaw,pred
263,Gryffindor,Right,237.510101,-4.884012,5.614,-697.136575,607.395493,-6.969016,940.452086,0.806860,-251.18317,220.86,0,0,1,0,Gryffindor
478,Gryffindor,Right,278.462545,-6.279350,5.109,-296.834007,563.656879,-5.136535,945.247235,-0.427582,-252.16262,142.10,0,0,1,0,Gryffindor
297,Gryffindor,Left,447.409851,-5.586751,4.649,-596.427045,560.805587,-5.753091,939.794031,1.309745,-253.63740,202.40,0,0,1,0,Gryffindor
28,Gryffindor,Left,236.888879,-5.077751,5.517,-544.758909,604.096378,-4.955320,977.896778,4.377129,-249.38786,173.06,0,0,1,0,Gryffindor
1042,Gryffindor,Left,420.547643,-6.949362,4.202,-582.448533,569.204830,-4.035152,940.895902,3.690323,-253.59301,185.49,0,0,1,0,Gryffindor
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
897,Slytherin,Right,-376.775355,-1.942163,-5.949,-504.131371,432.158807,2.606439,1042.283637,8.870901,-250.31200,6.62,1,0,0,0,Slytherin
1470,Slytherin,Left,-344.639322,-6.710850,-5.191,-273.396055,447.028599,3.414097,1048.417290,9.960919,-252.41067,-59.42,1,0,0,0,Slytherin
72,Slytherin,Right,-508.019034,-4.079019,-5.725,,407.140184,4.735857,1052.071559,10.657080,-249.20965,-39.65,1,0,0,0,Slytherin
13,Slytherin,Left,-544.192049,-7.308856,-6.180,-319.946875,391.652916,2.914732,1082.581409,10.948791,-251.12516,-80.42,1,0,0,0,Slytherin


0.9833333333333333

In [20]:
data_test = data.iloc[X_test.index]

In [21]:
data_test['pred'] = sorting_hat.predict(X_test_preprocessed)
houses = {0 : 'Slytherin', 1 : 'Hufflepuff', 2 : 'Gryffindor', 3 : 'Ravenclaw'}
data_test = data_test.replace({'pred':houses})
data_test

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data_test['pred'] = sorting_hat.predict(X_test_preprocessed)


Unnamed: 0,Hogwarts House,Best Hand,Astronomy,Herbology,Divination,Muggle Studies,Ancient Runes,History of Magic,Transfiguration,Potions,Charms,Flying,pred
1,Slytherin,Right,-552.060507,-5.987446,-5.612,-487.340557,367.760303,4.107170,1058.944592,7.248742,-252.18425,-113.45,Slytherin
2,Ravenclaw,Left,-366.076117,7.725017,6.140,664.893521,602.585284,3.555579,1088.088348,8.728531,-227.34265,30.42,Ravenclaw
6,Gryffindor,Left,628.046051,-4.861976,,-926.892512,583.742442,-7.322486,923.539573,1.646666,-257.83447,261.55,Gryffindor
14,Ravenclaw,Right,-197.527318,2.742444,6.603,527.356323,605.590600,5.480097,1063.522361,9.407484,-232.65964,-19.94,Ravenclaw
15,Ravenclaw,Left,-447.649812,4.046727,4.949,810.154483,615.531088,3.653495,1075.853850,9.622899,-229.38229,17.00,Ravenclaw
...,...,...,...,...,...,...,...,...,...,...,...,...,...
1589,Hufflepuff,Left,708.202206,4.850931,5.660,-504.777873,417.520448,4.568628,1046.345436,6.272776,-245.03263,28.70,Hufflepuff
1593,Ravenclaw,Left,-426.175401,5.681107,6.205,473.879478,647.238809,6.254227,1046.815627,7.206156,-230.80139,-29.82,Ravenclaw
1594,Hufflepuff,Left,599.901612,5.479485,5.543,-525.883264,467.950418,5.933211,1034.394428,9.344054,-241.96940,71.23,Hufflepuff
1595,Gryffindor,Right,354.280086,-4.541837,5.702,-497.235066,618.220213,-5.231721,964.219853,3.389086,-250.39401,185.83,Gryffindor


In [22]:
losses = pd.DataFrame(sorting_hat.losses)
fig = go.Figure()
fig.add_trace(go.Line(x = losses['step'], y = losses['train_loss']))
fig.add_trace(go.Line(x = losses['step'], y = losses['test_loss']))
fig.show()


plotly.graph_objs.Line is deprecated.
Please replace it with one of the following more specific types
  - plotly.graph_objs.scatter.Line
  - plotly.graph_objs.layout.shape.Line
  - etc.




In [23]:
losses = pd.DataFrame(sorting_hat.losses)
fig = go.Figure()
fig.add_trace(go.Line(x = losses['step'][-100:], y = losses['train_loss'][-100:]))
fig.add_trace(go.Line(x = losses['step'][-100:], y = losses['test_loss'][-100:]))
fig.show()

In [24]:
data_train['Hogwarts House'].value_counts() / len(data_train)

Hogwarts House
Hufflepuff    0.330833
Ravenclaw     0.276667
Gryffindor    0.204167
Slytherin     0.188333
Name: count, dtype: float64

In [25]:
data_test['Hogwarts House'].value_counts() / len(data_test)

Hogwarts House
Hufflepuff    0.3300
Ravenclaw     0.2775
Gryffindor    0.2050
Slytherin     0.1875
Name: count, dtype: float64

In [26]:
data['Hogwarts House'].value_counts() / len(data)

Hogwarts House
Hufflepuff    0.330625
Ravenclaw     0.276875
Gryffindor    0.204375
Slytherin     0.188125
Name: count, dtype: float64

In [27]:
feature_importance_s = {X_train_preprocessed.columns[i]: sorting_hat.logregs['Slytherin'].weights[i] for i in range(X_train_preprocessed.shape[1])}
feature_importance_h = {X_train_preprocessed.columns[i]: sorting_hat.logregs['Hufflepuff'].weights[i] for i in range(X_train_preprocessed.shape[1])}
feature_importance_g = {X_train_preprocessed.columns[i]: sorting_hat.logregs['Gryffindor'].weights[i] for i in range(X_train_preprocessed.shape[1])}
feature_importance_r = {X_train_preprocessed.columns[i]: sorting_hat.logregs['Ravenclaw'].weights[i] for i in range(X_train_preprocessed.shape[1])}
feature_importance = pd.DataFrame([feature_importance_s, feature_importance_h, feature_importance_g, feature_importance_r],
                                  index=['Slytherin', 'Hufflepuff', 'Gryffindor', 'Ravenclaw']).T

feature_importance

Unnamed: 0,Slytherin,Hufflepuff,Gryffindor,Ravenclaw
Astronomy,-0.759641,1.458894,0.278907,-0.836246
Herbology,-0.755232,1.140713,-0.624532,0.463126
Divination,-1.239785,0.739395,0.266627,0.29799
Muggle Studies,-0.272799,-0.841882,-0.112177,1.09505
Ancient Runes,-0.445783,-1.250534,0.586532,0.896267
History of Magic,0.125882,0.647803,-0.784329,0.168985
Transfiguration,0.226734,0.57881,-0.805397,0.158636
Potions,0.581076,-0.421289,-0.309418,0.130735
Charms,-0.458084,-0.204595,-0.384118,1.053814
Flying,-0.508221,-0.315199,0.797189,-0.063495


In [28]:
data_test['true']=(data_test['Hogwarts House'] == data_test['pred'])
print(data_train['true'].sum() / len(data_train))
print(data_test['true'].sum() / len(data_test))

0.9833333333333333
0.9775


In [29]:
px.bar(feature_importance, labels={
                     "index": "Feature",
                     "value": "Importance"
                 })

# Training model on whole dataset

In [30]:
data = pd.read_csv("datasets/dataset_train.csv")
data['Birthday'] = pd.to_datetime(data['Birthday'])
data['Birthday Weekday'] = data['Birthday'].dt.dayofweek
data['Birthday Year'] = data['Birthday'].dt.year
data['Birthday Month'] = data['Birthday'].dt.month
data = data.drop(columns=COLS_TO_DROP)
data

Unnamed: 0,Hogwarts House,Best Hand,Astronomy,Herbology,Divination,Muggle Studies,Ancient Runes,History of Magic,Transfiguration,Potions,Charms,Flying
0,Ravenclaw,Left,-487.886086,5.727180,4.722,272.035831,532.484226,5.231058,1039.788281,3.790369,-232.79405,-26.89
1,Slytherin,Right,-552.060507,-5.987446,-5.612,-487.340557,367.760303,4.107170,1058.944592,7.248742,-252.18425,-113.45
2,Ravenclaw,Left,-366.076117,7.725017,6.140,664.893521,602.585284,3.555579,1088.088348,8.728531,-227.34265,30.42
3,Gryffindor,Left,697.742809,-6.497214,4.026,-537.001128,523.982133,-4.809637,920.391449,0.821911,-256.84675,200.64
4,Gryffindor,Left,436.775204,-7.820623,2.236,-444.262537,599.324514,-3.444377,937.434724,4.311066,-256.38730,157.98
...,...,...,...,...,...,...,...,...,...,...,...,...
1595,Gryffindor,Right,354.280086,-4.541837,5.702,-497.235066,618.220213,-5.231721,964.219853,3.389086,-250.39401,185.83
1596,Slytherin,Left,367.531174,6.061064,1.757,-643.271092,445.827565,2.238112,1056.147366,5.825263,-246.42719,44.80
1597,Gryffindor,Right,544.018925,-3.203269,6.065,-385.150457,635.211486,-5.984257,953.866685,1.709808,-251.63679,198.47
1598,Hufflepuff,Left,453.676219,3.442831,6.738,-831.741123,383.444937,3.813111,1087.949205,3.904100,-246.19072,-76.81


In [31]:
X = data.drop(columns=['Hogwarts House'])
data['Slytherin'] = (data['Hogwarts House'] == 'Slytherin').astype(int)
data['Hufflepuff'] = (data['Hogwarts House'] == 'Hufflepuff').astype(int)
data['Gryffindor'] = (data['Hogwarts House'] == 'Gryffindor').astype(int)
data['Ravenclaw'] = (data['Hogwarts House'] == 'Ravenclaw').astype(int)
Y = data[TARGET_CLASSES]
display(X)
display(Y)

Unnamed: 0,Best Hand,Astronomy,Herbology,Divination,Muggle Studies,Ancient Runes,History of Magic,Transfiguration,Potions,Charms,Flying
0,Left,-487.886086,5.727180,4.722,272.035831,532.484226,5.231058,1039.788281,3.790369,-232.79405,-26.89
1,Right,-552.060507,-5.987446,-5.612,-487.340557,367.760303,4.107170,1058.944592,7.248742,-252.18425,-113.45
2,Left,-366.076117,7.725017,6.140,664.893521,602.585284,3.555579,1088.088348,8.728531,-227.34265,30.42
3,Left,697.742809,-6.497214,4.026,-537.001128,523.982133,-4.809637,920.391449,0.821911,-256.84675,200.64
4,Left,436.775204,-7.820623,2.236,-444.262537,599.324514,-3.444377,937.434724,4.311066,-256.38730,157.98
...,...,...,...,...,...,...,...,...,...,...,...
1595,Right,354.280086,-4.541837,5.702,-497.235066,618.220213,-5.231721,964.219853,3.389086,-250.39401,185.83
1596,Left,367.531174,6.061064,1.757,-643.271092,445.827565,2.238112,1056.147366,5.825263,-246.42719,44.80
1597,Right,544.018925,-3.203269,6.065,-385.150457,635.211486,-5.984257,953.866685,1.709808,-251.63679,198.47
1598,Left,453.676219,3.442831,6.738,-831.741123,383.444937,3.813111,1087.949205,3.904100,-246.19072,-76.81


Unnamed: 0,Slytherin,Hufflepuff,Gryffindor,Ravenclaw
0,0,0,0,1
1,1,0,0,0
2,0,0,0,1
3,0,0,1,0
4,0,0,1,0
...,...,...,...,...
1595,0,0,1,0
1596,1,0,0,0
1597,0,0,1,0
1598,0,1,0,0


In [32]:
imputer = SimpleImputer(NUMERICAL_COLS, CATEGORICAL_COLS)
scaler = StandardScaler(NUMERICAL_COLS)
ohe = OneHotEncoder(CATEGORICAL_COLS)
preprocessor = PreprocessorPipeline([imputer, scaler, ohe])
preprocessor

--- SimpleImputer ---
Means: None
Modes: None

--- StandardScaler ---
Means: None
Standard Deviations: None

--- OneHotEncoder ---
Columns mapping: {}
Drop Last: True


In [33]:
preprocessor.fit(X)
display(preprocessor)
X_preprocessed = preprocessor.transform(X)
X_preprocessed

--- SimpleImputer ---
Means: {'Astronomy': 39.79713089016475, 'Herbology': 1.1410195296768046, 'Divination': 3.1539096732863547, 'Muggle Studies': -224.58991486346417, 'Ancient Runes': 495.74797005915786, 'Charms': -243.3744090125, 'Potions': 5.950372992780089, 'Transfiguration': 1030.0969463871306, 'History of Magic': 2.9630946151165936, 'Flying': 21.9580125}
Modes: {'Best Hand': 'Right'}

--- StandardScaler ---
Means: {'Astronomy': 39.79713089016475, 'Herbology': 1.1410195296768046, 'Divination': 3.1539096732863547, 'Muggle Studies': -224.58991486346417, 'Ancient Runes': 495.74797005915786, 'Charms': -243.3744090125, 'Potions': 5.950372992780089, 'Transfiguration': 1030.0969463871306, 'History of Magic': 2.9630946151165936, 'Flying': 21.9580125}
Standard Deviations: {'Astronomy': 520.2982676051708, 'Herbology': 5.2196819935318235, 'Divination': 4.155300897977581, 'Muggle Studies': 486.34483965206664, 'Ancient Runes': 106.28516457845274, 'Charms': 8.783639876017117, 'Potions': 3.14785

Unnamed: 0,Astronomy,Herbology,Divination,Muggle Studies,Ancient Runes,History of Magic,Transfiguration,Potions,Charms,Flying,Best Hand_Left
0,-1.014194,0.878628,0.377371,1.021139,0.345639,0.512444,0.219633,-0.686183,1.204553,-0.500330,1
1,-1.137535,-1.365690,-2.109573,-0.540256,-1.204191,0.258503,0.653769,0.412462,-1.002983,-1.386928,0
2,-0.780078,1.261379,0.718622,1.828915,1.005195,0.133871,1.314249,0.882556,1.825184,0.086673,1
3,1.264555,-1.463352,0.209874,-0.642366,0.265645,-1.756242,-2.486237,-1.629193,-1.533799,1.830165,1
4,0.762982,-1.716894,-0.220901,-0.451681,0.974516,-1.447763,-2.099988,-0.520770,-1.481492,1.393217,1
...,...,...,...,...,...,...,...,...,...,...,...
1595,0.604428,-1.088736,0.613214,-0.560600,1.152299,-1.851612,-1.492961,-0.813661,-0.799168,1.678473,0
1596,0.629896,0.942595,-0.336175,-0.860873,-0.469684,-0.163809,0.590376,-0.039745,-0.347553,0.233961,1
1597,0.969101,-0.832290,0.700573,-0.330137,1.312164,-2.021646,-1.727593,-1.347129,-0.940656,1.807939,0
1598,0.795465,0.440987,0.862534,-1.248397,-1.056620,0.192060,1.311096,-0.650053,-0.320631,-1.011640,1


In [34]:
sorting_hat = SortingHat(X_preprocessed.shape[1], lr=LEARNING_RATE)
for i in range(EPOCHS):
    sorting_hat.train_step(X_preprocessed, Y)

In [35]:
Y

Unnamed: 0,Slytherin,Hufflepuff,Gryffindor,Ravenclaw
0,0,0,0,1
1,1,0,0,0
2,0,0,0,1
3,0,0,1,0
4,0,0,1,0
...,...,...,...,...
1595,0,0,1,0
1596,1,0,0,0
1597,0,0,1,0
1598,0,1,0,0


In [36]:
sorting_hat.predict(X_preprocessed)

0        Ravenclaw
1        Slytherin
2        Ravenclaw
3       Gryffindor
4       Gryffindor
           ...    
1595    Gryffindor
1596    Hufflepuff
1597    Gryffindor
1598    Hufflepuff
1599    Hufflepuff
Length: 1600, dtype: object

In [37]:
data['pred'] = sorting_hat.predict(X_preprocessed)
houses = {0 : 'Slytherin', 1 : 'Hufflepuff', 2 : 'Gryffindor', 3 : 'Ravenclaw'}
data = data.replace({'pred':houses})
display(data.loc[X_preprocessed.index[0], :])
data['true']=(data['Hogwarts House'] == data['pred'])
data['true'].sum() / len(data)

Hogwarts House        Ravenclaw
Best Hand                  Left
Astronomy           -487.886086
Herbology               5.72718
Divination                4.722
Muggle Studies       272.035831
Ancient Runes        532.484226
History of Magic       5.231058
Transfiguration     1039.788281
Potions                3.790369
Charms               -232.79405
Flying                   -26.89
Slytherin                     0
Hufflepuff                    0
Gryffindor                    0
Ravenclaw                     1
pred                  Ravenclaw
Name: 0, dtype: object

0.981875

In [38]:
losses = pd.DataFrame(sorting_hat.losses)
fig = go.Figure()
fig.add_trace(go.Line(x = losses['step'], y = losses['train_loss']))
fig.show()


plotly.graph_objs.Line is deprecated.
Please replace it with one of the following more specific types
  - plotly.graph_objs.scatter.Line
  - plotly.graph_objs.layout.shape.Line
  - etc.




In [39]:
losses = pd.DataFrame(sorting_hat.losses)
fig = go.Figure()
fig.add_trace(go.Line(x = losses['step'][-100:], y = losses['train_loss'][-100:]))
fig.show()

# Predictions

In [40]:
test = pd.read_csv("datasets/dataset_test.csv")
test['Birthday'] = pd.to_datetime(test['Birthday'])
test['Birthday Weekday'] = test['Birthday'].dt.dayofweek
test['Birthday Year'] = test['Birthday'].dt.year
test['Birthday Month'] = test['Birthday'].dt.month
test = test.drop(columns=COLS_TO_DROP + ['Hogwarts House'])
test

Unnamed: 0,Best Hand,Astronomy,Herbology,Divination,Muggle Studies,Ancient Runes,History of Magic,Transfiguration,Potions,Charms,Flying
0,Right,696.096071,3.020172,7.996,-365.151850,393.138185,4.207691,1046.742736,3.668983,-244.48172,-13.62
1,Left,-370.844655,2.965226,6.349,522.580486,602.853051,6.460017,1048.053878,8.514622,-231.29200,-26.26
2,Left,320.303990,-6.185697,4.619,-630.073207,588.071795,-5.565818,936.437358,1.850829,-252.99343,200.15
3,Right,407.202928,4.962442,,-449.179806,427.699966,,1043.397718,4.656573,-244.01660,-11.15
4,Right,288.337747,3.737656,4.886,-449.732166,385.712782,2.876347,1051.377936,2.750586,-243.99806,-7.12
...,...,...,...,...,...,...,...,...,...,...,...
395,Left,-554.181932,-5.647655,-3.799,-591.764651,392.973420,7.048482,1047.648405,10.408749,-248.39978,-94.89
396,Left,632.233530,6.754862,3.294,-221.848397,319.360250,3.921402,1035.681313,-0.169741,-246.87982,-15.53
397,Right,292.108738,5.234530,4.230,-787.036050,433.259967,3.898160,1069.794110,6.495579,-244.01333,1.25
398,Left,-726.418553,6.735582,3.908,511.960762,613.391514,7.244499,1042.058804,7.554259,-228.24290,-18.27


In [41]:
test_preprocessed = preprocessor.transform(test)
test_preprocessed

Unnamed: 0,Astronomy,Herbology,Divination,Muggle Studies,Ancient Runes,History of Magic,Transfiguration,Potions,Charms,Flying,Best Hand_Left
0,1.261390,0.360013,1.165280,-0.289017,-0.965420,0.281215,0.377241,-0.724744,-0.126065,-0.364411,0
1,-0.789243,0.349486,0.768919,1.536298,1.007714,0.790127,0.406955,0.814602,1.375558,-0.493877,1
2,0.539127,-1.403671,0.352583,-0.833736,0.868643,-1.927101,-2.122591,-1.302330,-1.095106,1.825147,1
3,0.706145,0.732118,0.000000,-0.461791,-0.640240,0.000000,0.301433,-0.411010,-0.073112,-0.339112,0
4,0.477689,0.497470,0.416839,-0.462927,-1.035283,-0.019601,0.482287,-1.016498,-0.071001,-0.297834,0
...,...,...,...,...,...,...,...,...,...,...,...
395,-1.141613,-1.300592,-1.673263,-0.754968,-0.966970,0.923090,0.397766,1.416322,-0.572129,-1.196826,1
396,1.138648,1.075514,0.033714,0.005637,-1.659570,0.216529,0.126558,-1.944218,-0.399084,-0.383974,1
397,0.484936,0.784245,0.258968,-1.156476,-0.587928,0.211277,0.899650,0.173199,-0.072740,-0.212104,0
398,-1.472647,1.071821,0.181477,1.514462,1.106867,0.967380,0.271090,0.509517,1.722692,-0.412039,1


In [42]:
sorting_hat.predict(test_preprocessed)

0      Hufflepuff
1       Ravenclaw
2      Gryffindor
3      Hufflepuff
4      Hufflepuff
          ...    
395     Slytherin
396    Hufflepuff
397    Hufflepuff
398     Ravenclaw
399     Ravenclaw
Length: 400, dtype: object

In [43]:
predictions = sorting_hat.predict(test_preprocessed)
houses = {0 : 'Slytherin', 1 : 'Hufflepuff', 2 : 'Gryffindor', 3 : 'Ravenclaw'}
predictions = predictions.replace(houses)
display(predictions)

0      Hufflepuff
1       Ravenclaw
2      Gryffindor
3      Hufflepuff
4      Hufflepuff
          ...    
395     Slytherin
396    Hufflepuff
397    Hufflepuff
398     Ravenclaw
399     Ravenclaw
Length: 400, dtype: object

In [44]:
truth = pd.read_csv("~/Downloads/dataset_truth.csv")
truth

Unnamed: 0,Index,Hogwarts House
0,0,Hufflepuff
1,1,Ravenclaw
2,2,Gryffindor
3,3,Hufflepuff
4,4,Hufflepuff
...,...,...
395,395,Slytherin
396,396,Hufflepuff
397,397,Hufflepuff
398,398,Ravenclaw


In [45]:
truth['pred'] = predictions

In [46]:
truth

Unnamed: 0,Index,Hogwarts House,pred
0,0,Hufflepuff,Hufflepuff
1,1,Ravenclaw,Ravenclaw
2,2,Gryffindor,Gryffindor
3,3,Hufflepuff,Hufflepuff
4,4,Hufflepuff,Hufflepuff
...,...,...,...
395,395,Slytherin,Slytherin
396,396,Hufflepuff,Hufflepuff
397,397,Hufflepuff,Hufflepuff
398,398,Ravenclaw,Ravenclaw


In [47]:
len(truth[truth['pred'] == truth['Hogwarts House']]) / len(truth) * 100

99.0

In [48]:
sorting_hat.parameters

{'Slytherin': [-3.0579327284528643,
  -0.7733964451437805,
  -0.7664903320987481,
  -1.2607581532754655,
  -0.2905443365499026,
  -0.4618680099643445,
  0.1543255914434574,
  0.23230334399099717,
  0.6280589936390716,
  -0.4619604521526662,
  -0.5207984254123783,
  0.0049656103200915284],
 'Hufflepuff': [-1.6866141778707584,
  1.4562828806759482,
  1.135949523783265,
  0.7558144854992594,
  -0.8637548825907131,
  -1.2640288016097028,
  0.630695660195349,
  0.581792666188333,
  -0.46335452241376457,
  -0.19801112216221947,
  -0.3325983544936224,
  -0.007376702291039833],
 'Gryffindor': [-3.024288384449755,
  0.320633946178578,
  -0.607497323721094,
  0.2611091011559458,
  -0.13562067509331935,
  0.573984959293176,
  -0.7661125160520594,
  -0.8156979329242279,
  -0.3091732923249972,
  -0.4089485003421139,
  0.8140605022469902,
  -0.017424643069245058],
 'Ravenclaw': [-2.304002710051526,
  -0.8417566379129177,
  0.47367433494433553,
  0.2983295429721254,
  1.1223061012951063,
  0.88251389

# Verification of model saving and loading

In [49]:
sorting_hat.save_model('model.json')

In [50]:
sorting_hat_verif = SortingHat(X_preprocessed.shape[1], lr=LEARNING_RATE)
print(sorting_hat_verif.logregs)
sorting_hat_verif.load_model('model.json')
sorting_hat_verif.parameters

{'Slytherin': <logreg_train.LogReg object at 0x7f0caeb34e80>, 'Hufflepuff': <logreg_train.LogReg object at 0x7f0caeb34f10>, 'Gryffindor': <logreg_train.LogReg object at 0x7f0caeb34eb0>, 'Ravenclaw': <logreg_train.LogReg object at 0x7f0caeb34100>}


{'Slytherin': [-3.0579327284528643,
  -0.7733964451437805,
  -0.7664903320987481,
  -1.2607581532754655,
  -0.2905443365499026,
  -0.4618680099643445,
  0.1543255914434574,
  0.23230334399099717,
  0.6280589936390716,
  -0.4619604521526662,
  -0.5207984254123783,
  0.0049656103200915284],
 'Hufflepuff': [-1.6866141778707584,
  1.4562828806759482,
  1.135949523783265,
  0.7558144854992594,
  -0.8637548825907131,
  -1.2640288016097028,
  0.630695660195349,
  0.581792666188333,
  -0.46335452241376457,
  -0.19801112216221947,
  -0.3325983544936224,
  -0.007376702291039833],
 'Gryffindor': [-3.024288384449755,
  0.320633946178578,
  -0.607497323721094,
  0.2611091011559458,
  -0.13562067509331935,
  0.573984959293176,
  -0.7661125160520594,
  -0.8156979329242279,
  -0.3091732923249972,
  -0.4089485003421139,
  0.8140605022469902,
  -0.017424643069245058],
 'Ravenclaw': [-2.304002710051526,
  -0.8417566379129177,
  0.47367433494433553,
  0.2983295429721254,
  1.1223061012951063,
  0.88251389

In [51]:
truth['pred_verif'] = sorting_hat_verif.predict(test_preprocessed).replace(houses)
truth

Unnamed: 0,Index,Hogwarts House,pred,pred_verif
0,0,Hufflepuff,Hufflepuff,Hufflepuff
1,1,Ravenclaw,Ravenclaw,Ravenclaw
2,2,Gryffindor,Gryffindor,Gryffindor
3,3,Hufflepuff,Hufflepuff,Hufflepuff
4,4,Hufflepuff,Hufflepuff,Hufflepuff
...,...,...,...,...
395,395,Slytherin,Slytherin,Slytherin
396,396,Hufflepuff,Hufflepuff,Hufflepuff
397,397,Hufflepuff,Hufflepuff,Hufflepuff
398,398,Ravenclaw,Ravenclaw,Ravenclaw


In [52]:
len(truth[truth['pred_verif'] == truth['Hogwarts House']]) / len(truth) * 100

99.0