In [1]:
import pandas as pd
import numpy as np
from sklearn import datasets, linear_model
from __future__ import division

In [2]:
class LRPI:
    def __init__(self, normalize=False, n_jobs=1, t_value = 2.13144955):
        self.normalize = normalize
        self.n_jobs = n_jobs
        self.LR = linear_model.LinearRegression(normalize=self.normalize, n_jobs= self.n_jobs)
        self.t_value = t_value
        
    def fit(self, X_train, y_train):
        self.X_train = pd.DataFrame(X_train.values)
        self.y_train = pd.DataFrame(y_train.values)
        
        self.LR.fit(self.X_train, self.y_train)
        X_train_fit = self.LR.predict(self.X_train)
        self.MSE = np.power(self.y_train.subtract(X_train_fit), 2).sum(axis=0) / (self.X_train.shape[0] - self.X_train.shape[1] - 1)
        self.X_train.loc[:, 'const_one'] = 1
        self.XTX_inv = np.linalg.inv(np.dot(np.transpose(self.X_train.values) , self.X_train.values))
        
    def predict(self, X_test):
        self.X_test = pd.DataFrame(X_test.values)
        self.pred = self.LR.predict(self.X_test)
        self.X_test.loc[: , 'const_one'] =1
        SE = [np.dot(np.transpose(self.X_test.values[i]) , np.dot(self.XTX_inv, self.X_test.values[i]) ) for i in range(len(self.X_test)) ]
        results = pd.DataFrame(self.pred , columns=['Pred'])
        
        results.loc[:,"lower"] = results['Pred'].subtract((self.t_value)* (np.sqrt(self.MSE.values + np.multiply(SE,self.MSE.values) )),  axis=0)
        results.loc[:,"upper"] = results['Pred'].add((self.t_value)* (np.sqrt(self.MSE.values + np.multiply(SE,self.MSE.values) )),  axis=0)
        
        return results

In [3]:
data = pd.DataFrame(datasets.load_diabetes().data)
target = pd.DataFrame(datasets.load_diabetes().target)

In [4]:
X_train = data.iloc[:-30 ]
y_train = target.iloc[: -30 ]

X_test = data.iloc[-30: ]
y_test = target.iloc[-30: ]

In [5]:
model = LRPI()
model.fit(X_train, y_train)

In [6]:
results = model.predict(X_test)
results.head(10)

Unnamed: 0,Pred,lower,upper
0,234.040695,115.313069,352.76832
1,122.922466,4.996653,240.848279
2,166.354331,47.642949,285.065713
3,174.392229,56.049506,292.734952
4,226.948514,107.843118,346.05391
5,151.821902,33.002894,270.64091
6,100.943585,-17.585884,219.473054
7,83.089561,-35.107119,201.286241
8,143.129867,24.313062,261.946671
9,192.678237,74.918278,310.438195
