<a href="https://colab.research.google.com/github/sdgroeve/D012554_Machine_Learning_2023/blob/main/02_neural_networks_in_pytorch_lightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [17]:
#@title
import requests
from pathlib import Path 

request = requests.get("https://raw.githubusercontent.com/sdgroeve/D012554_Machine_Learning_2023/main/utils/utils.py")
with open("utils.py", "wb") as f:
  f.write(request.content)

from utils import plot_decision_boundary

In [18]:
#@title
!pip install tqdm
!pip install pytorch-lightning

# 2. Neural networks in PyTorch


In [19]:
import torch
from torch import nn 

torch.manual_seed(46)

# Check PyTorch version
torch.__version__

## Preparing the data

The dataset for this notebook is in a flat file called `dataset_neural_networks.csv`. 

We read this file into a Pandas DataFrame.

In [20]:
from torch.utils.data import Dataset

class XORDataset(Dataset):
    # This loads the data and converts it, make data rdy
    def __init__(self):
        # load data
        fn = "https://raw.githubusercontent.com/sdgroeve/D012554_Machine_Learning_2023/main/datasets/dataset_neural_networks.csv"
        self.df=pd.read_csv(fn)
        # extract labels
        self.df_labels=self.df[['y']]
        self.df.pop('y')
        # conver to torch dtypes
        self.dataset=torch.tensor(self.df.to_numpy(),dtype=torch.float)
        self.labels=torch.tensor(self.df_labels.to_numpy(),dtype=torch.float)
    
    # This returns the total amount of samples in your Dataset
    def __len__(self):
        return len(self.dataset)
    
    # This returns given an index the i-th sample and label
    def __getitem__(self, idx):
        return self.dataset[idx],self.labels[idx]

In [21]:
import torch.utils.data as data
from torch.utils.data import DataLoader
import pytorch_lightning as pl

class XORDataModule(pl.LightningDataModule):
    def __init__(self, batch_size: int = 32):
        super().__init__()
        self.batch_size = batch_size

    def setup(self, stage: str):
        self.data = XORDataset()

        # Random split
        self.train_set_size = int(len(self.data) * 0.8)
        self.valid_set_size = len(self.data) - self.train_set_size
        self.train_set, self.valid_set = data.random_split(self.data, [self.train_set_size, self.valid_set_size])

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.valid_set, batch_size=self.batch_size)

## Building the model

We increase the complexity of our model by adding an additional linear layer to the **model architecture**.   

In [22]:
import pytorch_lightning as pl


class NeuralNetwork(pl.LightningModule):
    def __init__(self, input_dim, output_dim):
        super().__init__()

        num_neurons_layer_2 = 6

        self.layer_1 = nn.Linear(in_features=input_dim, out_features=num_neurons_layer_2)
        self.layer_2 = nn.Linear(in_features=num_neurons_layer_2, out_features=output_dim)
        
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.layer_1(x))
        x = self.layer_2(x)
        x = self.sigmoid(x)
        return x

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=0.01)
        return [optimizer]

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.binary_cross_entropy(y_hat, y)
        #print("train_loss = %f"%loss)
        return loss    

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        val_loss = nn.functional.binary_cross_entropy(y_hat, y)
        print("val_loss = %f"%val_loss)

### `__init()__`

Our neural network has two linear layers. The first layer `layer_1` has `input_dim` (the number of features in our dataset) input features that form the **input layer**. It has `num_neurons_layer_2` output features that form the **hidden layer** where these features are typically called **hidden neurons**.

The second layer `layer_2` has `num_neurons_layer_2` input features (neurons) and `output_dim` (which equals to 1 for two-class classification) output features, the **output layer**.

