In [3]:
import numpy as np
from sklearn.datasets import make_blobs
from sklearn.model_selection import train_test_split

In [2]:
class SVM:
    def __init__(self,lr=0.001,alpha=0.01,n_iters=1000):
        self.lr = lr
        self.alpha = alpha
        self.n_iters = n_iters
        self.cls_map = None
        self.w = None
        self.b = None
    def _init_weights_bias(self,X):
        n_features = X.shape[1]
        self.w = np.zeros(n_features)
        self.b = 0
    def _get_cls_map(self,y):
        return np.where(y<=0,-1,1)

    def satisfy_constraint(self,x,idx):
        linear_model = np.dot(x,self.w)+self.b
        return self.cls_map[idx]*linear_model >= 1

    def _get_gradient_descent(self,constrain,x,idx):
        if(constrain):
            dw = self.alpha*self.w
            db = 0
        else:
            dw = self.alpha*self.w - np.dot(self.cls_map[idx],x)
            db = - self.cls_map[idx]
        return dw,db
    
    def update_weights_bias(self,dw,db):
        self.w -= self.lr*dw
        self.b -= self.lr*db

    def fit(self,X,y):
        self._init_weights_bias(X)
        self.cls_map = self._get_cls_map()
        for i in range(self.n_iters):
            for idx,x in enumerate(X):
                constrain = self.satisfy_constraint(x,idx)
                dw,db = self._get_gradient_descent(constrain,x,idx)
                self.update_weights_bias(dw,db)
    def predict(self,X):
        estimate = np.dot(X,self.w)+self.b
        prediction = np.sign(estimate)

        return np.where(prediction==-1,0,1)

In [6]:
X,y = make_blobs(n_samples=250,n_features=2,centers=2,cluster_std=1.05,random_state=0)

# print(X.shape,y.shape)
# print(X)

(250, 2) (250,)
[[-0.81531362  6.35210149]
 [ 0.51935885  6.24551424]
 [ 2.53555391 -0.11517895]
 [ 1.10106611  0.45213016]
 [ 0.21158241  1.37114484]
 [ 0.02217373  1.09588119]
 [ 0.94790763  1.75892389]
 [ 1.28942016  5.69649252]
 [ 2.13091265  4.99181424]
 [ 1.51187253 -0.13010769]
 [ 0.65180646  2.54398333]
 [ 1.87594222  3.62021045]
 [ 0.43992468  3.06412353]
 [ 2.58568825  5.84661404]
 [ 2.65092231  0.6638548 ]
 [ 2.75496976  0.41390788]
 [-1.70436923  4.99008685]
 [ 0.04409504  2.22395104]
 [ 1.44232647  4.65414537]
 [ 4.10897545  1.30726165]
 [ 0.8069653   0.36550649]
 [ 1.86096117  0.04963275]
 [ 1.15127725  4.97057034]
 [ 0.07198311  6.30935553]
 [ 1.68217957  4.73162226]
 [ 1.39577558  0.39258519]
 [ 0.55293428  5.58735465]
 [ 0.06193307  3.69599518]
 [ 2.66271509  1.26480084]
 [ 1.74883829  0.09809684]
 [ 2.99535921  5.85832786]
 [ 0.95775149  2.16936621]
 [ 3.87462477  1.61638982]
 [ 1.96504022  1.10167124]
 [ 0.19972893  5.92395265]
 [ 1.04611316  4.62138282]
 [ 0.9872951