In [4]:
import pandas as pd
import numpy as np

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold, RandomizedSearchCV, GridSearchCV
from sklearn.linear_model import SGDRegressor, SGDClassifier
from sklearn.metrics import mean_absolute_error,mean_squared_error,accuracy_score,classification_report,roc_auc_score, roc_curve

In [1]:
from sklearn.datasets import load_boston

In [2]:
bhp = load_boston()
print(bhp.DESCR)

.. _boston_dataset:

Boston house prices dataset
---------------------------

**Data Set Characteristics:**  

    :Number of Instances: 506 

    :Number of Attributes: 13 numeric/categorical predictive. Median Value (attribute 14) is usually the target.

    :Attribute Information (in order):
        - CRIM     per capita crime rate by town
        - ZN       proportion of residential land zoned for lots over 25,000 sq.ft.
        - INDUS    proportion of non-retail business acres per town
        - CHAS     Charles River dummy variable (= 1 if tract bounds river; 0 otherwise)
        - NOX      nitric oxides concentration (parts per 10 million)
        - RM       average number of rooms per dwelling
        - AGE      proportion of owner-occupied units built prior to 1940
        - DIS      weighted distances to five Boston employment centres
        - RAD      index of accessibility to radial highways
        - TAX      full-value property-tax rate per $10,000
        - PTRATIO  pu


    The Boston housing prices dataset has an ethical problem. You can refer to
    the documentation of this function for further details.

    The scikit-learn maintainers therefore strongly discourage the use of this
    dataset unless the purpose of the code is to study and educate about
    ethical issues in data science and machine learning.

    In this special case, you can fetch the dataset from the original
    source::

        import pandas as pd
        import numpy as np


        data_url = "http://lib.stat.cmu.edu/datasets/boston"
        raw_df = pd.read_csv(data_url, sep="\s+", skiprows=22, header=None)
        data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]])
        target = raw_df.values[1::2, 2]

    Alternative datasets include the California housing dataset (i.e.
    :func:`~sklearn.datasets.fetch_california_housing`) and the Ames housing
    dataset. You can load the datasets as follows::

        from sklearn.datasets import fetch_california_h

In [5]:
bhp = pd.read_csv('boston_house_prices.csv')

In [6]:
bhp.head(3)

Unnamed: 0,CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT,MEDV
0,0.00632,18.0,2.31,0,0.538,6.575,65.2,4.09,1,296,15.3,396.9,4.98,24.0
1,0.02731,0.0,7.07,0,0.469,6.421,78.9,4.9671,2,242,17.8,396.9,9.14,21.6
2,0.02729,0.0,7.07,0,0.469,7.185,61.1,4.9671,2,242,17.8,392.83,4.03,34.7


In [8]:
X = bhp.drop(columns='MEDV')
y = bhp.MEDV

In [None]:
# train, test = train_test_split(bhp,test_size=.33,random_state=42)

In [10]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

In [11]:
sc = StandardScaler()

In [13]:
X_train_std = sc.fit_transform(X_train)
X_test_std = sc.transform(X_test)

In [32]:
sgd_reg = SGDRegressor(random_state=2022)

In [33]:
sgd_reg.fit(X_train_std,y_train.values)

SGDRegressor(random_state=2022)

In [35]:
medv_pred = sgd_reg.predict(X_test_std)

In [38]:
df_validation = pd.DataFrame({'medv_actual':y_test,'medv_pred':medv_pred})
df_validation

Unnamed: 0,medv_actual,medv_pred
173,23.6,28.414975
274,32.4,36.138594
491,13.6,16.807925
72,22.8,25.442495
452,16.1,18.764191
...,...,...
110,21.7,21.387833
321,23.1,25.010961
265,22.8,27.359265
29,21.0,20.968473


In [41]:
((df_validation.medv_actual - df_validation.medv_pred)**2).mean()

21.17002669359018

In [43]:
mean_squared_error(y_test,medv_pred)

21.17002669359017

In [46]:
mean_squared_error(y_test,medv_pred,squared=False)  # RMSE

4.601089728921853

In [47]:
mean_squared_error(y_train,sgd_reg.predict(X_train_std),squared=False)

4.80468801363288

### Coefficient of determination

In [50]:
sgd_reg.score(X_train_std,y_train.values)

0.7380998091037393

## Minibatch Stochastic Gradient Descent

In [100]:
def iter_mb(batch_size = 1):    
    
    start = 0
    
    while start < X_train_std.shape[0]:
        rows = range(start, start + batch_size)
        
        X_batch = X_train_std[rows,:]
        y_batch = y_train[rows]
        
        yield X_batch, y_batch
        start = start + batch_size

In [101]:
batch_iterator = iter_mb(batch_size = 113)

In [102]:
sgd_reg = SGDRegressor()

In [103]:
for X_, y_ in batch_iterator:
    sgd_reg.partial_fit(X_,y_)

In [104]:
mean_squared_error(y_test,sgd_reg.predict(X_test_std))

84.40913502331564

## Cross Validation

In [106]:
train, test = train_test_split(bhp,test_size=.33,random_state=42)

In [107]:
train

Unnamed: 0,CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT,MEDV
478,10.23300,0.0,18.10,0,0.614,6.185,96.7,2.1705,24,666,20.2,379.70,18.03,14.6
26,0.67191,0.0,8.14,0,0.538,5.813,90.3,4.6820,4,307,21.0,376.88,14.81,16.6
7,0.14455,12.5,7.87,0,0.524,6.172,96.1,5.9505,5,311,15.2,396.90,19.15,27.1
492,0.11132,0.0,27.74,0,0.609,5.983,83.5,2.1099,4,711,20.1,396.90,13.35,20.1
108,0.12802,0.0,8.56,0,0.520,6.474,97.1,2.4329,5,384,20.9,395.24,12.27,19.8
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
106,0.17120,0.0,8.56,0,0.520,5.836,91.9,2.2110,5,384,20.9,395.67,18.66,19.5
270,0.29916,20.0,6.96,0,0.464,5.856,42.1,4.4290,3,223,18.6,388.65,13.00,21.1
348,0.01501,80.0,2.01,0,0.435,6.635,29.7,8.3440,4,280,17.0,390.94,5.99,24.5
435,11.16040,0.0,18.10,0,0.740,6.629,94.6,2.1247,24,666,20.2,109.85,23.27,13.4


In [113]:
train = train.sample(frac=1,random_state=2022).reset_index(drop=True)

In [114]:
train['kfold'] = -999

In [116]:
train.head(2)

Unnamed: 0,CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT,MEDV,kfold
0,0.26938,0.0,9.9,0,0.544,6.266,82.8,3.2628,4,304,18.4,393.39,7.9,21.6,-999
1,4.83567,0.0,18.1,0,0.583,5.905,53.2,3.1523,24,666,20.2,388.22,11.45,20.6,-999


In [117]:
kf = KFold(n_splits=5)