An example of this model architecture with `num_neurons_layer_2 = 6` can be seen [here](https://playground.tensorflow.org/#activation=sigmoid&batchSize=30&dataset=xor&regDataset=reg-plane&learningRate=0.03&regularizationRate=0&noise=0&networkShape=6&seed=0.86658&showTestData=false&discretize=false&percTrainData=70&x=true&y=true&xTimesY=false&xSquared=false&ySquared=false&cosX=false&sinX=false&cosY=false&sinY=false&collectStats=false&problem=classification&initZero=false&hideText=false).

We will use the Rectified Linear Unit (ReLU) activation function in the hidden layer. In the output layer we use the sigmoid function (through `BCEWithLogitsLoss`, so we not to explicitly apply the sigmoid function during inference (see notebook about logistic regression)).

Next, we create an instance of the class `NeuralNetwork`.

In [23]:
# Two inputs x_1 and x_2
input_dim = 2  
# Single binary output 
output_dim = 1 

# Create an instance of the model (this is a subclass of nn.Module that contains nn.Parameter(s))
model = NeuralNetwork(input_dim, output_dim)

model.state_dict()

In [24]:
print(model.layer_1.weight.dtype)

Let's plot the decision boundary of this initial neural network.

In [25]:
#plot_decision_boundary(model, X_train, y_train)

In [None]:
import pandas as pd
trainer = pl.Trainer(max_epochs=8000)
xor = XORDataModule()
trainer.fit(model,xor)
#trainer.validate(model,xor)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type    | Params
------------------------------------
0 | layer_1 | Linear  | 18    
1 | layer_2 | Linear  | 7     
2 | sigmoid | Sigmoid | 0     
3 | relu    | ReLU    | 0     
------------------------------------
25        Trainable params
0         Non-trainable params
25        Total params
0.000     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

val_loss = 0.672341
val_loss = 0.696593


  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

val_loss = 0.666439
val_loss = 0.686506
val_loss = 0.701799
val_loss = 0.718100
val_loss = 0.786214
val_loss = 0.704597
val_loss = 0.817964


Validation: 0it [00:00, ?it/s]

val_loss = 0.662128
val_loss = 0.678674
val_loss = 0.694516
val_loss = 0.714913
val_loss = 0.765497
val_loss = 0.693772
val_loss = 0.792002


Validation: 0it [00:00, ?it/s]

val_loss = 0.658658
val_loss = 0.672335
val_loss = 0.688636
val_loss = 0.712522
val_loss = 0.748720
val_loss = 0.685175
val_loss = 0.770447


Validation: 0it [00:00, ?it/s]

val_loss = 0.655646
val_loss = 0.666930
val_loss = 0.683593
val_loss = 0.710654
val_loss = 0.734824
val_loss = 0.678149
val_loss = 0.752264


Validation: 0it [00:00, ?it/s]

val_loss = 0.652836
val_loss = 0.662169
val_loss = 0.679196
val_loss = 0.708942
val_loss = 0.723193
val_loss = 0.672163
val_loss = 0.736852


Validation: 0it [00:00, ?it/s]

val_loss = 0.650051
val_loss = 0.657796
val_loss = 0.675194
val_loss = 0.707216
val_loss = 0.713260
val_loss = 0.666943
val_loss = 0.723476


Validation: 0it [00:00, ?it/s]

val_loss = 0.647293
val_loss = 0.653633
val_loss = 0.671357
val_loss = 0.705402
val_loss = 0.704619
val_loss = 0.662237
val_loss = 0.711777


Validation: 0it [00:00, ?it/s]

val_loss = 0.644441
val_loss = 0.649561
val_loss = 0.667540
val_loss = 0.703495
val_loss = 0.696988
val_loss = 0.657867
val_loss = 0.701320


Validation: 0it [00:00, ?it/s]

val_loss = 0.641456
val_loss = 0.645546
val_loss = 0.663706
val_loss = 0.701421
val_loss = 0.690093
val_loss = 0.653723
val_loss = 0.691806


Validation: 0it [00:00, ?it/s]

val_loss = 0.638307
val_loss = 0.641611
val_loss = 0.659828
val_loss = 0.699243
val_loss = 0.683808
val_loss = 0.649758
val_loss = 0.683071


Validation: 0it [00:00, ?it/s]

val_loss = 0.634998
val_loss = 0.637611
val_loss = 0.655995
val_loss = 0.696914
val_loss = 0.677979
val_loss = 0.645872
val_loss = 0.674965


Validation: 0it [00:00, ?it/s]

val_loss = 0.631513
val_loss = 0.633573
val_loss = 0.652203
val_loss = 0.694395
val_loss = 0.672537
val_loss = 0.642054
val_loss = 0.667413


Validation: 0it [00:00, ?it/s]

val_loss = 0.627873
val_loss = 0.629553
val_loss = 0.648335
val_loss = 0.691748
val_loss = 0.667456
val_loss = 0.638352
val_loss = 0.660394


Validation: 0it [00:00, ?it/s]

val_loss = 0.624072
val_loss = 0.625454
val_loss = 0.644401
val_loss = 0.688915
val_loss = 0.662615
val_loss = 0.634617
val_loss = 0.653711


Validation: 0it [00:00, ?it/s]

val_loss = 0.611764
val_loss = 0.612570
val_loss = 0.632234
val_loss = 0.679471
val_loss = 0.649096
val_loss = 0.623224
val_loss = 0.635222


Validation: 0it [00:00, ?it/s]

val_loss = 0.607462
val_loss = 0.608090
val_loss = 0.628077
val_loss = 0.676037
val_loss = 0.644736
val_loss = 0.619471
val_loss = 0.629447


Validation: 0it [00:00, ?it/s]

val_loss = 0.603147
val_loss = 0.603559
val_loss = 0.623927
val_loss = 0.672531
val_loss = 0.640496
val_loss = 0.615731
val_loss = 0.623828


Validation: 0it [00:00, ?it/s]

val_loss = 0.585550
val_loss = 0.585347
val_loss = 0.606878
val_loss = 0.657819
val_loss = 0.624177
val_loss = 0.600431
val_loss = 0.602486


Validation: 0it [00:00, ?it/s]

val_loss = 0.581032
val_loss = 0.580688
val_loss = 0.602485
val_loss = 0.653845
val_loss = 0.620193
val_loss = 0.596441
val_loss = 0.597357


Validation: 0it [00:00, ?it/s]

val_loss = 0.576465
val_loss = 0.576014
val_loss = 0.598277
val_loss = 0.649970
val_loss = 0.616198
val_loss = 0.592408
val_loss = 0.592301


Validation: 0it [00:00, ?it/s]

val_loss = 0.553940
val_loss = 0.552478
val_loss = 0.577693
val_loss = 0.630643
val_loss = 0.596826
val_loss = 0.572020
val_loss = 0.567949


Validation: 0it [00:00, ?it/s]

val_loss = 0.549355
val_loss = 0.547799
val_loss = 0.573691
val_loss = 0.626780
val_loss = 0.593126
val_loss = 0.567889
val_loss = 0.563246


Validation: 0it [00:00, ?it/s]

val_loss = 0.544774
val_loss = 0.543186
val_loss = 0.569679
val_loss = 0.622953
val_loss = 0.589463
val_loss = 0.563838
val_loss = 0.558614


Validation: 0it [00:00, ?it/s]

val_loss = 0.527045
val_loss = 0.525046
val_loss = 0.553897
val_loss = 0.607651
val_loss = 0.575051
val_loss = 0.547975
val_loss = 0.540661


Validation: 0it [00:00, ?it/s]

val_loss = 0.522738
val_loss = 0.520551
val_loss = 0.549893
val_loss = 0.603752
val_loss = 0.571483
val_loss = 0.544177
val_loss = 0.536278


Validation: 0it [00:00, ?it/s]

val_loss = 0.518391
val_loss = 0.516082
val_loss = 0.545868
val_loss = 0.599893
val_loss = 0.568035
val_loss = 0.540396
val_loss = 0.531964


Validation: 0it [00:00, ?it/s]

val_loss = 0.492573
val_loss = 0.489626
val_loss = 0.521384
val_loss = 0.576150
val_loss = 0.547422
val_loss = 0.517865
val_loss = 0.507731


Validation: 0it [00:00, ?it/s]

val_loss = 0.488377
val_loss = 0.485298
val_loss = 0.517247
val_loss = 0.572036
val_loss = 0.544033
val_loss = 0.514204
val_loss = 0.503874


Validation: 0it [00:00, ?it/s]

val_loss = 0.484143
val_loss = 0.480943
val_loss = 0.513056
val_loss = 0.567804
val_loss = 0.540607
val_loss = 0.510487
val_loss = 0.500002


Validation: 0it [00:00, ?it/s]

val_loss = 0.449528
val_loss = 0.445353
val_loss = 0.477307
val_loss = 0.530133
val_loss = 0.511658
val_loss = 0.478126
val_loss = 0.467012


Validation: 0it [00:00, ?it/s]

val_loss = 0.445040
val_loss = 0.440736
val_loss = 0.472491
val_loss = 0.524862
val_loss = 0.507918
val_loss = 0.473763
val_loss = 0.462691


Validation: 0it [00:00, ?it/s]

val_loss = 0.440493
val_loss = 0.436050
val_loss = 0.467585
val_loss = 0.519412
val_loss = 0.504081
val_loss = 0.469282
val_loss = 0.458283


Validation: 0it [00:00, ?it/s]

val_loss = 0.377434
val_loss = 0.371172
val_loss = 0.397282
val_loss = 0.438289
val_loss = 0.447066
val_loss = 0.404586
val_loss = 0.397468


Validation: 0it [00:00, ?it/s]

val_loss = 0.372381
val_loss = 0.365962
val_loss = 0.391457
val_loss = 0.431347
val_loss = 0.442143
val_loss = 0.399211
val_loss = 0.392794


Validation: 0it [00:00, ?it/s]

val_loss = 0.367323
val_loss = 0.360806
val_loss = 0.385617
val_loss = 0.424382
val_loss = 0.437208
val_loss = 0.393794
val_loss = 0.388121


Validation: 0it [00:00, ?it/s]

val_loss = 0.337519
val_loss = 0.330118
val_loss = 0.350760
val_loss = 0.382491
val_loss = 0.407105
val_loss = 0.360944
val_loss = 0.360226


Validation: 0it [00:00, ?it/s]

val_loss = 0.332682
val_loss = 0.325064
val_loss = 0.345100
val_loss = 0.375737
val_loss = 0.402129
val_loss = 0.355516
val_loss = 0.355661


Validation: 0it [00:00, ?it/s]

val_loss = 0.327897
val_loss = 0.320070
val_loss = 0.339510
val_loss = 0.369088
val_loss = 0.397166
val_loss = 0.350127
val_loss = 0.351142


Validation: 0it [00:00, ?it/s]

val_loss = 0.323163
val_loss = 0.315104
val_loss = 0.334004
val_loss = 0.362624
val_loss = 0.392242
val_loss = 0.344792
val_loss = 0.346648
val_loss = 0.296611
val_loss = 0.286768
val_loss = 0.302059
val_loss = 0.325850
val_loss = 0.363707
val_loss = 0.314710
val_loss = 0.320751


Validation: 0it [00:00, ?it/s]

val_loss = 0.292425
val_loss = 0.282334
val_loss = 0.296938
val_loss = 0.320118
val_loss = 0.359139
val_loss = 0.310054
val_loss = 0.316600


Validation: 0it [00:00, ?it/s]

val_loss = 0.288327
val_loss = 0.277946
val_loss = 0.291895
val_loss = 0.314455
val_loss = 0.354644
val_loss = 0.305490
val_loss = 0.312517


Validation: 0it [00:00, ?it/s]

val_loss = 0.262595
val_loss = 0.249426
val_loss = 0.259220
val_loss = 0.278110
val_loss = 0.325294
val_loss = 0.276163
val_loss = 0.286119


Validation: 0it [00:00, ?it/s]

val_loss = 0.259287
val_loss = 0.245688
val_loss = 0.254777
val_loss = 0.273439
val_loss = 0.321320
val_loss = 0.272345
val_loss = 0.282670


Validation: 0it [00:00, ?it/s]

val_loss = 0.256065
val_loss = 0.242056
val_loss = 0.250408
val_loss = 0.268890
val_loss = 0.317430
val_loss = 0.268625
val_loss = 0.279304


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

val_loss = 0.232818
val_loss = 0.215875
val_loss = 0.217971
val_loss = 0.236948
val_loss = 0.288899
val_loss = 0.241558
val_loss = 0.255102


Validation: 0it [00:00, ?it/s]

val_loss = 0.230140
val_loss = 0.212970
val_loss = 0.214243
val_loss = 0.233383
val_loss = 0.285638
val_loss = 0.238560
val_loss = 0.252399


Validation: 0it [00:00, ?it/s]

val_loss = 0.227535
val_loss = 0.210142
val_loss = 0.210603
val_loss = 0.229916
val_loss = 0.282453
val_loss = 0.235647
val_loss = 0.249766


Validation: 0it [00:00, ?it/s]

val_loss = 0.211153
val_loss = 0.192242
val_loss = 0.187446
val_loss = 0.208335
val_loss = 0.262003
val_loss = 0.216874
val_loss = 0.233023


Validation: 0it [00:00, ?it/s]

val_loss = 0.209065
val_loss = 0.189944
val_loss = 0.184444
val_loss = 0.205597
val_loss = 0.259248
val_loss = 0.214445
val_loss = 0.230867


Validation: 0it [00:00, ?it/s]

val_loss = 0.207036
val_loss = 0.187707
val_loss = 0.181516
val_loss = 0.202943
val_loss = 0.256552
val_loss = 0.212086
val_loss = 0.228771


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

val_loss = 0.179658
val_loss = 0.159054
val_loss = 0.143744
val_loss = 0.169563
val_loss = 0.220266
val_loss = 0.181149
val_loss = 0.201244


Validation: 0it [00:00, ?it/s]

val_loss = 0.178310
val_loss = 0.157638
val_loss = 0.141862
val_loss = 0.167711
val_loss = 0.218400
val_loss = 0.179601
val_loss = 0.199839


Validation: 0it [00:00, ?it/s]

val_loss = 0.177002
val_loss = 0.156262
val_loss = 0.140063
val_loss = 0.165912
val_loss = 0.216552
val_loss = 0.178084
val_loss = 0.198470


Validation: 0it [00:00, ?it/s]

val_loss = 0.167651
val_loss = 0.145978
val_loss = 0.127336
val_loss = 0.153083
val_loss = 0.203958
val_loss = 0.167208
val_loss = 0.188649


Validation: 0it [00:00, ?it/s]

val_loss = 0.166605
val_loss = 0.144717
val_loss = 0.125921
val_loss = 0.151650
val_loss = 0.202522
val_loss = 0.165992
val_loss = 0.187556


Validation: 0it [00:00, ?it/s]

val_loss = 0.165583
val_loss = 0.143484
val_loss = 0.124542
val_loss = 0.150249
val_loss = 0.201115
val_loss = 0.164806
val_loss = 0.186490


Validation: 0it [00:00, ?it/s]

val_loss = 0.158221
val_loss = 0.134535
val_loss = 0.114720
val_loss = 0.140140
val_loss = 0.190824
val_loss = 0.156262
val_loss = 0.178807


Validation: 0it [00:00, ?it/s]

val_loss = 0.157390
val_loss = 0.133518
val_loss = 0.113623
val_loss = 0.139002
val_loss = 0.189640
val_loss = 0.155298
val_loss = 0.177966


Validation: 0it [00:00, ?it/s]

val_loss = 0.156582
val_loss = 0.132522
val_loss = 0.112552
val_loss = 0.137884
val_loss = 0.188480
val_loss = 0.154351
val_loss = 0.177140


Validation: 0it [00:00, ?it/s]

val_loss = 0.155790
val_loss = 0.131546
val_loss = 0.111506
val_loss = 0.136789
val_loss = 0.187343
val_loss = 0.153425
val_loss = 0.176333
val_loss = 0.151359
val_loss = 0.126063
val_loss = 0.105719
val_loss = 0.130660
val_loss = 0.180941
val_loss = 0.148284
val_loss = 0.172510


Validation: 0it [00:00, ?it/s]

val_loss = 0.150671
val_loss = 0.125208
val_loss = 0.104830
val_loss = 0.129707
val_loss = 0.179936
val_loss = 0.147490
val_loss = 0.171960


Validation: 0it [00:00, ?it/s]

val_loss = 0.150006
val_loss = 0.124368
val_loss = 0.103962
val_loss = 0.128771
val_loss = 0.178949
val_loss = 0.146713
val_loss = 0.171425


Validation: 0it [00:00, ?it/s]

val_loss = 0.145170
val_loss = 0.118261
val_loss = 0.097740
val_loss = 0.121979
val_loss = 0.171592
val_loss = 0.141044
val_loss = 0.167661


Validation: 0it [00:00, ?it/s]

val_loss = 0.144625
val_loss = 0.117562
val_loss = 0.097037
val_loss = 0.121202
val_loss = 0.170738
val_loss = 0.140394
val_loss = 0.167251


Validation: 0it [00:00, ?it/s]

val_loss = 0.144094
val_loss = 0.116876
val_loss = 0.096300
val_loss = 0.120442
val_loss = 0.169896
val_loss = 0.139800
val_loss = 0.166853


Validation: 0it [00:00, ?it/s]

val_loss = 0.137609
val_loss = 0.108557
val_loss = 0.086849
val_loss = 0.111166
val_loss = 0.159264
val_loss = 0.133343
val_loss = 0.162420


Validation: 0it [00:00, ?it/s]

val_loss = 0.137204
val_loss = 0.108037
val_loss = 0.086260
val_loss = 0.110586
val_loss = 0.158580
val_loss = 0.132960
val_loss = 0.162159


Validation: 0it [00:00, ?it/s]

val_loss = 0.136806
val_loss = 0.107522
val_loss = 0.085680
val_loss = 0.110013
val_loss = 0.157906
val_loss = 0.132582
val_loss = 0.161898


Validation: 0it [00:00, ?it/s]

val_loss = 0.132548
val_loss = 0.101958
val_loss = 0.079441
val_loss = 0.103821
val_loss = 0.150452
val_loss = 0.128591
val_loss = 0.159197


Validation: 0it [00:00, ?it/s]

val_loss = 0.132232
val_loss = 0.101540
val_loss = 0.078976
val_loss = 0.103358
val_loss = 0.149876
val_loss = 0.128301
val_loss = 0.159012


Validation: 0it [00:00, ?it/s]

val_loss = 0.131922
val_loss = 0.101130
val_loss = 0.078520
val_loss = 0.102904
val_loss = 0.149306
val_loss = 0.128018
val_loss = 0.158838


Validation: 0it [00:00, ?it/s]

val_loss = 0.131618
val_loss = 0.100726
val_loss = 0.078071
val_loss = 0.102458
val_loss = 0.148742
val_loss = 0.127741
val_loss = 0.158671
val_loss = 0.129073
val_loss = 0.097358
val_loss = 0.074330
val_loss = 0.098742
val_loss = 0.143935
val_loss = 0.125510
val_loss = 0.157421


Validation: 0it [00:00, ?it/s]

val_loss = 0.128809
val_loss = 0.097010
val_loss = 0.073946
val_loss = 0.098361
val_loss = 0.143426
val_loss = 0.125289
val_loss = 0.157306


Validation: 0it [00:00, ?it/s]

val_loss = 0.128548
val_loss = 0.096666
val_loss = 0.073567
val_loss = 0.097986
val_loss = 0.142923
val_loss = 0.125071
val_loss = 0.157193


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

val_loss = 0.125931
val_loss = 0.093181
val_loss = 0.069741
val_loss = 0.094184
val_loss = 0.137727
val_loss = 0.122944
val_loss = 0.156186


Validation: 0it [00:00, ?it/s]

val_loss = 0.125714
val_loss = 0.092891
val_loss = 0.069422
val_loss = 0.093866
val_loss = 0.137284
val_loss = 0.122773
val_loss = 0.156116


Validation: 0it [00:00, ?it/s]

val_loss = 0.125500
val_loss = 0.092604
val_loss = 0.069107
val_loss = 0.093553
val_loss = 0.136846
val_loss = 0.122606
val_loss = 0.156051


Validation: 0it [00:00, ?it/s]

val_loss = 0.123967
val_loss = 0.090472
val_loss = 0.066786
val_loss = 0.091266
val_loss = 0.133364
val_loss = 0.121354
val_loss = 0.155639


Validation: 0it [00:00, ?it/s]

val_loss = 0.123788
val_loss = 0.090220
val_loss = 0.066513
val_loss = 0.090996
val_loss = 0.132950
val_loss = 0.121212
val_loss = 0.155600


Validation: 0it [00:00, ?it/s]

val_loss = 0.123611
val_loss = 0.089971
val_loss = 0.066243
val_loss = 0.090729
val_loss = 0.132541
val_loss = 0.121073
val_loss = 0.155566


Validation: 0it [00:00, ?it/s]

val_loss = 0.121071
val_loss = 0.086370
val_loss = 0.062340
val_loss = 0.086858
val_loss = 0.126604
val_loss = 0.119212
val_loss = 0.155332


Validation: 0it [00:00, ?it/s]

val_loss = 0.120913
val_loss = 0.086162
val_loss = 0.062120
val_loss = 0.086640
val_loss = 0.126253
val_loss = 0.119129
val_loss = 0.155351


Validation: 0it [00:00, ?it/s]

val_loss = 0.120758
val_loss = 0.085955
val_loss = 0.061902
val_loss = 0.086424
val_loss = 0.125907
val_loss = 0.119047
val_loss = 0.155371


Validation: 0it [00:00, ?it/s]

val_loss = 0.118944
val_loss = 0.083461
val_loss = 0.059273
val_loss = 0.083797
val_loss = 0.121732
val_loss = 0.118094
val_loss = 0.155673


Validation: 0it [00:00, ?it/s]

val_loss = 0.118819
val_loss = 0.083283
val_loss = 0.059086
val_loss = 0.083613
val_loss = 0.121431
val_loss = 0.118031
val_loss = 0.155705


Validation: 0it [00:00, ?it/s]

val_loss = 0.118695
val_loss = 0.083108
val_loss = 0.058902
val_loss = 0.083430
val_loss = 0.121133
val_loss = 0.117970
val_loss = 0.155738


Validation: 0it [00:00, ?it/s]

val_loss = 0.117193
val_loss = 0.080972
val_loss = 0.056688
val_loss = 0.081216
val_loss = 0.117465
val_loss = 0.117318
val_loss = 0.156292


Validation: 0it [00:00, ?it/s]

val_loss = 0.117085
val_loss = 0.080818
val_loss = 0.056529
val_loss = 0.081057
val_loss = 0.117197
val_loss = 0.117277
val_loss = 0.156341


Validation: 0it [00:00, ?it/s]

val_loss = 0.116979
val_loss = 0.080665
val_loss = 0.056372
val_loss = 0.080900
val_loss = 0.116932
val_loss = 0.117238
val_loss = 0.156392


Validation: 0it [00:00, ?it/s]

val_loss = 0.115799
val_loss = 0.078940
val_loss = 0.054599
val_loss = 0.079154
val_loss = 0.113898
val_loss = 0.116843
val_loss = 0.157067


Validation: 0it [00:00, ?it/s]

val_loss = 0.115711
val_loss = 0.078805
val_loss = 0.054460
val_loss = 0.079020
val_loss = 0.113658
val_loss = 0.116813
val_loss = 0.157120


Validation: 0it [00:00, ?it/s]

val_loss = 0.115622
val_loss = 0.078671
val_loss = 0.054322
val_loss = 0.078887
val_loss = 0.113418
val_loss = 0.116783
val_loss = 0.157173


Validation: 0it [00:00, ?it/s]

val_loss = 0.114404
val_loss = 0.076808
val_loss = 0.052415
val_loss = 0.077057
val_loss = 0.110006
val_loss = 0.116457
val_loss = 0.158088


Validation: 0it [00:00, ?it/s]

val_loss = 0.114330
val_loss = 0.076693
val_loss = 0.052297
val_loss = 0.076944
val_loss = 0.109790
val_loss = 0.116443
val_loss = 0.158156


Validation: 0it [00:00, ?it/s]

val_loss = 0.114256
val_loss = 0.076579
val_loss = 0.052180
val_loss = 0.076833
val_loss = 0.109575
val_loss = 0.116429
val_loss = 0.158226


Validation: 0it [00:00, ?it/s]

val_loss = 0.113493
val_loss = 0.075386
val_loss = 0.050964
val_loss = 0.075680
val_loss = 0.107303
val_loss = 0.116337
val_loss = 0.159056


Validation: 0it [00:00, ?it/s]

val_loss = 0.113428
val_loss = 0.075283
val_loss = 0.050859
val_loss = 0.075582
val_loss = 0.107104
val_loss = 0.116333
val_loss = 0.159137


Validation: 0it [00:00, ?it/s]

val_loss = 0.113364
val_loss = 0.075181
val_loss = 0.050755
val_loss = 0.075484
val_loss = 0.106907
val_loss = 0.116330
val_loss = 0.159218


Validation: 0it [00:00, ?it/s]

val_loss = 0.112808
val_loss = 0.074295
val_loss = 0.049858
val_loss = 0.074640
val_loss = 0.105176
val_loss = 0.116336
val_loss = 0.159981


Validation: 0it [00:00, ?it/s]

val_loss = 0.112748
val_loss = 0.074200
val_loss = 0.049762
val_loss = 0.074550
val_loss = 0.104989
val_loss = 0.116341
val_loss = 0.160068


Validation: 0it [00:00, ?it/s]

val_loss = 0.112689
val_loss = 0.074106
val_loss = 0.049667
val_loss = 0.074460
val_loss = 0.104803
val_loss = 0.116345
val_loss = 0.160156


Validation: 0it [00:00, ?it/s]

val_loss = 0.112237
val_loss = 0.073376
val_loss = 0.048934
val_loss = 0.073773
val_loss = 0.103353
val_loss = 0.116405
val_loss = 0.160880


Validation: 0it [00:00, ?it/s]

val_loss = 0.112183
val_loss = 0.073287
val_loss = 0.048845
val_loss = 0.073690
val_loss = 0.103174
val_loss = 0.116415
val_loss = 0.160972


Validation: 0it [00:00, ?it/s]

val_loss = 0.112131
val_loss = 0.073198
val_loss = 0.048757
val_loss = 0.073608
val_loss = 0.102997
val_loss = 0.116426
val_loss = 0.161064


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

val_loss = 0.111621
val_loss = 0.072269
val_loss = 0.047826
val_loss = 0.072736
val_loss = 0.101132
val_loss = 0.116522
val_loss = 0.161993


Validation: 0it [00:00, ?it/s]

val_loss = 0.111578
val_loss = 0.072188
val_loss = 0.047744
val_loss = 0.072659
val_loss = 0.100969
val_loss = 0.116533
val_loss = 0.162079


Validation: 0it [00:00, ?it/s]

val_loss = 0.111535
val_loss = 0.072108
val_loss = 0.047664
val_loss = 0.072584
val_loss = 0.100807
val_loss = 0.116545
val_loss = 0.162166


Validation: 0it [00:00, ?it/s]

val_loss = 0.110842
val_loss = 0.070828
val_loss = 0.046387
val_loss = 0.071407
val_loss = 0.098164
val_loss = 0.116841
val_loss = 0.163768


Validation: 0it [00:00, ?it/s]

val_loss = 0.110802
val_loss = 0.070758
val_loss = 0.046317
val_loss = 0.071343
val_loss = 0.098016
val_loss = 0.116863
val_loss = 0.163867


Validation: 0it [00:00, ?it/s]

val_loss = 0.110763
val_loss = 0.070688
val_loss = 0.046248
val_loss = 0.071279
val_loss = 0.097868
val_loss = 0.116886
val_loss = 0.163966


Validation: 0it [00:00, ?it/s]

val_loss = 0.110724
val_loss = 0.070618
val_loss = 0.046179
val_loss = 0.071216
val_loss = 0.097722
val_loss = 0.116909
val_loss = 0.164066
val_loss = 0.110461
val_loss = 0.070144
val_loss = 0.045709
val_loss = 0.070787
val_loss = 0.096720
val_loss = 0.117077
val_loss = 0.164767


Validation: 0it [00:00, ?it/s]

val_loss = 0.110425
val_loss = 0.070078
val_loss = 0.045644
val_loss = 0.070728
val_loss = 0.096580
val_loss = 0.117102
val_loss = 0.164868


Validation: 0it [00:00, ?it/s]

val_loss = 0.110390
val_loss = 0.070012
val_loss = 0.045579
val_loss = 0.070668
val_loss = 0.096441
val_loss = 0.117127
val_loss = 0.164967


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

val_loss = 0.110118
val_loss = 0.069498
val_loss = 0.045074
val_loss = 0.070207
val_loss = 0.095353
val_loss = 0.117332
val_loss = 0.165758


Validation: 0it [00:00, ?it/s]

val_loss = 0.110086
val_loss = 0.069435
val_loss = 0.045012
val_loss = 0.070151
val_loss = 0.095220
val_loss = 0.117359
val_loss = 0.165857


Validation: 0it [00:00, ?it/s]

val_loss = 0.110054
val_loss = 0.069373
val_loss = 0.044951
val_loss = 0.070095
val_loss = 0.095088
val_loss = 0.117386
val_loss = 0.165957


Validation: 0it [00:00, ?it/s]

val_loss = 0.109806
val_loss = 0.068888
val_loss = 0.044478
val_loss = 0.069664
val_loss = 0.094055
val_loss = 0.117615
val_loss = 0.166771


Validation: 0it [00:00, ?it/s]

val_loss = 0.109776
val_loss = 0.068829
val_loss = 0.044421
val_loss = 0.069612
val_loss = 0.093928
val_loss = 0.117646
val_loss = 0.166875


Validation: 0it [00:00, ?it/s]

val_loss = 0.109746
val_loss = 0.068770
val_loss = 0.044364
val_loss = 0.069560
val_loss = 0.093801
val_loss = 0.117676
val_loss = 0.166978


Validation: 0it [00:00, ?it/s]

val_loss = 0.109525
val_loss = 0.068314
val_loss = 0.043920
val_loss = 0.069159
val_loss = 0.092814
val_loss = 0.117929
val_loss = 0.167818


Validation: 0it [00:00, ?it/s]

val_loss = 0.109498
val_loss = 0.068258
val_loss = 0.043866
val_loss = 0.069111
val_loss = 0.092694
val_loss = 0.117962
val_loss = 0.167924


Validation: 0it [00:00, ?it/s]

val_loss = 0.109472
val_loss = 0.068203
val_loss = 0.043812
val_loss = 0.069063
val_loss = 0.092574
val_loss = 0.117994
val_loss = 0.168030


Validation: 0it [00:00, ?it/s]

val_loss = 0.109446
val_loss = 0.068148
val_loss = 0.043759
val_loss = 0.069015
val_loss = 0.092454
val_loss = 0.118027
val_loss = 0.168136
val_loss = 0.109107
val_loss = 0.067417
val_loss = 0.043052
val_loss = 0.068386
val_loss = 0.090841
val_loss = 0.118516
val_loss = 0.169648


Validation: 0it [00:00, ?it/s]

val_loss = 0.109084
val_loss = 0.067366
val_loss = 0.043003
val_loss = 0.068343
val_loss = 0.090730
val_loss = 0.118553
val_loss = 0.169758


Validation: 0it [00:00, ?it/s]

val_loss = 0.109061
val_loss = 0.067316
val_loss = 0.042955
val_loss = 0.068300
val_loss = 0.090619
val_loss = 0.118590
val_loss = 0.169867


Validation: 0it [00:00, ?it/s]

val_loss = 0.108802
val_loss = 0.066737
val_loss = 0.042397
val_loss = 0.067809
val_loss = 0.089333
val_loss = 0.119045
val_loss = 0.171193


Validation: 0it [00:00, ?it/s]

val_loss = 0.108781
val_loss = 0.066691
val_loss = 0.042353
val_loss = 0.067770
val_loss = 0.089229
val_loss = 0.119084
val_loss = 0.171304


Validation: 0it [00:00, ?it/s]

val_loss = 0.108761
val_loss = 0.066644
val_loss = 0.042308
val_loss = 0.067731
val_loss = 0.089125
val_loss = 0.119123
val_loss = 0.171416


Validation: 0it [00:00, ?it/s]

val_loss = 0.108602
val_loss = 0.066280
val_loss = 0.041960
val_loss = 0.067426
val_loss = 0.088313
val_loss = 0.119443
val_loss = 0.172308


Validation: 0it [00:00, ?it/s]

val_loss = 0.108583
val_loss = 0.066236
val_loss = 0.041917
val_loss = 0.067389
val_loss = 0.088214
val_loss = 0.119483
val_loss = 0.172420


Validation: 0it [00:00, ?it/s]

val_loss = 0.108564
val_loss = 0.066191
val_loss = 0.041875
val_loss = 0.067352
val_loss = 0.088115
val_loss = 0.119524
val_loss = 0.172531


Validation: 0it [00:00, ?it/s]

val_loss = 0.108545
val_loss = 0.066147
val_loss = 0.041833
val_loss = 0.067315
val_loss = 0.088016
val_loss = 0.119564
val_loss = 0.172643
val_loss = 0.108415
val_loss = 0.065843
val_loss = 0.041544
val_loss = 0.067063
val_loss = 0.087339
val_loss = 0.119852
val_loss = 0.173421


Validation: 0it [00:00, ?it/s]

val_loss = 0.108396
val_loss = 0.065800
val_loss = 0.041503
val_loss = 0.067028
val_loss = 0.087243
val_loss = 0.119894
val_loss = 0.173532


Validation: 0it [00:00, ?it/s]

val_loss = 0.108378
val_loss = 0.065758
val_loss = 0.041463
val_loss = 0.066992
val_loss = 0.087148
val_loss = 0.119936
val_loss = 0.173643


Validation: 0it [00:00, ?it/s]

val_loss = 0.108223
val_loss = 0.065383
val_loss = 0.041109
val_loss = 0.066684
val_loss = 0.086314
val_loss = 0.120315
val_loss = 0.174642


Validation: 0it [00:00, ?it/s]

val_loss = 0.108206
val_loss = 0.065343
val_loss = 0.041070
val_loss = 0.066651
val_loss = 0.086223
val_loss = 0.120358
val_loss = 0.174753


Validation: 0it [00:00, ?it/s]

val_loss = 0.108190
val_loss = 0.065302
val_loss = 0.041032
val_loss = 0.066618
val_loss = 0.086133
val_loss = 0.120400
val_loss = 0.174864


Validation: 0it [00:00, ?it/s]

val_loss = 0.108034
val_loss = 0.064907
val_loss = 0.040660
val_loss = 0.066296
val_loss = 0.085253
val_loss = 0.120832
val_loss = 0.175973


Validation: 0it [00:00, ?it/s]

val_loss = 0.108019
val_loss = 0.064869
val_loss = 0.040624
val_loss = 0.066265
val_loss = 0.085168
val_loss = 0.120876
val_loss = 0.176084


Validation: 0it [00:00, ?it/s]

val_loss = 0.108005
val_loss = 0.064831
val_loss = 0.040588
val_loss = 0.066234
val_loss = 0.085082
val_loss = 0.120919
val_loss = 0.176194


Validation: 0it [00:00, ?it/s]

val_loss = 0.107843
val_loss = 0.064385
val_loss = 0.040168
val_loss = 0.065875
val_loss = 0.084086
val_loss = 0.121448
val_loss = 0.177519


Validation: 0it [00:00, ?it/s]

val_loss = 0.107830
val_loss = 0.064349
val_loss = 0.040134
val_loss = 0.065846
val_loss = 0.084005
val_loss = 0.121492
val_loss = 0.177629


Validation: 0it [00:00, ?it/s]

val_loss = 0.107817
val_loss = 0.064313
val_loss = 0.040100
val_loss = 0.065817
val_loss = 0.083925
val_loss = 0.121537
val_loss = 0.177740


Validation: 0it [00:00, ?it/s]

val_loss = 0.107698
val_loss = 0.063960
val_loss = 0.039770
val_loss = 0.065537
val_loss = 0.083138
val_loss = 0.121986
val_loss = 0.178841


Validation: 0it [00:00, ?it/s]

val_loss = 0.107687
val_loss = 0.063926
val_loss = 0.039738
val_loss = 0.065509
val_loss = 0.083061
val_loss = 0.122031
val_loss = 0.178951


Validation: 0it [00:00, ?it/s]

val_loss = 0.107676
val_loss = 0.063891
val_loss = 0.039705
val_loss = 0.065482
val_loss = 0.082985
val_loss = 0.122076
val_loss = 0.179061


Validation: 0it [00:00, ?it/s]

val_loss = 0.107579
val_loss = 0.063588
val_loss = 0.039423
val_loss = 0.065244
val_loss = 0.082309
val_loss = 0.122487
val_loss = 0.180051


Validation: 0it [00:00, ?it/s]

val_loss = 0.107568
val_loss = 0.063555
val_loss = 0.039392
val_loss = 0.065218
val_loss = 0.082235
val_loss = 0.122533
val_loss = 0.180161


Validation: 0it [00:00, ?it/s]

val_loss = 0.107557
val_loss = 0.063522
val_loss = 0.039361
val_loss = 0.065192
val_loss = 0.082162
val_loss = 0.122579
val_loss = 0.180270


Validation: 0it [00:00, ?it/s]

val_loss = 0.107466
val_loss = 0.063230
val_loss = 0.039091
val_loss = 0.064965
val_loss = 0.081514
val_loss = 0.122993
val_loss = 0.181253


Validation: 0it [00:00, ?it/s]

val_loss = 0.107461
val_loss = 0.063205
val_loss = 0.039063
val_loss = 0.064943
val_loss = 0.081456
val_loss = 0.123038
val_loss = 0.181366


Validation: 0it [00:00, ?it/s]

val_loss = 0.107455
val_loss = 0.063180
val_loss = 0.039036
val_loss = 0.064922
val_loss = 0.081398
val_loss = 0.123083
val_loss = 0.181479


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

val_loss = 0.107406
val_loss = 0.062982
val_loss = 0.038821
val_loss = 0.064759
val_loss = 0.080913
val_loss = 0.123444
val_loss = 0.182386


Validation: 0it [00:00, ?it/s]

val_loss = 0.107400
val_loss = 0.062956
val_loss = 0.038795
val_loss = 0.064739
val_loss = 0.080849
val_loss = 0.123490
val_loss = 0.182499


Validation: 0it [00:00, ?it/s]

val_loss = 0.107393
val_loss = 0.062930
val_loss = 0.038768
val_loss = 0.064719
val_loss = 0.080785
val_loss = 0.123535
val_loss = 0.182612


Validation: 0it [00:00, ?it/s]

val_loss = 0.107289
val_loss = 0.062521
val_loss = 0.038359
val_loss = 0.064402
val_loss = 0.079766
val_loss = 0.124267
val_loss = 0.184390


Validation: 0it [00:00, ?it/s]

val_loss = 0.107282
val_loss = 0.062495
val_loss = 0.038334
val_loss = 0.064383
val_loss = 0.079703
val_loss = 0.124313
val_loss = 0.184500


Validation: 0it [00:00, ?it/s]

val_loss = 0.107276
val_loss = 0.062470
val_loss = 0.038310
val_loss = 0.064364
val_loss = 0.079640
val_loss = 0.124359
val_loss = 0.184610


Validation: 0it [00:00, ?it/s]

val_loss = 0.107270
val_loss = 0.062444
val_loss = 0.038285
val_loss = 0.064344
val_loss = 0.079577
val_loss = 0.124404
val_loss = 0.184719
val_loss = 0.107184
val_loss = 0.062091
val_loss = 0.037947
val_loss = 0.064078
val_loss = 0.078715
val_loss = 0.125044
val_loss = 0.186230


Validation: 0it [00:00, ?it/s]

val_loss = 0.107179
val_loss = 0.062066
val_loss = 0.037924
val_loss = 0.064059
val_loss = 0.078654
val_loss = 0.125089
val_loss = 0.186337


Validation: 0it [00:00, ?it/s]

val_loss = 0.107173
val_loss = 0.062041
val_loss = 0.037900
val_loss = 0.064040
val_loss = 0.078594
val_loss = 0.125135
val_loss = 0.186444


Validation: 0it [00:00, ?it/s]

val_loss = 0.107117
val_loss = 0.061771
val_loss = 0.037647
val_loss = 0.063839
val_loss = 0.077944
val_loss = 0.125637
val_loss = 0.187609


Validation: 0it [00:00, ?it/s]

val_loss = 0.107112
val_loss = 0.061747
val_loss = 0.037624
val_loss = 0.063821
val_loss = 0.077886
val_loss = 0.125683
val_loss = 0.187714


Validation: 0it [00:00, ?it/s]

val_loss = 0.107107
val_loss = 0.061722
val_loss = 0.037601
val_loss = 0.063803
val_loss = 0.077828
val_loss = 0.125729
val_loss = 0.187819


Validation: 0it [00:00, ?it/s]

val_loss = 0.107069
val_loss = 0.061531
val_loss = 0.037424
val_loss = 0.063662
val_loss = 0.077373
val_loss = 0.126095
val_loss = 0.188657


Validation: 0it [00:00, ?it/s]

val_loss = 0.107065
val_loss = 0.061507
val_loss = 0.037402
val_loss = 0.063644
val_loss = 0.077317
val_loss = 0.126140
val_loss = 0.188761


Validation: 0it [00:00, ?it/s]

val_loss = 0.107061
val_loss = 0.061483
val_loss = 0.037380
val_loss = 0.063627
val_loss = 0.077261
val_loss = 0.126186
val_loss = 0.188865


Validation: 0it [00:00, ?it/s]

val_loss = 0.106995
val_loss = 0.061112
val_loss = 0.037040
val_loss = 0.063356
val_loss = 0.076391
val_loss = 0.126917
val_loss = 0.190512


Validation: 0it [00:00, ?it/s]

val_loss = 0.106992
val_loss = 0.061089
val_loss = 0.037019
val_loss = 0.063339
val_loss = 0.076338
val_loss = 0.126963
val_loss = 0.190614


Validation: 0it [00:00, ?it/s]

val_loss = 0.106988
val_loss = 0.061066
val_loss = 0.036999
val_loss = 0.063323
val_loss = 0.076286
val_loss = 0.127008
val_loss = 0.190716


Validation: 0it [00:00, ?it/s]

val_loss = 0.106984
val_loss = 0.061044
val_loss = 0.036978
val_loss = 0.063307
val_loss = 0.076234
val_loss = 0.127054
val_loss = 0.190817
val_loss = 0.106948
val_loss = 0.060819
val_loss = 0.036776
val_loss = 0.063146
val_loss = 0.075718
val_loss = 0.127512
val_loss = 0.191831


Validation: 0it [00:00, ?it/s]

val_loss = 0.106945
val_loss = 0.060797
val_loss = 0.036756
val_loss = 0.063130
val_loss = 0.075667
val_loss = 0.127557
val_loss = 0.191932


Validation: 0it [00:00, ?it/s]

val_loss = 0.106942
val_loss = 0.060775
val_loss = 0.036737
val_loss = 0.063114
val_loss = 0.075617
val_loss = 0.127603
val_loss = 0.192033


Validation: 0it [00:00, ?it/s]

val_loss = 0.106914
val_loss = 0.060579
val_loss = 0.036561
val_loss = 0.062975
val_loss = 0.075171
val_loss = 0.128014
val_loss = 0.192934


Validation: 0it [00:00, ?it/s]

val_loss = 0.106911
val_loss = 0.060557
val_loss = 0.036542
val_loss = 0.062959
val_loss = 0.075122
val_loss = 0.128060
val_loss = 0.193034


Validation: 0it [00:00, ?it/s]

val_loss = 0.106909
val_loss = 0.060536
val_loss = 0.036523
val_loss = 0.062944
val_loss = 0.075074
val_loss = 0.128106
val_loss = 0.193133


Validation: 0it [00:00, ?it/s]

val_loss = 0.106906
val_loss = 0.060515
val_loss = 0.036504
val_loss = 0.062929
val_loss = 0.075025
val_loss = 0.128151
val_loss = 0.193233
val_loss = 0.106888
val_loss = 0.060367
val_loss = 0.036372
val_loss = 0.062825
val_loss = 0.074691
val_loss = 0.128471
val_loss = 0.193929


Validation: 0it [00:00, ?it/s]

### `forward()`

The `forward()` method applies the neural network to the provided feature vectors. Here we see that the data is first passed through `layer_1`, then through the ReLU activations that then pass through `layer_2`.

In [16]:
from sklearn.metrics import accuracy_score

with torch.inference_mode(): 
    predictions = model(X_test)

predictions = torch.squeeze(torch.round(torch.sigmoid(predictions)))
predictions = predictions.detach().numpy()

print("test set accuracy: {}".format(accuracy_score(y_test,predictions)))

## Training the model

We use `BCEWithLogitsLoss` as the loss function and SGD, `torch.optim.SGD(params, lr)` as the optimizer.

In [26]:
learning_rate = 0.005

#the loss function
loss_func = torch.nn.BCEWithLogitsLoss()

#the optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

Now we can create and run our training and validation loop.



In [27]:
#number of times we iterate trough the train set
num_epochs = 8000

for epoch in range(num_epochs):

    #step 1
    predictions_train = torch.squeeze(model(X_train)) 

    #step 2
    loss = loss_func(predictions_train, y_train) 

    #step 3
    optimizer.zero_grad() 

    #step 4
    loss.backward() 

    #step 5
    optimizer.step() 
        
    if epoch % 500 == 0:    
      print("training loss: {}".format(loss))    
      model.eval()
      with torch.inference_mode(): 
        predictions_val = torch.squeeze(torch.round(torch.sigmoid(model(X_val)))).detach().numpy()
        print("validation accuracy: {}".format(accuracy_score(y_val,predictions_val)))
      model.train()
      plot_decision_boundary(model, X_train, y_train)
      plt.show()


## Computing predictions and evaluating the model


In [None]:
model.eval()

with torch.inference_mode(): 
    predictions_test = model(X_test)

predictions_test = torch.round(torch.sigmoid(torch.squeeze(predictions_test))).detach().numpy()

print("test set accuracy: {}".format(accuracy_score(y_test,predictions_test)))

In [None]:
plot_decision_boundary(model, X_test, y_test)