# A regression model for California Housing Prices Dataset

To replicate with success the experiments conducted in the paper we tried, as mush as possible, to replicate also the model used by authors:
"We used a 8:2 train-test split and trained a regression model with 3 hidden layers with 168K parameters,
using Adam optimizer minimizing MSE for 200 epochs"

1) Import of necessary libraries and California Housing dataset fetching function:

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, StandardScaler

from tqdm import tqdm
from sklearn.datasets import fetch_california_housing
from torchmetrics import ExplainedVariance

2) Setup to enable saving checkpoints in a dedicated folder, to reuse them in the implementation of TracInCP and in other eventualities:

In [3]:
def savecheckpoint(checkpoint, filename="chpcheckpoint.pth.tar", dir_name="batch"):
  print("=> saving checkpoint")
  ph = F"./{dir_name}/{filename}"
  torch.save(checkpoint, ph)

3: Definition of the model: to replicate the results of the paper, a 3 hidden layer neural network has been defined with 164K parameters.

In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(8, 400)
        self.fc2 = nn.Linear(400, 400)
        self.fc3 = nn.Linear(400, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Create an instance of the network
model = Net()

# Count the total number of parameters in the network
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total Parameters:", total_params)

Total Parameters: 164401


5) Dataset is fetched and separated into features X and labels y:

In [4]:
data = fetch_california_housing()
X = data.data.astype(np.float32)
y = data.target

5) Then it is further processed: first the features are standardized with a standard scaler then they are divided into Train and Test sets. The trade off between train and test is set to 8:2 as in the paper.

In [5]:
# Standardize the features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Split the data into training and testing sets
X_train, X_test, Y_train, Y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)

5) The splitted sets are processed to convert them into tensors and batched. The model is loaded into CUDA if it is available, else it will run on cpu.

In [6]:
# Convert numpy arrays to PyTorch tensors
X_train = torch.tensor(X_train, dtype=torch.float)
X_test = torch.tensor(X_test, dtype=torch.float)
Y_train = torch.tensor(Y_train, dtype=torch.float).view(-1, 1)
Y_test = torch.tensor(Y_test, dtype=torch.float).view(-1, 1)

datasets = torch.utils.data.TensorDataset(X_train, Y_train)
datasets_test = torch.utils.data.TensorDataset(X_test, Y_test)

batch_size = 8

train_iter = torch.utils.data.DataLoader(datasets, batch_size=batch_size, shuffle=True)
test_iter = torch.utils.data.DataLoader(datasets_test, batch_size=batch_size, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

# Create an instance of the model
model = Net()
model.to(device)

cuda


Net(
  (fc1): Linear(in_features=8, out_features=400, bias=True)
  (fc2): Linear(in_features=400, out_features=400, bias=True)
  (fc3): Linear(in_features=400, out_features=1, bias=True)
  (relu): ReLU()
)

6) Training: in order to replicate also here the process of the paper, an Adam optimizer is exploited and the network is trained for 200 epochs.

In [7]:
import os
if not os.path.exists("CHP_checkpoints/batch_"+str(batch_size)):
    os.makedirs("CHP_checkpoints/batch_"+str(batch_size))

# Define the loss function and optimizer

loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_list = []

num_epochs = 200

for epoch in tqdm(range(num_epochs)):
    if epoch % 10 == 0:
        savecheckpoint(model.state_dict(), str(epoch) + "chpcheckpoint.pth.tar", "CHP_checkpoints/batch_"+str(batch_size))
    for x, y in train_iter:
        x = x.to(device)
        y = y.to(device)
        output = model(x)
        l = loss(output, y)
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
    loss_list.append(l.data)
    print("epoch {} loss: {:.4f}".format(epoch + 1, l.item()))

  0%|                                                   | 0/200 [00:00<?, ?it/s]

=> saving checkpoint


  0%|▏                                          | 1/200 [00:12<40:55, 12.34s/it]

epoch 1 loss: 0.1048


  1%|▍                                          | 2/200 [00:15<22:36,  6.85s/it]

