In [6]:
from scipy.optimize import minimize
import matplotlib.pyplot as plt
import utils
import numpy as np

class DataLoader(object):
    # this class has a standard iterator declared
    # __len__ returns the number of batches (size of the object)
    # __get_item__ handles integer based indexing of the object 
    def __init__(self, data_file, batch_size):
        with open(data_file, 'r') as df:
            data = df.readlines()
        data = data[1:]
        data = data[:(len(data)//batch_size)*batch_size]
        np.random.shuffle(data)
        data = np.array([[float(col) for col in row.split(',')] for row in data])
        input_data, targets = data[:, 1:-1], data[:, -1]
        input_data = np.hstack([input_data, np.ones((len(input_data), 1), dtype=np.float32)])

        self.num_features = input_data.shape[1]
        self.current_batch_index = 0
        self.input_batches = np.split(input_data, len(input_data)//batch_size)
        self.target_batches = np.split(targets, len(targets)//batch_size)

    def __len__(self):
        return len(self.input_batches)

    def __getitem__(self,i):

        batch_input_data = self.input_batches[i]
        batch_targets = self.target_batches[i]
        return batch_input_data, batch_targets

def classify(inputs, weights):
    #this functions returns w^Tx . The output  is batch_size*1
	return np.dot(inputs, np.reshape(weights, (np.size(weights), 1)).reshape((-1,)))

def get_objective_function(trainx,trainy,loss_type, regularizer_type, loss_weight):
    # this function calculates the loss for a current batch
    loss_function = utils.loss_functions[loss_type]
    if regularizer_type != None:

        regularizer_function = utils.regularizer_functions[regularizer_type]
    def objective_function(weights):
        loss = 0
        
        inputs, targets = trainx,trainy
        outputs = classify(inputs, weights)
        loss += loss_weight*loss_function(targets, outputs)
        if regularizer_type != None:
            # regulariser function is called from utils.py
            loss += regularizer_function(weights)
        return loss
    return objective_function

def get_gradient_function(trainx,trainy,loss_type, regularizer_type, loss_weight):
    # This is a way to declare function inside a function 
    # The get_gradient_function receives the train data from the current batch
    # and all other parameters on which the loss function and gradient depend
    # like C,regulariser_type and loss function
    loss_grad_function = utils.loss_grad_functions[loss_type]
    if regularizer_type != None:
        regularizer_grad_function = utils.regularizer_grad_functions[regularizer_type]
    # gradient function is called from scipy.optimise.minimise()
    # the only paramter its can send is weights 
    # hence there was a need to pass the current batch through get_objective_function


    def gradient_function(weights):

        gradient = np.zeros(len(weights), dtype=np.float32)
        X=trainx
        Y=trainy
        outputs = classify(X,weights)
        # loss_grad_function is called from utils.py
        gradient = loss_weight*loss_grad_function(weights,X,Y,outputs)/len(trainx)
        if regularizer_type != None:
            # regulariser grad function is called from utils.py
            gradient += regularizer_grad_function(weights)
        return gradient
    
    return gradient_function

def train(data_loader, loss_type, regularizer_type, loss_weight, num_epochs):
    initial_model_parameters = np.zeros(data_loader.num_features)

    num_epochs=num_epochs
    for i in range(num_epochs):
        loss=0
        if(i==0):
            start_parameters=initial_model_parameters
        for j in range(len(data_loader)):
            trainx,trainy=data_loader[j]
            objective_function = get_objective_function(trainx,trainy,loss_type, 
                                                regularizer_type,loss_weight)
            gradient_function = get_gradient_function(trainx,trainy, loss_type, 
                                              regularizer_type, loss_weight)
            # to know about this function please read about scipy.optimise.minimise
            trained_model_parameters = minimize(objective_function, 
                                        start_parameters, 
                                        method="L-BFGS-B", 
                                        jac=gradient_function,
                                        options={'disp': False,
                                                 'maxiter': 1})
            loss+=objective_function(trained_model_parameters.x)
            
            start_parameters=trained_model_parameters.x
        # prints the batch loss
        print("loss is  ",loss)
        
    print("Optimizer information:")
    print(trained_model_parameters)
    return trained_model_parameters.x
            

def test(inputs, weights):
    outputs = classify(inputs, weights)
    probs = 1/(1+np.exp(-outputs))
    # this is done to get all terms in 0 or 1 You can change for -1 and 1
    return np.round(probs).astype(int)

def write_csv_file(outputs, output_file):
    # dumps the output file
    with open(output_file, "w") as out_file:
        out_file.write("ID,Output\n")
        for i in range(len(outputs)):
            out_file.write("{}, {}".format(i+1, str(outputs[i])) + "\n")
def get_data(data_file):
    with open(data_file, 'r') as df:
        data = df.readlines()

    data = data[1:]
    data = np.array([[float(col) for col in row.split(',')] for row in data])
    input_data = np.hstack([data, np.ones((len(data), 1), dtype=np.float32)])
    return input_data


In [2]:
train_data_loader = DataLoader("train.csv",32)
test_data = get_data("train.csv")
print("Got files")

Got files


In [9]:
tr_accuracies=[]
c_list=[1]
num_epochs=[1,50,100,150, 200, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 950, 1000]
for n in num_epochs:
    print("Started training for n=",n)
    trained_model_parameters = train(train_data_loader, "square_hinge_loss", None, 1, n)
    print("Predicting outputs")
    train_data_output = test(test_data[:,1:-1], trained_model_parameters)
    tr_accuracies.append(np.sum(train_data_output==test_data[:,-1])/test_data.shape[0])

Started training for n= 1
loss is   363.85595022984006
Optimizer information:
      fun: 4.922320204805989
 hess_inv: <5x5 LbfgsInvHessProduct with dtype=float64>
      jac: array([-0.07801152, -0.94687435,  0.97796282,  0.871122  , -0.21459737])
  message: b'STOP: TOTAL NO. of ITERATIONS REACHED LIMIT'
     nfev: 5
      nit: 1
   status: 1
  success: False
        x: array([-1.24262936, -0.83232319, -0.56045894, -0.14083791,  0.76047594])
Predicting outputs
Started training for n= 50
loss is   363.85595022984006
loss is   63.754211619405076
loss is   100.41359964463909
loss is   97.38657239567884
loss is   77.10339457982643
loss is   87.65059391964087
loss is   131.1207562648272
loss is   62.15487507648921
loss is   38.917965884897555
loss is   56.35349807167305
loss is   50.69217143713903
loss is   118.24032601512583
loss is   81.14311378321379
loss is   98.6832051798847
loss is   51.34548720420334
loss is   78.62099081807474
loss is   65.25458763485008
loss is   54.52556321667974
l

loss is   156.48902700109298
loss is   137.11584050116844
loss is   122.32901074219602
loss is   143.5649696578118
loss is   101.6751026793432
loss is   60.27089852247859
loss is   70.14026420468358
loss is   64.86049665677758
loss is   66.37697832322152
loss is   105.74708523969092
loss is   97.79953017519026
loss is   88.84809029344292
loss is   55.787805810818405
loss is   76.93389071584812
loss is   106.58571655285364
loss is   88.5394408532068
loss is   75.63089681297078
loss is   78.9078063418422
loss is   67.8199764960898
loss is   80.09004172354095
loss is   118.13870812870016
loss is   104.82391769757587
loss is   91.61257034987527
loss is   75.55213353133993
loss is   94.52913780255663
loss is   89.82627332482059
loss is   89.87931557133608
loss is   63.89202247047802
loss is   88.76841588974504
loss is   89.19293799591262
loss is   62.932735242348954
loss is   63.35153530229101
loss is   80.94355151223192
loss is   75.54843641427973
loss is   84.57676139440882
loss is   59.2

loss is   55.63999545292342
loss is   59.53770877584124
loss is   90.13803930806588
loss is   137.0401579309213
loss is   75.6976876315583
loss is   70.23371782495164
loss is   82.73305632153233
loss is   63.493347561865434
loss is   64.9551222701315
loss is   96.47337120179917
loss is   48.53078146581098
loss is   103.50189768470688
loss is   68.45513965287617
loss is   66.30496880200425
loss is   67.85090601053388
loss is   70.13154304571229
loss is   113.09050457583338
loss is   74.08637487739453
loss is   96.13851926383958
loss is   72.78884387096657
loss is   81.18203411374975
loss is   71.23946519480309
loss is   99.09305267563268
loss is   74.03130317349284
loss is   81.32454882588686
loss is   65.92840828210704
loss is   73.77102244977253
loss is   106.7429885123888
loss is   117.87585346103043
loss is   74.31066151443092
loss is   77.15235333845268
loss is   56.36735311379975
loss is   64.25323265561647
loss is   65.2630609993967
loss is   130.35625067561725
loss is   102.1064

loss is   96.76043914846386
loss is   62.25617846185352
loss is   85.94839182738752
loss is   85.83905635537768
loss is   67.71874372147097
loss is   99.08043144547406
loss is   124.19670643036922
loss is   88.25629228477762
loss is   117.93088875321844
loss is   119.83901018963594
loss is   142.0722924570478
loss is   86.29575470474232
loss is   103.72017820958307
loss is   86.27513477571213
loss is   73.67384016480518
loss is   133.06285290042644
loss is   127.2612790252002
loss is   61.22854180403937
loss is   115.65769220001557
loss is   74.80544302039607
loss is   93.14712289287141
loss is   103.74530287752214
loss is   96.85542721804798
loss is   74.66631227232197
loss is   84.80054609602519
loss is   89.83016466966133
loss is   108.47734183497745
loss is   142.86933262319977
loss is   180.62076522982574
loss is   106.97747779716994
loss is   110.27148862692799
loss is   101.24452237207996
loss is   106.84786397821705
loss is   142.06838488667674
loss is   155.26352970845258
loss

loss is   102.1064663897813
loss is   67.47095475846471
loss is   91.28153703782877
loss is   96.76043914846386
loss is   62.25617846185352
loss is   85.94839182738752
loss is   85.83905635537768
loss is   67.71874372147097
loss is   99.08043144547406
loss is   124.19670643036922
loss is   88.25629228477762
loss is   117.93088875321844
loss is   119.83901018963594
loss is   142.0722924570478
loss is   86.29575470474232
loss is   103.72017820958307
loss is   86.27513477571213
loss is   73.67384016480518
loss is   133.06285290042644
loss is   127.2612790252002
loss is   61.22854180403937
loss is   115.65769220001557
loss is   74.80544302039607
loss is   93.14712289287141
loss is   103.74530287752214
loss is   96.85542721804798
loss is   74.66631227232197
loss is   84.80054609602519
loss is   89.83016466966133
loss is   108.47734183497745
loss is   142.86933262319977
loss is   180.62076522982574
loss is   106.97747779716994
loss is   110.27148862692799
loss is   101.24452237207996
loss is

loss is   77.10339457982643
loss is   87.65059391964087
loss is   131.1207562648272
loss is   62.15487507648921
loss is   38.917965884897555
loss is   56.35349807167305
loss is   50.69217143713903
loss is   118.24032601512583
loss is   81.14311378321379
loss is   98.6832051798847
loss is   51.34548720420334
loss is   78.62099081807474
loss is   65.25458763485008
loss is   54.52556321667974
loss is   55.63999545292342
loss is   59.53770877584124
loss is   90.13803930806588
loss is   137.0401579309213
loss is   75.6976876315583
loss is   70.23371782495164
loss is   82.73305632153233
loss is   63.493347561865434
loss is   64.9551222701315
loss is   96.47337120179917
loss is   48.53078146581098
loss is   103.50189768470688
loss is   68.45513965287617
loss is   66.30496880200425
loss is   67.85090601053388
loss is   70.13154304571229
loss is   113.09050457583338
loss is   74.08637487739453
loss is   96.13851926383958
loss is   72.78884387096657
loss is   81.18203411374975
loss is   71.23946

loss is   99.42091330388907
loss is   102.07303422744391
loss is   75.8758892652027
loss is   86.57812180004346
loss is   153.76687732393503
loss is   120.05303822206355
loss is   93.83151615922296
loss is   95.13927586661377
loss is   112.1290016007156
loss is   83.54413654817154
loss is   72.51617957167024
loss is   88.75030366065904
loss is   71.95381644483857
loss is   68.76932212224135
loss is   55.653631879420615
loss is   69.08256150877037
loss is   96.78123342564393
loss is   129.1012888941969
loss is   84.49705727257202
loss is   86.22754458662753
loss is   53.61273167129941
loss is   79.08906627151455
loss is   113.56559277709063
loss is   44.59211844516391
loss is   82.02527722199773
loss is   86.50234630745572
loss is   118.87299294495317
loss is   161.25840187559285
loss is   131.04160350772483
loss is   88.12095270930544
loss is   71.1029831644623
loss is   59.3233842374985
loss is   73.7972970851325
loss is   56.900211234040405
loss is   71.2474031972452
loss is   53.205

loss is   91.32380334056141
loss is   53.98830008555951
loss is   67.8241560361108
loss is   98.7810709700822
loss is   72.52743641828542
loss is   119.98960055848106
loss is   89.93552161969893
loss is   64.53387917442544
loss is   65.03326338204306
loss is   62.11617037306624
loss is   60.10303706652493
loss is   90.24224666265737
loss is   69.08534862370723
loss is   61.94595242240097
loss is   66.79196226186261
loss is   73.14415358779614
loss is   83.40725698265322
loss is   44.79741482513415
loss is   55.29489916523243
loss is   83.10684023551092
loss is   111.40355810146951
loss is   104.44841482547751
loss is   68.47489322400041
loss is   78.00890085750349
loss is   60.10424528592636
loss is   75.97782367109915
loss is   62.289020382063846
loss is   82.07560068785418
loss is   84.50433635012119
loss is   64.58523938911254
loss is   91.1548246379415
loss is   84.24230581555152
loss is   76.48221250567381
loss is   92.92590851724671
loss is   68.40802503408946
loss is   74.474807

loss is   59.53770877584124
loss is   90.13803930806588
loss is   137.0401579309213
loss is   75.6976876315583
loss is   70.23371782495164
loss is   82.73305632153233
loss is   63.493347561865434
loss is   64.9551222701315
loss is   96.47337120179917
loss is   48.53078146581098
loss is   103.50189768470688
loss is   68.45513965287617
loss is   66.30496880200425
loss is   67.85090601053388
loss is   70.13154304571229
loss is   113.09050457583338
loss is   74.08637487739453
loss is   96.13851926383958
loss is   72.78884387096657
loss is   81.18203411374975
loss is   71.23946519480309
loss is   99.09305267563268
loss is   74.03130317349284
loss is   81.32454882588686
loss is   65.92840828210704
loss is   73.77102244977253
loss is   106.7429885123888
loss is   117.87585346103043
loss is   74.31066151443092
loss is   77.15235333845268
loss is   56.36735311379975
loss is   64.25323265561647
loss is   65.2630609993967
loss is   130.35625067561725
loss is   102.1064663897813
loss is   67.47095

loss is   96.92949424440836
loss is   52.35824872929077
loss is   63.53642612510304
loss is   71.9799624746093
loss is   78.45717383946342
loss is   79.03255238095727
loss is   78.77581138348752
loss is   66.03777169115678
loss is   71.62157268376136
loss is   108.96981324663125
loss is   86.29629341716779
loss is   96.43635965912874
loss is   108.27286643511272
loss is   89.04194887258333
loss is   76.99116203201321
loss is   112.72763736068373
loss is   90.60185710533867
loss is   57.88503045914604
loss is   92.31366186466529
loss is   74.46011863938844
loss is   75.56815003634604
loss is   94.47314070152392
loss is   139.3930919176322
loss is   65.62375801485275
loss is   39.06106232761871
loss is   92.66919162859193
loss is   76.93102536013545
loss is   53.228773936745114
loss is   72.09459972380904
loss is   40.496113467927586
loss is   55.32862106199084
loss is   79.24205931847644
loss is   87.04397887270089
loss is   80.74466522135789
loss is   52.46412394024168
loss is   91.025

loss is   95.33484212874795
loss is   115.84901888186866
loss is   132.68260888006012
loss is   68.82739527945213
loss is   103.8746459494812
loss is   71.05528427322527
loss is   68.97792331459087
loss is   98.87984030456921
loss is   84.37125422374628
loss is   119.76061336443016
loss is   92.78619911582811
loss is   66.97047200894902
loss is   83.99721456036077
loss is   98.97048606233814
loss is   123.12822124365272
loss is   115.3967248706597
loss is   90.07807037259907
loss is   81.06683731550665
loss is   82.49432557904487
loss is   76.9446628691747
loss is   96.31519464355527
loss is   90.78502907415395
loss is   106.96817707469992
loss is   96.97055108564227
loss is   82.93845886644732
loss is   127.36265055992229
loss is   122.78187219736739
loss is   82.77982588736991
loss is   61.13626202018049
loss is   73.37798978192052
loss is   74.6967075802776
loss is   109.80603148858533
loss is   71.78882153769653
loss is   46.347012387302016
loss is   57.52759872035844
loss is   68.

loss is   44.20346754781722
loss is   78.3969713354429
loss is   73.98200458524254
loss is   70.53317528641882
loss is   79.23556264449762
loss is   79.8970862661672
loss is   84.54483019241073
loss is   119.80173761825985
loss is   93.54540039198693
loss is   162.06379539436313
loss is   70.8636890048948
loss is   72.69713496264204
loss is   99.74894654343616
loss is   125.61911390001778
loss is   104.99837254893691
loss is   80.95568563503714
loss is   87.46753576608178
loss is   93.01969943421801
loss is   111.29486929134077
loss is   76.38165719823272
loss is   60.47179378937948
loss is   53.8371477415416
loss is   91.82192977094792
loss is   72.41751330792522
loss is   82.79342992355488
loss is   59.85030813364375
loss is   78.549869969914
loss is   85.29569931972682
loss is   49.076480732548276
loss is   53.23680713279095
loss is   76.87483264233283
loss is   73.16049093927238
loss is   88.98674706750782
loss is   75.64183760899161
loss is   78.3072761243424
loss is   72.37795131

loss is   91.78878799944768
loss is   64.55441790860107
loss is   89.25158666466325
loss is   106.18052279865583
loss is   81.66717832836194
loss is   88.65511962569856
loss is   67.65911236176984
loss is   62.317461408714465
loss is   112.3508446422156
loss is   134.5345266029598
loss is   78.52513419994777
loss is   76.57004070806515
loss is   78.62579226044171
loss is   107.10424908952842
loss is   128.38522753220397
loss is   93.12173240987721
loss is   86.27945992469083
loss is   73.53338657011177
loss is   112.86290841579647
loss is   149.6187147479658
loss is   86.60668487886697
loss is   87.77790779312298
loss is   120.38506388107078
loss is   119.09276915063745
loss is   111.56369458477094
loss is   86.55628389530682
loss is   47.09256449657515
loss is   69.2981726204806
loss is   59.12370096826645
loss is   48.202019355057
loss is   82.77418319550644
loss is   89.61576696585084
loss is   66.9567990799535
loss is   58.99904839537452
loss is   64.33735872353033
loss is   74.623

loss is   68.92413938555198
loss is   81.60348623966821
loss is   87.27750263278739
loss is   74.36460488069262
loss is   44.61419582380049
loss is   101.2800444887482
loss is   104.44347876458848
loss is   76.95011688533856
loss is   87.23731575569583
loss is   108.12599689665957
loss is   109.2437676014014
loss is   77.35920563009076
loss is   92.97927889970708
loss is   125.62489301456105
loss is   102.8485471111543
loss is   76.4737941907949
loss is   82.33157749741915
loss is   80.28366494287644
loss is   105.50507775137831
loss is   101.94818431393772
loss is   78.67975147417528
loss is   64.23509354697066
loss is   104.64608702420554
loss is   105.39632359608709
loss is   86.50567032507065
loss is   94.37451162045527
loss is   117.3630985301884
loss is   120.60161610934104
loss is   83.65443513576568
loss is   56.53131331505921
loss is   85.03452984355143
loss is   63.213574673435375
loss is   70.74549567805104
loss is   75.65622982115914
loss is   86.2751287138777
loss is   92.

loss is   116.63787772503824
loss is   73.89488857281108
loss is   77.9140813745515
loss is   93.03440271868811
loss is   94.80502681671925
loss is   88.66168995207967
loss is   60.579362280032434
loss is   91.48914321236822
loss is   68.62113419001795
loss is   66.09359659083714
loss is   64.52849317535322
loss is   71.2667104074039
loss is   85.19476088149476
loss is   99.42091330388907
loss is   102.07303422744391
loss is   75.8758892652027
loss is   86.57812180004346
loss is   153.76687732393503
loss is   120.05303822206355
loss is   93.83151615922296
loss is   95.13927586661377
loss is   112.1290016007156
loss is   83.54413654817154
loss is   72.51617957167024
loss is   88.75030366065904
loss is   71.95381644483857
loss is   68.76932212224135
loss is   55.653631879420615
loss is   69.08256150877037
loss is   96.78123342564393
loss is   129.1012888941969
loss is   84.49705727257202
loss is   86.22754458662753
loss is   53.61273167129941
loss is   79.08906627151455
loss is   113.565

loss is   70.74549567805104
loss is   75.65622982115914
loss is   86.2751287138777
loss is   92.33622122273694
loss is   99.74603629441769
loss is   87.35344878936593
loss is   69.37919815516581
loss is   91.94236541588963
loss is   102.94950971764223
loss is   123.91503601501161
loss is   87.84212077627907
loss is   62.31141228352862
loss is   74.84077447840617
loss is   80.67778511548619
loss is   92.11853075639604
loss is   89.74195984169555
loss is   80.46641229189302
loss is   82.99735528260015
loss is   88.82514277948435
loss is   84.94958730933438
loss is   78.36715377883222
loss is   86.50533745128565
loss is   88.25725238284522
loss is   57.28993793276939
loss is   64.69805758480801
loss is   83.2604580854806
loss is   89.95947451602996
loss is   98.78326461942956
loss is   72.65549929478259
loss is   47.22504487449856
loss is   88.5943983489865
loss is   75.69978131209056
loss is   99.70913355266427
loss is   102.10829462913199
loss is   79.92223285066368
loss is   87.9552088

loss is   107.10424908952842
loss is   128.38522753220397
loss is   93.12173240987721
loss is   86.27945992469083
loss is   73.53338657011177
loss is   112.86290841579647
loss is   149.6187147479658
loss is   86.60668487886697
loss is   87.77790779312298
loss is   120.38506388107078
loss is   119.09276915063745
loss is   111.56369458477094
loss is   86.55628389530682
loss is   47.09256449657515
loss is   69.2981726204806
loss is   59.12370096826645
loss is   48.202019355057
loss is   82.77418319550644
loss is   89.61576696585084
loss is   66.9567990799535
loss is   58.99904839537452
loss is   64.33735872353033
loss is   74.62302547151873
loss is   103.22668311392782
loss is   81.48268923952251
loss is   53.11172719520454
loss is   74.03651468229903
loss is   58.61609881197525
loss is   80.39028126367829
loss is   72.77905917915439
loss is   68.70127137618802
loss is   54.29881508563571
loss is   70.21159069007945
loss is   88.92333762524214
loss is   97.48379111262331
loss is   83.5239

loss is   78.67172633907211
loss is   117.87554624027634
loss is   65.35062165287543
loss is   85.81456919432046
loss is   90.59210761409611
loss is   56.73087659981771
loss is   80.06230108901462
loss is   85.1905890095747
loss is   77.73120153814853
loss is   57.78094824160982
loss is   68.92413938555198
loss is   81.60348623966821
loss is   87.27750263278739
loss is   74.36460488069262
loss is   44.61419582380049
loss is   101.2800444887482
loss is   104.44347876458848
loss is   76.95011688533856
loss is   87.23731575569583
loss is   108.12599689665957
loss is   109.2437676014014
loss is   77.35920563009076
loss is   92.97927889970708
loss is   125.62489301456105
loss is   102.8485471111543
loss is   76.4737941907949
loss is   82.33157749741915
loss is   80.28366494287644
loss is   105.50507775137831
loss is   101.94818431393772
loss is   78.67975147417528
loss is   64.23509354697066
loss is   104.64608702420554
loss is   105.39632359608709
loss is   86.50567032507065
loss is   94.3

loss is   92.78619911582811
loss is   66.97047200894902
loss is   83.99721456036077
loss is   98.97048606233814
loss is   123.12822124365272
loss is   115.3967248706597
loss is   90.07807037259907
loss is   81.06683731550665
loss is   82.49432557904487
loss is   76.9446628691747
loss is   96.31519464355527
loss is   90.78502907415395
loss is   106.96817707469992
loss is   96.97055108564227
loss is   82.93845886644732
loss is   127.36265055992229
loss is   122.78187219736739
loss is   82.77982588736991
loss is   61.13626202018049
loss is   73.37798978192052
loss is   74.6967075802776
loss is   109.80603148858533
loss is   71.78882153769653
loss is   46.347012387302016
loss is   57.52759872035844
loss is   68.85644323227682
loss is   71.0834655795753
loss is   63.22561470106108
loss is   67.17769285260681
loss is   92.1350607263928
loss is   91.32380334056141
loss is   53.98830008555951
loss is   67.8241560361108
loss is   98.7810709700822
loss is   72.52743641828542
loss is   119.989600

loss is   44.08824829553334
loss is   72.93772328758654
loss is   87.31866135313878
loss is   90.72901290553118
loss is   44.20346754781722
loss is   78.3969713354429
loss is   73.98200458524254
loss is   70.53317528641882
loss is   79.23556264449762
loss is   79.8970862661672
loss is   84.54483019241073
loss is   119.80173761825985
loss is   93.54540039198693
loss is   162.06379539436313
loss is   70.8636890048948
loss is   72.69713496264204
loss is   99.74894654343616
loss is   125.61911390001778
loss is   104.99837254893691
loss is   80.95568563503714
loss is   87.46753576608178
loss is   93.01969943421801
loss is   111.29486929134077
loss is   76.38165719823272
loss is   60.47179378937948
loss is   53.8371477415416
loss is   91.82192977094792
loss is   72.41751330792522
loss is   82.79342992355488
loss is   59.85030813364375
loss is   78.549869969914
loss is   85.29569931972682
loss is   49.076480732548276
loss is   53.23680713279095
loss is   76.87483264233283
loss is   73.1604909

Started training for n= 800
loss is   363.85595022984006
loss is   63.754211619405076
loss is   100.41359964463909
loss is   97.38657239567884
loss is   77.10339457982643
loss is   87.65059391964087
loss is   131.1207562648272
loss is   62.15487507648921
loss is   38.917965884897555
loss is   56.35349807167305
loss is   50.69217143713903
loss is   118.24032601512583
loss is   81.14311378321379
loss is   98.6832051798847
loss is   51.34548720420334
loss is   78.62099081807474
loss is   65.25458763485008
loss is   54.52556321667974
loss is   55.63999545292342
loss is   59.53770877584124
loss is   90.13803930806588
loss is   137.0401579309213
loss is   75.6976876315583
loss is   70.23371782495164
loss is   82.73305632153233
loss is   63.493347561865434
loss is   64.9551222701315
loss is   96.47337120179917
loss is   48.53078146581098
loss is   103.50189768470688
loss is   68.45513965287617
loss is   66.30496880200425
loss is   67.85090601053388
loss is   70.13154304571229
loss is   113.09

loss is   53.61273167129941
loss is   79.08906627151455
loss is   113.56559277709063
loss is   44.59211844516391
loss is   82.02527722199773
loss is   86.50234630745572
loss is   118.87299294495317
loss is   161.25840187559285
loss is   131.04160350772483
loss is   88.12095270930544
loss is   71.1029831644623
loss is   59.3233842374985
loss is   73.7972970851325
loss is   56.900211234040405
loss is   71.2474031972452
loss is   53.205847491336996
loss is   74.18111336766015
loss is   50.8668857558254
loss is   94.86348757809527
loss is   86.66567210161868
loss is   91.88286151634189
loss is   56.510518912064235
loss is   61.095329467938015
loss is   82.1171157683506
loss is   60.51409961154164
loss is   84.576692748028
loss is   96.92949424440836
loss is   52.35824872929077
loss is   63.53642612510304
loss is   71.9799624746093
loss is   78.45717383946342
loss is   79.03255238095727
loss is   78.77581138348752
loss is   66.03777169115678
loss is   71.62157268376136
loss is   108.9698132

loss is   72.65549929478259
loss is   47.22504487449856
loss is   88.5943983489865
loss is   75.69978131209056
loss is   99.70913355266427
loss is   102.10829462913199
loss is   79.92223285066368
loss is   87.95520881291644
loss is   116.10093342395993
loss is   84.92621457773998
loss is   78.9371020288476
loss is   77.06087167990896
loss is   79.20981556833318
loss is   89.15522445423177
loss is   106.65637921935297
loss is   99.37879726859663
loss is   78.68091027486827
loss is   88.96046567071683
loss is   87.25500616361819
loss is   84.66459984538822
loss is   37.94730423396265
loss is   91.32036317564787
loss is   86.95930129531413
loss is   64.88014335772728
loss is   85.54975349487488
loss is   96.89877697812378
loss is   62.260532130372965
loss is   91.28767635785158
loss is   102.17791967810659
loss is   108.6445940191104
loss is   101.94443519873569
loss is   67.93114619771274
loss is   96.06927677329436
loss is   89.70555015626317
loss is   73.83916753604957
loss is   64.822

loss is   90.07807037259907
loss is   81.06683731550665
loss is   82.49432557904487
loss is   76.9446628691747
loss is   96.31519464355527
loss is   90.78502907415395
loss is   106.96817707469992
loss is   96.97055108564227
loss is   82.93845886644732
loss is   127.36265055992229
loss is   122.78187219736739
loss is   82.77982588736991
loss is   61.13626202018049
loss is   73.37798978192052
loss is   74.6967075802776
loss is   109.80603148858533
loss is   71.78882153769653
loss is   46.347012387302016
loss is   57.52759872035844
loss is   68.85644323227682
loss is   71.0834655795753
loss is   63.22561470106108
loss is   67.17769285260681
loss is   92.1350607263928
loss is   91.32380334056141
loss is   53.98830008555951
loss is   67.8241560361108
loss is   98.7810709700822
loss is   72.52743641828542
loss is   119.98960055848106
loss is   89.93552161969893
loss is   64.53387917442544
loss is   65.03326338204306
loss is   62.11617037306624
loss is   60.10303706652493
loss is   90.2422466

loss is   55.709537847978964
loss is   66.73726187473068
loss is   98.95052831502326
loss is   75.49516085490062
loss is   87.18494041716622
loss is   71.90253998930447
loss is   82.95202412390478
loss is   79.11937986179328
loss is   82.29945879819326
loss is   75.86142640715336
loss is   68.33318108806058
loss is   123.2664750271316
loss is   56.48897595078082
loss is   117.87287429740526
loss is   97.82478429544395
loss is   99.59724963813441
loss is   82.36267882139839
loss is   96.10167195035032
loss is   93.88864684156403
loss is   100.04157566809312
loss is   78.67172633907211
loss is   117.87554624027634
loss is   65.35062165287543
loss is   85.81456919432046
loss is   90.59210761409611
loss is   56.73087659981771
loss is   80.06230108901462
loss is   85.1905890095747
loss is   77.73120153814853
loss is   57.78094824160982
loss is   68.92413938555198
loss is   81.60348623966821
loss is   87.27750263278739
loss is   74.36460488069262
loss is   44.61419582380049
loss is   101.280

loss is   123.63822330407503
loss is   55.23460993207981
loss is   103.34310860971233
loss is   57.84693088391642
loss is   83.7767059351282
loss is   59.097992046427244
loss is   85.44794790000964
loss is   72.25596568096425
loss is   120.15905517763805
Optimizer information:
      fun: 0.0
 hess_inv: <5x5 LbfgsInvHessProduct with dtype=float64>
      jac: array([0., 0., 0., 0., 0.])
  message: b'CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL'
     nfev: 1
      nit: 0
   status: 0
  success: True
        x: array([-16.69932073,  -8.33174815, -11.71935761,  -1.57365565,
        14.2190542 ])
Predicting outputs
Started training for n= 900
loss is   363.85595022984006
loss is   63.754211619405076
loss is   100.41359964463909
loss is   97.38657239567884
loss is   77.10339457982643
loss is   87.65059391964087
loss is   131.1207562648272
loss is   62.15487507648921
loss is   38.917965884897555
loss is   56.35349807167305
loss is   50.69217143713903
loss is   118.24032601512583
loss is   

loss is   44.59211844516391
loss is   82.02527722199773
loss is   86.50234630745572
loss is   118.87299294495317
loss is   161.25840187559285
loss is   131.04160350772483
loss is   88.12095270930544
loss is   71.1029831644623
loss is   59.3233842374985
loss is   73.7972970851325
loss is   56.900211234040405
loss is   71.2474031972452
loss is   53.205847491336996
loss is   74.18111336766015
loss is   50.8668857558254
loss is   94.86348757809527
loss is   86.66567210161868
loss is   91.88286151634189
loss is   56.510518912064235
loss is   61.095329467938015
loss is   82.1171157683506
loss is   60.51409961154164
loss is   84.576692748028
loss is   96.92949424440836
loss is   52.35824872929077
loss is   63.53642612510304
loss is   71.9799624746093
loss is   78.45717383946342
loss is   79.03255238095727
loss is   78.77581138348752
loss is   66.03777169115678
loss is   71.62157268376136
loss is   108.96981324663125
loss is   86.29629341716779
loss is   96.43635965912874
loss is   108.2728664

loss is   57.03894072578469
loss is   50.831509851570125
loss is   94.29332021854879
loss is   92.6515179163649
loss is   74.37989211491072
loss is   88.11847372218095
loss is   91.55266801604986
loss is   92.01813572470874
loss is   134.0737643744535
loss is   71.08208051832348
loss is   102.73190341736024
loss is   102.98505569199042
loss is   95.67447614803595
loss is   79.61111665899107
loss is   85.121194442694
loss is   81.78639271691718
loss is   70.12700223977978
loss is   83.10495223482624
loss is   84.32391366566839
loss is   82.94955928002089
loss is   57.933979915315945
loss is   86.1240834078299
loss is   100.07151418371062
loss is   58.68497567230635
loss is   58.117689376681064
loss is   67.75599122145523
loss is   77.5079319598506
loss is   58.079564496358834
loss is   50.678041831803114
loss is   74.83684617672382
loss is   76.10286473359123
loss is   91.04300934520126
loss is   69.72382787968624
loss is   74.78325804339494
loss is   77.9619057718964
loss is   89.11166

loss is   98.87984030456921
loss is   84.37125422374628
loss is   119.76061336443016
loss is   92.78619911582811
loss is   66.97047200894902
loss is   83.99721456036077
loss is   98.97048606233814
loss is   123.12822124365272
loss is   115.3967248706597
loss is   90.07807037259907
loss is   81.06683731550665
loss is   82.49432557904487
loss is   76.9446628691747
loss is   96.31519464355527
loss is   90.78502907415395
loss is   106.96817707469992
loss is   96.97055108564227
loss is   82.93845886644732
loss is   127.36265055992229
loss is   122.78187219736739
loss is   82.77982588736991
loss is   61.13626202018049
loss is   73.37798978192052
loss is   74.6967075802776
loss is   109.80603148858533
loss is   71.78882153769653
loss is   46.347012387302016
loss is   57.52759872035844
loss is   68.85644323227682
loss is   71.0834655795753
loss is   63.22561470106108
loss is   67.17769285260681
loss is   92.1350607263928
loss is   91.32380334056141
loss is   53.98830008555951
loss is   67.8241

loss is   82.38284386557511
loss is   81.30425124990862
loss is   90.61661031255511
loss is   71.09791152293036
loss is   63.541673022250684
loss is   50.362806796223495
loss is   89.3120053946846
loss is   69.25720612028269
loss is   84.06281625871817
loss is   98.42104952121085
loss is   146.07149192960537
loss is   82.83469841856528
loss is   70.06150623633206
loss is   57.97401579374894
loss is   89.34533713286628
loss is   97.81184626279112
loss is   94.12038954695443
loss is   89.16752383392827
loss is   88.0065227276419
loss is   69.57997760152702
loss is   55.709537847978964
loss is   66.73726187473068
loss is   98.95052831502326
loss is   75.49516085490062
loss is   87.18494041716622
loss is   71.90253998930447
loss is   82.95202412390478
loss is   79.11937986179328
loss is   82.29945879819326
loss is   75.86142640715336
loss is   68.33318108806058
loss is   123.2664750271316
loss is   56.48897595078082
loss is   117.87287429740526
loss is   97.82478429544395
loss is   99.5972

loss is   83.7767059351282
loss is   59.097992046427244
loss is   85.44794790000964
loss is   72.25596568096425
loss is   120.15905517763805
loss is   56.69291800723641
loss is   100.819479706224
loss is   55.534039980676745
loss is   82.43976286085349
loss is   58.692003564414215
loss is   84.9578909011217
loss is   70.27725503701946
loss is   78.06946876601657
loss is   109.93065798333188
loss is   66.10132538942965
loss is   66.93871610023335
loss is   86.01120370184438
loss is   73.84724247818095
loss is   70.77390098545942
loss is   89.525311372423
loss is   85.30883933433785
loss is   89.43951160099903
loss is   82.18958595296094
loss is   92.9830475246378
loss is   108.48108576800834
loss is   95.16836237455819
loss is   70.91559921638265
loss is   92.8297902303466
loss is   99.379306648635
loss is   115.4245520970184
loss is   49.56858775369148
loss is   94.41136072385748
loss is   95.42255517061929
loss is   82.92042287413024
loss is   83.08189099724426
loss is   89.2560778908

loss is   102.77696049630312
loss is   134.83245236871812
loss is   91.78878799944768
loss is   64.55441790860107
loss is   89.25158666466325
loss is   106.18052279865583
loss is   81.66717832836194
loss is   88.65511962569856
loss is   67.65911236176984
loss is   62.317461408714465
loss is   112.3508446422156
loss is   134.5345266029598
loss is   78.52513419994777
loss is   76.57004070806515
loss is   78.62579226044171
loss is   107.10424908952842
loss is   128.38522753220397
loss is   93.12173240987721
loss is   86.27945992469083
loss is   73.53338657011177
loss is   112.86290841579647
loss is   149.6187147479658
loss is   86.60668487886697
loss is   87.77790779312298
loss is   120.38506388107078
loss is   119.09276915063745
loss is   111.56369458477094
loss is   86.55628389530682
loss is   47.09256449657515
loss is   69.2981726204806
loss is   59.12370096826645
loss is   48.202019355057
loss is   82.77418319550644
loss is   89.61576696585084
loss is   66.9567990799535
loss is   58.9

loss is   109.2437676014014
loss is   77.35920563009076
loss is   92.97927889970708
loss is   125.62489301456105
loss is   102.8485471111543
loss is   76.4737941907949
loss is   82.33157749741915
loss is   80.28366494287644
loss is   105.50507775137831
loss is   101.94818431393772
loss is   78.67975147417528
loss is   64.23509354697066
loss is   104.64608702420554
loss is   105.39632359608709
loss is   86.50567032507065
loss is   94.37451162045527
loss is   117.3630985301884
loss is   120.60161610934104
loss is   83.65443513576568
loss is   56.53131331505921
loss is   85.03452984355143
loss is   63.213574673435375
loss is   70.74549567805104
loss is   75.65622982115914
loss is   86.2751287138777
loss is   92.33622122273694
loss is   99.74603629441769
loss is   87.35344878936593
loss is   69.37919815516581
loss is   91.94236541588963
loss is   102.94950971764223
loss is   123.91503601501161
loss is   87.84212077627907
loss is   62.31141228352862
loss is   74.84077447840617
loss is   80.

loss is   82.75686466959237
loss is   64.3222022671889
loss is   100.8891271315956
loss is   66.1018129897457
loss is   78.50743953790733
loss is   54.59865932194564
loss is   59.52008416010823
loss is   92.64057654217864
loss is   102.0145467933473
loss is   64.67467533395252
loss is   53.103802528151625
loss is   85.330602114039
loss is   96.44605498038739
loss is   80.2916216628146
loss is   87.63210502228287
loss is   88.1328291832169
loss is   81.59254534550927
loss is   89.84984673992055
loss is   121.23018311173743
loss is   86.96868234244707
loss is   85.87840751204072
loss is   91.82671388958956
loss is   100.03773150941214
loss is   87.26441455135387
loss is   69.1633368283797
loss is   91.51836263326699
loss is   82.99658117734678
loss is   99.93027351198499
loss is   81.47779872485188
loss is   70.95518839181277
loss is   64.296671725044
loss is   94.42973407901759
loss is   92.89269746167878
loss is   103.1863927819509
loss is   57.65894155169655
loss is   72.1194086020022

In [8]:
tr_accuracies

[0.4238833181403829,
 0.4247948951686418,
 0.4074749316317229,
 0.4056517775752051]

In [5]:
trained_model_parameters = train(train_data_loader, "logistic_loss", None, 1)

loss is   173.9201264032248
loss is   50.11542105135275
loss is   30.795270317178367
loss is   28.675509637408442
loss is   24.54608758676125
loss is   18.69927034762205
loss is   19.407129566241
loss is   19.289731632021436
loss is   16.60769878958542
loss is   22.742887068652283
loss is   20.459325788978628
loss is   19.965357604600587
loss is   23.436849807684194
loss is   24.490023682104294
loss is   24.996582382547672
loss is   23.70017326514732
loss is   25.70143896582056
loss is   30.498091996848284
loss is   26.499856449670236
loss is   29.29380798084746
loss is   22.830110921188233
loss is   23.575939052049435
loss is   20.218677763118833
loss is   22.26943353230828
loss is   18.02696702290046
loss is   20.4579988252204
loss is   22.881822738470408
loss is   22.853294203002253
loss is   18.054245099920074
loss is   20.164993544838108
loss is   18.562608035404082
loss is   23.045859724002128
loss is   19.61164289212171
loss is   20.36239334688854
loss is   19.680211775857686
lo

loss is   72.59845134400263
loss is   89.40738084694752
loss is   79.78011906009445
loss is   129.3815104488328
loss is   52.716084803043046
loss is   51.48805098634219
loss is   54.153959203850384
loss is   43.34023990806093
loss is   118.3003171105091
loss is   42.172983316380844
loss is   120.44742607201151
loss is   42.5477838941505
loss is   77.61028145019728
loss is   99.20773599054613
loss is   74.51908490403433
loss is   97.15790032311482
loss is   118.84631798352987
loss is   27.039490169430195
loss is   110.12768743687899
loss is   123.75768024853018
loss is   41.59934206617763
loss is   115.80619900361927
loss is   128.32208435806487
loss is   28.18698050096148
loss is   112.8518917388581
loss is   116.78308112440835
loss is   128.29261532065735
loss is   46.57436135116028
loss is   84.89053458764108
loss is   41.302006643650074
loss is   110.30268926970241
loss is   127.67788608822184
loss is   49.23277293836418
loss is   126.57992540469664
loss is   126.2396428113366
loss 

loss is   88.6806850102859
loss is   90.7242983564021
loss is   91.25254709619622
loss is   88.99509950261358
loss is   75.587003967147
loss is   147.75071794307723
loss is   87.21786740855092
loss is   91.35836096871381
loss is   89.6642384230403
loss is   76.14187971153495
loss is   141.25427514079908
loss is   92.03833831194999
loss is   83.92255648866771
loss is   90.83017629016634
loss is   100.8105982429379
loss is   99.41096487455938
loss is   49.61446927626919
loss is   104.18538840662082
loss is   55.68678923455946
loss is   87.66370686693853
loss is   74.49493564971661
loss is   101.06256839206934
loss is   88.13241844394977
loss is   70.58161507259437
loss is   95.66767068723757
loss is   101.44157271723039
loss is   75.22938397940807
loss is   83.49053534252765
loss is   99.72609069936149
loss is   80.6702754302077
loss is   91.7577941805075
loss is   38.56719604760436
loss is   137.77976360633707
loss is   81.36779726931057
loss is   94.76955905541456
loss is   47.44904606

loss is   93.62799835162342
loss is   88.25665737004056
loss is   102.01278666990636
loss is   84.86158777806695
loss is   135.67524979013595
loss is   137.2480134615224
loss is   142.94254696567234
loss is   137.33450005165207
loss is   77.23055859670662
loss is   140.23383623774643
loss is   71.21123929453404
loss is   132.22108516094443
loss is   83.42490261391941
loss is   84.6709729175322
loss is   141.5326340089542
loss is   107.4985585378991
loss is   89.95231338929814
loss is   143.3653066748962
loss is   82.02179303492679
loss is   142.49896495821645
loss is   135.34582777452917
loss is   103.46081604762632
loss is   90.11397187695985
loss is   143.51798378232476
loss is   82.26517120824164
loss is   142.7217311430307
loss is   135.36181498386978
loss is   103.69915611431396
loss is   90.18164275908241
loss is   149.1227505266643
loss is   78.06395918250381
loss is   145.38101287291084
loss is   147.89691079644282
loss is   76.10498570188369
loss is   138.60346511051415
loss i