In [1]:
import os
from pathlib import Path

import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score, make_scorer
from sklearn.model_selection import KFold, cross_val_score
from tqdm.notebook import tqdm

In [2]:
data_path = Path(os.getcwd()).parent/"data"

df = pd.read_csv(data_path/"abalone.csv")

In [3]:
def save_ans(*ans, delimiter=" "):
    with open(data_path.parent/"res.txt", "w") as f:
        f.write(delimiter.join(map(str, ans)))

In [4]:
df.head()

Unnamed: 0,Sex,Length,Diameter,Height,WholeWeight,ShuckedWeight,VisceraWeight,ShellWeight,Rings
0,M,0.455,0.365,0.095,0.514,0.2245,0.101,0.15,15
1,M,0.35,0.265,0.09,0.2255,0.0995,0.0485,0.07,7
2,F,0.53,0.42,0.135,0.677,0.2565,0.1415,0.21,9
3,M,0.44,0.365,0.125,0.516,0.2155,0.114,0.155,10
4,I,0.33,0.255,0.08,0.205,0.0895,0.0395,0.055,7


In [5]:
df["Sex"] = df["Sex"].map({"M": 1, "I": 0, "F": -1})

In [6]:
X = df.iloc[:, :-1]
y = df["Rings"]

In [7]:
res = {}
cv = KFold(random_state=1, shuffle=True)
for n_e in tqdm(range(1, 51)):
    clf = RandomForestRegressor(n_estimators=n_e, random_state=1)
    res[n_e] = cross_val_score(clf, X, y, scoring=make_scorer(r2_score), cv=cv)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))




In [8]:
res = {i: np.mean(v) for i, v in res.items()}

In [9]:
sorted(res.items(), key=lambda i: -i[1])

[(50, 0.5309509147417047),
 (49, 0.530813058616495),
 (33, 0.5301073722643779),
 (48, 0.5300509396315634),
 (42, 0.5300433306143383),
 (34, 0.5299613734264366),
 (43, 0.5299135764090978),
 (36, 0.529910050667947),
 (35, 0.5298209779129148),
 (41, 0.5298087685207094),
 (44, 0.5296814957917958),
 (39, 0.529515898349607),
 (38, 0.5294715388671245),
 (40, 0.5294703580378128),
 (37, 0.5294320415136227),
 (47, 0.5291786571646144),
 (46, 0.5290083070325597),
 (32, 0.5289244806388986),
 (45, 0.528908112349864),
 (31, 0.5276420438225101),
 (30, 0.5270858715838138),
 (29, 0.5265556293057552),
 (28, 0.5256557724971402),
 (27, 0.5246393588459404),
 (26, 0.5243076139284634),
 (25, 0.5232486470488318),
 (24, 0.5231059969795335),
 (23, 0.521742855685855),
 (22, 0.5208044230080824),
 (21, 0.520529096463528),
 (19, 0.5198293095329432),
 (20, 0.51948435033775),
 (18, 0.5172203573170132),
 (17, 0.5148917747729636),
 (16, 0.5114105314179662),
 (15, 0.5091809969556578),
 (14, 0.5073168234618861),
 (13, 0.5

In [10]:
save_ans(22)