epoch 2 loss: 0.1644


  2%|▋                                          | 3/200 [00:18<16:33,  5.04s/it]

epoch 3 loss: 0.3006


  2%|▊                                          | 4/200 [00:21<13:43,  4.20s/it]

epoch 4 loss: 0.7569


  2%|█                                          | 5/200 [00:23<12:03,  3.71s/it]

epoch 5 loss: 0.1401


  3%|█▎                                         | 6/200 [00:26<11:05,  3.43s/it]

epoch 6 loss: 0.2148


  4%|█▌                                         | 7/200 [00:29<10:25,  3.24s/it]

epoch 7 loss: 0.2315


  4%|█▋                                         | 8/200 [00:32<10:12,  3.19s/it]

epoch 8 loss: 0.4091


  4%|█▉                                         | 9/200 [00:35<10:00,  3.14s/it]

epoch 9 loss: 0.2213


  5%|██                                        | 10/200 [00:38<09:43,  3.07s/it]

epoch 10 loss: 0.1291
=> saving checkpoint


  6%|██▎                                       | 11/200 [00:41<09:39,  3.07s/it]

epoch 11 loss: 0.4148


  6%|██▌                                       | 12/200 [00:44<09:31,  3.04s/it]

epoch 12 loss: 0.0967


  6%|██▋                                       | 13/200 [00:47<09:19,  2.99s/it]

epoch 13 loss: 0.7460


  7%|██▉                                       | 14/200 [00:50<09:11,  2.97s/it]

epoch 14 loss: 0.2547


  8%|███▏                                      | 15/200 [00:53<09:12,  2.99s/it]

epoch 15 loss: 0.1355


  8%|███▎                                      | 16/200 [00:56<09:04,  2.96s/it]

epoch 16 loss: 0.0724


  8%|███▌                                      | 17/200 [00:59<08:59,  2.95s/it]

epoch 17 loss: 0.6862


  9%|███▊                                      | 18/200 [01:02<08:51,  2.92s/it]

epoch 18 loss: 0.0944


 10%|███▉                                      | 19/200 [01:05<08:44,  2.90s/it]

epoch 19 loss: 0.6030


 10%|████▏                                     | 20/200 [01:07<08:39,  2.88s/it]

epoch 20 loss: 0.2041
=> saving checkpoint


 10%|████▍                                     | 21/200 [01:10<08:33,  2.87s/it]

epoch 21 loss: 0.0824


 11%|████▌                                     | 22/200 [01:13<08:26,  2.84s/it]

epoch 22 loss: 0.1423


 12%|████▊                                     | 23/200 [01:16<08:25,  2.85s/it]

epoch 23 loss: 0.1628


 12%|█████                                     | 24/200 [01:19<08:23,  2.86s/it]

epoch 24 loss: 0.0742


 12%|█████▎                                    | 25/200 [01:22<08:20,  2.86s/it]

epoch 25 loss: 0.0221


 13%|█████▍                                    | 26/200 [01:25<08:21,  2.88s/it]

epoch 26 loss: 0.0620


 14%|█████▋                                    | 27/200 [01:27<08:12,  2.85s/it]

epoch 27 loss: 0.0769


 14%|█████▉                                    | 28/200 [01:31<08:25,  2.94s/it]

epoch 28 loss: 0.0748


 14%|██████                                    | 29/200 [01:34<08:28,  2.97s/it]

epoch 29 loss: 0.0733


 15%|██████▎                                   | 30/200 [01:37<08:29,  3.00s/it]

epoch 30 loss: 0.2195
=> saving checkpoint


 16%|██████▌                                   | 31/200 [01:40<08:29,  3.02s/it]

epoch 31 loss: 0.1946


 16%|██████▋                                   | 32/200 [01:43<08:25,  3.01s/it]

epoch 32 loss: 0.0573


 16%|██████▉                                   | 33/200 [01:46<08:16,  2.97s/it]

