# Exercise 2
Due:  Tue November 5, 8:00am

In [12]:
!pip install torch_geometric ogb torch-scatter



In [54]:
import torch
import torch_geometric as pyg
import numpy as np
import scipy
from ogb.graphproppred import PygGraphPropPredDataset
from ogb.graphproppred.mol_encoder import AtomEncoder
import ogb
import torch_scatter
import copy

import time
import random
from tqdm import tqdm
import os

In this exercise, we use sparse message passing to make our networks scale to larger graphs. 


1) In this exercise we are working with the node-classification dataset Cora and the graph regression dataset ZINC. When working with a new dataset, it makes sense to at least quickly look into the data and some statistics for it. So for Cora: which is the second-biggest label class and what does it stand for? And for ZINC: how many HCO molecules (i.e. molecules consisting only of Hydrogen, Carbon, and Oxygen) are in the train set?

1) When working on the Cora dataset your model should at least reach an accuracy of 0.6 (an accuracy of 0.7-0.8 is well within reach).
Cora is a node classification dataset, so there is only one graph and we perform message passing on the whole graph (but evaluate the loss only on the nodes selected by cora_graph.train_mask).
The dataset is mostly balanced, so we evaluate the accuracy.
When implementing the message passing step, keep in mind that the graph does not contain self-loops (so one needs to somehow treat the "old" state).
Since Cora is small enough to be run with dense tensors too, you can verify your implementation this way.

2) ZINC is a small molecular regression dataset. Please compare the difference in performance between the (trainable) Atomencoder provided by ogb and the one-hot encoding you implemented in the first exercise.
Note that since you need batches, you need to modify the pooling layer to respect the batches.


In [14]:
# find device
if torch.cuda.is_available(): # NVIDIA
    device = torch.device('cuda')
elif torch.backends.mps.is_available(): # apple M1/M2
    device = torch.device('mps') 
else:
    device = torch.device('cpu')
device

device(type='cuda')

In [15]:
def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

set_seed()


Random seed set as 42


## Cora

In [16]:
cora = pyg.datasets.Planetoid(root = "dataset/cora", name="Cora")
cora_graph = cora[0]
cora_dense_adj = pyg.utils.to_dense_adj(cora_graph.edge_index).to(device)
# cora_graph.x = cora_graph.x.unsqueeze(0) # Add an empty batch dimension. I needed that for compatibility with MolHIV later.
cora_graph = cora_graph.to(device)

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [17]:
torch.unique(cora_graph.y)

tensor([0, 1, 2, 3, 4, 5, 6], device='cuda:0')

In [18]:
# dataloader with batch size 1
cora_loader = pyg.data.DataLoader([cora_graph], batch_size=1, shuffle=True)



In [19]:
def get_accuracy(model, cora_loader, mask):
    model.eval()
    with torch.no_grad():
        accs = []
        for data in cora_loader:
            outputs = model(data.x, data.edge_index, data.batch)
            correct = (outputs[mask].argmax(-1) == data.y[mask]).sum()
            accs.append(int(correct) / int(mask.sum()))
        return np.mean(accs)