In [123]:
# lst = ['a','b','c']
# for i in enumerate(lst,start=1):
#     print(i){
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "d16040f8-0273-4629-9c44-d4f14c96e4d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.model_selection import train_test_split, KFold, StratifiedKFold, RandomizedSearchCV, GridSearchCV\n",
    "from sklearn.linear_model import SGDRegressor, SGDClassifier\n",
    "from sklearn.metrics import mean_absolute_error,mean_squared_error,accuracy_score,classification_report,roc_auc_score, roc_curve\n",
    "from sklearn.datasets import load_boston\n",
    "from sklearn.datasets import load_breast_cancer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "94b4d174-fdf0-45a4-b806-fb6081c90171",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import load_boston"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1b45bd10-1824-4869-bbc4-d14093cf13c1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ".. _boston_dataset:\n",
      "\n",
      "Boston house prices dataset\n",
      "---------------------------\n",
      "\n",
      "**Data Set Characteristics:**  \n",
      "\n",
      "    :Number of Instances: 506 \n",
      "\n",
      "    :Number of Attributes: 13 numeric/categorical predictive. Median Value (attribute 14) is usually the target.\n",
      "\n",
      "    :Attribute Information (in order):\n",
      "        - CRIM     per capita crime rate by town\n",
      "        - ZN       proportion of residential land zoned for lots over 25,000 sq.ft.\n",
      "        - INDUS    proportion of non-retail business acres per town\n",
      "        - CHAS     Charles River dummy variable (= 1 if tract bounds river; 0 otherwise)\n",
      "        - NOX      nitric oxides concentration (parts per 10 million)\n",
      "        - RM       average number of rooms per dwelling\n",
      "        - AGE      proportion of owner-occupied units built prior to 1940\n",
      "        - DIS      weighted distances to five Boston employment centres\n",
      "        - RAD      index of accessibility to radial highways\n",
      "        - TAX      full-value property-tax rate per $10,000\n",
      "        - PTRATIO  pupil-teacher ratio by town\n",
      "        - B        1000(Bk - 0.63)^2 where Bk is the proportion of black people by town\n",
      "        - LSTAT    % lower status of the population\n",
      "        - MEDV     Median value of owner-occupied homes in $1000's\n",
      "\n",
      "    :Missing Attribute Values: None\n",
      "\n",
      "    :Creator: Harrison, D. and Rubinfeld, D.L.\n",
      "\n",
      "This is a copy of UCI ML housing dataset.\n",
      "https://archive.ics.uci.edu/ml/machine-learning-databases/housing/\n",
      "\n",
      "\n",
      "This dataset was taken from the StatLib library which is maintained at Carnegie Mellon University.\n",
      "\n",
      "The Boston house-price data of Harrison, D. and Rubinfeld, D.L. 'Hedonic\n",
      "prices and the demand for clean air', J. Environ. Economics & Management,\n",
      "vol.5, 81-102, 1978.   Used in Belsley, Kuh & Welsch, 'Regression diagnostics\n",
      "...', Wiley, 1980.   N.B. Various transformations are used in the table on\n",
      "pages 244-261 of the latter.\n",
      "\n",
      "The Boston house-price data has been used in many machine learning papers that address regression\n",
      "problems.   \n",
      "     \n",
      ".. topic:: References\n",
      "\n",
      "   - Belsley, Kuh & Welsch, 'Regression diagnostics: Identifying Influential Data and Sources of Collinearity', Wiley, 1980. 244-261.\n",
      "   - Quinlan,R. (1993). Combining Instance-Based and Model-Based Learning. In Proceedings on the Tenth International Conference of Machine Learning, 236-243, University of Massachusetts, Amherst. Morgan Kaufmann.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/lib/python3.8/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function load_boston is deprecated; `load_boston` is deprecated in 1.0 and will be removed in 1.2.\n",
      "\n",
      "    The Boston housing prices dataset has an ethical problem. You can refer to\n",
      "    the documentation of this function for further details.\n",
      "\n",
      "    The scikit-learn maintainers therefore strongly discourage the use of this\n",
      "    dataset unless the purpose of the code is to study and educate about\n",
      "    ethical issues in data science and machine learning.\n",
      "\n",
      "    In this special case, you can fetch the dataset from the original\n",
      "    source::\n",
      "\n",
      "        import pandas as pd\n",
      "        import numpy as np\n",
      "\n",
      "\n",
      "        data_url = \"http://lib.stat.cmu.edu/datasets/boston\"\n",
      "        raw_df = pd.read_csv(data_url, sep=\"\\s+\", skiprows=22, header=None)\n",
      "        data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]])\n",
      "        target = raw_df.values[1::2, 2]\n",
      "\n",
      "    Alternative datasets include the California housing dataset (i.e.\n",
      "    :func:`~sklearn.datasets.fetch_california_housing`) and the Ames housing\n",
      "    dataset. You can load the datasets as follows::\n",
      "\n",
      "        from sklearn.datasets import fetch_california_housing\n",
      "        housing = fetch_california_housing()\n",
      "\n",
      "    for the California housing dataset and::\n",
      "\n",
      "        from sklearn.datasets import fetch_openml\n",
      "        housing = fetch_openml(name=\"house_prices\", as_frame=True)\n",
      "\n",
      "    for the Ames housing dataset.\n",
      "    \n",
      "  warnings.warn(msg, category=FutureWarning)\n"
     ]
    }
   ],
   "source": [
    "bhp = load_boston()\n",
    "print(bhp.DESCR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "404edb9a-a674-4549-9609-68cb9ed81892",
   "metadata": {},
   "outputs": [],
   "source": [
    "bhp = pd.read_csv('boston_house_prices.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "83d79954-f400-41e3-9a42-d0445b19ef5a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>CRIM</th>\n",
       "      <th>ZN</th>\n",
       "      <th>INDUS</th>\n",
       "      <th>CHAS</th>\n",
       "      <th>NOX</th>\n",
       "      <th>RM</th>\n",
       "      <th>AGE</th>\n",
       "      <th>DIS</th>\n",
       "      <th>RAD</th>\n",
       "      <th>TAX</th>\n",
       "      <th>PTRATIO</th>\n",
       "      <th>B</th>\n",
       "      <th>LSTAT</th>\n",
       "      <th>MEDV</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.00632</td>\n",
       "      <td>18.0</td>\n",
       "      <td>2.31</td>\n",
       "      <td>0</td>\n",
       "      <td>0.538</td>\n",
       "      <td>6.575</td>\n",
       "      <td>65.2</td>\n",
       "      <td>4.0900</td>\n",
       "      <td>1</td>\n",
       "      <td>296</td>\n",
       "      <td>15.3</td>\n",
       "      <td>396.90</td>\n",
       "      <td>4.98</td>\n",
       "      <td>24.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.02731</td>\n",
       "      <td>0.0</td>\n",
       "      <td>7.07</td>\n",
       "      <td>0</td>\n",
       "      <td>0.469</td>\n",
       "      <td>6.421</td>\n",
       "      <td>78.9</td>\n",
       "      <td>4.9671</td>\n",
       "      <td>2</td>\n",
       "      <td>242</td>\n",
       "      <td>17.8</td>\n",
       "      <td>396.90</td>\n",
       "      <td>9.14</td>\n",
       "      <td>21.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.02729</td>\n",
       "      <td>0.0</td>\n",
       "      <td>7.07</td>\n",
       "      <td>0</td>\n",
       "      <td>0.469</td>\n",
       "      <td>7.185</td>\n",
       "      <td>61.1</td>\n",
       "      <td>4.9671</td>\n",
       "      <td>2</td>\n",
       "      <td>242</td>\n",
       "      <td>17.8</td>\n",
       "      <td>392.83</td>\n",
       "      <td>4.03</td>\n",
       "      <td>34.7</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      CRIM    ZN  INDUS  CHAS    NOX     RM   AGE     DIS  RAD  TAX  PTRATIO  \\\n",
       "0  0.00632  18.0   2.31     0  0.538  6.575  65.2  4.0900    1  296     15.3   \n",
       "1  0.02731   0.0   7.07     0  0.469  6.421  78.9  4.9671    2  242     17.8   \n",
       "2  0.02729   0.0   7.07     0  0.469  7.185  61.1  4.9671    2  242     17.8   \n",
       "\n",
       "        B  LSTAT  MEDV  \n",
       "0  396.90   4.98  24.0  \n",
       "1  396.90   9.14  21.6  \n",
       "2  392.83   4.03  34.7  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bhp.head(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f5c2f277-ecb0-4472-8b2e-b65e3ca1d7a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = bhp.drop(columns='MEDV')\n",
    "y = bhp.MEDV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d08cd1a0-21ec-4fe5-b35d-e5aa364d7877",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train, test = train_test_split(bhp,test_size=.33,random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "40091636-2fef-471a-815c-b1b2e0efd2b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "08ff7c2d-178f-4785-b02b-586ab7cb7683",
   "metadata": {},
   "outputs": [],
   "source": [
    "sc = StandardScaler()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "53473ccc-517a-4392-9a02-1f77b23a8b57",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_std = sc.fit_transform(X_train)\n",
    "X_test_std = sc.transform(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "c4e56e4d-b21b-4448-b408-552a5f49ab60",
   "metadata": {},
   "outputs": [],
   "source": [
    "sgd_reg = SGDRegressor(random_state=2022)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "a557eb43-87a1-4a9a-ac74-db706c114b84",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SGDRegressor(random_state=2022)"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sgd_reg.fit(X_train_std,y_train.values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "290bbfa7-a16d-494e-8654-9fce49d4131c",
   "metadata": {},
   "outputs": [],
   "source": [
    "medv_pred = sgd_reg.predict(X_test_std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "672df9ed-6852-436a-b0b2-d6903eef52d4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>medv_actual</th>\n",
       "      <th>medv_pred</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>173</th>\n",
       "      <td>23.6</td>\n",
       "      <td>28.414975</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>274</th>\n",
       "      <td>32.4</td>\n",
       "      <td>36.138594</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>491</th>\n",
       "      <td>13.6</td>\n",
       "      <td>16.807925</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>72</th>\n",
       "      <td>22.8</td>\n",
       "      <td>25.442495</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>452</th>\n",
       "      <td>16.1</td>\n",
       "      <td>18.764191</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>110</th>\n",
       "      <td>21.7</td>\n",
       "      <td>21.387833</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>321</th>\n",
       "      <td>23.1</td>\n",
       "      <td>25.010961</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>265</th>\n",
       "      <td>22.8</td>\n",
       "      <td>27.359265</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>21.0</td>\n",
       "      <td>20.968473</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>262</th>\n",
       "      <td>48.8</td>\n",
       "      <td>40.288485</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>167 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     medv_actual  medv_pred\n",
       "173         23.6  28.414975\n",
       "274         32.4  36.138594\n",
       "491         13.6  16.807925\n",
       "72          22.8  25.442495\n",
       "452         16.1  18.764191\n",
       "..           ...        ...\n",
       "110         21.7  21.387833\n",
       "321         23.1  25.010961\n",
       "265         22.8  27.359265\n",
       "29          21.0  20.968473\n",
       "262         48.8  40.288485\n",
       "\n",
       "[167 rows x 2 columns]"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_validation = pd.DataFrame({'medv_actual':y_test,'medv_pred':medv_pred})\n",
    "df_validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "79ecf9f8-6dbd-45c5-9671-a47a9fd8a778",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "21.17002669359018"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "((df_validation.medv_actual - df_validation.medv_pred)**2).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "12600be5-2812-424e-8d1f-1fd994846832",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "21.17002669359017"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mean_squared_error(y_test,medv_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "f4c23db0-ecb8-4a58-b640-c48baf8519e3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4.601089728921853"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mean_squared_error(y_test,medv_pred,squared=False)  # RMSE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "84402d7e-8ca6-451c-ae72-a54e23b2639d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4.80468801363288"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mean_squared_error(y_train,sgd_reg.predict(X_train_std),squared=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9d99671-85ed-48e6-a8ee-95dab3dc8e5c",
   "metadata": {},
   "source": [
    "### Coefficient of determination"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "6a68989a-690b-44db-96ae-1a70ded2f991",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7380998091037393"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sgd_reg.score(X_train_std,y_train.values)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0ea57e2-7bc1-4d26-8f46-d5d4320dffdf",
   "metadata": {},
   "source": [
    "## Minibatch Stochastic Gradient Descent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "id": "1370878c-036e-496a-b248-d6b63f058af5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def iter_mb(batch_size = 1):    \n",
    "    \n",
    "    start = 0\n",
    "    \n",
    "    while start < X_train_std.shape[0]:\n",
    "        rows = range(start, start + batch_size)\n",
    "        \n",
    "        X_batch = X_train_std[rows,:]\n",
    "        y_batch = y_train[rows]\n",
    "        \n",
    "        yield X_batch, y_batch\n",
    "        start = start + batch_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "id": "2bca24b0-5ec6-4a10-9b7d-a010e1dc074b",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_iterator = iter_mb(batch_size = 113)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "id": "ee4dcdae-1bae-4dce-81f3-9b591859d8a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "sgd_reg = SGDRegressor()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "id": "52ac97ea-863f-43c0-a3f3-87e4f4a897f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "for X_, y_ in batch_iterator:\n",
    "    sgd_reg.partial_fit(X_,y_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "id": "1ac878ce-c24b-4875-a7b4-82652848e768",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "84.40913502331564"
      ]
     },
     "execution_count": 104,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mean_squared_error(y_test,sgd_reg.predict(X_test_std))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f8af7e8-859a-43d9-8106-0ec1de99f01a",
   "metadata": {},
   "source": [
    "## Cross Validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "965d3519-7ae4-4f84-86c8-98e1dbf2af4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "train, test = train_test_split(bhp,test_size=.33,random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "85b580a3-5dd8-4059-838a-e4c2839b6092",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>CRIM</th>\n",
       "      <th>ZN</th>\n",
       "      <th>INDUS</th>\n",
       "      <th>CHAS</th>\n",
       "      <th>NOX</th>\n",
       "      <th>RM</th>\n",
       "      <th>AGE</th>\n",
       "      <th>DIS</th>\n",
       "      <th>RAD</th>\n",
       "      <th>TAX</th>\n",
       "      <th>PTRATIO</th>\n",
       "      <th>B</th>\n",
       "      <th>LSTAT</th>\n",
       "      <th>MEDV</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>478</th>\n",
       "      <td>10.23300</td>\n",
       "      <td>0.0</td>\n",
       "      <td>18.10</td>\n",
       "      <td>0</td>\n",
       "      <td>0.614</td>\n",
       "      <td>6.185</td>\n",
       "      <td>96.7</td>\n",
       "      <td>2.1705</td>\n",
       "      <td>24</td>\n",
       "      <td>666</td>\n",
       "      <td>20.2</td>\n",
       "      <td>379.70</td>\n",
       "      <td>18.03</td>\n",
       "      <td>14.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>0.67191</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8.14</td>\n",
       "      <td>0</td>\n",
       "      <td>0.538</td>\n",
       "      <td>5.813</td>\n",
       "      <td>90.3</td>\n",
       "      <td>4.6820</td>\n",
       "      <td>4</td>\n",
       "      <td>307</td>\n",
       "      <td>21.0</td>\n",
       "      <td>376.88</td>\n",
       "      <td>14.81</td>\n",
       "      <td>16.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0.14455</td>\n",
       "      <td>12.5</td>\n",
       "      <td>7.87</td>\n",
       "      <td>0</td>\n",
       "      <td>0.524</td>\n",
       "      <td>6.172</td>\n",
       "      <td>96.1</td>\n",
       "      <td>5.9505</td>\n",
       "      <td>5</td>\n",
       "      <td>311</td>\n",
       "      <td>15.2</td>\n",
       "      <td>396.90</td>\n",
       "      <td>19.15</td>\n",
       "      <td>27.1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>492</th>\n",
       "      <td>0.11132</td>\n",
       "      <td>0.0</td>\n",
       "      <td>27.74</td>\n",
       "      <td>0</td>\n",
       "      <td>0.609</td>\n",
       "      <td>5.983</td>\n",
       "      <td>83.5</td>\n",
       "      <td>2.1099</td>\n",
       "      <td>4</td>\n",
       "      <td>711</td>\n",
       "      <td>20.1</td>\n",
       "      <td>396.90</td>\n",
       "      <td>13.35</td>\n",
       "      <td>20.1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>108</th>\n",
       "      <td>0.12802</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8.56</td>\n",
       "      <td>0</td>\n",
       "      <td>0.520</td>\n",
       "      <td>6.474</td>\n",
       "      <td>97.1</td>\n",
       "      <td>2.4329</td>\n",
       "      <td>5</td>\n",
       "      <td>384</td>\n",
       "      <td>20.9</td>\n",
       "      <td>395.24</td>\n",
       "      <td>12.27</td>\n",
       "      <td>19.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>106</th>\n",
       "      <td>0.17120</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8.56</td>\n",
       "      <td>0</td>\n",
       "      <td>0.520</td>\n",
       "      <td>5.836</td>\n",
       "      <td>91.9</td>\n",
       "      <td>2.2110</td>\n",
       "      <td>5</td>\n",
       "      <td>384</td>\n",
       "      <td>20.9</td>\n",
       "      <td>395.67</td>\n",
       "      <td>18.66</td>\n",
       "      <td>19.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>270</th>\n",
       "      <td>0.29916</td>\n",
       "      <td>20.0</td>\n",
       "      <td>6.96</td>\n",
       "      <td>0</td>\n",
       "      <td>0.464</td>\n",
       "      <td>5.856</td>\n",
       "      <td>42.1</td>\n",
       "      <td>4.4290</td>\n",
       "      <td>3</td>\n",
       "      <td>223</td>\n",
       "      <td>18.6</td>\n",
       "      <td>388.65</td>\n",
       "      <td>13.00</td>\n",
       "      <td>21.1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>348</th>\n",
       "      <td>0.01501</td>\n",
       "      <td>80.0</td>\n",
       "      <td>2.01</td>\n",
       "      <td>0</td>\n",
       "      <td>0.435</td>\n",
       "      <td>6.635</td>\n",
       "      <td>29.7</td>\n",
       "      <td>8.3440</td>\n",
       "      <td>4</td>\n",
       "      <td>280</td>\n",
       "      <td>17.0</td>\n",
       "      <td>390.94</td>\n",
       "      <td>5.99</td>\n",
       "      <td>24.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>435</th>\n",
       "      <td>11.16040</td>\n",
       "      <td>0.0</td>\n",
       "      <td>18.10</td>\n",
       "      <td>0</td>\n",
       "      <td>0.740</td>\n",
       "      <td>6.629</td>\n",
       "      <td>94.6</td>\n",
       "      <td>2.1247</td>\n",
       "      <td>24</td>\n",
       "      <td>666</td>\n",
       "      <td>20.2</td>\n",
       "      <td>109.85</td>\n",
       "      <td>23.27</td>\n",
       "      <td>13.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>102</th>\n",
       "      <td>0.22876</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8.56</td>\n",
       "      <td>0</td>\n",
       "      <td>0.520</td>\n",
       "      <td>6.405</td>\n",
       "      <td>85.4</td>\n",
       "      <td>2.7147</td>\n",
       "      <td>5</td>\n",
       "      <td>384</td>\n",
       "      <td>20.9</td>\n",
       "      <td>70.80</td>\n",
       "      <td>10.63</td>\n",
       "      <td>18.6</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>339 rows × 14 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "         CRIM    ZN  INDUS  CHAS    NOX     RM   AGE     DIS  RAD  TAX  \\\n",
       "478  10.23300   0.0  18.10     0  0.614  6.185  96.7  2.1705   24  666   \n",
       "26    0.67191   0.0   8.14     0  0.538  5.813  90.3  4.6820    4  307   \n",
       "7     0.14455  12.5   7.87     0  0.524  6.172  96.1  5.9505    5  311   \n",
       "492   0.11132   0.0  27.74     0  0.609  5.983  83.5  2.1099    4  711   \n",
       "108   0.12802   0.0   8.56     0  0.520  6.474  97.1  2.4329    5  384   \n",
       "..        ...   ...    ...   ...    ...    ...   ...     ...  ...  ...   \n",
       "106   0.17120   0.0   8.56     0  0.520  5.836  91.9  2.2110    5  384   \n",
       "270   0.29916  20.0   6.96     0  0.464  5.856  42.1  4.4290    3  223   \n",
       "348   0.01501  80.0   2.01     0  0.435  6.635  29.7  8.3440    4  280   \n",
       "435  11.16040   0.0  18.10     0  0.740  6.629  94.6  2.1247   24  666   \n",
       "102   0.22876   0.0   8.56     0  0.520  6.405  85.4  2.7147    5  384   \n",
       "\n",
       "     PTRATIO       B  LSTAT  MEDV  \n",
       "478     20.2  379.70  18.03  14.6  \n",
       "26      21.0  376.88  14.81  16.6  \n",
       "7       15.2  396.90  19.15  27.1  \n",
       "492     20.1  396.90  13.35  20.1  \n",
       "108     20.9  395.24  12.27  19.8  \n",
       "..       ...     ...    ...   ...  \n",
       "106     20.9  395.67  18.66  19.5  \n",
       "270     18.6  388.65  13.00  21.1  \n",
       "348     17.0  390.94   5.99  24.5  \n",
       "435     20.2  109.85  23.27  13.4  \n",
       "102     20.9   70.80  10.63  18.6  \n",
       "\n",
       "[339 rows x 14 columns]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "17a940dc-0f3a-4ebe-b28f-3c12f830d4f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "train = train.sample(frac=1,random_state=2022).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "93e5fd9b-bb93-4a36-99f5-86bcd5187cac",
   "metadata": {},
   "outputs": [],
   "source": [
    "train['kfold'] = -999"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ecd7c6c0-fb9d-4a35-ad14-6860496c454c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>CRIM</th>\n",
       "      <th>ZN</th>\n",
       "      <th>INDUS</th>\n",
       "      <th>CHAS</th>\n",
       "      <th>NOX</th>\n",
       "      <th>RM</th>\n",
       "      <th>AGE</th>\n",
       "      <th>DIS</th>\n",
       "      <th>RAD</th>\n",
       "      <th>TAX</th>\n",
       "      <th>PTRATIO</th>\n",
       "      <th>B</th>\n",
       "      <th>LSTAT</th>\n",
       "      <th>MEDV</th>\n",
       "      <th>kfold</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.26938</td>\n",
       "      <td>0.0</td>\n",
       "      <td>9.9</td>\n",
       "      <td>0</td>\n",
       "      <td>0.544</td>\n",
       "      <td>6.266</td>\n",
       "      <td>82.8</td>\n",
       "      <td>3.2628</td>\n",
       "      <td>4</td>\n",
       "      <td>304</td>\n",
       "      <td>18.4</td>\n",
       "      <td>393.39</td>\n",
       "      <td>7.90</td>\n",
       "      <td>21.6</td>\n",
       "      <td>-999</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4.83567</td>\n",
       "      <td>0.0</td>\n",
       "      <td>18.1</td>\n",
       "      <td>0</td>\n",
       "      <td>0.583</td>\n",
       "      <td>5.905</td>\n",
       "      <td>53.2</td>\n",
       "      <td>3.1523</td>\n",
       "      <td>24</td>\n",
       "      <td>666</td>\n",
       "      <td>20.2</td>\n",
       "      <td>388.22</td>\n",
       "      <td>11.45</td>\n",
       "      <td>20.6</td>\n",
       "      <td>-999</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      CRIM   ZN  INDUS  CHAS    NOX     RM   AGE     DIS  RAD  TAX  PTRATIO  \\\n",
       "0  0.26938  0.0    9.9     0  0.544  6.266  82.8  3.2628    4  304     18.4   \n",
       "1  4.83567  0.0   18.1     0  0.583  5.905  53.2  3.1523   24  666     20.2   \n",
       "\n",
       "        B  LSTAT  MEDV  kfold  \n",
       "0  393.39   7.90  21.6   -999  \n",
       "1  388.22  11.45  20.6   -999  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a7323956-d4b8-4ef4-9616-cfddac3bb2eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "kf = KFold(n_splits=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ee806af3-057d-4c6a-a1ad-859e9860f3d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# lst = ['a','b','c']\n",
    "# for i in enumerate(lst,start=1):\n",
    "#     print(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "d086e85d-a18b-4b72-84d1-7f0343be2fc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "for fold, (t,v) in enumerate(kf.split(train)):\n",
    "    train.loc[v,'kfold'] = fold\n",
    "    \n",
    "    #print(fold,(t,v))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "1e6e273b-f8ba-4b6c-a347-9c43c3982a5e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>CRIM</th>\n",
       "      <th>ZN</th>\n",
       "      <th>INDUS</th>\n",
       "      <th>CHAS</th>\n",
       "      <th>NOX</th>\n",
       "      <th>RM</th>\n",
       "      <th>AGE</th>\n",
       "      <th>DIS</th>\n",
       "      <th>RAD</th>\n",
       "      <th>TAX</th>\n",
       "      <th>PTRATIO</th>\n",
       "      <th>B</th>\n",
       "      <th>LSTAT</th>\n",
       "      <th>MEDV</th>\n",
       "      <th>kfold</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.26938</td>\n",
       "      <td>0.0</td>\n",
       "      <td>9.90</td>\n",
       "      <td>0</td>\n",
       "      <td>0.544</td>\n",
       "      <td>6.266</td>\n",
       "      <td>82.8</td>\n",
       "      <td>3.2628</td>\n",
       "      <td>4</td>\n",
       "      <td>304</td>\n",
       "      <td>18.4</td>\n",
       "      <td>393.39</td>\n",
       "      <td>7.90</td>\n",
       "      <td>21.6</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4.83567</td>\n",
       "      <td>0.0</td>\n",
       "      <td>18.10</td>\n",
       "      <td>0</td>\n",
       "      <td>0.583</td>\n",
       "      <td>5.905</td>\n",
       "      <td>53.2</td>\n",
       "      <td>3.1523</td>\n",
       "      <td>24</td>\n",
       "      <td>666</td>\n",
       "      <td>20.2</td>\n",
       "      <td>388.22</td>\n",
       "      <td>11.45</td>\n",
       "      <td>20.6</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1.34284</td>\n",
       "      <td>0.0</td>\n",
       "      <td>19.58</td>\n",
       "      <td>0</td>\n",
       "      <td>0.605</td>\n",
       "      <td>6.066</td>\n",
       "      <td>100.0</td>\n",
       "      <td>1.7573</td>\n",
       "      <td>5</td>\n",
       "      <td>403</td>\n",
       "      <td>14.7</td>\n",
       "      <td>353.89</td>\n",
       "      <td>6.43</td>\n",
       "      <td>24.3</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>28.65580</td>\n",
       "      <td>0.0</td>\n",
       "      <td>18.10</td>\n",
       "      <td>0</td>\n",
       "      <td>0.597</td>\n",
       "      <td>5.155</td>\n",
       "      <td>100.0</td>\n",
       "      <td>1.5894</td>\n",
       "      <td>24</td>\n",
       "      <td>666</td>\n",
       "      <td>20.2</td>\n",
       "      <td>210.97</td>\n",
       "      <td>20.08</td>\n",
       "      <td>16.3</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.57529</td>\n",
       "      <td>0.0</td>\n",
       "      <td>6.20</td>\n",
       "      <td>0</td>\n",
       "      <td>0.507</td>\n",
       "      <td>8.337</td>\n",
       "      <td>73.3</td>\n",
       "      <td>3.8384</td>\n",
       "      <td>8</td>\n",
       "      <td>307</td>\n",
       "      <td>17.4</td>\n",
       "      <td>385.91</td>\n",
       "      <td>2.47</td>\n",
       "      <td>41.7</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       CRIM   ZN  INDUS  CHAS    NOX     RM    AGE     DIS  RAD  TAX  PTRATIO  \\\n",
       "0   0.26938  0.0   9.90     0  0.544  6.266   82.8  3.2628    4  304     18.4   \n",
       "1   4.83567  0.0  18.10     0  0.583  5.905   53.2  3.1523   24  666     20.2   \n",
       "2   1.34284  0.0  19.58     0  0.605  6.066  100.0  1.7573    5  403     14.7   \n",
       "3  28.65580  0.0  18.10     0  0.597  5.155  100.0  1.5894   24  666     20.2   \n",
       "4   0.57529  0.0   6.20     0  0.507  8.337   73.3  3.8384    8  307     17.4   \n",
       "\n",
       "        B  LSTAT  MEDV  kfold  \n",
       "0  393.39   7.90  21.6      0  \n",
       "1  388.22  11.45  20.6      0  \n",
       "2  353.89   6.43  24.3      0  \n",
       "3  210.97  20.08  16.3      0  \n",
       "4  385.91   2.47  41.7      0  "
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.head(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "9cbf7bca-6a17-413a-85b4-6fe48d318b4f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>kfold</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>kfold</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>68</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>68</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>68</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>68</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>67</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       kfold\n",
       "kfold       \n",
       "0         68\n",
       "1         68\n",
       "2         68\n",
       "3         68\n",
       "4         67"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.groupby('kfold')[['kfold']].count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "3f646229-1508-4418-b67d-0c2cc9972e3d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX',\n",
       "       'PTRATIO', 'B', 'LSTAT', 'MEDV', 'kfold'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "a9d80add-b8ea-457e-9030-70651af9413c",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_vars = [var for var in train.columns if var not in ['MEDV', 'kfold']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "8db7b683-cda0-4e73-a0a1-0210841969f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "sgd = SGDRegressor(random_state=2022)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "7f0ea479-36f9-4a0e-94b9-c512fde4ac46",
   "metadata": {},
   "outputs": [],
   "source": [
    "sc = StandardScaler()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "6e430ddf-cae4-438a-9aa1-1830a3b3ab47",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 4.727464674829044\n",
      "1 5.528386959309025\n",
      "2 4.061059476309088\n",
      "3 3.9848092266990327\n",
      "4 6.429927069431654\n"
     ]
    }
   ],
   "source": [
    "final_pred_lst = []\n",
    "rmse = []\n",
    "\n",
    "param = []\n",
    "for fold in range(5):\n",
    "    \n",
    "    train_cv = train[train.kfold != fold]\n",
    "    test_cv = train[train.kfold == fold]\n",
    "    \n",
    "    y_train = train_cv['MEDV']\n",
    "    y_test = test_cv['MEDV']\n",
    "    \n",
    "    x_train = train_cv[x_vars]\n",
    "    x_test = test_cv[x_vars]\n",
    "    \n",
    "    x_train_std = sc.fit_transform(x_train)\n",
    "    x_test_std = sc.transform(x_test)\n",
    "    \n",
    "    \n",
    "    sgd.fit(x_train_std, y_train)\n",
    "    \n",
    "    pred_test = sgd.predict(x_test_std)\n",
    "    \n",
    "    x_deploy_std = sc.transform(test[x_vars])\n",
    "    \n",
    "    final_pred = sgd.predict(x_deploy_std)\n",
    "    \n",
    "    rmse.append(mean_squared_error(y_test,pred_test,squared=False))\n",
    "    \n",
    "    final_pred_lst.append(final_pred)\n",
    "    \n",
    "    param.append(sgd.coef_)\n",
    "    \n",
    "    print(fold, mean_squared_error(y_test,pred_test,squared=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "b0993d29-5ba6-4049-8f73-b270ce8528e4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[4.727464674829044,\n",
       " 5.528386959309025,\n",
       " 4.061059476309088,\n",
       " 3.9848092266990327,\n",
       " 6.429927069431654]"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rmse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "e0c872c8-c5a0-4011-9365-f4c24bab0ccf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([ 2.88799946e+01,  3.59214586e+01,  1.62974841e+01,  2.55253999e+01,\n",
       "         1.86254835e+01,  2.34567703e+01,  1.78591311e+01,  1.45834448e+01,\n",
       "         2.26516622e+01,  2.11839198e+01,  2.46565619e+01,  1.85979724e+01,\n",
       "        -6.02272538e+00,  2.21241524e+01,  1.97398688e+01,  2.56588726e+01,\n",
       "         1.95843949e+01,  5.94895343e+00,  3.91936940e+01,  1.74207726e+01,\n",
       "         2.72941083e+01,  2.96905487e+01,  1.19337949e+01,  2.44462243e+01,\n",
       "         1.79020900e+01,  1.58537535e+01,  2.34273428e+01,  1.44101876e+01,\n",
       "         2.26603014e+01,  1.99504436e+01,  2.26579300e+01,  2.54636216e+01,\n",
       "         2.49024052e+01,  1.83570256e+01,  1.64635594e+01,  1.74203258e+01,\n",
       "         3.09507666e+01,  2.06172289e+01,  2.41405042e+01,  2.49952212e+01,\n",
       "         1.47609553e+01,  3.13970199e+01,  4.08656419e+01,  1.82642486e+01,\n",
       "         2.75422844e+01,  1.70696011e+01,  1.47187857e+01,  2.61801708e+01,\n",
       "         2.00198888e+01,  3.03138867e+01,  2.12976542e+01,  3.35302759e+01,\n",
       "         1.64322872e+01,  2.67037146e+01,  3.85323077e+01,  2.25576315e+01,\n",
       "         1.87477988e+01,  3.21529644e+01,  2.52379377e+01,  1.33567150e+01,\n",
       "         2.25607055e+01,  2.97566493e+01,  3.14465030e+01,  1.70935582e+01,\n",
       "         2.13374177e+01,  1.71296314e+01,  2.01082345e+01,  2.62408488e+01,\n",
       "         3.04425011e+01,  1.19380117e+01,  2.07379144e+01,  2.69870539e+01,\n",
       "         1.10187062e+01,  1.70442320e+01,  2.42024130e+01,  5.69350457e+00,\n",
       "         2.22339807e+01,  3.96840764e+01,  1.80632995e+01,  1.05875531e+01,\n",
       "         2.17059428e+01,  1.29657346e+01,  2.13865805e+01,  9.57472944e+00,\n",
       "         2.33706293e+01,  3.18963644e+01,  1.88789450e+01,  2.58442426e+01,\n",
       "         2.86279222e+01,  2.04278759e+01,  2.58797026e+01,  5.74977215e+00,\n",
       "         2.05279743e+01,  1.62062207e+01,  1.41293848e+01,  2.11779411e+01,\n",
       "         2.44495536e+01,  3.47582563e-02,  1.36317948e+01,  1.55532426e+01,\n",
       "         2.25719590e+01,  2.48742692e+01,  1.08669299e+01,  2.03649492e+01,\n",
       "         2.37294543e+01,  1.20328262e+01,  1.92256636e+01,  2.57396170e+01,\n",
       "         2.13024664e+01,  2.44264703e+01,  7.73679137e+00,  1.87726260e+01,\n",
       "         2.22563904e+01,  2.75098105e+01,  3.19881110e+01,  1.51733509e+01,\n",
       "         3.43652578e+01,  1.38806418e+01,  2.15906629e+01,  2.82237748e+01,\n",
       "         1.62128866e+01,  2.48284894e+01,  3.91332370e+00,  2.43846867e+01,\n",
       "         2.59025323e+01,  2.34639937e+01,  2.48545843e+01,  3.34064451e+01,\n",
       "         2.04917540e+01,  3.76919222e+01,  1.34998060e+01,  2.60175940e+01,\n",
       "         1.85723346e+01,  2.12273351e+01,  1.00064867e+01,  2.07253262e+01,\n",
       "         2.26849946e+01,  3.10427202e+01,  3.09862712e+01,  1.63030229e+01,\n",
       "         1.76437967e+01,  2.84218543e+01,  2.44996497e+01,  1.72753741e+01,\n",
       "         6.50144614e+00,  2.60573146e+01,  2.31699185e+01,  1.72437620e+01,\n",
       "         1.43066410e+01,  3.94535003e+01,  1.64776998e+01,  1.81055132e+01,\n",
       "         2.57194736e+01,  2.41725324e+01,  2.24205239e+01,  2.18670430e+01,\n",
       "         1.66800398e+01,  2.29706603e+01,  2.92501074e+01,  6.86589741e+00,\n",
       "         2.44794855e+01,  1.74766938e+01,  2.18610009e+01,  2.54238789e+01,\n",
       "         2.73146629e+01,  2.15831585e+01,  3.99720241e+01]),\n",
       " array([27.71148689, 36.57204763, 16.51999503, 25.62143366, 18.7749978 ,\n",
       "        22.77408393, 16.83688847, 14.03854977, 22.32377501, 20.24188481,\n",
       "        24.26604101, 18.14565161, -4.97339598, 21.86297571, 18.4683187 ,\n",
       "        26.29399535, 18.93522324,  5.44487063, 40.23615682, 17.67271269,\n",
       "        27.2174564 , 29.60390558, 10.28389674, 23.88540429, 17.88076687,\n",
       "        15.23171169, 23.19813367, 14.74123582, 21.88688838, 19.07734443,\n",
       "        20.87147341, 24.94461863, 26.27075293, 18.07751506, 15.8797522 ,\n",
       "        17.85121246, 30.88634516, 19.71841895, 23.98300233, 25.17830976,\n",
       "        13.51207289, 30.40395217, 42.04495376, 17.61188033, 26.50873931,\n",
       "        16.8498418 , 13.04704081, 26.46306753, 19.86619143, 30.0671286 ,\n",
       "        20.89968711, 33.68727523, 15.11140921, 26.02707769, 39.14128978,\n",
       "        23.13638785, 18.8239833 , 32.40421707, 24.51706121, 12.36906976,\n",
       "        22.63567573, 30.69337637, 31.4537356 , 16.15194862, 22.12751619,\n",
       "        15.83417045, 20.34290194, 25.9146144 , 30.67723459, 11.63190732,\n",
       "        20.06448557, 27.39734682, 11.16410137, 17.38428323, 23.89961819,\n",
       "         7.09291417, 20.81156065, 40.73355727, 18.65585374, 10.36370795,\n",
       "        21.33381159, 12.10556975, 22.03174908,  9.13825134, 22.23051812,\n",
       "        31.86922982, 19.19771792, 25.55071182, 28.86305412, 20.15850661,\n",
       "        25.14338815,  6.21044244, 20.35556746, 15.67074066, 12.82287288,\n",
       "        21.00427545, 24.37527696, -0.98359332, 14.19332614, 15.53651976,\n",
       "        21.58384776, 24.74322663, 10.55456312, 19.57605154, 23.98027596,\n",
       "        11.52793532, 18.04412064, 25.92975615, 20.80793875, 24.96093265,\n",
       "         7.92543071, 19.02219672, 21.93853552, 26.10403718, 31.4142967 ,\n",
       "        15.5896829 , 33.59695307, 12.34863178, 20.85799875, 28.55046583,\n",
       "        15.20539698, 25.26315336,  7.77801477, 23.57836497, 26.03196793,\n",
       "        22.79935012, 25.69726309, 32.20943293, 21.92618091, 37.84959892,\n",
       "        12.5246386 , 26.15362665, 17.20055362, 20.14856034, 10.10367766,\n",
       "        20.68941064, 21.97678876, 31.81460473, 31.51016183, 14.89344648,\n",
       "        16.44919508, 29.15425949, 24.51602031, 16.12716196,  8.37741442,\n",
       "        25.54108003, 24.44973626, 17.69090162, 12.6663928 , 39.68962213,\n",
       "        16.56345179, 18.55832353, 24.95360166, 23.56619748, 21.74138048,\n",
       "        21.57680776, 17.26100753, 23.53207421, 29.01883023, 10.463424  ,\n",
       "        23.02093828, 16.49129733, 21.01054283, 24.85437459, 26.72302244,\n",
       "        20.73183863, 40.623578  ]),\n",
       " array([28.43278542, 36.73973495, 17.5511637 , 25.04778783, 19.32776696,\n",
       "        22.99136743, 17.45820679, 14.29483929, 22.61831242, 21.02832165,\n",
       "        25.46801196, 19.03857258, -6.05784795, 21.87645858, 19.08815083,\n",
       "        26.02434847, 18.70208328,  6.12575591, 39.46167679, 18.14750946,\n",
       "        26.70801009, 29.18915106, 10.97918182, 23.89020664, 18.5130177 ,\n",
       "        15.81063989, 23.23164995, 14.55165591, 23.03452048, 19.89120368,\n",
       "        22.02444205, 24.95819957, 25.51274544, 19.52562206, 16.25046718,\n",
       "        19.30383193, 30.67386742, 19.91196941, 24.69443138, 24.43467964,\n",
       "        14.46090741, 30.31381605, 40.89757371, 17.91265018, 26.9510429 ,\n",
       "        17.66083414, 13.81305815, 25.65260663, 20.15235936, 30.86805793,\n",
       "        21.81175497, 33.34508864, 15.61339894, 26.23134008, 39.1514091 ,\n",
       "        23.08986159, 19.56878242, 31.598256  , 24.51030563, 12.94986572,\n",
       "        22.27727454, 29.81712392, 31.15200354, 16.37550697, 22.43132451,\n",
       "        17.32093565, 20.63693331, 25.72166929, 30.18115524, 12.03471076,\n",
       "        20.00358891, 27.2007071 , 11.45922118, 17.37464212, 23.8632835 ,\n",
       "         6.2016703 , 21.50504668, 39.68955041, 18.41756817, 11.56653733,\n",
       "        21.49421286, 12.47869952, 22.14453399,  9.94213265, 22.95980276,\n",
       "        32.65140297, 19.44137949, 25.42764026, 28.21916635, 20.48756076,\n",
       "        25.62163731,  5.74276263, 20.75818481, 16.2905146 , 15.00991588,\n",
       "        21.42759369, 24.7352694 , -0.9661257 , 13.96218175, 14.88987802,\n",
       "        21.6820532 , 24.64798989, 10.81219457, 20.01185747, 24.02351487,\n",
       "        11.36236453, 18.62867401, 25.51575716, 20.81036343, 25.15768076,\n",
       "         7.75916659, 18.37571449, 21.87195127, 26.42746919, 31.47525486,\n",
       "        15.59163558, 33.32405518, 13.07952733, 21.52743996, 28.20291281,\n",
       "        15.95922217, 24.8075646 ,  4.63115135, 24.55764228, 25.56308966,\n",
       "        22.98514384, 25.57407139, 32.43323748, 23.40794481, 37.37617317,\n",
       "        12.23947189, 26.45776187, 17.92779403, 21.1332478 , 11.6393425 ,\n",
       "        20.43185036, 21.95422317, 31.63312774, 31.1892019 , 15.62567147,\n",
       "        16.99876109, 29.21652012, 24.74444642, 17.41303612,  7.09505302,\n",
       "        25.42773285, 26.40021734, 18.03056217, 13.45520366, 38.93802898,\n",
       "        16.76576824, 19.02341344, 25.23562728, 24.2992849 , 21.81066555,\n",
       "        21.60689055, 16.87352717, 24.30033932, 28.85474773,  8.01869781,\n",
       "        23.38049148, 16.9096955 , 21.23159677, 24.92641288, 27.94601162,\n",
       "        20.99958998, 40.72790413]),\n",
       " array([28.83423864, 36.02128579, 17.1565242 , 25.39305364, 19.24058089,\n",
       "        23.54108905, 16.92237993, 13.8632932 , 23.82497365, 20.87974514,\n",
       "        23.84421005, 17.85529784, -7.47512992, 21.96082762, 19.23426303,\n",
       "        26.74334852, 19.16119505,  5.28900065, 39.706381  , 17.93973941,\n",
       "        26.9260155 , 29.67125356, 10.48122889, 24.15046885, 18.76452991,\n",
       "        16.54006247, 23.50774606, 15.14648291, 21.65264236, 19.58958011,\n",
       "        22.10987686, 25.24628659, 25.8066537 , 19.72226667, 16.86955198,\n",
       "        17.57807436, 30.9819025 , 19.79324298, 23.15517256, 24.65628951,\n",
       "        13.73566183, 31.35915026, 41.53664235, 17.53588318, 27.19038372,\n",
       "        17.76217282, 13.72466995, 25.91536304, 20.76480022, 30.47026406,\n",
       "        20.93068587, 33.50212948, 15.317975  , 26.86247975, 38.9857063 ,\n",
       "        23.60100206, 19.46930377, 32.60719606, 24.58241821, 12.28978468,\n",
       "        22.01628142, 29.75489349, 31.60071727, 16.44337706, 21.46070382,\n",
       "        15.61068382, 20.78494699, 26.03943719, 30.65584994, 12.68441655,\n",
       "        19.61754794, 28.51112713, 11.06120456, 16.88179025, 24.25422611,\n",
       "         4.68553213, 22.10597811, 40.27525055, 18.2794981 , 10.92282331,\n",
       "        21.3696074 , 12.93588797, 21.1635858 ,  9.38448591, 23.20665183,\n",
       "        31.67224009, 19.6554632 , 25.63877005, 28.95614887, 20.65917922,\n",
       "        25.43693616,  5.47210258, 20.94884918, 15.35308018, 15.89490817,\n",
       "        21.72602679, 26.00617626, -0.65009772, 14.35966555, 16.16759265,\n",
       "        21.6114631 , 25.02421528, 10.79996924, 20.01832838, 23.5185626 ,\n",
       "        12.43507323, 18.73911519, 25.42646686, 20.52242831, 24.53059937,\n",
       "         7.6970919 , 19.39607972, 21.59490018, 26.90388139, 32.27328513,\n",
       "        16.16061675, 34.02648068, 13.01943313, 21.36381733, 28.15865077,\n",
       "        15.79461374, 24.62787041,  3.729277  , 24.69737675, 25.39099066,\n",
       "        23.17939029, 25.10601514, 33.42293247, 20.48555778, 37.94088805,\n",
       "        12.15286271, 25.85272701, 18.08287039, 20.95907791,  8.82646417,\n",
       "        22.10710134, 21.95066644, 31.56532506, 31.15792452, 15.43339235,\n",
       "        16.86632703, 28.8487641 , 24.54220039, 15.60581001,  5.72040316,\n",
       "        26.27572622, 23.06868501, 17.59925813, 13.14115229, 39.89039075,\n",
       "        17.41243624, 18.75429824, 25.32522694, 24.24224524, 21.6983644 ,\n",
       "        21.52475174, 17.04453902, 24.45189098, 28.97575161,  6.22147468,\n",
       "        23.51418517, 16.80547308, 21.40388906, 25.2541465 , 27.16127248,\n",
       "        21.09165336, 40.35609094]),\n",
       " array([ 2.78675305e+01,  3.64005816e+01,  1.66549109e+01,  2.53705966e+01,\n",
       "         1.81525393e+01,  2.30674727e+01,  1.77177141e+01,  1.52638576e+01,\n",
       "         2.16702530e+01,  2.05145085e+01,  2.54224660e+01,  1.88370608e+01,\n",
       "        -5.25183632e+00,  2.21187961e+01,  1.86580411e+01,  2.53396731e+01,\n",
       "         2.01794066e+01,  5.94401333e+00,  4.00597671e+01,  1.69414705e+01,\n",
       "         2.72438874e+01,  2.97807813e+01,  1.14290025e+01,  2.38954166e+01,\n",
       "         1.70435170e+01,  1.49007221e+01,  2.31745010e+01,  1.40973479e+01,\n",
       "         2.28995894e+01,  1.92948692e+01,  2.16491205e+01,  2.46767032e+01,\n",
       "         2.47217104e+01,  1.66493122e+01,  1.60530572e+01,  1.78271365e+01,\n",
       "         3.04765446e+01,  1.97051150e+01,  2.47099051e+01,  2.51415902e+01,\n",
       "         1.45026304e+01,  3.05666529e+01,  4.20294327e+01,  1.79713880e+01,\n",
       "         2.66752472e+01,  1.62001442e+01,  1.37769171e+01,  2.63389617e+01,\n",
       "         1.94568310e+01,  2.97044243e+01,  2.12181903e+01,  3.35992486e+01,\n",
       "         1.58484489e+01,  2.56418450e+01,  3.87542651e+01,  2.20668271e+01,\n",
       "         1.80592162e+01,  3.21882195e+01,  2.47336574e+01,  1.30955373e+01,\n",
       "         2.24035633e+01,  3.03862534e+01,  3.09601263e+01,  1.62475782e+01,\n",
       "         2.12076303e+01,  1.74644873e+01,  1.96347888e+01,  2.57983425e+01,\n",
       "         3.05130710e+01,  1.12637470e+01,  2.08331163e+01,  2.60921071e+01,\n",
       "         1.08747201e+01,  1.65304816e+01,  2.40329707e+01,  6.43782296e+00,\n",
       "         2.09650566e+01,  4.08189456e+01,  1.81034361e+01,  1.11600135e+01,\n",
       "         2.12282987e+01,  1.27649962e+01,  2.21138230e+01,  9.21590617e+00,\n",
       "         2.25092400e+01,  3.18057477e+01,  1.84896310e+01,  2.51508718e+01,\n",
       "         2.89242478e+01,  2.03131202e+01,  2.53644377e+01,  5.80649035e+00,\n",
       "         2.01769283e+01,  1.56869579e+01,  1.12435479e+01,  2.07269039e+01,\n",
       "         2.30455993e+01,  3.05759372e-02,  1.32217837e+01,  1.53323796e+01,\n",
       "         2.21947338e+01,  2.46897110e+01,  1.05730868e+01,  1.92198305e+01,\n",
       "         2.39543500e+01,  1.18207364e+01,  1.80483233e+01,  2.53781146e+01,\n",
       "         2.05327540e+01,  2.45450486e+01,  7.83486299e+00,  1.86243064e+01,\n",
       "         2.14725297e+01,  2.65399757e+01,  3.11566451e+01,  1.43697238e+01,\n",
       "         3.37703090e+01,  1.28458920e+01,  2.09231622e+01,  2.84004373e+01,\n",
       "         1.52842058e+01,  2.47309903e+01,  4.39328746e+00,  2.34779665e+01,\n",
       "         2.58600531e+01,  2.25662103e+01,  2.52776536e+01,  3.20908093e+01,\n",
       "         2.18191463e+01,  3.76903258e+01,  1.48332328e+01,  2.57771125e+01,\n",
       "         1.72943718e+01,  2.07851137e+01,  1.14312418e+01,  1.98442306e+01,\n",
       "         2.18724540e+01,  3.15238101e+01,  3.14246659e+01,  1.53509337e+01,\n",
       "         1.67551358e+01,  2.88204711e+01,  2.48943807e+01,  1.78879017e+01,\n",
       "         7.23227790e+00,  2.60340716e+01,  2.44073055e+01,  1.70407223e+01,\n",
       "         1.35261605e+01,  3.92441363e+01,  1.58554357e+01,  1.75546424e+01,\n",
       "         2.45532235e+01,  2.36637538e+01,  2.20863003e+01,  2.10209438e+01,\n",
       "         1.65379579e+01,  2.17534423e+01,  2.83764831e+01,  7.55471085e+00,\n",
       "         2.37828762e+01,  1.67234597e+01,  2.12343551e+01,  2.45490176e+01,\n",
       "         2.70360811e+01,  2.07239574e+01,  4.03583640e+01])]"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "final_pred_lst"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "2b13ce4c-5dd7-4589-b0fd-e3da45104440",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([28.34520721, 36.33102171, 16.83601558, 25.39165434, 18.82427369,\n",
       "       23.16615667, 17.35886407, 14.40879693, 22.61779524, 20.76967598,\n",
       "       24.73145819, 18.49491105, -5.95618711, 21.98864207, 19.03772849,\n",
       "       26.01204759, 19.31246062,  5.75051879, 39.73153513, 17.62444094,\n",
       "       27.07789554, 29.58712804, 11.02142096, 24.05354414, 18.02078429,\n",
       "       15.66737794, 23.3078747 , 14.58938203, 22.42678841, 19.56068822,\n",
       "       21.86256856, 25.05788593, 25.44285353, 18.46634832, 16.30327759,\n",
       "       17.99611622, 30.79388525, 19.94919505, 24.13660312, 24.88121805,\n",
       "       14.19444556, 30.80811825, 41.47484887, 17.85921005, 26.97353951,\n",
       "       17.10851881, 13.81609434, 26.11003395, 20.05201416, 30.28475231,\n",
       "       21.23159448, 33.53280358, 15.66470384, 26.29329143, 38.9129956 ,\n",
       "       22.89034202, 18.93381689, 32.19017062, 24.71627603, 12.8121945 ,\n",
       "       22.3787001 , 30.08165929, 31.32261712, 16.46239381, 21.71291851,\n",
       "       16.67198172, 20.30156111, 25.94298243, 30.49396237, 11.91055867,\n",
       "       20.25133063, 27.2376684 , 11.11559067, 17.04308582, 24.0505023 ,\n",
       "        6.02228883, 21.52432454, 40.24027604, 18.30393113, 10.92012703,\n",
       "       21.42637466, 12.6501776 , 21.76805448,  9.4511011 , 22.85536839,\n",
       "       31.978997  , 19.13262731, 25.5224473 , 28.71810787, 20.40924853,\n",
       "       25.48922038,  5.79631403, 20.55350082, 15.84150279, 13.82012592,\n",
       "       21.21254817, 24.52237511, -0.50689651, 13.8737504 , 15.49592252,\n",
       "       21.92881138, 24.7958824 , 10.72134873, 19.83820342, 23.84123154,\n",
       "       11.83578715, 18.53717936, 25.59794235, 20.79519017, 24.72414632,\n",
       "        7.79066871, 18.83818466, 21.82686142, 26.69703479, 31.66151856,\n",
       "       15.37700199, 33.81661115, 13.0348252 , 21.25261624, 28.30724831,\n",
       "       15.69126506, 24.85161362,  4.88901085, 24.13920745, 25.74972671,\n",
       "       22.99881763, 25.30191749, 32.71257146, 21.62611675, 37.70978163,\n",
       "       13.05000241, 26.0517644 , 17.81558488, 20.85066696, 10.40144257,\n",
       "       20.75958382, 22.08782539, 31.51591756, 31.25364508, 15.52129339,\n",
       "       16.94264313, 28.89237381, 24.63933951, 16.86185679,  6.98531893,\n",
       "       25.86718506, 24.29917253, 17.52104123, 13.41911004, 39.44313569,\n",
       "       16.61495837, 18.39923816, 25.15743061, 23.98880277, 21.95144692,\n",
       "       21.51928737, 16.87941429, 23.40168143, 28.89518402,  7.82484095,\n",
       "       23.63559532, 16.88132388, 21.34827695, 25.00156609, 27.23621011,\n",
       "       21.02603958, 40.40759224])"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(np.column_stack(final_pred_lst),axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "19bb2f4a-b459-4de9-8541-fb0545d6e3fb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([-0.91087252,  0.52993232, -0.08898891,  0.82191116, -1.75743435,\n",
       "         2.78635451, -0.2751313 , -2.98125841,  1.39133042, -0.87548779,\n",
       "        -1.79472779,  1.00475511, -3.79564777]),\n",
       " array([-0.37186456,  0.65993185,  0.33715493,  0.98811186, -1.37824222,\n",
       "         3.17122818, -0.67997827, -2.54151232,  1.25839159, -0.64612266,\n",
       "        -1.99692831,  1.11032725, -4.01858337]),\n",
       " array([-0.92172263,  0.63388318, -0.00499632,  1.11389525, -1.08150927,\n",
       "         2.91278773, -0.31700001, -2.59564889,  0.95927578, -0.17600656,\n",
       "        -2.098587  ,  1.12121747, -4.16177789]),\n",
       " array([-1.12617702,  0.70845959,  0.17778608,  0.83884781, -1.85034259,\n",
       "         2.78777199, -0.08474836, -2.98940397,  1.55633497, -0.47720417,\n",
       "        -1.93403011,  0.91068495, -4.40561952]),\n",
       " array([-0.8626316 ,  0.6831068 ,  0.39354667,  0.95335961, -1.68464805,\n",
       "         2.97610466, -0.67902598, -2.76987726,  1.18174605, -0.97261911,\n",
       "        -2.17108074,  1.15340185, -3.00690351])]"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "param"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "810de327-a2d1-4f0d-a939-fb283fb538d8",
   "metadata": {},
   "source": [
    "## Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "2676f215-7d62-48db-a239-95c18f58b424",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ".. _breast_cancer_dataset:\n",
      "\n",
      "Breast cancer wisconsin (diagnostic) dataset\n",
      "--------------------------------------------\n",
      "\n",
      "**Data Set Characteristics:**\n",
      "\n",
      "    :Number of Instances: 569\n",
      "\n",
      "    :Number of Attributes: 30 numeric, predictive attributes and the class\n",
      "\n",
      "    :Attribute Information:\n",
      "        - radius (mean of distances from center to points on the perimeter)\n",
      "        - texture (standard deviation of gray-scale values)\n",
      "        - perimeter\n",
      "        - area\n",
      "        - smoothness (local variation in radius lengths)\n",
      "        - compactness (perimeter^2 / area - 1.0)\n",
      "        - concavity (severity of concave portions of the contour)\n",
      "        - concave points (number of concave portions of the contour)\n",
      "        - symmetry\n",
      "        - fractal dimension (\"coastline approximation\" - 1)\n",
      "\n",
      "        The mean, standard error, and \"worst\" or largest (mean of the three\n",
      "        worst/largest values) of these features were computed for each image,\n",
      "        resulting in 30 features.  For instance, field 0 is Mean Radius, field\n",
      "        10 is Radius SE, field 20 is Worst Radius.\n",
      "\n",
      "        - class:\n",
      "                - WDBC-Malignant\n",
      "                - WDBC-Benign\n",
      "\n",
      "    :Summary Statistics:\n",
      "\n",
      "    ===================================== ====== ======\n",
      "                                           Min    Max\n",
      "    ===================================== ====== ======\n",
      "    radius (mean):                        6.981  28.11\n",
      "    texture (mean):                       9.71   39.28\n",
      "    perimeter (mean):                     43.79  188.5\n",
      "    area (mean):                          143.5  2501.0\n",
      "    smoothness (mean):                    0.053  0.163\n",
      "    compactness (mean):                   0.019  0.345\n",
      "    concavity (mean):                     0.0    0.427\n",
      "    concave points (mean):                0.0    0.201\n",
      "    symmetry (mean):                      0.106  0.304\n",
      "    fractal dimension (mean):             0.05   0.097\n",
      "    radius (standard error):              0.112  2.873\n",
      "    texture (standard error):             0.36   4.885\n",
      "    perimeter (standard error):           0.757  21.98\n",
      "    area (standard error):                6.802  542.2\n",
      "    smoothness (standard error):          0.002  0.031\n",
      "    compactness (standard error):         0.002  0.135\n",
      "    concavity (standard error):           0.0    0.396\n",
      "    concave points (standard error):      0.0    0.053\n",
      "    symmetry (standard error):            0.008  0.079\n",
      "    fractal dimension (standard error):   0.001  0.03\n",
      "    radius (worst):                       7.93   36.04\n",
      "    texture (worst):                      12.02  49.54\n",
      "    perimeter (worst):                    50.41  251.2\n",
      "    area (worst):                         185.2  4254.0\n",
      "    smoothness (worst):                   0.071  0.223\n",
      "    compactness (worst):                  0.027  1.058\n",
      "    concavity (worst):                    0.0    1.252\n",
      "    concave points (worst):               0.0    0.291\n",
      "    symmetry (worst):                     0.156  0.664\n",
      "    fractal dimension (worst):            0.055  0.208\n",
      "    ===================================== ====== ======\n",
      "\n",
      "    :Missing Attribute Values: None\n",
      "\n",
      "    :Class Distribution: 212 - Malignant, 357 - Benign\n",
      "\n",
      "    :Creator:  Dr. William H. Wolberg, W. Nick Street, Olvi L. Mangasarian\n",
      "\n",
      "    :Donor: Nick Street\n",
      "\n",
      "    :Date: November, 1995\n",
      "\n",
      "This is a copy of UCI ML Breast Cancer Wisconsin (Diagnostic) datasets.\n",
      "https://goo.gl/U2Uwz2\n",
      "\n",
      "Features are computed from a digitized image of a fine needle\n",
      "aspirate (FNA) of a breast mass.  They describe\n",
      "characteristics of the cell nuclei present in the image.\n",
      "\n",
      "Separating plane described above was obtained using\n",
      "Multisurface Method-Tree (MSM-T) [K. P. Bennett, \"Decision Tree\n",
      "Construction Via Linear Programming.\" Proceedings of the 4th\n",
      "Midwest Artificial Intelligence and Cognitive Science Society,\n",
      "pp. 97-101, 1992], a classification method which uses linear\n",
      "programming to construct a decision tree.  Relevant features\n",
      "were selected using an exhaustive search in the space of 1-4\n",
      "features and 1-3 separating planes.\n",
      "\n",
      "The actual linear program used to obtain the separating plane\n",
      "in the 3-dimensional space is that described in:\n",
      "[K. P. Bennett and O. L. Mangasarian: \"Robust Linear\n",
      "Programming Discrimination of Two Linearly Inseparable Sets\",\n",
      "Optimization Methods and Software 1, 1992, 23-34].\n",
      "\n",
      "This database is also available through the UW CS ftp server:\n",
      "\n",
      "ftp ftp.cs.wisc.edu\n",
      "cd math-prog/cpo-dataset/machine-learn/WDBC/\n",
      "\n",
      ".. topic:: References\n",
      "\n",
      "   - W.N. Street, W.H. Wolberg and O.L. Mangasarian. Nuclear feature extraction \n",
      "     for breast tumor diagnosis. IS&T/SPIE 1993 International Symposium on \n",
      "     Electronic Imaging: Science and Technology, volume 1905, pages 861-870,\n",
      "     San Jose, CA, 1993.\n",
      "   - O.L. Mangasarian, W.N. Street and W.H. Wolberg. Breast cancer diagnosis and \n",
      "     prognosis via linear programming. Operations Research, 43(4), pages 570-577, \n",
      "     July-August 1995.\n",
      "   - W.H. Wolberg, W.N. Street, and O.L. Mangasarian. Machine learning techniques\n",
      "     to diagnose breast cancer from fine-needle aspirates. Cancer Letters 77 (1994) \n",
      "     163-171.\n"
     ]
    }
   ],
   "source": [
    "bc = load_breast_cancer()\n",
    "\n",
    "print(bc.DESCR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "06b3eca0-57c5-432c-a621-0c45ff1bef07",
   "metadata": {},
   "outputs": [],
   "source": [
    "bc_data = pd.DataFrame(bc.data,columns=bc.feature_names)\n",
    "bc_data['target'] = bc.target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "672efd73-1e1c-440b-979e-d203e116cd74",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>mean radius</th>\n",
       "      <th>mean texture</th>\n",
       "      <th>mean perimeter</th>\n",
       "      <th>mean area</th>\n",
       "      <th>mean smoothness</th>\n",
       "      <th>mean compactness</th>\n",
       "      <th>mean concavity</th>\n",
       "      <th>mean concave points</th>\n",
       "      <th>mean symmetry</th>\n",
       "      <th>mean fractal dimension</th>\n",
       "      <th>...</th>\n",
       "      <th>worst texture</th>\n",
       "      <th>worst perimeter</th>\n",
       "      <th>worst area</th>\n",
       "      <th>worst smoothness</th>\n",
       "      <th>worst compactness</th>\n",
       "      <th>worst concavity</th>\n",
       "      <th>worst concave points</th>\n",
       "      <th>worst symmetry</th>\n",
       "      <th>worst fractal dimension</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>17.99</td>\n",
       "      <td>10.38</td>\n",
       "      <td>122.8</td>\n",
       "      <td>1001.0</td>\n",
       "      <td>0.11840</td>\n",
       "      <td>0.27760</td>\n",
       "      <td>0.3001</td>\n",
       "      <td>0.14710</td>\n",
       "      <td>0.2419</td>\n",
       "      <td>0.07871</td>\n",
       "      <td>...</td>\n",
       "      <td>17.33</td>\n",
       "      <td>184.6</td>\n",
       "      <td>2019.0</td>\n",
       "      <td>0.1622</td>\n",
       "      <td>0.6656</td>\n",
       "      <td>0.7119</td>\n",
       "      <td>0.2654</td>\n",
       "      <td>0.4601</td>\n",
       "      <td>0.11890</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>20.57</td>\n",
       "      <td>17.77</td>\n",
       "      <td>132.9</td>\n",
       "      <td>1326.0</td>\n",
       "      <td>0.08474</td>\n",
       "      <td>0.07864</td>\n",
       "      <td>0.0869</td>\n",
       "      <td>0.07017</td>\n",
       "      <td>0.1812</td>\n",
       "      <td>0.05667</td>\n",
       "      <td>...</td>\n",
       "      <td>23.41</td>\n",
       "      <td>158.8</td>\n",
       "      <td>1956.0</td>\n",
       "      <td>0.1238</td>\n",
       "      <td>0.1866</td>\n",
       "      <td>0.2416</td>\n",
       "      <td>0.1860</td>\n",
       "      <td>0.2750</td>\n",
       "      <td>0.08902</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>19.69</td>\n",
       "      <td>21.25</td>\n",
       "      <td>130.0</td>\n",
       "      <td>1203.0</td>\n",
       "      <td>0.10960</td>\n",
       "      <td>0.15990</td>\n",
       "      <td>0.1974</td>\n",
       "      <td>0.12790</td>\n",
       "      <td>0.2069</td>\n",
       "      <td>0.05999</td>\n",
       "      <td>...</td>\n",
       "      <td>25.53</td>\n",
       "      <td>152.5</td>\n",
       "      <td>1709.0</td>\n",
       "      <td>0.1444</td>\n",
       "      <td>0.4245</td>\n",
       "      <td>0.4504</td>\n",
       "      <td>0.2430</td>\n",
       "      <td>0.3613</td>\n",
       "      <td>0.08758</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>3 rows × 31 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   mean radius  mean texture  mean perimeter  mean area  mean smoothness  \\\n",
       "0        17.99         10.38           122.8     1001.0          0.11840   \n",
       "1        20.57         17.77           132.9     1326.0          0.08474   \n",
       "2        19.69         21.25           130.0     1203.0          0.10960   \n",
       "\n",
       "   mean compactness  mean concavity  mean concave points  mean symmetry  \\\n",
       "0           0.27760          0.3001              0.14710         0.2419   \n",
       "1           0.07864          0.0869              0.07017         0.1812   \n",
       "2           0.15990          0.1974              0.12790         0.2069   \n",
       "\n",
       "   mean fractal dimension  ...  worst texture  worst perimeter  worst area  \\\n",
       "0                 0.07871  ...          17.33            184.6      2019.0   \n",
       "1                 0.05667  ...          23.41            158.8      1956.0   \n",
       "2                 0.05999  ...          25.53            152.5      1709.0   \n",
       "\n",
       "   worst smoothness  worst compactness  worst concavity  worst concave points  \\\n",
       "0            0.1622             0.6656           0.7119                0.2654   \n",
       "1            0.1238             0.1866           0.2416                0.1860   \n",
       "2            0.1444             0.4245           0.4504                0.2430   \n",
       "\n",
       "   worst symmetry  worst fractal dimension  target  \n",
       "0          0.4601                  0.11890       0  \n",
       "1          0.2750                  0.08902       0  \n",
       "2          0.3613                  0.08758       0  \n",
       "\n",
       "[3 rows x 31 columns]"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bc_data.head(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "id": "6c364040-f418-4e49-8039-4b275e26a8ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "sgd_class = SGDClassifier(loss='log',random_state=2022)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "d4daf3f3-ace6-4450-bb7b-050ea127616b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_bc, test_bc = train_test_split(bc_data, test_size=.2,random_state=2022)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "bf7d5635-94f2-4ce0-90df-e880281c996e",
   "metadata": {},
   "outputs": [],
   "source": [
    "y = bc_data.target\n",
    "X = bc_data.drop(columns='target')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "c89b8c17-c2bb-453c-96a4-616fedb3dc65",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=2022)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "58c1211f-ffe5-40af-b9fd-cc1b8cdf87a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "sc = StandardScaler()\n",
    "X_train_std = sc.fit_transform(X_train)\n",
    "X_test_std = sc.transform(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "4509e38e-20dc-444e-9a1f-1271de17a0dc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 100,  101,  102, ..., 1997, 1998, 1999])"
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.arange(100,2000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "id": "60ec3412-640a-4ce2-9618-4892a577f8a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "penalty = ['l2', 'l1', 'elasticnet']\n",
    "alpha = [.001,.1,1,10,100,1000]\n",
    "l1_ratio = [.01,.1,.9,.8,.6]\n",
    "learning_rate = ['constant','adaptive']\n",
    "eta0 = [.001,.1,.02,.2,.8,1,10,100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "id": "c3417194-de42-4621-a1b8-bb17b8cb984f",
   "metadata": {},
   "outputs": [],
   "source": [
    "param_dist = dict(penalty=penalty,alpha=alpha,l1_ratio=l1_ratio,learning_rate=learning_rate,eta0=eta0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "id": "8692a2ee-d305-490d-972e-d3d68f726c0f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'penalty': ['l2', 'l1', 'elasticnet'],\n",
       " 'alpha': [0.001, 0.1, 1, 10, 100, 1000],\n",
       " 'l1_ratio': [0.01, 0.1, 0.9, 0.8, 0.6],\n",
       " 'learning_rate': ['constant', 'adaptive'],\n",
       " 'eta0': [0.001, 0.1, 0.02, 0.2, 0.8, 1, 10, 100]}"
      ]
     },
     "execution_count": 84,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "param_dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "10c591d0-0f99-4707-9205-cf31024273ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_class = GridSearchCV(estimator = sgd_class,param_grid=param_dist,n_jobs=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "49e04edf-5390-470e-8968-68ec3b77068a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GridSearchCV(estimator=SGDClassifier(loss='log', random_state=2022), n_jobs=-1,\n",
       "             param_grid={'alpha': [0.001, 0.1, 1, 10, 100, 1000],\n",
       "                         'eta0': [0.001, 0.1, 0.02, 0.2, 0.8, 1, 10, 100],\n",
       "                         'l1_ratio': [0.01, 0.1, 0.9, 0.8, 0.6],\n",
       "                         'learning_rate': ['constant', 'adaptive'],\n",
       "                         'penalty': ['l2', 'l1', 'elasticnet']})"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grid_class.fit(X_train_std,y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "id": "68089c9a-58cc-4164-b3a3-3ffc0628cd25",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SGDClassifier(alpha=0.001, eta0=0.001, l1_ratio=0.01, learning_rate='constant',\n",
       "              loss='log', random_state=2022)"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grid_class.best_estimator_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "8d1e4f64-01af-4cea-a8ae-be7a5458b42b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sklearn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "id": "c77ad451-bc5e-4f61-83ea-1898ff180e18",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1.0.2'"
      ]
     },
     "execution_count": 91,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sklearn.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "id": "0079a088-b4a1-4ccd-ba7b-37d088e05115",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1,\n",
       "       0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1,\n",
       "       1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0,\n",
       "       1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1,\n",
       "       1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0,\n",
       "       0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0,\n",
       "       1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0,\n",
       "       1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1])"
      ]
     },
     "execution_count": 90,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grid_class.predict(X_test_std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "id": "14485fa8-e2c3-470d-876b-c82d01422f6a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1    357\n",
       "0    212\n",
       "Name: target, dtype: int64"
      ]
     },
     "execution_count": 94,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bc_data.target.value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "id": "0d5379f8-46cd-4c3e-a137-f872e2b3937c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>col_0</th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>target</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>71</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>113</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "col_0    0    1\n",
       "target         \n",
       "0       71    2\n",
       "1        2  113"
      ]
     },
     "execution_count": 93,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.crosstab(y_test,grid_class.predict(X_test_std))   # confusion matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "id": "6b6d9101-715b-4053-837c-6286ffe37499",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9787234042553191"
      ]
     },
     "execution_count": 95,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(71+113)/len(y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "id": "a4673b94-2fdc-472b-bc26-c11c553f68a9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9787234042553191"
      ]
     },
     "execution_count": 96,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy_score(y_test,grid_class.predict(X_test_std))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "id": "b32cab4a-f107-46d6-9669-35eaac0ef711",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(188, 2)"
      ]
     },
     "execution_count": 98,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grid_class.predict_proba(X_test_std).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "id": "0fb9b228-b281-4019-8616-2163dfa1fe79",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1.00961892e-01, 2.37915879e-03, 5.77579170e-01, 8.24401305e-03,\n",
       "       1.21379703e-01, 1.36498164e-02, 5.79929347e-01, 9.94857360e-01,\n",
       "       9.43051986e-01, 9.99999974e-01, 9.99999659e-01, 7.50630275e-01,\n",
       "       9.60081006e-01, 9.94757081e-01, 2.75255347e-01, 9.88923623e-03,\n",
       "       5.97068205e-02, 8.44767197e-03, 3.40685572e-01, 9.93896027e-04,\n",
       "       3.57076541e-02, 1.74627004e-02, 7.46848661e-02, 5.02523547e-02,\n",
       "       6.03997733e-01, 4.69606746e-03, 4.75998100e-02, 1.42854701e-04,\n",
       "       9.85154833e-01, 5.28955721e-03, 2.32374047e-02, 1.70208310e-02,\n",
       "       9.99734136e-01, 5.83896736e-01, 2.79086216e-02, 9.83488911e-01,\n",
       "       9.93049153e-01, 4.26336273e-04, 9.99764906e-01, 3.84763167e-02,\n",
       "       3.46188176e-02, 9.14088548e-01, 9.98976926e-01, 3.10374760e-04,\n",
       "       9.59159111e-01, 1.08694758e-03, 9.99528233e-01, 2.29924734e-03,\n",
       "       9.99999975e-01, 9.99999721e-01, 9.67793709e-02, 8.58256764e-01,\n",
       "       6.70180704e-01, 1.02691400e-02, 9.99997168e-01, 9.40243380e-01,\n",
       "       3.56896430e-02, 2.26201548e-03, 2.24034463e-01, 1.25497778e-03,\n",
       "       9.99652902e-01, 4.09866892e-03, 2.33167023e-02, 9.90896290e-01,\n",
       "       5.63008616e-01, 4.16565603e-03, 9.51653069e-02, 5.33414682e-03,\n",
       "       8.71340637e-01, 1.91806763e-02, 3.63541554e-02, 9.99676886e-01,\n",
       "       1.84103146e-02, 2.44877629e-03, 9.03498728e-01, 1.18742369e-01,\n",
       "       9.99018283e-01, 9.84205682e-01, 9.95844767e-01, 4.79456022e-02,\n",
       "       3.42755711e-02, 2.44422976e-03, 9.92224643e-01, 2.03964731e-02,\n",
       "       2.94918513e-02, 9.99181175e-01, 9.98793446e-01, 9.98597703e-01,\n",
       "       1.75356297e-02, 9.99994350e-01, 6.72190426e-03, 9.98917169e-01,\n",
       "       2.61937696e-04, 5.92689716e-02, 9.87065674e-01, 1.16789481e-03,\n",
       "       7.28606739e-03, 5.68673235e-02, 6.87406696e-02, 1.13657447e-02,\n",
       "       9.98936880e-01, 6.54101368e-02, 8.48819931e-01, 9.99997645e-01,\n",
       "       1.20988680e-01, 1.79378960e-01, 9.47155322e-02, 9.99793445e-01,\n",
       "       7.63082143e-01, 1.16381529e-03, 3.12912844e-03, 2.29016577e-02,\n",
       "       5.36205745e-01, 2.60990290e-02, 5.91574316e-02, 1.94257136e-03,\n",
       "       1.75855673e-02, 9.07300659e-02, 9.13620842e-01, 1.10232999e-02,\n",
       "       9.47666146e-01, 3.82174686e-01, 1.35921097e-03, 1.48502621e-01,\n",
       "       1.20089046e-01, 8.60107215e-02, 2.36155857e-01, 3.65565013e-04,\n",
       "       9.99800394e-01, 1.74810271e-01, 1.56237521e-01, 9.99959332e-01,\n",
       "       9.90798668e-01, 1.32931664e-01, 6.89186467e-01, 1.61910388e-02,\n",
       "       6.43777224e-02, 1.83301206e-02, 9.99833956e-01, 8.98550263e-02,\n",
       "       7.71510927e-02, 3.94442640e-03, 9.53559609e-01, 1.07918137e-02,\n",
       "       9.99341427e-01, 1.05920636e-02, 1.03580726e-02, 9.32379036e-01,\n",
       "       3.19160870e-03, 9.86162666e-02, 4.65472815e-01, 6.94315590e-03,\n",
       "       9.99590132e-01, 9.97357456e-01, 1.25041014e-01, 9.96548897e-01,\n",
       "       9.96631127e-01, 9.95292317e-01, 9.97174293e-01, 1.21291736e-01,\n",
       "       7.46166134e-04, 4.49312836e-01, 2.19508906e-02, 5.64136548e-01,\n",
       "       9.99998036e-01, 5.90061960e-02, 9.99890755e-01, 1.08437633e-03,\n",
       "       3.60191442e-01, 9.99999181e-01, 2.60104369e-03, 9.37418304e-01,\n",
       "       9.40631572e-02, 9.90626352e-01, 2.79708465e-03, 9.96151249e-01,\n",
       "       1.19200749e-02, 3.21691252e-02, 9.14125004e-01, 2.25416587e-01,\n",
       "       5.96782764e-03, 1.01558857e-03, 8.40305001e-01, 9.52511765e-02,\n",
       "       2.55467860e-03, 3.41143320e-01, 3.16560931e-04, 3.28414014e-03])"
      ]
     },
     "execution_count": 104,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grid_class.predict_proba(X_test_std)[:,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "id": "aa514985-377d-4358-8c50-610a37c4b1e9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "       1.])"
      ]
     },
     "execution_count": 101,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sum(grid_class.predict_proba(X_test_std),axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bd2b16f-234f-4d5d-97c9-25e9a1c2243d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}


In [124]:
for fold, (t,v) in enumerate(kf.split(train)):
    train.loc[v,'kfold'] = fold
    
    #print(fold,(t,v))

In [125]:
train.head(5)

Unnamed: 0,CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT,MEDV,kfold
0,0.26938,0.0,9.9,0,0.544,6.266,82.8,3.2628,4,304,18.4,393.39,7.9,21.6,0
1,4.83567,0.0,18.1,0,0.583,5.905,53.2,3.1523,24,666,20.2,388.22,11.45,20.6,0
2,1.34284,0.0,19.58,0,0.605,6.066,100.0,1.7573,5,403,14.7,353.89,6.43,24.3,0
3,28.6558,0.0,18.1,0,0.597,5.155,100.0,1.5894,24,666,20.2,210.97,20.08,16.3,0
4,0.57529,0.0,6.2,0,0.507,8.337,73.3,3.8384,8,307,17.4,385.91,2.47,41.7,0


In [2]:
train.groupby('kfold')[['kfold']].count()

NameError: name 'X_train' is not defined