epoch 33 loss: 0.1784


 17%|███████▏                                  | 34/200 [01:48<08:06,  2.93s/it]

epoch 34 loss: 0.1377


 18%|███████▎                                  | 35/200 [01:51<08:08,  2.96s/it]

epoch 35 loss: 0.0884


 18%|███████▌                                  | 36/200 [01:54<08:06,  2.97s/it]

epoch 36 loss: 0.2047


 18%|███████▊                                  | 37/200 [01:57<07:53,  2.90s/it]

epoch 37 loss: 0.0269


 19%|███████▉                                  | 38/200 [02:00<07:46,  2.88s/it]

epoch 38 loss: 0.1016


 20%|████████▏                                 | 39/200 [02:03<07:41,  2.87s/it]

epoch 39 loss: 0.0591


 20%|████████▍                                 | 40/200 [02:06<07:37,  2.86s/it]

epoch 40 loss: 0.1833
=> saving checkpoint


 20%|████████▌                                 | 41/200 [02:09<07:33,  2.85s/it]

epoch 41 loss: 0.0586


 21%|████████▊                                 | 42/200 [02:11<07:27,  2.83s/it]

epoch 42 loss: 0.3670


 22%|█████████                                 | 43/200 [02:14<07:23,  2.82s/it]

epoch 43 loss: 1.1081


 22%|█████████▏                                | 44/200 [02:17<07:21,  2.83s/it]

epoch 44 loss: 0.4695


 22%|█████████▍                                | 45/200 [02:20<07:16,  2.82s/it]

epoch 45 loss: 0.1252


 23%|█████████▋                                | 46/200 [02:23<07:14,  2.82s/it]

epoch 46 loss: 0.2074


 24%|█████████▊                                | 47/200 [02:25<07:11,  2.82s/it]

epoch 47 loss: 0.0371


 24%|██████████                                | 48/200 [02:28<07:08,  2.82s/it]

epoch 48 loss: 0.4170


 24%|██████████▎                               | 49/200 [02:31<07:04,  2.81s/it]

epoch 49 loss: 0.0498


 25%|██████████▌                               | 50/200 [02:34<07:01,  2.81s/it]

epoch 50 loss: 0.5287
=> saving checkpoint


 26%|██████████▋                               | 51/200 [02:37<07:00,  2.82s/it]

epoch 51 loss: 0.0562


 26%|██████████▉                               | 52/200 [02:40<06:58,  2.82s/it]

epoch 52 loss: 0.2288


 26%|███████████▏                              | 53/200 [02:42<06:55,  2.83s/it]

epoch 53 loss: 0.0204


 27%|███████████▎                              | 54/200 [02:45<06:49,  2.81s/it]

epoch 54 loss: 0.2253


 28%|███████████▌                              | 55/200 [02:48<06:49,  2.82s/it]

epoch 55 loss: 0.0693


 28%|███████████▊                              | 56/200 [02:51<06:49,  2.84s/it]

epoch 56 loss: 0.1608


 28%|███████████▉                              | 57/200 [02:54<06:52,  2.88s/it]

epoch 57 loss: 0.3940


 29%|████████████▏                             | 58/200 [02:57<06:48,  2.88s/it]

epoch 58 loss: 0.3277


 30%|████████████▍                             | 59/200 [03:00<06:44,  2.87s/it]

epoch 59 loss: 0.0516


 30%|████████████▌                             | 60/200 [03:02<06:42,  2.87s/it]

epoch 60 loss: 0.0242
=> saving checkpoint


 30%|████████████▊                             | 61/200 [03:05<06:39,  2.87s/it]

epoch 61 loss: 0.1165


 31%|█████████████                             | 62/200 [03:08<06:35,  2.87s/it]

epoch 62 loss: 0.2623


 32%|█████████████▏                            | 63/200 [03:11<06:34,  2.88s/it]

epoch 63 loss: 0.0239


 32%|█████████████▍                            | 64/200 [03:14<06:28,  2.85s/it]