In [116]:
class GCNLayer(torch.nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        activation=torch.nn.functional.relu,
        reduction="mean",
    ):
        super(GCNLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.activation = activation
        self.reduction = reduction
        self.U = torch.nn.Sequential(
            torch.nn.Linear(in_features, out_features, bias=False),
            torch.nn.ReLU(),
            torch.nn.Linear(out_features, out_features, bias=False),
        )

        # Initialize weights
        torch.nn.init.kaiming_normal_(self.U[0].weight)
        torch.nn.init.kaiming_normal_(self.U[2].weight)

    def _add_remaining_self_loops(self, edge_index):
        num_nodes = edge_index.max().item() + 1
        mask = edge_index[0] != edge_index[1]
        loop_index = (
            torch.arange(num_nodes, device=edge_index.device).view(1, -1).repeat(2, 1)
        )
        edge_index = edge_index[:, mask]
        edge_index = torch.cat([edge_index, loop_index], dim=1)
        return edge_index

    def forward(self, H: torch.Tensor, edge_index: torch.Tensor):
        edge_index = self._add_remaining_self_loops(edge_index)
        Y = H[edge_index[0]]
        Z = pyg.utils.scatter(
            Y, edge_index[1], dim=0, dim_size=H.size(0), reduce=self.reduction
        )
#         print("H shape: ", H.shape)
#         print("Z shape: ", Z.shape)
        return self.activation(self.U(Z))

In [127]:
class GraphNet(torch.nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        hidden_features: int,
        num_layers=2,
        activation=torch.nn.functional.relu,
        dropout=0.1,
        reduction="mean",
        pooling="mean",
        task="node",
        feature_encoder=None,
    ):
        super(GraphNet, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.hidden_features = hidden_features
        self.activation = activation
        self.dropout = dropout
        self.reduction = reduction
        self.pooling = pooling
        self.task = task
        self.feature_encoder = feature_encoder

        self.layers = torch.nn.ModuleList(
            [
                GCNLayer(
                    in_features if i == 0 else hidden_features,
                    (
                        hidden_features
                        if (i < num_layers - 1) # if not the last layer
                        else (out_features if task == "node" else hidden_features) # if last layer: output features if node task, else hidden features as we will add a MLP
                    ),
                    activation,
                    reduction,
                )
                for i in range(num_layers)
            ]
        )
        
        if self.task != "node":
            self.bns = torch.nn.ModuleList(
                [
                    torch.nn.BatchNorm1d(
                        hidden_features
                        if i < num_layers - 1
                        else (out_features if task == "node" else hidden_features)
                    )
                    for i in range(num_layers)
                ]
            )

            self.mlp = torch.nn.Sequential(
                torch.nn.Linear(hidden_features, hidden_features),
                torch.nn.ReLU(),
                torch.nn.Dropout(dropout),
                torch.nn.Linear(hidden_features, out_features),
            )

    def forward(self, H: torch.Tensor, edge_index: torch.Tensor, batch: torch.Tensor):
        if self.feature_encoder is not None:
            H = self.feature_encoder(H)
#         print("H shape:", H.shape)
        for i in range(len(self.layers)):
            H = self.layers[i](H, edge_index)
            if self.task != "node":
                H = self.bns[i](H)
                H = torch.nn.functional.dropout(H, p=self.dropout, training=self.training)
            
        
        if self.task != "node":
            # global pooling via scatter
            H = pyg.utils.scatter(
                H, batch, dim=0, dim_size=batch.max().item() + 1, reduce=self.pooling
            )
            H = self.mlp(H)
        return H

In [22]:
cora_num_classes = len(torch.unique(cora_graph.y))
model = GraphNet(
    in_features=cora_graph.num_node_features,
    out_features=cora_num_classes,
    hidden_features=32,
    num_layers=2,
    reduction="mean",
).to(device)

In [23]:
model

GraphNet(
  (layers): ModuleList(
    (0): GCNLayer(
      (U): Sequential(
        (0): Linear(in_features=1433, out_features=32, bias=False)
        (1): ReLU()
        (2): Linear(in_features=32, out_features=32, bias=False)
      )
    )
    (1): GCNLayer(
      (U): Sequential(
        (0): Linear(in_features=32, out_features=7, bias=False)
        (1): ReLU()
        (2): Linear(in_features=7, out_features=7, bias=False)
      )
    )
  )
)

In [24]:
# Training loop goes here
def train_cora(model, data_loader, epochs=200, lr=0.01, weight_decay=5e-4, device=device):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = torch.nn.CrossEntropyLoss()

    for epoch in tqdm(range(epochs)):
        val_accs = []
        train_accs = []
        for data in data_loader:
            model.train()
            optimizer.zero_grad()
            out = model(data.x, data.edge_index, data.batch)
            print(out.shape)
            loss = loss_fn(out[data.train_mask], data.y[data.train_mask])
            loss.backward()
            optimizer.step()   
            train_accs.append(get_accuracy(model, data_loader, data.train_mask))
            val_accs.append(get_accuracy(model, data_loader, data.val_mask))
        print(f"Epoch {epoch}, Validation Accuracy: {np.mean(val_accs)}, Training Accuracy: {np.mean(train_accs)}")

In [25]:
train_cora(model, cora_loader, epochs=200, lr=0.01, weight_decay=5e-4, device=device)

  0%|          | 0/200 [00:00<?, ?it/s]

torch.Size([2708, 7])


  4%|▍         | 8/200 [00:00<00:14, 13.22it/s]

Epoch 0, Validation Accuracy: 0.16, Training Accuracy: 0.2357142857142857
torch.Size([2708, 7])
Epoch 1, Validation Accuracy: 0.318, Training Accuracy: 0.45714285714285713
torch.Size([2708, 7])
Epoch 2, Validation Accuracy: 0.374, Training Accuracy: 0.5
torch.Size([2708, 7])
Epoch 3, Validation Accuracy: 0.37, Training Accuracy: 0.5857142857142857
torch.Size([2708, 7])
Epoch 4, Validation Accuracy: 0.416, Training Accuracy: 0.6357142857142857
torch.Size([2708, 7])
Epoch 5, Validation Accuracy: 0.47, Training Accuracy: 0.6642857142857143
torch.Size([2708, 7])
Epoch 6, Validation Accuracy: 0.5, Training Accuracy: 0.7357142857142858
torch.Size([2708, 7])
Epoch 7, Validation Accuracy: 0.592, Training Accuracy: 0.7857142857142857
torch.Size([2708, 7])
Epoch 8, Validation Accuracy: 0.636, Training Accuracy: 0.8214285714285714
torch.Size([2708, 7])
Epoch 9, Validation Accuracy: 0.654, Training Accuracy: 0.8285714285714286
torch.Size([2708, 7])
Epoch 10, Validation Accuracy: 0.676, Training Ac

 10%|█         | 21/200 [00:00<00:05, 32.51it/s]

Epoch 13, Validation Accuracy: 0.73, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 14, Validation Accuracy: 0.732, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 15, Validation Accuracy: 0.73, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 16, Validation Accuracy: 0.73, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 17, Validation Accuracy: 0.72, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 18, Validation Accuracy: 0.718, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 19, Validation Accuracy: 0.72, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 20, Validation Accuracy: 0.712, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 21, Validation Accuracy: 0.712, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 22, Validation Accuracy: 0.712, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 23, Validation Accu

 17%|█▋        | 34/200 [00:01<00:03, 44.95it/s]

Epoch 26, Validation Accuracy: 0.716, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 27, Validation Accuracy: 0.712, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 28, Validation Accuracy: 0.712, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 29, Validation Accuracy: 0.71, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 30, Validation Accuracy: 0.71, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 31, Validation Accuracy: 0.71, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 32, Validation Accuracy: 0.708, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 33, Validation Accuracy: 0.71, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 34, Validation Accuracy: 0.71, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 35, Validation Accuracy: 0.712, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 36, Validation Accu

 24%|██▍       | 48/200 [00:01<00:02, 54.85it/s]

Epoch 40, Validation Accuracy: 0.71, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 41, Validation Accuracy: 0.708, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 42, Validation Accuracy: 0.71, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 43, Validation Accuracy: 0.71, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 44, Validation Accuracy: 0.71, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 45, Validation Accuracy: 0.712, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 46, Validation Accuracy: 0.714, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 47, Validation Accuracy: 0.714, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 48, Validation Accuracy: 0.712, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 49, Validation Accuracy: 0.712, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 50, Validation Acc

 31%|███       | 62/200 [00:01<00:02, 61.71it/s]

Epoch 54, Validation Accuracy: 0.712, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 55, Validation Accuracy: 0.712, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 56, Validation Accuracy: 0.712, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 57, Validation Accuracy: 0.712, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 58, Validation Accuracy: 0.71, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 59, Validation Accuracy: 0.708, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 60, Validation Accuracy: 0.706, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 61, Validation Accuracy: 0.706, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 62, Validation Accuracy: 0.706, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 63, Validation Accuracy: 0.706, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 64, Validation 

 38%|███▊      | 77/200 [00:01<00:01, 65.56it/s]

torch.Size([2708, 7])
Epoch 69, Validation Accuracy: 0.71, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 70, Validation Accuracy: 0.71, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 71, Validation Accuracy: 0.71, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 72, Validation Accuracy: 0.712, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 73, Validation Accuracy: 0.712, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 74, Validation Accuracy: 0.708, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 75, Validation Accuracy: 0.71, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 76, Validation Accuracy: 0.71, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 77, Validation Accuracy: 0.71, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 78, Validation Accuracy: 0.71, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch

 46%|████▌     | 91/200 [00:02<00:01, 66.84it/s]

torch.Size([2708, 7])
Epoch 83, Validation Accuracy: 0.714, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 84, Validation Accuracy: 0.716, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 85, Validation Accuracy: 0.716, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 86, Validation Accuracy: 0.718, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 87, Validation Accuracy: 0.72, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 88, Validation Accuracy: 0.718, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 89, Validation Accuracy: 0.716, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 90, Validation Accuracy: 0.72, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 91, Validation Accuracy: 0.72, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 92, Validation Accuracy: 0.71, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Ep

 53%|█████▎    | 106/200 [00:02<00:01, 68.51it/s]

torch.Size([2708, 7])
Epoch 97, Validation Accuracy: 0.716, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 98, Validation Accuracy: 0.72, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 99, Validation Accuracy: 0.716, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 100, Validation Accuracy: 0.714, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 101, Validation Accuracy: 0.712, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 102, Validation Accuracy: 0.712, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 103, Validation Accuracy: 0.712, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 104, Validation Accuracy: 0.714, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 105, Validation Accuracy: 0.716, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 106, Validation Accuracy: 0.714, Training Accuracy: 0.8571428571428571
torch.Size([27

 60%|██████    | 120/200 [00:02<00:01, 68.76it/s]

Epoch 111, Validation Accuracy: 0.708, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 112, Validation Accuracy: 0.708, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 113, Validation Accuracy: 0.708, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 114, Validation Accuracy: 0.71, Training Accuracy: 0.8571428571428571
torch.Size([2708, 7])
Epoch 115, Validation Accuracy: 0.714, Training Accuracy: 0.8857142857142857
torch.Size([2708, 7])
Epoch 116, Validation Accuracy: 0.744, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 117, Validation Accuracy: 0.754, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 118, Validation Accuracy: 0.754, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 119, Validation Accuracy: 0.758, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 120, Validation Accuracy: 0.768, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 121, Validation Accuracy: 0.756, Training Accuracy: 1.0
torch.Size([2708, 7])
Ep

 68%|██████▊   | 135/200 [00:02<00:00, 69.36it/s]

Epoch 125, Validation Accuracy: 0.694, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 126, Validation Accuracy: 0.714, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 127, Validation Accuracy: 0.746, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 128, Validation Accuracy: 0.752, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 129, Validation Accuracy: 0.752, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 130, Validation Accuracy: 0.754, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 131, Validation Accuracy: 0.758, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 132, Validation Accuracy: 0.756, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 133, Validation Accuracy: 0.754, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 134, Validation Accuracy: 0.746, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 135, Validation Accuracy: 0.748, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 136, Validation Accuracy: 0.75, Training Accuracy: 1.0
torch.Size([270

 74%|███████▍  | 149/200 [00:02<00:00, 69.46it/s]

torch.Size([2708, 7])
Epoch 140, Validation Accuracy: 0.73, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 141, Validation Accuracy: 0.732, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 142, Validation Accuracy: 0.73, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 143, Validation Accuracy: 0.738, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 144, Validation Accuracy: 0.742, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 145, Validation Accuracy: 0.744, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 146, Validation Accuracy: 0.744, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 147, Validation Accuracy: 0.746, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 148, Validation Accuracy: 0.74, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 149, Validation Accuracy: 0.746, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 150, Validation Accuracy: 0.752, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 151, Validation Accuracy: 0.76, Training Accuracy: 

 82%|████████▏ | 164/200 [00:03<00:00, 69.70it/s]

Epoch 154, Validation Accuracy: 0.756, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 155, Validation Accuracy: 0.754, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 156, Validation Accuracy: 0.752, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 157, Validation Accuracy: 0.752, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 158, Validation Accuracy: 0.752, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 159, Validation Accuracy: 0.75, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 160, Validation Accuracy: 0.75, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 161, Validation Accuracy: 0.748, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 162, Validation Accuracy: 0.752, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 163, Validation Accuracy: 0.756, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 164, Validation Accuracy: 0.756, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 165, Validation Accuracy: 0.756, Training Accuracy: 1.0
torch.Size([2708

 90%|████████▉ | 179/200 [00:03<00:00, 70.14it/s]

Epoch 169, Validation Accuracy: 0.754, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 170, Validation Accuracy: 0.754, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 171, Validation Accuracy: 0.756, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 172, Validation Accuracy: 0.756, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 173, Validation Accuracy: 0.756, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 174, Validation Accuracy: 0.754, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 175, Validation Accuracy: 0.752, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 176, Validation Accuracy: 0.752, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 177, Validation Accuracy: 0.752, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 178, Validation Accuracy: 0.752, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 179, Validation Accuracy: 0.754, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 180, Validation Accuracy: 0.754, Training Accuracy: 1.0
torch.Size([27

 98%|█████████▊| 195/200 [00:03<00:00, 70.67it/s]

torch.Size([2708, 7])
Epoch 184, Validation Accuracy: 0.752, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 185, Validation Accuracy: 0.752, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 186, Validation Accuracy: 0.752, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 187, Validation Accuracy: 0.752, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 188, Validation Accuracy: 0.756, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 189, Validation Accuracy: 0.754, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 190, Validation Accuracy: 0.754, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 191, Validation Accuracy: 0.754, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 192, Validation Accuracy: 0.752, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 193, Validation Accuracy: 0.752, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 194, Validation Accuracy: 0.752, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 195, Validation Accuracy: 0.752, Training Accura

100%|██████████| 200/200 [00:03<00:00, 55.43it/s]

Epoch 198, Validation Accuracy: 0.754, Training Accuracy: 1.0
torch.Size([2708, 7])
Epoch 199, Validation Accuracy: 0.754, Training Accuracy: 1.0





In [26]:
get_accuracy(model, cora_loader, cora_graph.test_mask)

0.791

In [27]:
class GCNLayerDense(torch.nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        activation=torch.nn.functional.relu,
        skip_connection=False,
    ):
        super(GCNLayerDense, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.activation = activation
        self.skip_connection = skip_connection
        self.weight = torch.nn.Parameter(torch.FloatTensor(in_features, out_features))
        torch.nn.init.kaiming_normal_(self.weight)

    def forward(self, H: torch.Tensor, adj: torch.Tensor):
        h_x = torch.bmm(adj, torch.matmul(H, self.weight))
        if self.skip_connection:
            return self.activation(h_x + H)
        return self.activation(h_x)


class MeanPooling(torch.nn.Module):
    def __init__(self, dim: int | tuple[int, ...]):
        super(MeanPooling, self).__init__()
        self.dim = dim

    def forward(self, H: torch.Tensor):
        return H.mean(dim=self.dim)


class SumPooling(torch.nn.Module):
    def __init__(self, dim: int | tuple[int, ...]):
        super(SumPooling, self).__init__()
        self.dim = dim

    def forward(self, H: torch.Tensor):
        return H.sum(dim=self.dim)


class GraphGCNDense(torch.nn.Module):
    def __init__(
        self,
        num_layers: int,
        in_features: int,
        hidden_features: int,
        out_features: int,
        pooling: MeanPooling | SumPooling,
        activation=torch.nn.functional.relu,
        skip_connection: bool = False,
        # mlp_dropout_rate=0.1,
    ):
        super(GraphGCNDense, self).__init__()
        self.pooling = pooling
        self.activation = activation
        self.skip_connection = skip_connection
        self.layers = torch.nn.ModuleList(
            [
                GCNLayerDense(
                    in_features=in_features if i == 0 else hidden_features,
                    out_features=(
                        hidden_features if i < num_layers - 1 else out_features
                    ),
                    activation=activation,
                    skip_connection=skip_connection if (i != 0 and i != num_layers - 1) else False,
                )
                for i in range(num_layers)
            ]
        )
        # self.mlp = torch.nn.Sequential(
        #     torch.nn.Linear(hidden_features, hidden_features),
        #     torch.nn.ReLU(),
        #     torch.nn.Dropout(mlp_dropout_rate),
        #     torch.nn.Linear(hidden_features, out_features),
        # )

    def forward(self, H_in: torch.Tensor, adj: torch.Tensor):
        H = H_in
        for i in range(len(self.layers)):
            H = self.layers[i](H, adj)
        # H = self.pooling(H)
        # return self.mlp(H)
        return H


class GraphDataSetVectorized(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self._dataset = dataset
        self._largest_graph_size = int(dataset.get_summary().num_nodes.max)
        self.targets = self._dataset.data.y

    def __len__(self):
        return len(self._dataset)

    def __getitem__(self, idx):
        graph = self._dataset[idx]
        # print(graph)

        # adjacency matrix
        A = torch.zeros((self._largest_graph_size, self._largest_graph_size))
        # symmetric
        A[graph.edge_index[0], graph.edge_index[1]] = 1
        # self loop
        A = A + torch.eye(self._largest_graph_size)
        # Degree matrix
        D = torch.diag(torch.sum(A, axis=1))
        # Normalized adjacency matrix
        d_inv_sqrt = torch.pow(D, -0.5)
        d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.0
        A_normalized = d_inv_sqrt @ A @ d_inv_sqrt

        # node features
        H = torch.zeros((self._largest_graph_size, graph.x.shape[1]))
        H[: graph.x.shape[0]] = graph.x

        # target
        target = graph.y
        return A_normalized, H, target, graph.train_mask, graph.val_mask, graph.test_mask

    def num_features(self):
        return self._dataset.num_features

    def compute_class_weights(self):
        class_counts = np.unique(self.targets, return_counts=True)[-1]
        frequencies = class_counts / len(self.targets)
        weights = np.round(1 / frequencies, 2)
        return torch.FloatTensor(weights / weights.sum()).to(device)

In [28]:
# Load the dataset
cora_dataset_vectorized = GraphDataSetVectorized(cora)
# print(len(cora_dataset_vectorized))
cora_loader_vectorized = torch.utils.data.DataLoader(
    cora_dataset_vectorized, batch_size=1, shuffle=True
)
# for A, H, target, train_mask, val_mask, test_mask in cora_loader_vectorized:
#     print(A.shape, H.shape, target.shape)
#     print(train_mask, val_mask, test_mask)
#     break

  std=data.std().item(),


In [29]:
# Train the model using GraphGCNDense
model_dense = GraphGCNDense(
    num_layers=2,
    in_features=cora_dataset_vectorized.num_features(),
    hidden_features=32,
    out_features=cora_num_classes,
    pooling=MeanPooling(dim=1),
    activation=torch.nn.functional.relu,
    skip_connection=True,
).to(device)
# print(model_dense)
optimizer_dense = torch.optim.Adam(model_dense.parameters(), lr=0.01, weight_decay=5e-4)
loss_fn_dense = torch.nn.CrossEntropyLoss()
num_epochs = 200

for epoch in tqdm(range(num_epochs)):
    val_accs = []
    train_accs = []
    losses = []
    for adj, features, target, train_mask, val_mask, test_mask in cora_loader_vectorized:
        model_dense.train()
        adj, features, target = adj.to(device), features.to(device), target.to(device)
        train_mask, val_mask, test_mask = train_mask.to(device), val_mask.to(device), test_mask.to(device)
        model_dense.train()
        optimizer_dense.zero_grad()
        out = model_dense(features, adj)
         
        loss = loss_fn_dense(out[train_mask], target[train_mask])
        loss.backward()
        optimizer_dense.step()
        train_acc = out[train_mask].argmax(-1) == target[train_mask]
        train_accs.append(train_acc.float().sum() / train_mask.sum())
        val_acc = out[val_mask].argmax(-1) == target[val_mask]
        val_accs.append(val_acc.float().sum() / val_mask.sum())
        losses.append(loss.item())
    print(f"Epoch {epoch}, Validation Accuracy: {torch.mean(torch.tensor(val_accs))}, Training Accuracy: {torch.mean(torch.tensor(train_accs))}, Loss: {torch.mean(torch.tensor(losses))}")


  0%|          | 1/200 [00:00<01:32,  2.16it/s]

Epoch 0, Validation Accuracy: 0.20399999618530273, Training Accuracy: 0.2142857164144516, Loss: 2.1051907539367676


  1%|          | 2/200 [00:00<01:21,  2.43it/s]

Epoch 1, Validation Accuracy: 0.25999999046325684, Training Accuracy: 0.4000000059604645, Loss: 1.7094868421554565


  2%|▏         | 3/200 [00:01<01:18,  2.51it/s]

Epoch 2, Validation Accuracy: 0.3199999928474426, Training Accuracy: 0.5357142686843872, Loss: 1.4457322359085083


  2%|▏         | 4/200 [00:01<01:15,  2.60it/s]

Epoch 3, Validation Accuracy: 0.4020000100135803, Training Accuracy: 0.6571428775787354, Loss: 1.2366348505020142


  2%|▎         | 5/200 [00:01<01:15,  2.57it/s]

Epoch 4, Validation Accuracy: 0.46399998664855957, Training Accuracy: 0.7285714149475098, Loss: 1.054105281829834


  3%|▎         | 6/200 [00:02<01:13,  2.64it/s]

Epoch 5, Validation Accuracy: 0.5260000228881836, Training Accuracy: 0.7857142686843872, Loss: 0.8902105689048767


  4%|▎         | 7/200 [00:02<01:13,  2.61it/s]

Epoch 6, Validation Accuracy: 0.5580000281333923, Training Accuracy: 0.8357142806053162, Loss: 0.7465620040893555


  4%|▍         | 8/200 [00:03<01:13,  2.61it/s]

Epoch 7, Validation Accuracy: 0.6100000143051147, Training Accuracy: 0.8642857074737549, Loss: 0.6254264116287231


  4%|▍         | 9/200 [00:03<01:13,  2.59it/s]

Epoch 8, Validation Accuracy: 0.6439999938011169, Training Accuracy: 0.8714285492897034, Loss: 0.518104076385498


  5%|▌         | 10/200 [00:03<01:12,  2.61it/s]

Epoch 9, Validation Accuracy: 0.6579999923706055, Training Accuracy: 0.9071428775787354, Loss: 0.4279771149158478


  6%|▌         | 11/200 [00:04<01:11,  2.63it/s]

Epoch 10, Validation Accuracy: 0.6800000071525574, Training Accuracy: 0.9142857193946838, Loss: 0.3537556827068329


  6%|▌         | 12/200 [00:04<01:12,  2.58it/s]

Epoch 11, Validation Accuracy: 0.6940000057220459, Training Accuracy: 0.9357143044471741, Loss: 0.29500019550323486


  6%|▋         | 13/200 [00:05<01:12,  2.58it/s]

Epoch 12, Validation Accuracy: 0.7059999704360962, Training Accuracy: 0.9571428298950195, Loss: 0.24530982971191406


  7%|▋         | 14/200 [00:05<01:10,  2.64it/s]

Epoch 13, Validation Accuracy: 0.7260000109672546, Training Accuracy: 0.9642857313156128, Loss: 0.2018093466758728


  8%|▊         | 15/200 [00:05<01:11,  2.59it/s]

Epoch 14, Validation Accuracy: 0.7279999852180481, Training Accuracy: 0.9714285731315613, Loss: 0.16586828231811523


  8%|▊         | 16/200 [00:06<01:10,  2.61it/s]

Epoch 15, Validation Accuracy: 0.7379999756813049, Training Accuracy: 0.9714285731315613, Loss: 0.13916493952274323


  8%|▊         | 17/200 [00:06<01:08,  2.66it/s]

Epoch 16, Validation Accuracy: 0.7300000190734863, Training Accuracy: 0.9857142567634583, Loss: 0.1193307489156723


  9%|▉         | 18/200 [00:06<01:09,  2.60it/s]

Epoch 17, Validation Accuracy: 0.7319999933242798, Training Accuracy: 0.9785714149475098, Loss: 0.10271567851305008


 10%|▉         | 19/200 [00:07<01:10,  2.58it/s]

Epoch 18, Validation Accuracy: 0.7400000095367432, Training Accuracy: 0.9857142567634583, Loss: 0.08750458806753159


 10%|█         | 20/200 [00:07<01:10,  2.56it/s]

Epoch 19, Validation Accuracy: 0.75, Training Accuracy: 0.9928571581840515, Loss: 0.07377466559410095


 10%|█         | 21/200 [00:08<01:10,  2.55it/s]

Epoch 20, Validation Accuracy: 0.7580000162124634, Training Accuracy: 0.9928571581840515, Loss: 0.0625838190317154


 11%|█         | 22/200 [00:08<01:09,  2.56it/s]

Epoch 21, Validation Accuracy: 0.7620000243186951, Training Accuracy: 0.9928571581840515, Loss: 0.054152119904756546


 12%|█▏        | 23/200 [00:08<01:07,  2.63it/s]

Epoch 22, Validation Accuracy: 0.7580000162124634, Training Accuracy: 0.9928571581840515, Loss: 0.04772933945059776


 12%|█▏        | 24/200 [00:09<01:07,  2.61it/s]

Epoch 23, Validation Accuracy: 0.7559999823570251, Training Accuracy: 0.9928571581840515, Loss: 0.04270455241203308


 12%|█▎        | 25/200 [00:09<01:06,  2.62it/s]

Epoch 24, Validation Accuracy: 0.7620000243186951, Training Accuracy: 0.9928571581840515, Loss: 0.03862597420811653


 13%|█▎        | 26/200 [00:10<01:05,  2.66it/s]

Epoch 25, Validation Accuracy: 0.7599999904632568, Training Accuracy: 0.9928571581840515, Loss: 0.035225559026002884


 14%|█▎        | 27/200 [00:10<01:05,  2.65it/s]

Epoch 26, Validation Accuracy: 0.7599999904632568, Training Accuracy: 0.9928571581840515, Loss: 0.032518502324819565


 14%|█▍        | 28/200 [00:10<01:04,  2.67it/s]

Epoch 27, Validation Accuracy: 0.7639999985694885, Training Accuracy: 0.9928571581840515, Loss: 0.03031018376350403


 14%|█▍        | 29/200 [00:11<01:04,  2.64it/s]

Epoch 28, Validation Accuracy: 0.7620000243186951, Training Accuracy: 0.9928571581840515, Loss: 0.028512900695204735


 15%|█▌        | 30/200 [00:11<01:04,  2.64it/s]

Epoch 29, Validation Accuracy: 0.7620000243186951, Training Accuracy: 0.9928571581840515, Loss: 0.027101295068860054


 16%|█▌        | 31/200 [00:11<01:04,  2.64it/s]

Epoch 30, Validation Accuracy: 0.7599999904632568, Training Accuracy: 0.9928571581840515, Loss: 0.025918681174516678


 16%|█▌        | 32/200 [00:12<01:04,  2.62it/s]

Epoch 31, Validation Accuracy: 0.7620000243186951, Training Accuracy: 0.9928571581840515, Loss: 0.024785662069916725


 16%|█▋        | 33/200 [00:12<01:04,  2.60it/s]

Epoch 32, Validation Accuracy: 0.7580000162124634, Training Accuracy: 0.9928571581840515, Loss: 0.023718958720564842


 17%|█▋        | 34/200 [00:13<01:03,  2.60it/s]

Epoch 33, Validation Accuracy: 0.7559999823570251, Training Accuracy: 0.9928571581840515, Loss: 0.02275148034095764


 18%|█▊        | 35/200 [00:13<01:03,  2.61it/s]

Epoch 34, Validation Accuracy: 0.7580000162124634, Training Accuracy: 1.0, Loss: 0.02204231359064579


 18%|█▊        | 36/200 [00:13<01:03,  2.58it/s]

Epoch 35, Validation Accuracy: 0.7599999904632568, Training Accuracy: 1.0, Loss: 0.021552162244915962


 18%|█▊        | 37/200 [00:14<01:03,  2.57it/s]

Epoch 36, Validation Accuracy: 0.7639999985694885, Training Accuracy: 1.0, Loss: 0.021189959719777107


 19%|█▉        | 38/200 [00:14<01:02,  2.59it/s]

Epoch 37, Validation Accuracy: 0.765999972820282, Training Accuracy: 1.0, Loss: 0.020943578332662582


 20%|█▉        | 39/200 [00:15<01:01,  2.60it/s]

Epoch 38, Validation Accuracy: 0.7639999985694885, Training Accuracy: 1.0, Loss: 0.020786890760064125


 20%|██        | 40/200 [00:15<01:01,  2.59it/s]

Epoch 39, Validation Accuracy: 0.7639999985694885, Training Accuracy: 1.0, Loss: 0.02069295011460781


 20%|██        | 41/200 [00:15<01:01,  2.59it/s]

Epoch 40, Validation Accuracy: 0.7680000066757202, Training Accuracy: 1.0, Loss: 0.020635822787880898


 21%|██        | 42/200 [00:16<01:01,  2.58it/s]

Epoch 41, Validation Accuracy: 0.765999972820282, Training Accuracy: 1.0, Loss: 0.020597467198967934


 22%|██▏       | 43/200 [00:16<01:00,  2.58it/s]

Epoch 42, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.02056862786412239


 22%|██▏       | 44/200 [00:16<01:01,  2.54it/s]

Epoch 43, Validation Accuracy: 0.7680000066757202, Training Accuracy: 1.0, Loss: 0.020549491047859192


 22%|██▎       | 45/200 [00:17<01:01,  2.50it/s]

Epoch 44, Validation Accuracy: 0.7680000066757202, Training Accuracy: 1.0, Loss: 0.02054470032453537


 23%|██▎       | 46/200 [00:17<01:00,  2.53it/s]

Epoch 45, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.02056165412068367


 24%|██▎       | 47/200 [00:18<01:00,  2.52it/s]

Epoch 46, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.02060505747795105


 24%|██▍       | 48/200 [00:18<01:00,  2.51it/s]

Epoch 47, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.020673522725701332


 24%|██▍       | 49/200 [00:18<01:00,  2.51it/s]

Epoch 48, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.020762594416737556


 25%|██▌       | 50/200 [00:19<00:59,  2.52it/s]

Epoch 49, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.020865947008132935


 26%|██▌       | 51/200 [00:19<00:58,  2.53it/s]

Epoch 50, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.020979296416044235


 26%|██▌       | 52/200 [00:20<00:58,  2.55it/s]

Epoch 51, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.021101588383316994


 26%|██▋       | 53/200 [00:20<00:57,  2.55it/s]

Epoch 52, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.021232910454273224


 27%|██▋       | 54/200 [00:20<00:58,  2.51it/s]

Epoch 53, Validation Accuracy: 0.7699999809265137, Training Accuracy: 0.9928571581840515, Loss: 0.021376371383666992


 28%|██▊       | 55/200 [00:21<00:57,  2.50it/s]

Epoch 54, Validation Accuracy: 0.7720000147819519, Training Accuracy: 0.9928571581840515, Loss: 0.021533766761422157


 28%|██▊       | 56/200 [00:21<00:57,  2.51it/s]

Epoch 55, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.021688828244805336


 28%|██▊       | 57/200 [00:22<00:56,  2.54it/s]

Epoch 56, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.021857352927327156


 29%|██▉       | 58/200 [00:22<00:54,  2.59it/s]

Epoch 57, Validation Accuracy: 0.7739999890327454, Training Accuracy: 0.9928571581840515, Loss: 0.022051796317100525


 30%|██▉       | 59/200 [00:22<00:54,  2.59it/s]

Epoch 58, Validation Accuracy: 0.7760000228881836, Training Accuracy: 0.9928571581840515, Loss: 0.022220535203814507


 30%|███       | 60/200 [00:23<00:54,  2.58it/s]

Epoch 59, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.022379104048013687


 30%|███       | 61/200 [00:23<00:53,  2.58it/s]

Epoch 60, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.022547872737050056


 31%|███       | 62/200 [00:24<00:54,  2.53it/s]

Epoch 61, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.022709911689162254


 32%|███▏      | 63/200 [00:24<00:53,  2.56it/s]

Epoch 62, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.022864442318677902


 32%|███▏      | 64/200 [00:24<00:53,  2.54it/s]

Epoch 63, Validation Accuracy: 0.7760000228881836, Training Accuracy: 1.0, Loss: 0.023012546822428703


 32%|███▎      | 65/200 [00:25<00:52,  2.59it/s]

Epoch 64, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.023156071081757545


 33%|███▎      | 66/200 [00:25<00:51,  2.59it/s]

Epoch 65, Validation Accuracy: 0.7739999890327454, Training Accuracy: 0.9928571581840515, Loss: 0.023304400965571404


 34%|███▎      | 67/200 [00:25<00:50,  2.66it/s]

Epoch 66, Validation Accuracy: 0.7760000228881836, Training Accuracy: 0.9928571581840515, Loss: 0.023432500660419464


 34%|███▍      | 68/200 [00:26<00:50,  2.63it/s]

Epoch 67, Validation Accuracy: 0.7760000228881836, Training Accuracy: 0.9928571581840515, Loss: 0.023626986891031265


 34%|███▍      | 69/200 [00:26<00:48,  2.68it/s]

Epoch 68, Validation Accuracy: 0.777999997138977, Training Accuracy: 0.9928571581840515, Loss: 0.02374320849776268


 35%|███▌      | 70/200 [00:27<00:48,  2.69it/s]

Epoch 69, Validation Accuracy: 0.777999997138977, Training Accuracy: 1.0, Loss: 0.023802829906344414


 36%|███▌      | 71/200 [00:27<00:49,  2.62it/s]

Epoch 70, Validation Accuracy: 0.777999997138977, Training Accuracy: 1.0, Loss: 0.023910874500870705


 36%|███▌      | 72/200 [00:27<00:48,  2.64it/s]

Epoch 71, Validation Accuracy: 0.777999997138977, Training Accuracy: 1.0, Loss: 0.02400895580649376


 36%|███▋      | 73/200 [00:28<00:47,  2.69it/s]

Epoch 72, Validation Accuracy: 0.777999997138977, Training Accuracy: 1.0, Loss: 0.024098562076687813


 37%|███▋      | 74/200 [00:28<00:56,  2.22it/s]

Epoch 73, Validation Accuracy: 0.7760000228881836, Training Accuracy: 1.0, Loss: 0.024179812520742416


 38%|███▊      | 75/200 [00:29<01:00,  2.06it/s]

Epoch 74, Validation Accuracy: 0.7760000228881836, Training Accuracy: 0.9928571581840515, Loss: 0.024280328303575516


 38%|███▊      | 76/200 [00:29<01:03,  1.96it/s]

Epoch 75, Validation Accuracy: 0.7760000228881836, Training Accuracy: 0.9928571581840515, Loss: 0.024325242266058922


 38%|███▊      | 77/200 [00:30<01:04,  1.92it/s]

Epoch 76, Validation Accuracy: 0.777999997138977, Training Accuracy: 0.9928571581840515, Loss: 0.024401677772402763


 39%|███▉      | 78/200 [00:30<00:58,  2.09it/s]

Epoch 77, Validation Accuracy: 0.777999997138977, Training Accuracy: 0.9928571581840515, Loss: 0.02442311868071556


 40%|███▉      | 79/200 [00:31<00:54,  2.21it/s]

Epoch 78, Validation Accuracy: 0.777999997138977, Training Accuracy: 0.9928571581840515, Loss: 0.02445870451629162


 40%|████      | 80/200 [00:31<00:51,  2.33it/s]

Epoch 79, Validation Accuracy: 0.777999997138977, Training Accuracy: 0.9928571581840515, Loss: 0.024512987583875656


 40%|████      | 81/200 [00:32<00:48,  2.43it/s]

Epoch 80, Validation Accuracy: 0.7760000228881836, Training Accuracy: 0.9928571581840515, Loss: 0.024527421221137047


 41%|████      | 82/200 [00:32<00:48,  2.45it/s]

Epoch 81, Validation Accuracy: 0.777999997138977, Training Accuracy: 1.0, Loss: 0.024510353803634644


 42%|████▏     | 83/200 [00:32<00:47,  2.46it/s]

Epoch 82, Validation Accuracy: 0.777999997138977, Training Accuracy: 1.0, Loss: 0.024520156905055046


 42%|████▏     | 84/200 [00:33<00:46,  2.47it/s]

Epoch 83, Validation Accuracy: 0.777999997138977, Training Accuracy: 1.0, Loss: 0.024523945525288582


 42%|████▎     | 85/200 [00:33<00:45,  2.50it/s]

Epoch 84, Validation Accuracy: 0.777999997138977, Training Accuracy: 1.0, Loss: 0.02452089451253414


 43%|████▎     | 86/200 [00:33<00:44,  2.57it/s]

Epoch 85, Validation Accuracy: 0.777999997138977, Training Accuracy: 1.0, Loss: 0.024509519338607788


 44%|████▎     | 87/200 [00:34<00:44,  2.53it/s]

Epoch 86, Validation Accuracy: 0.777999997138977, Training Accuracy: 1.0, Loss: 0.024489708244800568


 44%|████▍     | 88/200 [00:34<00:43,  2.55it/s]

Epoch 87, Validation Accuracy: 0.777999997138977, Training Accuracy: 1.0, Loss: 0.02446378953754902


 44%|████▍     | 89/200 [00:35<00:43,  2.55it/s]

Epoch 88, Validation Accuracy: 0.777999997138977, Training Accuracy: 0.9928571581840515, Loss: 0.024443579837679863


 45%|████▌     | 90/200 [00:35<00:43,  2.53it/s]

Epoch 89, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.024164043366909027


 46%|████▌     | 91/200 [00:35<00:43,  2.53it/s]

Epoch 90, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.020482217893004417


 46%|████▌     | 92/200 [00:36<00:42,  2.54it/s]

Epoch 91, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.0160346869379282


 46%|████▋     | 93/200 [00:36<00:42,  2.54it/s]

Epoch 92, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.013307441957294941


 47%|████▋     | 94/200 [00:37<00:41,  2.56it/s]

Epoch 93, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.012348471209406853


 48%|████▊     | 95/200 [00:37<00:42,  2.45it/s]

Epoch 94, Validation Accuracy: 0.765999972820282, Training Accuracy: 1.0, Loss: 0.011944808065891266


 48%|████▊     | 96/200 [00:37<00:40,  2.54it/s]

Epoch 95, Validation Accuracy: 0.765999972820282, Training Accuracy: 1.0, Loss: 0.01151229813694954


 48%|████▊     | 97/200 [00:38<00:40,  2.56it/s]

Epoch 96, Validation Accuracy: 0.7639999985694885, Training Accuracy: 1.0, Loss: 0.011040585115551949


 49%|████▉     | 98/200 [00:38<00:38,  2.62it/s]

Epoch 97, Validation Accuracy: 0.7680000066757202, Training Accuracy: 1.0, Loss: 0.010684682987630367


 50%|████▉     | 99/200 [00:39<00:38,  2.62it/s]

Epoch 98, Validation Accuracy: 0.7680000066757202, Training Accuracy: 1.0, Loss: 0.010534672997891903


 50%|█████     | 100/200 [00:39<00:38,  2.57it/s]

Epoch 99, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.010568101890385151


 50%|█████     | 101/200 [00:39<00:39,  2.49it/s]

Epoch 100, Validation Accuracy: 0.7680000066757202, Training Accuracy: 1.0, Loss: 0.010647383518517017


 51%|█████     | 102/200 [00:40<00:38,  2.54it/s]

Epoch 101, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.010604221373796463


 52%|█████▏    | 103/200 [00:40<00:39,  2.49it/s]

Epoch 102, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.010391486808657646


 52%|█████▏    | 104/200 [00:41<00:38,  2.47it/s]

Epoch 103, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.010115833953022957


 52%|█████▎    | 105/200 [00:41<00:37,  2.52it/s]

Epoch 104, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.009904399514198303


 53%|█████▎    | 106/200 [00:41<00:37,  2.54it/s]

Epoch 105, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.00980384647846222


 54%|█████▎    | 107/200 [00:42<00:36,  2.57it/s]

Epoch 106, Validation Accuracy: 0.765999972820282, Training Accuracy: 1.0, Loss: 0.009784781374037266


 54%|█████▍    | 108/200 [00:42<00:35,  2.58it/s]

Epoch 107, Validation Accuracy: 0.765999972820282, Training Accuracy: 1.0, Loss: 0.009785991162061691


 55%|█████▍    | 109/200 [00:43<00:35,  2.58it/s]

Epoch 108, Validation Accuracy: 0.7639999985694885, Training Accuracy: 1.0, Loss: 0.009748967364430428


 55%|█████▌    | 110/200 [00:43<00:34,  2.59it/s]

Epoch 109, Validation Accuracy: 0.7639999985694885, Training Accuracy: 1.0, Loss: 0.009652973152697086


 56%|█████▌    | 111/200 [00:43<00:33,  2.63it/s]

Epoch 110, Validation Accuracy: 0.7639999985694885, Training Accuracy: 1.0, Loss: 0.009518003091216087


 56%|█████▌    | 112/200 [00:44<00:33,  2.60it/s]

Epoch 111, Validation Accuracy: 0.7680000066757202, Training Accuracy: 1.0, Loss: 0.00939126405864954


 56%|█████▋    | 113/200 [00:44<00:33,  2.63it/s]

Epoch 112, Validation Accuracy: 0.7680000066757202, Training Accuracy: 1.0, Loss: 0.009309138171374798


 57%|█████▋    | 114/200 [00:44<00:33,  2.59it/s]

Epoch 113, Validation Accuracy: 0.7680000066757202, Training Accuracy: 1.0, Loss: 0.009278771467506886


 57%|█████▊    | 115/200 [00:45<00:33,  2.56it/s]

Epoch 114, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.009274926967918873


 58%|█████▊    | 116/200 [00:45<00:32,  2.58it/s]

Epoch 115, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.00926296878606081


 58%|█████▊    | 117/200 [00:46<00:32,  2.59it/s]

Epoch 116, Validation Accuracy: 0.7760000228881836, Training Accuracy: 1.0, Loss: 0.00922226719558239


 59%|█████▉    | 118/200 [00:46<00:31,  2.64it/s]

Epoch 117, Validation Accuracy: 0.777999997138977, Training Accuracy: 1.0, Loss: 0.009160105139017105


 60%|█████▉    | 119/200 [00:46<00:30,  2.68it/s]

Epoch 118, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.009102527983486652


 60%|██████    | 120/200 [00:47<00:29,  2.71it/s]

Epoch 119, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.009068422019481659


 60%|██████    | 121/200 [00:47<00:29,  2.63it/s]

Epoch 120, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.009058346040546894


 61%|██████    | 122/200 [00:47<00:29,  2.68it/s]

Epoch 121, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.009059065021574497


 62%|██████▏   | 123/200 [00:48<00:28,  2.70it/s]

Epoch 122, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.009054486639797688


 62%|██████▏   | 124/200 [00:48<00:28,  2.69it/s]

Epoch 123, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.00903492234647274


 62%|██████▎   | 125/200 [00:49<00:28,  2.63it/s]

Epoch 124, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.009004554711282253


 63%|██████▎   | 126/200 [00:49<00:28,  2.59it/s]

Epoch 125, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.008974337950348854


 64%|██████▎   | 127/200 [00:49<00:28,  2.60it/s]

Epoch 126, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.00895355362445116


 64%|██████▍   | 128/200 [00:50<00:27,  2.59it/s]

Epoch 127, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.008947115391492844


 64%|██████▍   | 129/200 [00:50<00:27,  2.55it/s]

Epoch 128, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.008950164541602135


 65%|██████▌   | 130/200 [00:51<00:27,  2.55it/s]

Epoch 129, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.008953816257417202


 66%|██████▌   | 131/200 [00:51<00:27,  2.54it/s]

Epoch 130, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.0089521873742342


 66%|██████▌   | 132/200 [00:51<00:26,  2.53it/s]

Epoch 131, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.008945927023887634


 66%|██████▋   | 133/200 [00:52<00:26,  2.58it/s]

Epoch 132, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.00894039124250412


 67%|██████▋   | 134/200 [00:52<00:25,  2.56it/s]

Epoch 133, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.008939463645219803


 68%|██████▊   | 135/200 [00:53<00:25,  2.55it/s]

Epoch 134, Validation Accuracy: 0.7760000228881836, Training Accuracy: 1.0, Loss: 0.008944183588027954


 68%|██████▊   | 136/200 [00:53<00:25,  2.55it/s]

Epoch 135, Validation Accuracy: 0.7760000228881836, Training Accuracy: 1.0, Loss: 0.008951236493885517


 68%|██████▊   | 137/200 [00:53<00:24,  2.53it/s]

Epoch 136, Validation Accuracy: 0.7760000228881836, Training Accuracy: 1.0, Loss: 0.008956389501690865


 69%|██████▉   | 138/200 [00:54<00:24,  2.52it/s]

Epoch 137, Validation Accuracy: 0.7760000228881836, Training Accuracy: 1.0, Loss: 0.008957399055361748


 70%|██████▉   | 139/200 [00:54<00:23,  2.56it/s]

Epoch 138, Validation Accuracy: 0.7760000228881836, Training Accuracy: 1.0, Loss: 0.008955050259828568


 70%|███████   | 140/200 [00:54<00:23,  2.54it/s]

Epoch 139, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.00895118247717619


 70%|███████   | 141/200 [00:55<00:23,  2.54it/s]

Epoch 140, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.008949306793510914


 71%|███████   | 142/200 [00:55<00:23,  2.51it/s]

Epoch 141, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.008948156610131264


 72%|███████▏  | 143/200 [00:56<00:23,  2.44it/s]

Epoch 142, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.008946805261075497


 72%|███████▏  | 144/200 [00:56<00:23,  2.43it/s]

Epoch 143, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.00894306879490614


 72%|███████▎  | 145/200 [00:57<00:21,  2.54it/s]

Epoch 144, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.008937053382396698


 73%|███████▎  | 146/200 [00:57<00:20,  2.58it/s]

Epoch 145, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.008929322473704815


 74%|███████▎  | 147/200 [00:57<00:20,  2.56it/s]

Epoch 146, Validation Accuracy: 0.7739999890327454, Training Accuracy: 1.0, Loss: 0.008920423686504364


 74%|███████▍  | 148/200 [00:58<00:20,  2.59it/s]

Epoch 147, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.00891073327511549


 74%|███████▍  | 149/200 [00:58<00:19,  2.64it/s]

Epoch 148, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.008899608626961708


 75%|███████▌  | 150/200 [00:58<00:19,  2.61it/s]

Epoch 149, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.008886762894690037


 76%|███████▌  | 151/200 [00:59<00:18,  2.62it/s]

Epoch 150, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.008871263824403286


 76%|███████▌  | 152/200 [00:59<00:18,  2.64it/s]

Epoch 151, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.008854182437062263


 76%|███████▋  | 153/200 [01:00<00:17,  2.62it/s]

Epoch 152, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.008835642598569393


 77%|███████▋  | 154/200 [01:00<00:17,  2.59it/s]

Epoch 153, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.008817055262625217


 78%|███████▊  | 155/200 [01:00<00:18,  2.37it/s]

Epoch 154, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.00879843533039093


 78%|███████▊  | 156/200 [01:01<00:20,  2.16it/s]

Epoch 155, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.008779521100223064


 78%|███████▊  | 157/200 [01:02<00:20,  2.07it/s]

Epoch 156, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.008760428056120872


 79%|███████▉  | 158/200 [01:02<00:21,  1.95it/s]

Epoch 157, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.008741130121052265


 80%|███████▉  | 159/200 [01:03<00:21,  1.93it/s]

Epoch 158, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.008722187019884586


 80%|████████  | 160/200 [01:03<00:19,  2.10it/s]

Epoch 159, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.008704054169356823


 80%|████████  | 161/200 [01:03<00:17,  2.21it/s]

Epoch 160, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.008685591630637646


 81%|████████  | 162/200 [01:04<00:16,  2.33it/s]

Epoch 161, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.00866659451276064


 82%|████████▏ | 163/200 [01:04<00:15,  2.39it/s]

Epoch 162, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.008647526614367962


 82%|████████▏ | 164/200 [01:05<00:14,  2.46it/s]

Epoch 163, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.008628077805042267


 82%|████████▎ | 165/200 [01:05<00:13,  2.53it/s]

Epoch 164, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.008608829230070114


 83%|████████▎ | 166/200 [01:05<00:13,  2.57it/s]

Epoch 165, Validation Accuracy: 0.7720000147819519, Training Accuracy: 1.0, Loss: 0.008589590899646282


 84%|████████▎ | 167/200 [01:06<00:12,  2.62it/s]

Epoch 166, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.00857075396925211


 84%|████████▍ | 168/200 [01:06<00:11,  2.67it/s]

Epoch 167, Validation Accuracy: 0.7699999809265137, Training Accuracy: 1.0, Loss: 0.008552981540560722


 84%|████████▍ | 169/200 [01:06<00:11,  2.66it/s]

Epoch 168, Validation Accuracy: 0.7680000066757202, Training Accuracy: 1.0, Loss: 0.008535518310964108


 85%|████████▌ | 170/200 [01:07<00:11,  2.61it/s]

Epoch 169, Validation Accuracy: 0.7680000066757202, Training Accuracy: 1.0, Loss: 0.008519116789102554


 86%|████████▌ | 171/200 [01:07<00:11,  2.57it/s]

Epoch 170, Validation Accuracy: 0.7680000066757202, Training Accuracy: 1.0, Loss: 0.008503436110913754


 86%|████████▌ | 172/200 [01:08<00:10,  2.57it/s]

Epoch 171, Validation Accuracy: 0.7680000066757202, Training Accuracy: 1.0, Loss: 0.008488968014717102


 86%|████████▋ | 173/200 [01:08<00:10,  2.52it/s]

Epoch 172, Validation Accuracy: 0.7680000066757202, Training Accuracy: 1.0, Loss: 0.00847447570413351


 87%|████████▋ | 174/200 [01:08<00:10,  2.52it/s]

Epoch 173, Validation Accuracy: 0.7680000066757202, Training Accuracy: 1.0, Loss: 0.00846051424741745


 88%|████████▊ | 175/200 [01:09<00:09,  2.53it/s]

Epoch 174, Validation Accuracy: 0.7639999985694885, Training Accuracy: 1.0, Loss: 0.008446944877505302


 88%|████████▊ | 176/200 [01:09<00:09,  2.54it/s]

Epoch 175, Validation Accuracy: 0.7639999985694885, Training Accuracy: 1.0, Loss: 0.008433408103883266


 88%|████████▊ | 177/200 [01:10<00:09,  2.50it/s]

Epoch 176, Validation Accuracy: 0.7639999985694885, Training Accuracy: 1.0, Loss: 0.008420350030064583


 89%|████████▉ | 178/200 [01:10<00:08,  2.52it/s]

Epoch 177, Validation Accuracy: 0.7639999985694885, Training Accuracy: 1.0, Loss: 0.00840750802308321


 90%|████████▉ | 179/200 [01:10<00:08,  2.50it/s]

Epoch 178, Validation Accuracy: 0.7639999985694885, Training Accuracy: 1.0, Loss: 0.008394409902393818


 90%|█████████ | 180/200 [01:11<00:07,  2.52it/s]

Epoch 179, Validation Accuracy: 0.7639999985694885, Training Accuracy: 1.0, Loss: 0.008381584659218788


 90%|█████████ | 181/200 [01:11<00:07,  2.53it/s]

Epoch 180, Validation Accuracy: 0.7639999985694885, Training Accuracy: 1.0, Loss: 0.008369273506104946


 91%|█████████ | 182/200 [01:12<00:07,  2.55it/s]

Epoch 181, Validation Accuracy: 0.7639999985694885, Training Accuracy: 1.0, Loss: 0.008357871323823929


 92%|█████████▏| 183/200 [01:12<00:06,  2.53it/s]

Epoch 182, Validation Accuracy: 0.7639999985694885, Training Accuracy: 1.0, Loss: 0.008347363211214542


 92%|█████████▏| 184/200 [01:12<00:06,  2.54it/s]

Epoch 183, Validation Accuracy: 0.7639999985694885, Training Accuracy: 1.0, Loss: 0.008336829021573067


 92%|█████████▎| 185/200 [01:13<00:06,  2.50it/s]

Epoch 184, Validation Accuracy: 0.7639999985694885, Training Accuracy: 1.0, Loss: 0.00832640752196312


 93%|█████████▎| 186/200 [01:13<00:05,  2.51it/s]

Epoch 185, Validation Accuracy: 0.7639999985694885, Training Accuracy: 1.0, Loss: 0.00831651221960783


 94%|█████████▎| 187/200 [01:14<00:05,  2.53it/s]

Epoch 186, Validation Accuracy: 0.7639999985694885, Training Accuracy: 1.0, Loss: 0.00830627977848053


 94%|█████████▍| 188/200 [01:14<00:04,  2.59it/s]

Epoch 187, Validation Accuracy: 0.765999972820282, Training Accuracy: 1.0, Loss: 0.00829597283154726


 94%|█████████▍| 189/200 [01:14<00:04,  2.64it/s]

Epoch 188, Validation Accuracy: 0.765999972820282, Training Accuracy: 1.0, Loss: 0.008286173455417156


 95%|█████████▌| 190/200 [01:15<00:03,  2.60it/s]

Epoch 189, Validation Accuracy: 0.765999972820282, Training Accuracy: 1.0, Loss: 0.008275505155324936


 96%|█████████▌| 191/200 [01:15<00:03,  2.60it/s]

Epoch 190, Validation Accuracy: 0.765999972820282, Training Accuracy: 1.0, Loss: 0.008264695294201374


 96%|█████████▌| 192/200 [01:15<00:03,  2.65it/s]

Epoch 191, Validation Accuracy: 0.765999972820282, Training Accuracy: 1.0, Loss: 0.008254830725491047


 96%|█████████▋| 193/200 [01:16<00:02,  2.63it/s]

Epoch 192, Validation Accuracy: 0.765999972820282, Training Accuracy: 1.0, Loss: 0.00824509747326374


 97%|█████████▋| 194/200 [01:16<00:02,  2.66it/s]

Epoch 193, Validation Accuracy: 0.765999972820282, Training Accuracy: 1.0, Loss: 0.008235638029873371


 98%|█████████▊| 195/200 [01:17<00:01,  2.66it/s]

Epoch 194, Validation Accuracy: 0.7680000066757202, Training Accuracy: 1.0, Loss: 0.00822632946074009


 98%|█████████▊| 196/200 [01:17<00:01,  2.55it/s]

Epoch 195, Validation Accuracy: 0.7680000066757202, Training Accuracy: 1.0, Loss: 0.008217033930122852


 98%|█████████▊| 197/200 [01:17<00:01,  2.54it/s]

Epoch 196, Validation Accuracy: 0.7680000066757202, Training Accuracy: 1.0, Loss: 0.008207477629184723


 99%|█████████▉| 198/200 [01:18<00:00,  2.57it/s]

Epoch 197, Validation Accuracy: 0.7680000066757202, Training Accuracy: 1.0, Loss: 0.008197390474379063


100%|█████████▉| 199/200 [01:18<00:00,  2.56it/s]

Epoch 198, Validation Accuracy: 0.765999972820282, Training Accuracy: 1.0, Loss: 0.008186987601220608


100%|██████████| 200/200 [01:19<00:00,  2.53it/s]

Epoch 199, Validation Accuracy: 0.765999972820282, Training Accuracy: 1.0, Loss: 0.008177277632057667





In [30]:
# Evaluate the model
model_dense.eval()
with torch.no_grad():
    accs = []
    for adj, features, target, train_mask, val_mask, test_mask in cora_loader_vectorized:
        adj, features, target = adj.to(device), features.to(device), target.to(device)
        out = model_dense(features, adj)
        correct = (out[test_mask].argmax(-1) == target[test_mask]).sum()
        accs.append(int(correct) / int(test_mask.sum()))
    print(f"Test Accuracy: {np.mean(accs)}")

Test Accuracy: 0.797


In [31]:
# from torch_geometric.nn import GCNConv
# import torch.nn.functional as F

# class GCN(torch.nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.conv1 = GCNConv(1433, 16) # shape（输入的节点特征维度 * 中间隐藏层的维度）
#         self.conv2 = GCNConv(16, 7) # shape（中间隐藏层的维度 * 节点类别）

#     def forward(self, data):
#         x, edge_index = data.x, data.edge_index # 加载节点特征和邻接关系

#         x = self.conv1(x, edge_index) # 传入卷积层
#         x = F.relu(x) # 激活函数
#         x = F.dropout(x, training=self.training) # dropout层，防止过拟合
#         x = self.conv2(x, edge_index) # 第二层卷积层

#         # 将经过两层卷积得到的特征输入log_softmax函数得到概率分布
#         return F.log_softmax(x, dim=1)

In [32]:
# m2 = GCN().to(device)
# m2

In [33]:
# optimizer = torch.optim.Adam(m2.parameters(), lr=0.01, weight_decay=5e-4)
# num_epoch = 200

# m2.train()
# for epoch in range(50):
#     optimizer.zero_grad() # 梯度清空为0
#     out = m2(cora_graph) # 模型输出
#     print(out.shape)
#     loss = F.nll_loss(out[cora_graph.train_mask], cora_graph.y[cora_graph.train_mask]) # 计算损失函数
#     correct = (out.argmax(dim=1)[cora_graph.train_mask] == cora_graph.y[cora_graph.train_mask]).sum() # 计算正确的个数
#     acc = int(correct) / int(cora_graph.train_mask.sum()) # 得出准确率
#     loss.backward() # 反向传播计算梯度
#     optimizer.step() # 利用梯度更新模型参数
#     if (epoch+1) % 10 == 0:
#         print('Epoch: {}, Loss: {:.4f}, Training Acc: {:.4f}'.format(epoch+1, loss.item(), acc))



## ZINC

In [34]:
# Load the dataset
dataset = pyg.datasets.ZINC(root='dataset/ZINC', split='train', subset=True)
dataset_val = pyg.datasets.ZINC(root='dataset/ZINC', split='val', subset=True)
dataset_test = pyg.datasets.ZINC(root='dataset/ZINC', split='test', subset=True)
print(dataset[0])
# Create data loaders
batch_size=128
num_workers = 8
train_loader = pyg.loader.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = pyg.loader.DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = pyg.loader.DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

Downloading https://www.dropbox.com/s/feo9qle74kg48gy/molecules.zip?dl=1
Extracting dataset/ZINC/molecules.zip
Downloading https://raw.githubusercontent.com/graphdeeplearning/benchmarking-gnns/master/data/molecules/train.index
Downloading https://raw.githubusercontent.com/graphdeeplearning/benchmarking-gnns/master/data/molecules/val.index
Downloading https://raw.githubusercontent.com/graphdeeplearning/benchmarking-gnns/master/data/molecules/test.index
Processing...
  return torch.load(io.BytesIO(b))
Processing train dataset: 100%|██████████| 10000/10000 [00:00<00:00, 11413.43it/s]
Processing val dataset: 100%|██████████| 1000/1000 [00:00<00:00, 3653.44it/s]
Processing test dataset: 100%|██████████| 1000/1000 [00:00<00:00, 6488.90it/s]


Data(x=[29, 1], edge_index=[2, 64], edge_attr=[64], y=[1])


Done!


In [35]:
for data in train_loader:
    print(data)
    break

DataBatch(x=[2946, 1], edge_index=[2, 6342], edge_attr=[6342], y=[128], batch=[2946], ptr=[129])


In [45]:
dataset.x.shape

torch.Size([231664, 1])

In [48]:
torch.max(dataset[:].x, dim=0)[0]

tensor([20])

In [52]:
torch.unique(dataset[:].x, dim=0).size(0)

21

In [118]:
max_num_atoms = torch.unique(dataset[:].x, dim=0).size(0)
atom_encoder = AtomEncoder(emb_dim=max_num_atoms)
one_hot_encoder = lambda x: torch.cat([torch.nn.functional.one_hot(x[:, 0].long(), num_classes=max_num_atoms).float(), x[:, 1:]], dim=-1)

In [131]:
# your implementation goes here
def train_zinc(model, train_loader, val_loader, epochs=200, lr=0.01, weight_decay=5e-2, device=device):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = torch.nn.L1Loss()

    train_loss = 0.0
    for epoch in tqdm(range(epochs)):
        train_loss = 0.0
        val_loss = 0.0
        for data in train_loader:
            model.train()
            data = data.to(device)
            optimizer.zero_grad()
            out = model(data.x, data.edge_index, data.batch)
            loss = loss_fn(out.squeeze(), data.y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)

        model.eval()
        with torch.no_grad():
            for data in val_loader:
                data = data.to(device)
                out = model(data.x, data.edge_index, data.batch)
                loss = loss_fn(out, data.y)
                val_loss += loss.item()
        val_loss /= len(val_loader)            
        print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")


In [132]:
zinc_model_one_hot = GraphNet(
    in_features=max_num_atoms,
    out_features=1,
    hidden_features=64,
    num_layers=2,
    dropout=0.1,
    reduction="mean",
    pooling="mean",
    task="regression",
    feature_encoder=one_hot_encoder,
).to(device)

In [133]:
zinc_model_one_hot

GraphNet(
  (layers): ModuleList(
    (0): GCNLayer(
      (U): Sequential(
        (0): Linear(in_features=21, out_features=64, bias=False)
        (1): ReLU()
        (2): Linear(in_features=64, out_features=64, bias=False)
      )
    )
    (1): GCNLayer(
      (U): Sequential(
        (0): Linear(in_features=64, out_features=64, bias=False)
        (1): ReLU()
        (2): Linear(in_features=64, out_features=64, bias=False)
      )
    )
  )
  (bns): ModuleList(
    (0-1): 2 x BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (mlp): Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=64, out_features=1, bias=True)
  )
)

In [134]:
train_zinc(zinc_model_one_hot, train_loader, val_loader, epochs=200, lr=0.001, weight_decay=5e-2, device=device)

  0%|          | 1/200 [00:02<08:10,  2.47s/it]

Epoch 1/200, Training Loss: 1.0287, Validation Loss: 1.6672


  1%|          | 2/200 [00:04<07:28,  2.26s/it]

Epoch 2/200, Training Loss: 0.6967, Validation Loss: 1.7971


  2%|▏         | 3/200 [00:06<07:15,  2.21s/it]

Epoch 3/200, Training Loss: 0.6806, Validation Loss: 1.9582


  2%|▏         | 4/200 [00:08<07:10,  2.19s/it]

Epoch 4/200, Training Loss: 0.6774, Validation Loss: 1.8252


  2%|▎         | 5/200 [00:11<07:04,  2.18s/it]

Epoch 5/200, Training Loss: 0.6756, Validation Loss: 1.9174


  3%|▎         | 6/200 [00:13<07:00,  2.17s/it]

Epoch 6/200, Training Loss: 0.6796, Validation Loss: 1.8628


  4%|▎         | 7/200 [00:15<06:55,  2.15s/it]

Epoch 7/200, Training Loss: 0.6769, Validation Loss: 1.8561


  4%|▍         | 8/200 [00:17<06:50,  2.14s/it]

Epoch 8/200, Training Loss: 0.6776, Validation Loss: 1.8465


  4%|▍         | 9/200 [00:19<06:47,  2.13s/it]

Epoch 9/200, Training Loss: 0.6798, Validation Loss: 2.0287


  5%|▌         | 10/200 [00:21<06:49,  2.16s/it]

Epoch 10/200, Training Loss: 0.6929, Validation Loss: 1.8145


  6%|▌         | 11/200 [00:23<06:43,  2.14s/it]

Epoch 11/200, Training Loss: 0.6828, Validation Loss: 1.7505


  6%|▌         | 12/200 [00:25<06:41,  2.14s/it]

Epoch 12/200, Training Loss: 0.6734, Validation Loss: 2.0462


  6%|▋         | 13/200 [00:28<06:39,  2.14s/it]

Epoch 13/200, Training Loss: 0.6760, Validation Loss: 1.8619


  7%|▋         | 14/200 [00:30<06:35,  2.13s/it]

Epoch 14/200, Training Loss: 0.6720, Validation Loss: 1.8917


  8%|▊         | 15/200 [00:32<06:38,  2.15s/it]

Epoch 15/200, Training Loss: 0.6725, Validation Loss: 1.8524


  8%|▊         | 16/200 [00:34<06:49,  2.22s/it]

Epoch 16/200, Training Loss: 0.6682, Validation Loss: 1.9069


  8%|▊         | 17/200 [00:36<06:42,  2.20s/it]

Epoch 17/200, Training Loss: 0.6830, Validation Loss: 1.8494


  9%|▉         | 18/200 [00:39<06:36,  2.18s/it]

Epoch 18/200, Training Loss: 0.6742, Validation Loss: 1.8403


 10%|▉         | 19/200 [00:41<06:33,  2.18s/it]

Epoch 19/200, Training Loss: 0.6733, Validation Loss: 1.8324


 10%|█         | 20/200 [00:43<06:28,  2.16s/it]

Epoch 20/200, Training Loss: 0.6697, Validation Loss: 1.8731


 10%|█         | 21/200 [00:45<06:25,  2.15s/it]

Epoch 21/200, Training Loss: 0.6614, Validation Loss: 1.8882


 11%|█         | 22/200 [00:47<06:23,  2.16s/it]

Epoch 22/200, Training Loss: 0.6664, Validation Loss: 1.9811


 12%|█▏        | 23/200 [00:49<06:18,  2.14s/it]

Epoch 23/200, Training Loss: 0.6804, Validation Loss: 1.9068


 12%|█▏        | 24/200 [00:51<06:17,  2.15s/it]

Epoch 24/200, Training Loss: 0.6771, Validation Loss: 1.8611


 12%|█▎        | 25/200 [00:54<06:14,  2.14s/it]

Epoch 25/200, Training Loss: 0.6698, Validation Loss: 1.9135


 13%|█▎        | 26/200 [00:56<06:11,  2.14s/it]

Epoch 26/200, Training Loss: 0.6748, Validation Loss: 1.9742


 14%|█▎        | 27/200 [00:58<06:09,  2.14s/it]

Epoch 27/200, Training Loss: 0.6687, Validation Loss: 1.8946


 14%|█▍        | 28/200 [01:00<06:09,  2.15s/it]

Epoch 28/200, Training Loss: 0.6780, Validation Loss: 1.8656


 14%|█▍        | 29/200 [01:02<06:04,  2.13s/it]

Epoch 29/200, Training Loss: 0.6865, Validation Loss: 1.9862


 15%|█▌        | 30/200 [01:04<06:06,  2.16s/it]

Epoch 30/200, Training Loss: 0.6747, Validation Loss: 1.9869


 16%|█▌        | 31/200 [01:07<06:12,  2.21s/it]

Epoch 31/200, Training Loss: 0.6780, Validation Loss: 1.8072


 16%|█▌        | 32/200 [01:09<06:05,  2.18s/it]

Epoch 32/200, Training Loss: 0.6743, Validation Loss: 1.8559


 16%|█▋        | 33/200 [01:11<06:01,  2.16s/it]

Epoch 33/200, Training Loss: 0.6767, Validation Loss: 1.8430


 17%|█▋        | 34/200 [01:13<05:55,  2.14s/it]

Epoch 34/200, Training Loss: 0.6713, Validation Loss: 1.8379


 18%|█▊        | 35/200 [01:15<05:52,  2.13s/it]

Epoch 35/200, Training Loss: 0.6844, Validation Loss: 1.8722


 18%|█▊        | 36/200 [01:17<05:47,  2.12s/it]

Epoch 36/200, Training Loss: 0.6736, Validation Loss: 1.8068


 18%|█▊        | 37/200 [01:19<05:47,  2.13s/it]

Epoch 37/200, Training Loss: 0.6734, Validation Loss: 1.8818


 19%|█▉        | 38/200 [01:22<05:46,  2.14s/it]

Epoch 38/200, Training Loss: 0.6707, Validation Loss: 1.9041


 20%|█▉        | 39/200 [01:24<05:42,  2.13s/it]

Epoch 39/200, Training Loss: 0.6711, Validation Loss: 1.8659


 20%|██        | 40/200 [01:26<05:40,  2.13s/it]

Epoch 40/200, Training Loss: 0.6808, Validation Loss: 1.8232


 20%|██        | 41/200 [01:28<05:36,  2.12s/it]

Epoch 41/200, Training Loss: 0.6679, Validation Loss: 1.8299


 21%|██        | 42/200 [01:30<05:34,  2.12s/it]

Epoch 42/200, Training Loss: 0.6690, Validation Loss: 1.8333


 22%|██▏       | 43/200 [01:32<05:32,  2.12s/it]

Epoch 43/200, Training Loss: 0.6723, Validation Loss: 1.8646


 22%|██▏       | 44/200 [01:34<05:29,  2.11s/it]

Epoch 44/200, Training Loss: 0.6701, Validation Loss: 1.8373


 22%|██▎       | 45/200 [01:36<05:33,  2.15s/it]

Epoch 45/200, Training Loss: 0.6717, Validation Loss: 1.8356


 23%|██▎       | 46/200 [01:39<05:39,  2.20s/it]

Epoch 46/200, Training Loss: 0.6691, Validation Loss: 1.8704


 24%|██▎       | 47/200 [01:41<05:36,  2.20s/it]

Epoch 47/200, Training Loss: 0.6603, Validation Loss: 1.8288


 24%|██▍       | 48/200 [01:43<05:31,  2.18s/it]

Epoch 48/200, Training Loss: 0.6772, Validation Loss: 1.8952


 24%|██▍       | 49/200 [01:45<05:26,  2.16s/it]

Epoch 49/200, Training Loss: 0.6657, Validation Loss: 1.7932


 25%|██▌       | 50/200 [01:47<05:21,  2.14s/it]

Epoch 50/200, Training Loss: 0.6673, Validation Loss: 1.8551


 26%|██▌       | 51/200 [01:49<05:18,  2.14s/it]

Epoch 51/200, Training Loss: 0.6717, Validation Loss: 1.8850


 26%|██▌       | 52/200 [01:52<05:17,  2.15s/it]

Epoch 52/200, Training Loss: 0.6660, Validation Loss: 1.8361


 26%|██▋       | 53/200 [01:54<05:14,  2.14s/it]

Epoch 53/200, Training Loss: 0.6695, Validation Loss: 1.8873


 27%|██▋       | 54/200 [01:56<05:11,  2.13s/it]

Epoch 54/200, Training Loss: 0.6737, Validation Loss: 1.8371


 28%|██▊       | 55/200 [01:58<05:07,  2.12s/it]

Epoch 55/200, Training Loss: 0.6698, Validation Loss: 1.8033


 28%|██▊       | 56/200 [02:00<05:07,  2.14s/it]

Epoch 56/200, Training Loss: 0.6737, Validation Loss: 1.8278


 28%|██▊       | 57/200 [02:02<05:03,  2.12s/it]

Epoch 57/200, Training Loss: 0.6754, Validation Loss: 1.8574


 29%|██▉       | 58/200 [02:04<05:00,  2.12s/it]

Epoch 58/200, Training Loss: 0.6628, Validation Loss: 1.8480


 30%|██▉       | 59/200 [02:06<04:58,  2.11s/it]

Epoch 59/200, Training Loss: 0.6631, Validation Loss: 1.8291


 30%|███       | 60/200 [02:09<05:01,  2.16s/it]

Epoch 60/200, Training Loss: 0.6621, Validation Loss: 1.8345


 30%|███       | 61/200 [02:11<05:05,  2.20s/it]

Epoch 61/200, Training Loss: 0.6690, Validation Loss: 1.8162


 31%|███       | 62/200 [02:13<04:58,  2.16s/it]

Epoch 62/200, Training Loss: 0.6655, Validation Loss: 1.8838


 32%|███▏      | 63/200 [02:15<04:52,  2.14s/it]

Epoch 63/200, Training Loss: 0.6670, Validation Loss: 1.8173


 32%|███▏      | 64/200 [02:17<04:48,  2.12s/it]

Epoch 64/200, Training Loss: 0.6687, Validation Loss: 1.8813


 32%|███▎      | 65/200 [02:19<04:45,  2.12s/it]

Epoch 65/200, Training Loss: 0.6659, Validation Loss: 1.9122


 33%|███▎      | 66/200 [02:21<04:44,  2.12s/it]

Epoch 66/200, Training Loss: 0.6634, Validation Loss: 1.8622


 34%|███▎      | 67/200 [02:23<04:40,  2.11s/it]

Epoch 67/200, Training Loss: 0.6633, Validation Loss: 1.8227


 34%|███▍      | 68/200 [02:26<04:40,  2.12s/it]

Epoch 68/200, Training Loss: 0.6620, Validation Loss: 1.9368


 34%|███▍      | 69/200 [02:28<04:36,  2.11s/it]

Epoch 69/200, Training Loss: 0.6725, Validation Loss: 1.8008


 35%|███▌      | 70/200 [02:30<04:34,  2.11s/it]

Epoch 70/200, Training Loss: 0.6686, Validation Loss: 1.8442


 36%|███▌      | 71/200 [02:32<04:32,  2.11s/it]

Epoch 71/200, Training Loss: 0.6659, Validation Loss: 1.8595


 36%|███▌      | 72/200 [02:34<04:31,  2.12s/it]

Epoch 72/200, Training Loss: 0.6743, Validation Loss: 1.8730


 36%|███▋      | 73/200 [02:36<04:28,  2.11s/it]

Epoch 73/200, Training Loss: 0.6768, Validation Loss: 1.8672


 37%|███▋      | 74/200 [02:38<04:26,  2.11s/it]

Epoch 74/200, Training Loss: 0.6620, Validation Loss: 1.8791


 38%|███▊      | 75/200 [02:41<04:29,  2.16s/it]

Epoch 75/200, Training Loss: 0.6676, Validation Loss: 1.8977


 38%|███▊      | 76/200 [02:43<04:32,  2.20s/it]

Epoch 76/200, Training Loss: 0.6669, Validation Loss: 1.7993


 38%|███▊      | 77/200 [02:45<04:27,  2.17s/it]

Epoch 77/200, Training Loss: 0.6700, Validation Loss: 1.9333


 39%|███▉      | 78/200 [02:47<04:23,  2.16s/it]

Epoch 78/200, Training Loss: 0.6678, Validation Loss: 1.9031


 40%|███▉      | 79/200 [02:49<04:19,  2.15s/it]

Epoch 79/200, Training Loss: 0.6641, Validation Loss: 1.9031


 40%|████      | 80/200 [02:51<04:17,  2.15s/it]

Epoch 80/200, Training Loss: 0.6829, Validation Loss: 1.8592


 40%|████      | 81/200 [02:53<04:14,  2.14s/it]

Epoch 81/200, Training Loss: 0.6647, Validation Loss: 1.8435


 41%|████      | 82/200 [02:56<04:10,  2.13s/it]

Epoch 82/200, Training Loss: 0.6699, Validation Loss: 1.8048


 42%|████▏     | 83/200 [02:58<04:06,  2.11s/it]

Epoch 83/200, Training Loss: 0.6661, Validation Loss: 1.8521


 42%|████▏     | 84/200 [03:00<04:05,  2.12s/it]

Epoch 84/200, Training Loss: 0.6708, Validation Loss: 1.8603


 42%|████▎     | 85/200 [03:02<04:04,  2.13s/it]

Epoch 85/200, Training Loss: 0.6704, Validation Loss: 1.8664


 43%|████▎     | 86/200 [03:04<04:02,  2.12s/it]

Epoch 86/200, Training Loss: 0.6687, Validation Loss: 1.9419


 44%|████▎     | 87/200 [03:06<03:59,  2.12s/it]

Epoch 87/200, Training Loss: 0.6678, Validation Loss: 1.8253


 44%|████▍     | 88/200 [03:08<03:57,  2.12s/it]

Epoch 88/200, Training Loss: 0.6714, Validation Loss: 1.8510


 44%|████▍     | 89/200 [03:10<03:58,  2.14s/it]

Epoch 89/200, Training Loss: 0.6704, Validation Loss: 1.8437


 45%|████▌     | 90/200 [03:13<04:00,  2.19s/it]

Epoch 90/200, Training Loss: 0.6680, Validation Loss: 1.8608


 46%|████▌     | 91/200 [03:15<04:01,  2.22s/it]

Epoch 91/200, Training Loss: 0.6600, Validation Loss: 1.8530


 46%|████▌     | 92/200 [03:17<03:55,  2.18s/it]

Epoch 92/200, Training Loss: 0.6685, Validation Loss: 1.8053


 46%|████▋     | 93/200 [03:19<03:50,  2.15s/it]

Epoch 93/200, Training Loss: 0.6674, Validation Loss: 1.8885


 47%|████▋     | 94/200 [03:21<03:47,  2.15s/it]

Epoch 94/200, Training Loss: 0.6660, Validation Loss: 1.8204


 48%|████▊     | 95/200 [03:23<03:43,  2.13s/it]

Epoch 95/200, Training Loss: 0.6727, Validation Loss: 1.8909


 48%|████▊     | 96/200 [03:26<03:42,  2.14s/it]

Epoch 96/200, Training Loss: 0.6655, Validation Loss: 1.8166


 48%|████▊     | 97/200 [03:28<03:38,  2.12s/it]

Epoch 97/200, Training Loss: 0.6693, Validation Loss: 1.8534


 49%|████▉     | 98/200 [03:30<03:37,  2.13s/it]

Epoch 98/200, Training Loss: 0.6740, Validation Loss: 1.8714


 50%|████▉     | 99/200 [03:32<03:36,  2.14s/it]

Epoch 99/200, Training Loss: 0.6559, Validation Loss: 1.8411


 50%|█████     | 100/200 [03:34<03:33,  2.13s/it]

Epoch 100/200, Training Loss: 0.6629, Validation Loss: 1.8760


 50%|█████     | 101/200 [03:36<03:30,  2.12s/it]

Epoch 101/200, Training Loss: 0.6658, Validation Loss: 1.8096


 51%|█████     | 102/200 [03:38<03:27,  2.12s/it]

Epoch 102/200, Training Loss: 0.6624, Validation Loss: 1.8342


 52%|█████▏    | 103/200 [03:41<03:26,  2.13s/it]

Epoch 103/200, Training Loss: 0.6685, Validation Loss: 1.8292


 52%|█████▏    | 104/200 [03:43<03:23,  2.12s/it]

Epoch 104/200, Training Loss: 0.6796, Validation Loss: 1.8481


 52%|█████▎    | 105/200 [03:45<03:26,  2.17s/it]

Epoch 105/200, Training Loss: 0.6643, Validation Loss: 1.8626


 53%|█████▎    | 106/200 [03:47<03:26,  2.20s/it]

Epoch 106/200, Training Loss: 0.6704, Validation Loss: 1.9177


 54%|█████▎    | 107/200 [03:49<03:21,  2.17s/it]

Epoch 107/200, Training Loss: 0.6673, Validation Loss: 1.8424


 54%|█████▍    | 108/200 [03:51<03:18,  2.16s/it]

Epoch 108/200, Training Loss: 0.6714, Validation Loss: 1.8530


 55%|█████▍    | 109/200 [03:53<03:14,  2.14s/it]

Epoch 109/200, Training Loss: 0.6640, Validation Loss: 1.9014


 55%|█████▌    | 110/200 [03:56<03:11,  2.13s/it]

Epoch 110/200, Training Loss: 0.6628, Validation Loss: 1.8781


 56%|█████▌    | 111/200 [03:58<03:09,  2.13s/it]

Epoch 111/200, Training Loss: 0.6660, Validation Loss: 1.8549


 56%|█████▌    | 112/200 [04:00<03:07,  2.13s/it]

Epoch 112/200, Training Loss: 0.6744, Validation Loss: 1.8283


 56%|█████▋    | 113/200 [04:02<03:05,  2.13s/it]

Epoch 113/200, Training Loss: 0.6650, Validation Loss: 1.8797


 57%|█████▋    | 114/200 [04:04<03:02,  2.12s/it]

Epoch 114/200, Training Loss: 0.6671, Validation Loss: 1.7739


 57%|█████▊    | 115/200 [04:06<03:00,  2.12s/it]

Epoch 115/200, Training Loss: 0.6622, Validation Loss: 1.9036


 58%|█████▊    | 116/200 [04:08<02:59,  2.13s/it]

Epoch 116/200, Training Loss: 0.6671, Validation Loss: 1.8553


 58%|█████▊    | 117/200 [04:11<02:58,  2.15s/it]

Epoch 117/200, Training Loss: 0.6635, Validation Loss: 1.8731


 59%|█████▉    | 118/200 [04:13<02:56,  2.15s/it]

Epoch 118/200, Training Loss: 0.6687, Validation Loss: 1.8698


 60%|█████▉    | 119/200 [04:15<02:54,  2.15s/it]

Epoch 119/200, Training Loss: 0.6710, Validation Loss: 1.8589


 60%|██████    | 120/200 [04:17<02:57,  2.21s/it]

Epoch 120/200, Training Loss: 0.6664, Validation Loss: 1.8459


 60%|██████    | 121/200 [04:19<02:54,  2.21s/it]

Epoch 121/200, Training Loss: 0.6647, Validation Loss: 1.8212


 61%|██████    | 122/200 [04:22<02:51,  2.20s/it]

Epoch 122/200, Training Loss: 0.6750, Validation Loss: 1.8777


 62%|██████▏   | 123/200 [04:24<02:47,  2.18s/it]

Epoch 123/200, Training Loss: 0.6643, Validation Loss: 1.8217


 62%|██████▏   | 124/200 [04:26<02:44,  2.16s/it]

Epoch 124/200, Training Loss: 0.6627, Validation Loss: 1.8348


 62%|██████▎   | 125/200 [04:28<02:41,  2.16s/it]

Epoch 125/200, Training Loss: 0.6670, Validation Loss: 1.8344


 63%|██████▎   | 126/200 [04:30<02:39,  2.15s/it]

Epoch 126/200, Training Loss: 0.6616, Validation Loss: 1.8126


 64%|██████▎   | 127/200 [04:32<02:35,  2.13s/it]

Epoch 127/200, Training Loss: 0.6603, Validation Loss: 1.8635


 64%|██████▍   | 128/200 [04:34<02:33,  2.13s/it]

Epoch 128/200, Training Loss: 0.6641, Validation Loss: 1.8569


 64%|██████▍   | 129/200 [04:36<02:31,  2.13s/it]

Epoch 129/200, Training Loss: 0.6742, Validation Loss: 1.8194


 65%|██████▌   | 130/200 [04:39<02:29,  2.14s/it]

Epoch 130/200, Training Loss: 0.6605, Validation Loss: 1.8342


 66%|██████▌   | 131/200 [04:41<02:28,  2.15s/it]

Epoch 131/200, Training Loss: 0.6630, Validation Loss: 1.8371


 66%|██████▌   | 132/200 [04:43<02:25,  2.14s/it]

Epoch 132/200, Training Loss: 0.6707, Validation Loss: 1.9257


 66%|██████▋   | 133/200 [04:45<02:22,  2.13s/it]

Epoch 133/200, Training Loss: 0.6726, Validation Loss: 1.8570


 67%|██████▋   | 134/200 [04:47<02:20,  2.13s/it]

Epoch 134/200, Training Loss: 0.6750, Validation Loss: 2.2255


 68%|██████▊   | 135/200 [04:50<02:24,  2.23s/it]

Epoch 135/200, Training Loss: 0.6684, Validation Loss: 1.9070


 68%|██████▊   | 136/200 [04:52<02:21,  2.21s/it]

Epoch 136/200, Training Loss: 0.6664, Validation Loss: 1.8610


 68%|██████▊   | 137/200 [04:54<02:18,  2.19s/it]

Epoch 137/200, Training Loss: 0.6726, Validation Loss: 1.8412


 69%|██████▉   | 138/200 [04:56<02:14,  2.17s/it]

Epoch 138/200, Training Loss: 0.6662, Validation Loss: 1.8453


 70%|██████▉   | 139/200 [04:58<02:11,  2.16s/it]

Epoch 139/200, Training Loss: 0.6689, Validation Loss: 1.9772


 70%|███████   | 140/200 [05:00<02:08,  2.15s/it]

Epoch 140/200, Training Loss: 0.6829, Validation Loss: 1.9134


 70%|███████   | 141/200 [05:02<02:06,  2.14s/it]

Epoch 141/200, Training Loss: 0.6694, Validation Loss: 1.8659


 71%|███████   | 142/200 [05:05<02:03,  2.13s/it]

Epoch 142/200, Training Loss: 0.6679, Validation Loss: 1.8586


 72%|███████▏  | 143/200 [05:07<02:00,  2.12s/it]

Epoch 143/200, Training Loss: 0.6629, Validation Loss: 1.8422


 72%|███████▏  | 144/200 [05:09<01:58,  2.12s/it]

Epoch 144/200, Training Loss: 0.6649, Validation Loss: 1.8185


 72%|███████▎  | 145/200 [05:11<01:57,  2.14s/it]

Epoch 145/200, Training Loss: 0.6649, Validation Loss: 2.1087


 73%|███████▎  | 146/200 [05:13<01:55,  2.14s/it]

Epoch 146/200, Training Loss: 0.6619, Validation Loss: 1.9048


 74%|███████▎  | 147/200 [05:15<01:52,  2.13s/it]

Epoch 147/200, Training Loss: 0.6571, Validation Loss: 1.8710


 74%|███████▍  | 148/200 [05:17<01:50,  2.12s/it]

Epoch 148/200, Training Loss: 0.6703, Validation Loss: 1.9344


 74%|███████▍  | 149/200 [05:19<01:48,  2.12s/it]

Epoch 149/200, Training Loss: 0.6684, Validation Loss: 1.8669


 75%|███████▌  | 150/200 [05:22<01:51,  2.22s/it]

Epoch 150/200, Training Loss: 0.6583, Validation Loss: 1.8402


 76%|███████▌  | 151/200 [05:24<01:46,  2.18s/it]

Epoch 151/200, Training Loss: 0.6723, Validation Loss: 1.8099


 76%|███████▌  | 152/200 [05:26<01:43,  2.16s/it]

Epoch 152/200, Training Loss: 0.6682, Validation Loss: 1.8194


 76%|███████▋  | 153/200 [05:28<01:41,  2.15s/it]

Epoch 153/200, Training Loss: 0.6692, Validation Loss: 1.8006


 77%|███████▋  | 154/200 [05:30<01:38,  2.15s/it]

Epoch 154/200, Training Loss: 0.6604, Validation Loss: 1.9014


 78%|███████▊  | 155/200 [05:32<01:36,  2.14s/it]

Epoch 155/200, Training Loss: 0.6590, Validation Loss: 1.8037


 78%|███████▊  | 156/200 [05:35<01:34,  2.14s/it]

Epoch 156/200, Training Loss: 0.6690, Validation Loss: 1.8225


 78%|███████▊  | 157/200 [05:37<01:31,  2.13s/it]

Epoch 157/200, Training Loss: 0.6609, Validation Loss: 1.7807


 79%|███████▉  | 158/200 [05:39<01:29,  2.14s/it]

Epoch 158/200, Training Loss: 0.6633, Validation Loss: 1.8504


 80%|███████▉  | 159/200 [05:41<01:27,  2.14s/it]

Epoch 159/200, Training Loss: 0.6652, Validation Loss: 1.9025


 80%|████████  | 160/200 [05:43<01:25,  2.13s/it]

Epoch 160/200, Training Loss: 0.6668, Validation Loss: 1.8622


 80%|████████  | 161/200 [05:45<01:23,  2.13s/it]

Epoch 161/200, Training Loss: 0.6699, Validation Loss: 1.8411


 81%|████████  | 162/200 [05:47<01:20,  2.12s/it]

Epoch 162/200, Training Loss: 0.6620, Validation Loss: 1.8730


 82%|████████▏ | 163/200 [05:49<01:18,  2.12s/it]

Epoch 163/200, Training Loss: 0.6649, Validation Loss: 1.8794


 82%|████████▏ | 164/200 [05:52<01:16,  2.13s/it]

Epoch 164/200, Training Loss: 0.6607, Validation Loss: 1.8675


 82%|████████▎ | 165/200 [05:54<01:17,  2.21s/it]

Epoch 165/200, Training Loss: 0.6647, Validation Loss: 1.8036


 83%|████████▎ | 166/200 [05:56<01:14,  2.18s/it]

Epoch 166/200, Training Loss: 0.6591, Validation Loss: 1.8140


 84%|████████▎ | 167/200 [05:58<01:11,  2.16s/it]

Epoch 167/200, Training Loss: 0.6707, Validation Loss: 1.8252


 84%|████████▍ | 168/200 [06:00<01:09,  2.16s/it]

Epoch 168/200, Training Loss: 0.6628, Validation Loss: 1.8977


 84%|████████▍ | 169/200 [06:03<01:06,  2.15s/it]

Epoch 169/200, Training Loss: 0.6700, Validation Loss: 1.8389


 85%|████████▌ | 170/200 [06:05<01:04,  2.15s/it]

Epoch 170/200, Training Loss: 0.6640, Validation Loss: 1.8365


 86%|████████▌ | 171/200 [06:07<01:01,  2.13s/it]

Epoch 171/200, Training Loss: 0.6642, Validation Loss: 1.8392


 86%|████████▌ | 172/200 [06:09<00:59,  2.12s/it]

Epoch 172/200, Training Loss: 0.6566, Validation Loss: 1.8485


 86%|████████▋ | 173/200 [06:11<00:57,  2.13s/it]

Epoch 173/200, Training Loss: 0.6669, Validation Loss: 1.8609


 87%|████████▋ | 174/200 [06:13<00:55,  2.12s/it]

Epoch 174/200, Training Loss: 0.6669, Validation Loss: 1.8320


 88%|████████▊ | 175/200 [06:15<00:52,  2.11s/it]

Epoch 175/200, Training Loss: 0.6600, Validation Loss: 1.9185


 88%|████████▊ | 176/200 [06:17<00:50,  2.11s/it]

Epoch 176/200, Training Loss: 0.6614, Validation Loss: 1.8888


 88%|████████▊ | 177/200 [06:19<00:48,  2.11s/it]

Epoch 177/200, Training Loss: 0.6582, Validation Loss: 1.8678


 89%|████████▉ | 178/200 [06:22<00:46,  2.13s/it]

Epoch 178/200, Training Loss: 0.6554, Validation Loss: 1.8478


 90%|████████▉ | 179/200 [06:24<00:44,  2.11s/it]

Epoch 179/200, Training Loss: 0.6729, Validation Loss: 1.8638


 90%|█████████ | 180/200 [06:26<00:44,  2.21s/it]

Epoch 180/200, Training Loss: 0.6611, Validation Loss: 1.8456


 90%|█████████ | 181/200 [06:28<00:41,  2.18s/it]

Epoch 181/200, Training Loss: 0.6711, Validation Loss: 1.8506


 91%|█████████ | 182/200 [06:30<00:39,  2.17s/it]

Epoch 182/200, Training Loss: 0.6676, Validation Loss: 1.8817


 92%|█████████▏| 183/200 [06:32<00:36,  2.15s/it]

Epoch 183/200, Training Loss: 0.6935, Validation Loss: 1.8281


 92%|█████████▏| 184/200 [06:35<00:34,  2.14s/it]

Epoch 184/200, Training Loss: 0.6654, Validation Loss: 1.8297


 92%|█████████▎| 185/200 [06:37<00:31,  2.13s/it]

Epoch 185/200, Training Loss: 0.6700, Validation Loss: 1.8387


 93%|█████████▎| 186/200 [06:39<00:29,  2.13s/it]

Epoch 186/200, Training Loss: 0.6664, Validation Loss: 1.9240


 94%|█████████▎| 187/200 [06:41<00:27,  2.13s/it]

Epoch 187/200, Training Loss: 0.6677, Validation Loss: 1.8955


 94%|█████████▍| 188/200 [06:43<00:25,  2.14s/it]

Epoch 188/200, Training Loss: 0.6733, Validation Loss: 1.8306


 94%|█████████▍| 189/200 [06:45<00:23,  2.12s/it]

Epoch 189/200, Training Loss: 0.6663, Validation Loss: 1.8004


 95%|█████████▌| 190/200 [06:47<00:21,  2.12s/it]

Epoch 190/200, Training Loss: 0.6639, Validation Loss: 1.9091


 96%|█████████▌| 191/200 [06:49<00:19,  2.13s/it]

Epoch 191/200, Training Loss: 0.6791, Validation Loss: 1.8660


 96%|█████████▌| 192/200 [06:52<00:17,  2.14s/it]

Epoch 192/200, Training Loss: 0.6585, Validation Loss: 1.8121


 96%|█████████▋| 193/200 [06:54<00:14,  2.14s/it]

Epoch 193/200, Training Loss: 0.6649, Validation Loss: 1.8598


 97%|█████████▋| 194/200 [06:56<00:12,  2.15s/it]

Epoch 194/200, Training Loss: 0.6673, Validation Loss: 1.8182


 98%|█████████▊| 195/200 [06:58<00:11,  2.22s/it]

Epoch 195/200, Training Loss: 0.6575, Validation Loss: 1.8789


 98%|█████████▊| 196/200 [07:00<00:08,  2.19s/it]

Epoch 196/200, Training Loss: 0.6575, Validation Loss: 1.8442


 98%|█████████▊| 197/200 [07:03<00:06,  2.16s/it]

Epoch 197/200, Training Loss: 0.6624, Validation Loss: 1.8817


 99%|█████████▉| 198/200 [07:05<00:04,  2.16s/it]

Epoch 198/200, Training Loss: 0.6726, Validation Loss: 1.8658


100%|█████████▉| 199/200 [07:07<00:02,  2.15s/it]

Epoch 199/200, Training Loss: 0.6562, Validation Loss: 1.8716


100%|██████████| 200/200 [07:09<00:00,  2.15s/it]

Epoch 200/200, Training Loss: 0.6614, Validation Loss: 1.8594





In [135]:
# test the model
zinc_model_one_hot.eval()
loss_fn = torch.nn.L1Loss()
with torch.no_grad():
    test_loss = 0.0
    for data in test_loader:
        data = data.to(device)
        out = zinc_model_one_hot(data.x, data.edge_index, data.batch)
        loss = loss_fn(out, data.y)
        test_loss += loss.item()
    test_loss /= len(test_loader)            
    print(f"Test Loss: {test_loss:.4f}")

Test Loss: 1.9577


In [147]:
zinc_model_atom_emb = GraphNet(
    in_features=max_num_atoms,
    out_features=1,
    hidden_features=32,
    num_layers=2,
    dropout=0.2,
    reduction="mean",
    pooling="mean",
    task="regression",
    feature_encoder=atom_encoder,
).to(device)

In [148]:
train_zinc(zinc_model_atom_emb, train_loader, val_loader, epochs=200, lr=0.001, weight_decay=5e-2, device=device)

  0%|          | 1/200 [00:02<07:12,  2.17s/it]

Epoch 1/200, Training Loss: 1.1416, Validation Loss: 1.5411


  1%|          | 2/200 [00:04<07:08,  2.17s/it]

Epoch 2/200, Training Loss: 0.7408, Validation Loss: 1.9055


  2%|▏         | 3/200 [00:06<07:08,  2.17s/it]

Epoch 3/200, Training Loss: 0.7259, Validation Loss: 1.8368


  2%|▏         | 4/200 [00:08<07:02,  2.16s/it]

Epoch 4/200, Training Loss: 0.7206, Validation Loss: 1.7890


  2%|▎         | 5/200 [00:10<07:01,  2.16s/it]

Epoch 5/200, Training Loss: 0.7105, Validation Loss: 1.7627


  3%|▎         | 6/200 [00:12<06:59,  2.16s/it]

Epoch 6/200, Training Loss: 0.7005, Validation Loss: 1.8605


  4%|▎         | 7/200 [00:15<06:56,  2.16s/it]

Epoch 7/200, Training Loss: 0.7111, Validation Loss: 1.8861


  4%|▍         | 8/200 [00:17<06:52,  2.15s/it]

Epoch 8/200, Training Loss: 0.7066, Validation Loss: 1.8411


  4%|▍         | 9/200 [00:19<06:48,  2.14s/it]

Epoch 9/200, Training Loss: 0.7045, Validation Loss: 1.8914


  5%|▌         | 10/200 [00:21<06:58,  2.20s/it]

Epoch 10/200, Training Loss: 0.7004, Validation Loss: 1.8039


  6%|▌         | 11/200 [00:23<06:59,  2.22s/it]

Epoch 11/200, Training Loss: 0.7048, Validation Loss: 1.8538


  6%|▌         | 12/200 [00:26<06:54,  2.21s/it]

Epoch 12/200, Training Loss: 0.7178, Validation Loss: 1.8448


  6%|▋         | 13/200 [00:28<06:50,  2.20s/it]

Epoch 13/200, Training Loss: 0.7057, Validation Loss: 1.8480


  7%|▋         | 14/200 [00:30<06:46,  2.18s/it]

Epoch 14/200, Training Loss: 0.7055, Validation Loss: 1.7943


  8%|▊         | 15/200 [00:32<06:44,  2.19s/it]

Epoch 15/200, Training Loss: 0.7124, Validation Loss: 1.8302


  8%|▊         | 16/200 [00:34<06:41,  2.18s/it]

Epoch 16/200, Training Loss: 0.7051, Validation Loss: 1.8516


  8%|▊         | 17/200 [00:37<06:40,  2.19s/it]

Epoch 17/200, Training Loss: 0.7149, Validation Loss: 1.8100


  9%|▉         | 18/200 [00:39<06:36,  2.18s/it]

Epoch 18/200, Training Loss: 0.7001, Validation Loss: 1.8395


 10%|▉         | 19/200 [00:41<06:37,  2.19s/it]

Epoch 19/200, Training Loss: 0.7148, Validation Loss: 1.7920


 10%|█         | 20/200 [00:43<06:32,  2.18s/it]

Epoch 20/200, Training Loss: 0.7056, Validation Loss: 1.8240


 10%|█         | 21/200 [00:45<06:31,  2.18s/it]

Epoch 21/200, Training Loss: 0.7067, Validation Loss: 1.9052


 11%|█         | 22/200 [00:47<06:25,  2.17s/it]

Epoch 22/200, Training Loss: 0.7105, Validation Loss: 1.8210


 12%|█▏        | 23/200 [00:50<06:22,  2.16s/it]

Epoch 23/200, Training Loss: 0.6995, Validation Loss: 1.7994


 12%|█▏        | 24/200 [00:52<06:19,  2.15s/it]

Epoch 24/200, Training Loss: 0.7135, Validation Loss: 1.8050


 12%|█▎        | 25/200 [00:54<06:31,  2.23s/it]

Epoch 25/200, Training Loss: 0.7091, Validation Loss: 1.9210


 13%|█▎        | 26/200 [00:56<06:24,  2.21s/it]

Epoch 26/200, Training Loss: 0.7155, Validation Loss: 1.8579


 14%|█▎        | 27/200 [00:58<06:17,  2.18s/it]

Epoch 27/200, Training Loss: 0.7112, Validation Loss: 1.8506


 14%|█▍        | 28/200 [01:00<06:12,  2.16s/it]

Epoch 28/200, Training Loss: 0.7096, Validation Loss: 1.8667


 14%|█▍        | 29/200 [01:03<06:08,  2.16s/it]

Epoch 29/200, Training Loss: 0.7103, Validation Loss: 1.8500


 15%|█▌        | 30/200 [01:05<06:06,  2.16s/it]

Epoch 30/200, Training Loss: 0.7088, Validation Loss: 1.8053


 16%|█▌        | 31/200 [01:07<06:04,  2.16s/it]

Epoch 31/200, Training Loss: 0.7152, Validation Loss: 1.8443


 16%|█▌        | 32/200 [01:09<06:01,  2.15s/it]

Epoch 32/200, Training Loss: 0.7057, Validation Loss: 1.8559


 16%|█▋        | 33/200 [01:11<05:58,  2.15s/it]

Epoch 33/200, Training Loss: 0.7089, Validation Loss: 1.8148


 17%|█▋        | 34/200 [01:13<05:56,  2.15s/it]

Epoch 34/200, Training Loss: 0.7129, Validation Loss: 1.8535


 18%|█▊        | 35/200 [01:16<05:55,  2.15s/it]

Epoch 35/200, Training Loss: 0.7149, Validation Loss: 1.8765


 18%|█▊        | 36/200 [01:18<05:50,  2.14s/it]

Epoch 36/200, Training Loss: 0.7029, Validation Loss: 1.8457


 18%|█▊        | 37/200 [01:20<05:48,  2.14s/it]

Epoch 37/200, Training Loss: 0.7106, Validation Loss: 1.8653


 19%|█▉        | 38/200 [01:22<05:48,  2.15s/it]

Epoch 38/200, Training Loss: 0.7039, Validation Loss: 1.7952


 20%|█▉        | 39/200 [01:24<05:46,  2.15s/it]

Epoch 39/200, Training Loss: 0.7033, Validation Loss: 1.8106


 20%|██        | 40/200 [01:27<06:00,  2.26s/it]

Epoch 40/200, Training Loss: 0.7074, Validation Loss: 1.8201


 20%|██        | 41/200 [01:29<05:53,  2.22s/it]

Epoch 41/200, Training Loss: 0.7176, Validation Loss: 1.8389


 21%|██        | 42/200 [01:31<05:46,  2.20s/it]

Epoch 42/200, Training Loss: 0.7095, Validation Loss: 1.8358


 22%|██▏       | 43/200 [01:33<05:42,  2.18s/it]

Epoch 43/200, Training Loss: 0.7065, Validation Loss: 1.8020


 22%|██▏       | 44/200 [01:35<05:39,  2.18s/it]

Epoch 44/200, Training Loss: 0.7076, Validation Loss: 1.8412


 22%|██▎       | 45/200 [01:37<05:35,  2.17s/it]

Epoch 45/200, Training Loss: 0.7105, Validation Loss: 1.8390


 23%|██▎       | 46/200 [01:40<05:34,  2.17s/it]

Epoch 46/200, Training Loss: 0.7133, Validation Loss: 1.8425


 24%|██▎       | 47/200 [01:42<05:30,  2.16s/it]

Epoch 47/200, Training Loss: 0.7000, Validation Loss: 1.8488


 24%|██▍       | 48/200 [01:44<05:27,  2.15s/it]

Epoch 48/200, Training Loss: 0.7014, Validation Loss: 1.8085


 24%|██▍       | 49/200 [01:46<05:25,  2.15s/it]

Epoch 49/200, Training Loss: 0.7036, Validation Loss: 1.8456


 25%|██▌       | 50/200 [01:48<05:23,  2.16s/it]

Epoch 50/200, Training Loss: 0.7023, Validation Loss: 1.8673


 26%|██▌       | 51/200 [01:50<05:22,  2.16s/it]

Epoch 51/200, Training Loss: 0.7088, Validation Loss: 1.8017


 26%|██▌       | 52/200 [01:52<05:18,  2.15s/it]

Epoch 52/200, Training Loss: 0.7046, Validation Loss: 1.9489


 26%|██▋       | 53/200 [01:55<05:16,  2.15s/it]

Epoch 53/200, Training Loss: 0.7036, Validation Loss: 1.8635


 27%|██▋       | 54/200 [01:57<05:22,  2.21s/it]

Epoch 54/200, Training Loss: 0.7046, Validation Loss: 1.8865


 28%|██▊       | 55/200 [01:59<05:23,  2.23s/it]

Epoch 55/200, Training Loss: 0.6999, Validation Loss: 1.8075


 28%|██▊       | 56/200 [02:01<05:17,  2.20s/it]

Epoch 56/200, Training Loss: 0.7050, Validation Loss: 1.8191


 28%|██▊       | 57/200 [02:03<05:11,  2.18s/it]

Epoch 57/200, Training Loss: 0.7077, Validation Loss: 1.8112


 29%|██▉       | 58/200 [02:06<05:09,  2.18s/it]

Epoch 58/200, Training Loss: 0.7013, Validation Loss: 1.7906


 30%|██▉       | 59/200 [02:08<05:04,  2.16s/it]

Epoch 59/200, Training Loss: 0.7082, Validation Loss: 1.8502


 30%|███       | 60/200 [02:10<05:02,  2.16s/it]

Epoch 60/200, Training Loss: 0.7072, Validation Loss: 1.8276


 30%|███       | 61/200 [02:12<04:59,  2.16s/it]

Epoch 61/200, Training Loss: 0.7055, Validation Loss: 1.8352


 31%|███       | 62/200 [02:14<04:56,  2.15s/it]

Epoch 62/200, Training Loss: 0.7076, Validation Loss: 1.8150


 32%|███▏      | 63/200 [02:16<04:57,  2.17s/it]

Epoch 63/200, Training Loss: 0.7101, Validation Loss: 1.7858


 32%|███▏      | 64/200 [02:19<04:52,  2.15s/it]

Epoch 64/200, Training Loss: 0.7098, Validation Loss: 1.8112


 32%|███▎      | 65/200 [02:21<04:49,  2.15s/it]

Epoch 65/200, Training Loss: 0.7118, Validation Loss: 1.7857


 33%|███▎      | 66/200 [02:23<04:47,  2.15s/it]

Epoch 66/200, Training Loss: 0.7129, Validation Loss: 1.7720


 34%|███▎      | 67/200 [02:25<04:45,  2.15s/it]

Epoch 67/200, Training Loss: 0.7054, Validation Loss: 1.8323


 34%|███▍      | 68/200 [02:27<04:42,  2.14s/it]

Epoch 68/200, Training Loss: 0.7027, Validation Loss: 1.8340


 34%|███▍      | 69/200 [02:30<04:52,  2.24s/it]

Epoch 69/200, Training Loss: 0.7087, Validation Loss: 1.8103


 35%|███▌      | 70/200 [02:32<04:48,  2.22s/it]

Epoch 70/200, Training Loss: 0.7059, Validation Loss: 1.8426


 36%|███▌      | 71/200 [02:34<04:42,  2.19s/it]

Epoch 71/200, Training Loss: 0.6996, Validation Loss: 1.7917


 36%|███▌      | 72/200 [02:36<04:39,  2.18s/it]

Epoch 72/200, Training Loss: 0.7065, Validation Loss: 1.8199


 36%|███▋      | 73/200 [02:38<04:33,  2.16s/it]

Epoch 73/200, Training Loss: 0.7125, Validation Loss: 1.8416


 37%|███▋      | 74/200 [02:40<04:31,  2.16s/it]

Epoch 74/200, Training Loss: 0.7027, Validation Loss: 1.8015


 38%|███▊      | 75/200 [02:42<04:28,  2.15s/it]

Epoch 75/200, Training Loss: 0.6982, Validation Loss: 1.8503


 38%|███▊      | 76/200 [02:45<04:26,  2.15s/it]

Epoch 76/200, Training Loss: 0.7013, Validation Loss: 1.8495


 38%|███▊      | 77/200 [02:47<04:24,  2.15s/it]

Epoch 77/200, Training Loss: 0.6969, Validation Loss: 1.8291


 39%|███▉      | 78/200 [02:49<04:22,  2.15s/it]

Epoch 78/200, Training Loss: 0.7003, Validation Loss: 1.8409


 40%|███▉      | 79/200 [02:51<04:18,  2.14s/it]

Epoch 79/200, Training Loss: 0.7036, Validation Loss: 1.8183


 40%|████      | 80/200 [02:53<04:15,  2.13s/it]

Epoch 80/200, Training Loss: 0.6986, Validation Loss: 1.8015


 40%|████      | 81/200 [02:55<04:15,  2.15s/it]

Epoch 81/200, Training Loss: 0.7110, Validation Loss: 1.8294


 41%|████      | 82/200 [02:57<04:14,  2.15s/it]

Epoch 82/200, Training Loss: 0.6899, Validation Loss: 1.8444


 42%|████▏     | 83/200 [03:00<04:11,  2.15s/it]

Epoch 83/200, Training Loss: 0.7073, Validation Loss: 1.8212


 42%|████▏     | 84/200 [03:02<04:19,  2.24s/it]

Epoch 84/200, Training Loss: 0.6999, Validation Loss: 1.8242


 42%|████▎     | 85/200 [03:04<04:15,  2.22s/it]

Epoch 85/200, Training Loss: 0.6992, Validation Loss: 1.8549


 43%|████▎     | 86/200 [03:06<04:12,  2.21s/it]

Epoch 86/200, Training Loss: 0.7064, Validation Loss: 1.7842


 44%|████▎     | 87/200 [03:08<04:06,  2.18s/it]

Epoch 87/200, Training Loss: 0.7054, Validation Loss: 1.8111


 44%|████▍     | 88/200 [03:11<04:02,  2.16s/it]

Epoch 88/200, Training Loss: 0.7067, Validation Loss: 1.8325


 44%|████▍     | 89/200 [03:13<03:58,  2.15s/it]

Epoch 89/200, Training Loss: 0.6998, Validation Loss: 1.8301


 45%|████▌     | 90/200 [03:15<03:59,  2.17s/it]

Epoch 90/200, Training Loss: 0.7100, Validation Loss: 1.8004


 46%|████▌     | 91/200 [03:17<03:55,  2.16s/it]

Epoch 91/200, Training Loss: 0.6993, Validation Loss: 1.7912


 46%|████▌     | 92/200 [03:19<03:52,  2.15s/it]

Epoch 92/200, Training Loss: 0.6964, Validation Loss: 1.7945


 46%|████▋     | 93/200 [03:21<03:49,  2.14s/it]

Epoch 93/200, Training Loss: 0.7110, Validation Loss: 1.8324


 47%|████▋     | 94/200 [03:23<03:46,  2.13s/it]

Epoch 94/200, Training Loss: 0.6975, Validation Loss: 1.8368


 48%|████▊     | 95/200 [03:26<03:45,  2.15s/it]

Epoch 95/200, Training Loss: 0.7036, Validation Loss: 1.8242


 48%|████▊     | 96/200 [03:28<03:43,  2.15s/it]

Epoch 96/200, Training Loss: 0.7035, Validation Loss: 1.7932


 48%|████▊     | 97/200 [03:30<03:40,  2.14s/it]

Epoch 97/200, Training Loss: 0.7052, Validation Loss: 1.8177


 49%|████▉     | 98/200 [03:32<03:39,  2.15s/it]

Epoch 98/200, Training Loss: 0.7030, Validation Loss: 1.7975


 50%|████▉     | 99/200 [03:35<03:45,  2.23s/it]

Epoch 99/200, Training Loss: 0.7046, Validation Loss: 1.8723


 50%|█████     | 100/200 [03:37<03:40,  2.21s/it]

Epoch 100/200, Training Loss: 0.7055, Validation Loss: 1.8127


 50%|█████     | 101/200 [03:39<03:36,  2.19s/it]

Epoch 101/200, Training Loss: 0.7027, Validation Loss: 1.8245


 51%|█████     | 102/200 [03:41<03:33,  2.18s/it]

Epoch 102/200, Training Loss: 0.7076, Validation Loss: 1.8029


 52%|█████▏    | 103/200 [03:43<03:29,  2.16s/it]

Epoch 103/200, Training Loss: 0.7051, Validation Loss: 1.8312


 52%|█████▏    | 104/200 [03:45<03:27,  2.16s/it]

Epoch 104/200, Training Loss: 0.7044, Validation Loss: 1.8139


 52%|█████▎    | 105/200 [03:47<03:24,  2.15s/it]

Epoch 105/200, Training Loss: 0.7041, Validation Loss: 1.8170


 53%|█████▎    | 106/200 [03:50<03:22,  2.15s/it]

Epoch 106/200, Training Loss: 0.7145, Validation Loss: 1.7887


 54%|█████▎    | 107/200 [03:52<03:20,  2.16s/it]

Epoch 107/200, Training Loss: 0.7073, Validation Loss: 1.7925


 54%|█████▍    | 108/200 [03:54<03:18,  2.16s/it]

Epoch 108/200, Training Loss: 0.7020, Validation Loss: 1.8117


 55%|█████▍    | 109/200 [03:56<03:16,  2.16s/it]

Epoch 109/200, Training Loss: 0.7060, Validation Loss: 1.8018


 55%|█████▌    | 110/200 [03:58<03:13,  2.15s/it]

Epoch 110/200, Training Loss: 0.6963, Validation Loss: 1.8122


 56%|█████▌    | 111/200 [04:00<03:11,  2.15s/it]

Epoch 111/200, Training Loss: 0.7083, Validation Loss: 1.7840


 56%|█████▌    | 112/200 [04:02<03:08,  2.15s/it]

Epoch 112/200, Training Loss: 0.7012, Validation Loss: 1.8283


 56%|█████▋    | 113/200 [04:05<03:11,  2.21s/it]

Epoch 113/200, Training Loss: 0.6943, Validation Loss: 1.8551


 57%|█████▋    | 114/200 [04:07<03:12,  2.23s/it]

Epoch 114/200, Training Loss: 0.6996, Validation Loss: 1.8353


 57%|█████▊    | 115/200 [04:09<03:07,  2.21s/it]

Epoch 115/200, Training Loss: 0.6995, Validation Loss: 1.7931


 58%|█████▊    | 116/200 [04:11<03:03,  2.18s/it]

Epoch 116/200, Training Loss: 0.7073, Validation Loss: 1.8030


 58%|█████▊    | 117/200 [04:13<03:00,  2.17s/it]

Epoch 117/200, Training Loss: 0.7002, Validation Loss: 1.7798


 59%|█████▉    | 118/200 [04:16<02:57,  2.17s/it]

Epoch 118/200, Training Loss: 0.7047, Validation Loss: 1.7966


 60%|█████▉    | 119/200 [04:18<02:55,  2.16s/it]

Epoch 119/200, Training Loss: 0.7054, Validation Loss: 1.8417


 60%|██████    | 120/200 [04:20<02:52,  2.15s/it]

Epoch 120/200, Training Loss: 0.6947, Validation Loss: 1.7936


 60%|██████    | 121/200 [04:22<02:50,  2.16s/it]

Epoch 121/200, Training Loss: 0.6977, Validation Loss: 1.7904


 61%|██████    | 122/200 [04:24<02:47,  2.14s/it]

Epoch 122/200, Training Loss: 0.7001, Validation Loss: 1.8443


 62%|██████▏   | 123/200 [04:26<02:45,  2.15s/it]

Epoch 123/200, Training Loss: 0.6983, Validation Loss: 1.8049


 62%|██████▏   | 124/200 [04:29<02:43,  2.15s/it]

Epoch 124/200, Training Loss: 0.7039, Validation Loss: 1.7957


 62%|██████▎   | 125/200 [04:31<02:40,  2.14s/it]

Epoch 125/200, Training Loss: 0.7066, Validation Loss: 1.8057


 63%|██████▎   | 126/200 [04:33<02:38,  2.14s/it]

Epoch 126/200, Training Loss: 0.6984, Validation Loss: 1.8591


 64%|██████▎   | 127/200 [04:35<02:37,  2.16s/it]

Epoch 127/200, Training Loss: 0.6990, Validation Loss: 1.8226


 64%|██████▍   | 128/200 [04:37<02:40,  2.23s/it]

Epoch 128/200, Training Loss: 0.7006, Validation Loss: 1.7872


 64%|██████▍   | 129/200 [04:40<02:37,  2.21s/it]

Epoch 129/200, Training Loss: 0.7020, Validation Loss: 1.8210


 65%|██████▌   | 130/200 [04:42<02:33,  2.20s/it]

Epoch 130/200, Training Loss: 0.7088, Validation Loss: 1.8344


 66%|██████▌   | 131/200 [04:44<02:31,  2.19s/it]

Epoch 131/200, Training Loss: 0.7079, Validation Loss: 1.8026


 66%|██████▌   | 132/200 [04:46<02:28,  2.18s/it]

Epoch 132/200, Training Loss: 0.7132, Validation Loss: 1.8468


 66%|██████▋   | 133/200 [04:48<02:26,  2.18s/it]

Epoch 133/200, Training Loss: 0.7055, Validation Loss: 1.7958


 67%|██████▋   | 134/200 [04:50<02:22,  2.16s/it]

Epoch 134/200, Training Loss: 0.7050, Validation Loss: 1.8408


 68%|██████▊   | 135/200 [04:53<02:20,  2.17s/it]

Epoch 135/200, Training Loss: 0.6963, Validation Loss: 1.7957


 68%|██████▊   | 136/200 [04:55<02:19,  2.17s/it]

Epoch 136/200, Training Loss: 0.7120, Validation Loss: 1.8008


 68%|██████▊   | 137/200 [04:57<02:16,  2.17s/it]

Epoch 137/200, Training Loss: 0.7111, Validation Loss: 1.8303


 69%|██████▉   | 138/200 [04:59<02:13,  2.16s/it]

Epoch 138/200, Training Loss: 0.7009, Validation Loss: 1.8055


 70%|██████▉   | 139/200 [05:01<02:11,  2.15s/it]

Epoch 139/200, Training Loss: 0.7038, Validation Loss: 1.8434


 70%|███████   | 140/200 [05:03<02:09,  2.15s/it]

Epoch 140/200, Training Loss: 0.7068, Validation Loss: 1.8351


 70%|███████   | 141/200 [05:05<02:07,  2.16s/it]

Epoch 141/200, Training Loss: 0.7023, Validation Loss: 1.8466


 71%|███████   | 142/200 [05:08<02:05,  2.17s/it]

Epoch 142/200, Training Loss: 0.7054, Validation Loss: 1.8129


 72%|███████▏  | 143/200 [05:10<02:08,  2.25s/it]

Epoch 143/200, Training Loss: 0.7013, Validation Loss: 1.8238


 72%|███████▏  | 144/200 [05:12<02:04,  2.22s/it]

Epoch 144/200, Training Loss: 0.6937, Validation Loss: 1.8443


 72%|███████▎  | 145/200 [05:14<02:00,  2.20s/it]

Epoch 145/200, Training Loss: 0.7049, Validation Loss: 1.8084


 73%|███████▎  | 146/200 [05:17<01:58,  2.19s/it]

Epoch 146/200, Training Loss: 0.6955, Validation Loss: 1.7965


 74%|███████▎  | 147/200 [05:19<01:55,  2.17s/it]

Epoch 147/200, Training Loss: 0.7039, Validation Loss: 1.8585


 74%|███████▍  | 148/200 [05:21<01:52,  2.17s/it]

Epoch 148/200, Training Loss: 0.7018, Validation Loss: 1.8089


 74%|███████▍  | 149/200 [05:23<01:49,  2.15s/it]

Epoch 149/200, Training Loss: 0.6981, Validation Loss: 1.8193


 75%|███████▌  | 150/200 [05:25<01:47,  2.16s/it]

Epoch 150/200, Training Loss: 0.7033, Validation Loss: 1.8554


 76%|███████▌  | 151/200 [05:27<01:46,  2.17s/it]

Epoch 151/200, Training Loss: 0.7039, Validation Loss: 1.8057


 76%|███████▌  | 152/200 [05:29<01:43,  2.15s/it]

Epoch 152/200, Training Loss: 0.7010, Validation Loss: 1.7867


 76%|███████▋  | 153/200 [05:32<01:40,  2.14s/it]

Epoch 153/200, Training Loss: 0.7033, Validation Loss: 1.8789


 77%|███████▋  | 154/200 [05:34<01:38,  2.13s/it]

Epoch 154/200, Training Loss: 0.7081, Validation Loss: 1.7880


 78%|███████▊  | 155/200 [05:36<01:36,  2.14s/it]

Epoch 155/200, Training Loss: 0.6962, Validation Loss: 1.8139


 78%|███████▊  | 156/200 [05:38<01:34,  2.14s/it]

Epoch 156/200, Training Loss: 0.7109, Validation Loss: 1.7866


 78%|███████▊  | 157/200 [05:40<01:32,  2.15s/it]

Epoch 157/200, Training Loss: 0.6998, Validation Loss: 1.8283


 79%|███████▉  | 158/200 [05:43<01:33,  2.22s/it]

Epoch 158/200, Training Loss: 0.7054, Validation Loss: 1.8772


 80%|███████▉  | 159/200 [05:45<01:30,  2.22s/it]

Epoch 159/200, Training Loss: 0.7065, Validation Loss: 1.8422


 80%|████████  | 160/200 [05:47<01:27,  2.19s/it]

Epoch 160/200, Training Loss: 0.7013, Validation Loss: 1.7853


 80%|████████  | 161/200 [05:49<01:25,  2.18s/it]

Epoch 161/200, Training Loss: 0.7019, Validation Loss: 1.8288


 81%|████████  | 162/200 [05:51<01:23,  2.19s/it]

Epoch 162/200, Training Loss: 0.6972, Validation Loss: 1.7825


 82%|████████▏ | 163/200 [05:53<01:20,  2.18s/it]

Epoch 163/200, Training Loss: 0.7003, Validation Loss: 1.7791


 82%|████████▏ | 164/200 [05:56<01:18,  2.18s/it]

Epoch 164/200, Training Loss: 0.7003, Validation Loss: 1.8250


 82%|████████▎ | 165/200 [05:58<01:15,  2.17s/it]

Epoch 165/200, Training Loss: 0.6968, Validation Loss: 1.8303


 83%|████████▎ | 166/200 [06:00<01:13,  2.16s/it]

Epoch 166/200, Training Loss: 0.7045, Validation Loss: 1.8089


 84%|████████▎ | 167/200 [06:02<01:11,  2.16s/it]

Epoch 167/200, Training Loss: 0.7031, Validation Loss: 1.7797


 84%|████████▍ | 168/200 [06:04<01:08,  2.15s/it]

Epoch 168/200, Training Loss: 0.6963, Validation Loss: 1.8065


 84%|████████▍ | 169/200 [06:06<01:06,  2.16s/it]

Epoch 169/200, Training Loss: 0.7024, Validation Loss: 1.8840


 85%|████████▌ | 170/200 [06:08<01:04,  2.15s/it]

Epoch 170/200, Training Loss: 0.7002, Validation Loss: 1.8015


 86%|████████▌ | 171/200 [06:11<01:02,  2.14s/it]

Epoch 171/200, Training Loss: 0.7021, Validation Loss: 1.7983


 86%|████████▌ | 172/200 [06:13<01:01,  2.20s/it]

Epoch 172/200, Training Loss: 0.7014, Validation Loss: 1.8322


 86%|████████▋ | 173/200 [06:15<00:59,  2.22s/it]

Epoch 173/200, Training Loss: 0.7080, Validation Loss: 1.7886


 87%|████████▋ | 174/200 [06:17<00:56,  2.19s/it]

Epoch 174/200, Training Loss: 0.7110, Validation Loss: 1.8106


 88%|████████▊ | 175/200 [06:19<00:54,  2.18s/it]

Epoch 175/200, Training Loss: 0.6961, Validation Loss: 1.8114


 88%|████████▊ | 176/200 [06:22<00:51,  2.16s/it]

Epoch 176/200, Training Loss: 0.6972, Validation Loss: 1.8499


 88%|████████▊ | 177/200 [06:24<00:49,  2.15s/it]

Epoch 177/200, Training Loss: 0.6996, Validation Loss: 1.8545


 89%|████████▉ | 178/200 [06:26<00:47,  2.15s/it]

Epoch 178/200, Training Loss: 0.7073, Validation Loss: 1.8181


 90%|████████▉ | 179/200 [06:28<00:45,  2.15s/it]

Epoch 179/200, Training Loss: 0.6965, Validation Loss: 1.8176


 90%|█████████ | 180/200 [06:30<00:42,  2.15s/it]

Epoch 180/200, Training Loss: 0.7054, Validation Loss: 1.7852


 90%|█████████ | 181/200 [06:32<00:40,  2.14s/it]

Epoch 181/200, Training Loss: 0.6978, Validation Loss: 1.8331


 91%|█████████ | 182/200 [06:34<00:38,  2.14s/it]

Epoch 182/200, Training Loss: 0.7019, Validation Loss: 1.8060


 92%|█████████▏| 183/200 [06:37<00:36,  2.15s/it]

Epoch 183/200, Training Loss: 0.7043, Validation Loss: 1.8204


 92%|█████████▏| 184/200 [06:39<00:34,  2.14s/it]

Epoch 184/200, Training Loss: 0.7004, Validation Loss: 1.8083


 92%|█████████▎| 185/200 [06:41<00:32,  2.14s/it]

Epoch 185/200, Training Loss: 0.6989, Validation Loss: 1.8221


 93%|█████████▎| 186/200 [06:43<00:30,  2.15s/it]

Epoch 186/200, Training Loss: 0.7044, Validation Loss: 1.8463


 94%|█████████▎| 187/200 [06:46<00:29,  2.26s/it]

Epoch 187/200, Training Loss: 0.7003, Validation Loss: 1.8129


 94%|█████████▍| 188/200 [06:48<00:26,  2.24s/it]

Epoch 188/200, Training Loss: 0.6996, Validation Loss: 1.8694


 94%|█████████▍| 189/200 [06:50<00:24,  2.22s/it]

Epoch 189/200, Training Loss: 0.6931, Validation Loss: 1.8332


 95%|█████████▌| 190/200 [06:52<00:21,  2.19s/it]

Epoch 190/200, Training Loss: 0.6901, Validation Loss: 1.8091


 96%|█████████▌| 191/200 [06:54<00:19,  2.17s/it]

Epoch 191/200, Training Loss: 0.6952, Validation Loss: 1.8381


 96%|█████████▌| 192/200 [06:56<00:17,  2.18s/it]

Epoch 192/200, Training Loss: 0.6982, Validation Loss: 1.8610


 96%|█████████▋| 193/200 [06:58<00:15,  2.17s/it]

Epoch 193/200, Training Loss: 0.6963, Validation Loss: 1.8416


 97%|█████████▋| 194/200 [07:01<00:12,  2.16s/it]

Epoch 194/200, Training Loss: 0.6951, Validation Loss: 1.8672


 98%|█████████▊| 195/200 [07:03<00:10,  2.15s/it]

Epoch 195/200, Training Loss: 0.7020, Validation Loss: 1.8236


 98%|█████████▊| 196/200 [07:05<00:08,  2.16s/it]

Epoch 196/200, Training Loss: 0.7061, Validation Loss: 1.8618


 98%|█████████▊| 197/200 [07:07<00:06,  2.15s/it]

Epoch 197/200, Training Loss: 0.6941, Validation Loss: 1.8169


 99%|█████████▉| 198/200 [07:09<00:04,  2.15s/it]

Epoch 198/200, Training Loss: 0.7022, Validation Loss: 1.8266


100%|█████████▉| 199/200 [07:11<00:02,  2.15s/it]

Epoch 199/200, Training Loss: 0.7032, Validation Loss: 1.8577


100%|██████████| 200/200 [07:14<00:00,  2.17s/it]

Epoch 200/200, Training Loss: 0.7026, Validation Loss: 1.8092





In [149]:
zinc_model_atom_emb.eval()
loss_fn = torch.nn.L1Loss()
with torch.no_grad():
    test_loss = 0.0
    for data in test_loader:
        data = data.to(device)
        out = zinc_model_atom_emb(data.x, data.edge_index, data.batch)
        loss = loss_fn(out, data.y)
        test_loss += loss.item()
    test_loss /= len(test_loader)            
    print(f"Test Loss: {test_loss:.4f}")

Test Loss: 1.9081