epoch 64 loss: 0.0288


 32%|█████████████▋                            | 65/200 [03:17<06:21,  2.83s/it]

epoch 65 loss: 0.3158


 33%|█████████████▊                            | 66/200 [03:20<06:21,  2.85s/it]

epoch 66 loss: 0.0587


 34%|██████████████                            | 67/200 [03:22<06:18,  2.84s/it]

epoch 67 loss: 0.4682


 34%|██████████████▎                           | 68/200 [03:25<06:15,  2.85s/it]

epoch 68 loss: 0.0763


 34%|██████████████▍                           | 69/200 [03:28<06:14,  2.86s/it]

epoch 69 loss: 0.7741


 35%|██████████████▋                           | 70/200 [03:31<06:11,  2.86s/it]

epoch 70 loss: 0.0660
=> saving checkpoint


 36%|██████████████▉                           | 71/200 [03:34<06:07,  2.85s/it]

epoch 71 loss: 0.0566


 36%|███████████████                           | 72/200 [03:37<06:03,  2.84s/it]

epoch 72 loss: 0.1474


 36%|███████████████▎                          | 73/200 [03:39<06:00,  2.84s/it]

epoch 73 loss: 0.1050


 37%|███████████████▌                          | 74/200 [03:42<05:58,  2.85s/it]

epoch 74 loss: 0.0856


 38%|███████████████▊                          | 75/200 [03:45<05:54,  2.84s/it]

epoch 75 loss: 0.0661


 38%|███████████████▉                          | 76/200 [03:48<05:50,  2.83s/it]

epoch 76 loss: 0.0400


 38%|████████████████▏                         | 77/200 [03:51<05:47,  2.82s/it]

epoch 77 loss: 0.1098


 39%|████████████████▍                         | 78/200 [03:54<05:45,  2.83s/it]

epoch 78 loss: 0.1124


 40%|████████████████▌                         | 79/200 [03:56<05:43,  2.84s/it]

epoch 79 loss: 0.0647


 40%|████████████████▊                         | 80/200 [03:59<05:40,  2.84s/it]

epoch 80 loss: 0.1891
=> saving checkpoint


 40%|█████████████████                         | 81/200 [04:02<05:41,  2.87s/it]

epoch 81 loss: 0.1151


 41%|█████████████████▏                        | 82/200 [04:05<05:37,  2.86s/it]

epoch 82 loss: 0.1531


 42%|█████████████████▍                        | 83/200 [04:08<05:38,  2.90s/it]

epoch 83 loss: 0.0344


 42%|█████████████████▋                        | 84/200 [04:11<05:38,  2.92s/it]

epoch 84 loss: 0.0555


 42%|█████████████████▊                        | 85/200 [04:14<05:35,  2.92s/it]

epoch 85 loss: 0.0833


 43%|██████████████████                        | 86/200 [04:17<05:36,  2.96s/it]

epoch 86 loss: 0.1621


 44%|██████████████████▎                       | 87/200 [04:20<05:28,  2.91s/it]

epoch 87 loss: 0.0986


 44%|██████████████████▍                       | 88/200 [04:23<05:22,  2.88s/it]

epoch 88 loss: 0.0629


 44%|██████████████████▋                       | 89/200 [04:26<05:20,  2.89s/it]

epoch 89 loss: 0.2165


 45%|██████████████████▉                       | 90/200 [04:28<05:18,  2.89s/it]

epoch 90 loss: 0.1361
=> saving checkpoint


 46%|███████████████████                       | 91/200 [04:31<05:14,  2.89s/it]

epoch 91 loss: 0.3768


 46%|███████████████████▎                      | 92/200 [04:34<05:10,  2.87s/it]

epoch 92 loss: 0.3322


 46%|███████████████████▌                      | 93/200 [04:37<05:10,  2.90s/it]

epoch 93 loss: 0.2624


 47%|███████████████████▋                      | 94/200 [04:40<05:08,  2.91s/it]

epoch 94 loss: 0.1677


 48%|███████████████████▉                      | 95/200 [04:43<05:05,  2.91s/it]

epoch 95 loss: 0.1472


 48%|████████████████████▏                     | 96/200 [04:46<05:01,  2.90s/it]

epoch 96 loss: 0.2296


 48%|████████████████████▎                     | 97/200 [04:49<05:01,  2.92s/it]

epoch 97 loss: 0.1486


 49%|████████████████████▌                     | 98/200 [04:52<04:56,  2.91s/it]

epoch 98 loss: 0.2623


 50%|████████████████████▊                     | 99/200 [04:55<04:52,  2.90s/it]

epoch 99 loss: 0.1138


 50%|████████████████████▌                    | 100/200 [04:57<04:45,  2.86s/it]

epoch 100 loss: 0.1055
=> saving checkpoint


 50%|████████████████████▋                    | 101/200 [05:00<04:43,  2.86s/it]

epoch 101 loss: 0.2290


 51%|████████████████████▉                    | 102/200 [05:03<04:40,  2.86s/it]

epoch 102 loss: 0.0348


 52%|█████████████████████                    | 103/200 [05:06<04:37,  2.86s/it]

epoch 103 loss: 0.3343


 52%|█████████████████████▎                   | 104/200 [05:09<04:39,  2.91s/it]

epoch 104 loss: 0.1162


 52%|█████████████████████▌                   | 105/200 [05:12<04:38,  2.93s/it]

epoch 105 loss: 0.0889


 53%|█████████████████████▋                   | 106/200 [05:15<04:38,  2.96s/it]

epoch 106 loss: 0.0408


 54%|█████████████████████▉                   | 107/200 [05:18<04:34,  2.95s/it]

epoch 107 loss: 0.0404


 54%|██████████████████████▏                  | 108/200 [05:21<04:27,  2.90s/it]

epoch 108 loss: 0.0553


 55%|██████████████████████▎                  | 109/200 [05:23<04:21,  2.88s/it]

epoch 109 loss: 0.0832


 55%|██████████████████████▌                  | 110/200 [05:26<04:17,  2.87s/it]

epoch 110 loss: 0.1016
=> saving checkpoint


 56%|██████████████████████▊                  | 111/200 [05:29<04:13,  2.85s/it]

epoch 111 loss: 0.0102


 56%|██████████████████████▉                  | 112/200 [05:32<04:10,  2.84s/it]

epoch 112 loss: 0.0734


 56%|███████████████████████▏                 | 113/200 [05:35<04:06,  2.84s/it]

epoch 113 loss: 0.5250


 57%|███████████████████████▎                 | 114/200 [05:38<04:02,  2.82s/it]

epoch 114 loss: 0.1473


 57%|███████████████████████▌                 | 115/200 [05:40<03:58,  2.81s/it]

epoch 115 loss: 0.0466


 58%|███████████████████████▊                 | 116/200 [05:43<03:54,  2.80s/it]

epoch 116 loss: 0.1033


 58%|███████████████████████▉                 | 117/200 [05:46<03:51,  2.79s/it]

epoch 117 loss: 0.0284


 59%|████████████████████████▏                | 118/200 [05:49<03:49,  2.80s/it]

epoch 118 loss: 0.0339


 60%|████████████████████████▍                | 119/200 [05:52<03:46,  2.80s/it]

epoch 119 loss: 0.1645


 60%|████████████████████████▌                | 120/200 [05:54<03:43,  2.80s/it]

epoch 120 loss: 0.1061
=> saving checkpoint


 60%|████████████████████████▊                | 121/200 [05:57<03:41,  2.80s/it]

epoch 121 loss: 0.0776


 61%|█████████████████████████                | 122/200 [06:00<03:44,  2.87s/it]

epoch 122 loss: 0.0855


 62%|█████████████████████████▏               | 123/200 [06:03<03:44,  2.92s/it]

epoch 123 loss: 0.0682


 62%|█████████████████████████▍               | 124/200 [06:06<03:43,  2.95s/it]

epoch 124 loss: 0.0660


 62%|█████████████████████████▋               | 125/200 [06:09<03:39,  2.93s/it]

epoch 125 loss: 0.0084


 63%|█████████████████████████▊               | 126/200 [06:12<03:38,  2.95s/it]

epoch 126 loss: 0.2039


 64%|██████████████████████████               | 127/200 [06:15<03:34,  2.94s/it]

epoch 127 loss: 0.1299


 64%|██████████████████████████▏              | 128/200 [06:18<03:31,  2.93s/it]

epoch 128 loss: 0.0239


 64%|██████████████████████████▍              | 129/200 [06:21<03:25,  2.89s/it]

epoch 129 loss: 0.0420


 65%|██████████████████████████▋              | 130/200 [06:24<03:24,  2.92s/it]

epoch 130 loss: 0.1459
=> saving checkpoint


 66%|██████████████████████████▊              | 131/200 [06:27<03:22,  2.93s/it]

epoch 131 loss: 0.0155


 66%|███████████████████████████              | 132/200 [06:30<03:20,  2.95s/it]

epoch 132 loss: 0.0958


 66%|███████████████████████████▎             | 133/200 [06:33<03:16,  2.94s/it]

epoch 133 loss: 0.0410


 67%|███████████████████████████▍             | 134/200 [06:35<03:14,  2.94s/it]

epoch 134 loss: 0.1030


 68%|███████████████████████████▋             | 135/200 [06:38<03:11,  2.95s/it]

epoch 135 loss: 0.1279


 68%|███████████████████████████▉             | 136/200 [06:41<03:07,  2.92s/it]

epoch 136 loss: 0.2852


 68%|████████████████████████████             | 137/200 [06:44<03:03,  2.91s/it]

epoch 137 loss: 0.0644


 69%|████████████████████████████▎            | 138/200 [06:47<02:59,  2.89s/it]

epoch 138 loss: 0.2878


 70%|████████████████████████████▍            | 139/200 [06:50<02:56,  2.89s/it]

epoch 139 loss: 0.0466


 70%|████████████████████████████▋            | 140/200 [06:53<02:53,  2.90s/it]

epoch 140 loss: 0.5624
=> saving checkpoint


 70%|████████████████████████████▉            | 141/200 [06:56<02:51,  2.91s/it]

epoch 141 loss: 0.0447


 71%|█████████████████████████████            | 142/200 [06:59<02:47,  2.90s/it]

epoch 142 loss: 0.0204


 72%|█████████████████████████████▎           | 143/200 [07:01<02:44,  2.88s/it]

epoch 143 loss: 0.0593


 72%|█████████████████████████████▌           | 144/200 [07:04<02:40,  2.87s/it]

epoch 144 loss: 0.0962


 72%|█████████████████████████████▋           | 145/200 [07:07<02:37,  2.86s/it]

epoch 145 loss: 0.2807


 73%|█████████████████████████████▉           | 146/200 [07:10<02:34,  2.86s/it]

epoch 146 loss: 0.0611


 74%|██████████████████████████████▏          | 147/200 [07:13<02:31,  2.85s/it]

epoch 147 loss: 0.3791


 74%|██████████████████████████████▎          | 148/200 [07:16<02:27,  2.84s/it]

epoch 148 loss: 0.0944


 74%|██████████████████████████████▌          | 149/200 [07:19<02:25,  2.84s/it]

epoch 149 loss: 0.2204


 75%|██████████████████████████████▊          | 150/200 [07:21<02:21,  2.83s/it]

epoch 150 loss: 0.0570
=> saving checkpoint


 76%|██████████████████████████████▉          | 151/200 [07:24<02:19,  2.84s/it]

epoch 151 loss: 0.0998


 76%|███████████████████████████████▏         | 152/200 [07:27<02:17,  2.87s/it]

epoch 152 loss: 0.1915


 76%|███████████████████████████████▎         | 153/200 [07:30<02:15,  2.88s/it]

epoch 153 loss: 0.0454


 77%|███████████████████████████████▌         | 154/200 [07:33<02:11,  2.87s/it]

epoch 154 loss: 0.0368


 78%|███████████████████████████████▊         | 155/200 [07:36<02:08,  2.85s/it]

epoch 155 loss: 0.0911


 78%|███████████████████████████████▉         | 156/200 [07:39<02:05,  2.84s/it]

epoch 156 loss: 0.1720


 78%|████████████████████████████████▏        | 157/200 [07:42<02:06,  2.95s/it]

epoch 157 loss: 0.0136


 79%|████████████████████████████████▍        | 158/200 [07:45<02:06,  3.02s/it]

epoch 158 loss: 0.0381


 80%|████████████████████████████████▌        | 159/200 [07:48<02:02,  2.98s/it]

epoch 159 loss: 0.1145


 80%|████████████████████████████████▊        | 160/200 [07:51<01:58,  2.97s/it]

epoch 160 loss: 0.0728
=> saving checkpoint


 80%|█████████████████████████████████        | 161/200 [07:54<01:55,  2.96s/it]

epoch 161 loss: 0.0740


 81%|█████████████████████████████████▏       | 162/200 [07:57<01:52,  2.96s/it]

epoch 162 loss: 0.0555


 82%|█████████████████████████████████▍       | 163/200 [08:00<01:49,  2.95s/it]

epoch 163 loss: 0.0486


 82%|█████████████████████████████████▌       | 164/200 [08:03<01:47,  2.99s/it]

epoch 164 loss: 0.0248


 82%|█████████████████████████████████▊       | 165/200 [08:06<01:43,  2.96s/it]

epoch 165 loss: 0.0667


 83%|██████████████████████████████████       | 166/200 [08:08<01:40,  2.94s/it]

epoch 166 loss: 0.0547


 84%|██████████████████████████████████▏      | 167/200 [08:11<01:36,  2.91s/it]

epoch 167 loss: 0.6867


 84%|██████████████████████████████████▍      | 168/200 [08:14<01:33,  2.91s/it]

epoch 168 loss: 0.0884


 84%|██████████████████████████████████▋      | 169/200 [08:17<01:29,  2.88s/it]

epoch 169 loss: 0.0636


 85%|██████████████████████████████████▊      | 170/200 [08:20<01:25,  2.86s/it]

epoch 170 loss: 0.0816
=> saving checkpoint


 86%|███████████████████████████████████      | 171/200 [08:23<01:23,  2.86s/it]

epoch 171 loss: 0.1034


 86%|███████████████████████████████████▎     | 172/200 [08:26<01:20,  2.86s/it]

epoch 172 loss: 0.1558


 86%|███████████████████████████████████▍     | 173/200 [08:28<01:17,  2.86s/it]

epoch 173 loss: 0.0590


 87%|███████████████████████████████████▋     | 174/200 [08:31<01:14,  2.85s/it]

epoch 174 loss: 0.0750


 88%|███████████████████████████████████▉     | 175/200 [08:34<01:11,  2.85s/it]

epoch 175 loss: 0.3733


 88%|████████████████████████████████████     | 176/200 [08:37<01:07,  2.83s/it]

epoch 176 loss: 0.0439


 88%|████████████████████████████████████▎    | 177/200 [08:40<01:04,  2.81s/it]

epoch 177 loss: 0.1281


 89%|████████████████████████████████████▍    | 178/200 [08:42<01:01,  2.82s/it]

epoch 178 loss: 0.0290


 90%|████████████████████████████████████▋    | 179/200 [08:45<00:59,  2.82s/it]

epoch 179 loss: 0.0423


 90%|████████████████████████████████████▉    | 180/200 [08:48<00:56,  2.81s/it]

epoch 180 loss: 0.0219
=> saving checkpoint


 90%|█████████████████████████████████████    | 181/200 [08:51<00:53,  2.81s/it]

epoch 181 loss: 0.1593


 91%|█████████████████████████████████████▎   | 182/200 [08:54<00:50,  2.82s/it]

epoch 182 loss: 0.2695


 92%|█████████████████████████████████████▌   | 183/200 [08:57<00:47,  2.81s/it]

epoch 183 loss: 0.5447


 92%|█████████████████████████████████████▋   | 184/200 [08:59<00:45,  2.82s/it]

epoch 184 loss: 0.2959


 92%|█████████████████████████████████████▉   | 185/200 [09:02<00:42,  2.86s/it]

epoch 185 loss: 0.0901


 93%|██████████████████████████████████████▏  | 186/200 [09:06<00:41,  2.97s/it]

epoch 186 loss: 0.1352


 94%|██████████████████████████████████████▎  | 187/200 [09:08<00:38,  2.96s/it]

epoch 187 loss: 0.1027


 94%|██████████████████████████████████████▌  | 188/200 [09:11<00:35,  2.92s/it]

epoch 188 loss: 0.2528


 94%|██████████████████████████████████████▋  | 189/200 [09:14<00:31,  2.90s/it]

epoch 189 loss: 0.0680


 95%|██████████████████████████████████████▉  | 190/200 [09:17<00:28,  2.87s/it]

epoch 190 loss: 0.0602
=> saving checkpoint


 96%|███████████████████████████████████████▏ | 191/200 [09:20<00:25,  2.85s/it]

epoch 191 loss: 0.0402


 96%|███████████████████████████████████████▎ | 192/200 [09:23<00:22,  2.82s/it]

epoch 192 loss: 0.2767


 96%|███████████████████████████████████████▌ | 193/200 [09:25<00:19,  2.82s/it]

epoch 193 loss: 0.1397


 97%|███████████████████████████████████████▊ | 194/200 [09:28<00:16,  2.81s/it]

epoch 194 loss: 0.0995


 98%|███████████████████████████████████████▉ | 195/200 [09:31<00:14,  2.86s/it]

epoch 195 loss: 0.0111


 98%|████████████████████████████████████████▏| 196/200 [09:34<00:11,  2.87s/it]

epoch 196 loss: 0.0296


 98%|████████████████████████████████████████▍| 197/200 [09:37<00:08,  2.86s/it]

epoch 197 loss: 0.1141


 99%|████████████████████████████████████████▌| 198/200 [09:40<00:05,  2.84s/it]

epoch 198 loss: 0.0322


100%|████████████████████████████████████████▊| 199/200 [09:42<00:02,  2.83s/it]

epoch 199 loss: 0.6919


100%|█████████████████████████████████████████| 200/200 [09:45<00:00,  2.93s/it]

epoch 200 loss: 0.1518





7) Results: the model built achieved an explained variance of 0.72 on test set and 0.93 on train set, which is very similar to the variances obtained in the paper (respectively 0.72 and 0.70).
<p align = "center">
$$Explained Variance = 1 - \frac{Var(y - \hat{y})}{Var(y)}$$
</p>


In [27]:
# Evaluate the model
model.eval()
with torch.no_grad():
    X_test = X_test.to(device)
    Y_test = Y_test.to(device)
    predicted = model(X_test)
    mse = loss(predicted, Y_test)
    mae = torch.mean(torch.abs(predicted - Y_test))

    explained_variance = ExplainedVariance()
    exp_var_test = explained_variance(predicted, Y_test)

    X_train = X_train.to(device)
    Y_train = Y_train.to(device)
    predicted_train = model(X_train)

    exp_var_train = explained_variance(predicted_train, Y_train)

    print("Mean Squared Error: {:.4f}".format(mse))
    print("Mean Absolute Error: {:.4f}".format(mae))
    print("Explained Variance on test set: {:.4f}".format(exp_var_test))
    print("Explained Variance on train set: {:.4f}".format(exp_var_train))

Mean Squared Error: 0.3043
Mean Absolute Error: 0.3598
Explained Variance on test set: 0.7678
Explained Variance on train set: 0.9305
