# Model training in train set and prediction in validation set

reference link: https://wandb.ai/manan-goel/gnn-recommender/reports/Recommending-Amazon-Products-using-Graph-Neural-Networks-in-PyTorch-Geometric--VmlldzozMTA3MzYw

In [None]:
pip install torch_geometric

In [None]:
pip install pyvis

In [5]:
import pandas as pd
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric import utils
import torch_geometric as pyg
from tqdm.auto import tqdm
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import negative_sampling
from pyvis.network import Network
import matplotlib.pyplot as plt
from itertools import chain
from sklearn.metrics import accuracy_score

In [5]:
# from google.colab import drive
# drive.mount('/content/drive')

Mounted at /content/drive


### creating a sample set for model

In [6]:
# using 10000 nodes
graph = torch.load('/Users/lokesh/Downloads/Database_433_Project/DataProcessing/amazon0302.pt')
mask = np.zeros(graph.x.shape[0])
mask[:10000] = 1
mask = torch.tensor(mask == 1)
g = Data(x=graph.x[mask], edge_index=utils.subgraph(mask, graph.edge_index)[0])
torch.save(g, '/Users/lokesh/Downloads/Database_433_Project/gnn/model_graph.pt')

In [7]:
# Create a PyVis network
net = Network(height="750px", width="100%", bgcolor="#222222", font_color="white")

# Add nodes and edges to the PyVis network
for e in g.edge_index.T:
    src = e[0].item()
    dst = e[1].item()
    net.add_node(src)
    net.add_node(dst)
    net.add_edge(src, dst, value=0.1)
# Save the PyVis visualization to an HTML file
net.save_graph('/Users/lokesh/Downloads/Database_433_Project/gnn/model_graph.html')

### splitting the sample dataset into the training and validation set

In [8]:
#using RandomLinkSplit to split the graph nodes
graph = torch.load('/Users/lokesh/Downloads/Database_433_Project/gnn/model_graph.pt')
transform = RandomLinkSplit(num_val=0.1, num_test=0, split_labels=True)
train_data, val_data, test_data = transform(graph)

torch.save(train_data, '/Users/lokesh/Downloads/Database_433_Project/gnn/train.pt')
torch.save(val_data, '/Users/lokesh/Downloads/Database_433_Project/gnn/val.pt')

### Define the GNN model and link predictor

In [9]:
class GNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout):
        super(GNN, self).__init__()

        conv_model = pyg.nn.SAGEConv

        self.convs = nn.ModuleList()
        self.convs.append(conv_model(input_dim, hidden_dim))
        self.dropout = dropout
        self.num_layers = num_layers



        # Create num_layers GraphSAGE convs
        assert (self.num_layers >= 1), 'Number of layers is not >=1'
        for l in range(self.num_layers - 1):
            self.convs.append(conv_model(hidden_dim, hidden_dim))

        # post-message-passing processing
        self.post_mp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim), nn.Dropout(self.dropout),
            nn.Linear(hidden_dim, output_dim))

    def forward(self, x, edge_index):
        print(f"Input Shape: {x.shape}")
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.post_mp(x)
        print(f"Output Shape: {x.shape}")

        return x

In [10]:
class LinkPredictor(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(LinkPredictor, self).__init__()

        # Create linear layers
        self.lins = nn.ModuleList()
        self.lins.append(nn.Linear(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.lins.append(nn.Linear(hidden_channels, hidden_channels))
        self.lins.append(nn.Linear(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x_i, x_j):
        x = x_i * x_j
        for lin in self.lins[:-1]:
            x = lin(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        return torch.sigmoid(x)

### Train the model and output training loss

In [11]:
# the training data
train_graph = torch.load('/Users/lokesh/Downloads/Database_433_Project/gnn/train.pt')
# the setting
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optim_wd = 0
hidden_dim = 1024
dropout = 0.1
num_layers = 2
lr = 1e-5
node_emb_dim = 1
batch_size = 1024

train_graph = train_graph.to(device)

In [12]:
# the graph neural network that takes all the node embeddings as inputs to message pass and agregate
model = GNN(node_emb_dim, hidden_dim, hidden_dim, num_layers, dropout).to(device)
link_predictor = LinkPredictor(hidden_dim, hidden_dim, 1, num_layers + 1, dropout).to(device)

In [13]:
# define the train function
def train(model, link_predictor, x, edge_index, pos_train_edge, batch_size, optimizer):
    model.train()
    link_predictor.train()

    train_losses = []

    for edge_id in tqdm(pyg.loader.DataLoader(range(pos_train_edge.shape[0]), batch_size, shuffle=True), leave=True):
        optimizer.zero_grad()

        # Run message passing on the initial node features to get updated embeddings
        node_emb = model(x, edge_index)

        # Positive edges
        pos_edge = pos_train_edge[edge_id].T
        pos_pred = link_predictor(node_emb[pos_edge[0]], node_emb[pos_edge[1]])

        # Negative edges sampling
        neg_edge = negative_sampling(edge_index, num_nodes=x.shape[0], num_neg_samples=edge_id.shape[0], method='dense')
        neg_pred = link_predictor(node_emb[neg_edge[0]], node_emb[neg_edge[1]])

        # Compute the corresponding negative log likelihood loss on the positive and negative edges
        loss = -torch.log(pos_pred + 1e-15).mean() - torch.log(1 - neg_pred + 1e-15).mean()

        # Backpropagate and update parameters
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())

    return train_losses

In [14]:
# define the optimizer
optimizer = torch.optim.Adam(list(model.parameters()) + list(link_predictor.parameters()), lr=lr, weight_decay=optim_wd)

In [15]:
# Training loop
num_epochs = 30
train_losses = []

for epoch in range(num_epochs):
    epoch_loss = train(
        model,
        link_predictor,
        torch.tensor(train_graph.x).to(torch.float32).to(device),
        train_graph.edge_index,
        train_graph.pos_edge_label_index.T,
        batch_size,
        optimizer
    )
    train_losses.extend(epoch_loss)

    if epoch % 10 == 0:
        node_emb = model(torch.tensor(train_graph.x).to(torch.float32).to(device), train_graph.edge_index)
        pos_edge = train_graph.pos_edge_label_index.T
        pos_pred = link_predictor(node_emb[pos_edge[0]], node_emb[pos_edge[1]])
        print(f"Epoch {epoch}, Shape of node_emb: {node_emb.shape}, Shape of pos_pred: {pos_pred.shape}")

  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:16<09:58, 16.63s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:37<11:08, 19.09s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [01:05<13:09, 23.22s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [01:16<10:01, 18.24s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [01:26<08:14, 15.46s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [01:37<07:14, 14.00s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [01:46<06:04, 12.16s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [01:54<05:15, 10.90s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [02:02<04:39,  9.97s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [02:11<04:21,  9.67s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [02:20<04:07,  9.53s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [02:32<04:14, 10.17s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [02:42<04:06, 10.28s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [03:08<05:43, 14.95s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [03:23<05:30, 15.03s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [03:34<04:50, 13.85s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [03:42<04:01, 12.09s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [03:51<03:28, 10.95s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [03:58<02:58,  9.89s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [04:05<02:35,  9.14s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [04:12<02:15,  8.49s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [04:19<02:00,  8.01s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [04:27<01:50,  7.87s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [04:35<01:42,  7.87s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [04:45<01:41,  8.49s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [04:53<01:31,  8.33s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [05:00<01:21,  8.11s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [05:07<01:09,  7.75s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [05:14<01:00,  7.50s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [05:22<00:52,  7.50s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [05:29<00:44,  7.45s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [05:35<00:35,  7.14s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [05:42<00:28,  7.10s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [05:49<00:21,  7.07s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [06:00<00:16,  8.11s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [06:10<00:08,  8.69s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [06:15<00:00, 10.14s/it]


Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])
Epoch 0, Shape of node_emb: torch.Size([10000, 1024]), Shape of pos_pred: torch.Size([2, 1])


  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:05<03:24,  5.67s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:10<02:54,  4.98s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:14<02:41,  4.76s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:19<02:33,  4.64s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:23<02:27,  4.60s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:28<02:29,  4.84s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:34<02:28,  4.93s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:39<02:23,  4.94s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:43<02:16,  4.87s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:48<02:08,  4.77s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:52<02:01,  4.69s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [00:57<01:58,  4.73s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [01:02<01:52,  4.69s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [01:07<01:51,  4.85s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [01:12<01:45,  4.81s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [01:18<01:47,  5.13s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [01:22<01:41,  5.07s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:30<01:48,  5.71s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:35<01:40,  5.61s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:42<01:43,  6.07s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:47<01:29,  5.59s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [01:51<01:20,  5.35s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [01:59<01:22,  5.91s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [02:04<01:14,  5.77s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [02:09<01:06,  5.53s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [02:15<01:01,  5.58s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [02:20<00:54,  5.48s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [02:25<00:47,  5.23s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [02:30<00:41,  5.24s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [02:35<00:35,  5.05s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [02:41<00:32,  5.35s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [02:46<00:27,  5.40s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [02:51<00:20,  5.20s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [02:55<00:14,  4.92s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [03:00<00:09,  4.88s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [03:05<00:04,  4.92s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [03:09<00:00,  5.13s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:05<03:01,  5.05s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:09<02:44,  4.71s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:13<02:35,  4.56s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:18<02:30,  4.57s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:23<02:32,  4.77s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:29<02:35,  5.00s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:34<02:32,  5.08s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:39<02:23,  4.96s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:43<02:11,  4.71s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:47<02:03,  4.59s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:52<01:59,  4.60s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [00:56<01:55,  4.61s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [01:01<01:54,  4.79s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [01:06<01:50,  4.80s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [01:11<01:42,  4.64s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [01:15<01:36,  4.60s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [01:20<01:32,  4.64s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:25<01:30,  4.75s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:29<01:24,  4.69s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:34<01:20,  4.73s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:40<01:20,  5.04s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [01:44<01:12,  4.81s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [01:49<01:06,  4.72s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [01:54<01:04,  4.99s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [01:59<00:56,  4.75s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [02:03<00:51,  4.68s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [02:08<00:47,  4.71s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [02:12<00:41,  4.65s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [02:17<00:37,  4.71s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [02:22<00:33,  4.81s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [02:28<00:29,  4.98s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [02:33<00:25,  5.10s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [02:38<00:20,  5.10s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [02:43<00:15,  5.15s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [02:48<00:10,  5.04s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [02:53<00:04,  4.90s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [03:02<00:00,  4.93s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:05<03:17,  5.50s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:10<02:58,  5.10s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:15<02:55,  5.16s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:20<02:45,  5.02s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:25<02:40,  5.01s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:30<02:35,  5.01s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:35<02:27,  4.92s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:39<02:21,  4.89s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:45<02:21,  5.04s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:49<02:12,  4.92s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:55<02:13,  5.15s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [01:00<02:04,  4.98s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [01:05<02:04,  5.18s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [01:10<01:55,  5.03s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [01:15<01:49,  5.00s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [01:21<01:50,  5.28s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [01:28<01:55,  5.77s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:35<01:58,  6.22s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:41<01:51,  6.19s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:47<01:44,  6.14s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:54<01:39,  6.24s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [02:00<01:33,  6.25s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [02:06<01:26,  6.20s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [02:12<01:20,  6.21s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [02:19<01:16,  6.33s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [02:25<01:09,  6.34s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [02:33<01:08,  6.85s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [02:45<01:13,  8.17s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [02:53<01:06,  8.36s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [03:02<00:58,  8.41s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [03:09<00:47,  7.93s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [03:15<00:37,  7.56s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [03:21<00:28,  7.06s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [03:28<00:20,  6.84s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [03:34<00:13,  6.70s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [03:41<00:06,  6.88s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [03:47<00:00,  6.16s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:06<03:50,  6.41s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:13<04:00,  6.87s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:19<03:39,  6.45s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:25<03:24,  6.21s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:32<03:24,  6.39s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:39<03:28,  6.72s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:48<03:42,  7.42s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:57<03:47,  7.85s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [01:05<03:47,  8.11s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [01:14<03:40,  8.17s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [01:21<03:27,  7.97s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [01:29<03:16,  7.88s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [01:37<03:12,  8.02s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [01:46<03:09,  8.24s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [01:53<02:53,  7.87s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [02:00<02:39,  7.58s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [02:08<02:34,  7.72s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [02:14<02:19,  7.35s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [02:20<02:02,  6.80s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [02:27<01:57,  6.89s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [02:33<01:45,  6.57s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [02:39<01:37,  6.53s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [02:50<01:50,  7.91s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [02:56<01:35,  7.35s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [03:02<01:21,  6.78s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [03:09<01:17,  7.01s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [03:27<01:42, 10.21s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [03:35<01:26,  9.60s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [03:42<01:09,  8.71s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [03:49<00:58,  8.37s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [03:57<00:48,  8.12s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [04:04<00:38,  7.78s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [04:10<00:28,  7.22s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [04:15<00:20,  6.69s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [04:21<00:12,  6.49s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [04:26<00:05,  5.98s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [04:31<00:00,  7.34s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:04<02:53,  4.82s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:09<02:45,  4.74s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:14<02:44,  4.84s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:20<02:48,  5.12s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:25<02:44,  5.13s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:29<02:34,  4.98s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:35<02:40,  5.33s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:40<02:25,  5.02s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:44<02:17,  4.92s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:49<02:13,  4.94s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:54<02:06,  4.88s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [00:59<02:01,  4.85s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [01:03<01:53,  4.71s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [01:09<01:51,  4.86s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [01:13<01:46,  4.85s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [01:19<01:43,  4.94s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [01:23<01:37,  4.89s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:28<01:31,  4.81s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:33<01:28,  4.90s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:38<01:22,  4.87s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:43<01:17,  4.86s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [01:48<01:13,  4.91s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [01:53<01:08,  4.90s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [01:58<01:05,  5.07s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [02:03<00:59,  4.95s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [02:07<00:53,  4.86s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [02:13<00:49,  4.94s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [02:17<00:44,  4.92s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [02:22<00:38,  4.84s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [02:28<00:35,  5.06s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [02:32<00:29,  4.88s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [02:38<00:25,  5.10s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [02:43<00:21,  5.26s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [02:48<00:15,  5.09s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [02:53<00:09,  4.99s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [02:58<00:05,  5.11s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [03:03<00:00,  4.97s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:06<03:50,  6.40s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:12<03:41,  6.33s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:19<03:43,  6.58s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:25<03:25,  6.23s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:30<03:09,  5.93s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:34<02:45,  5.35s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:39<02:32,  5.08s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:44<02:27,  5.09s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:49<02:19,  4.99s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:55<02:21,  5.26s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [01:02<02:37,  6.05s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [01:08<02:26,  5.85s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [01:15<02:31,  6.31s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [01:20<02:14,  5.83s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [01:25<01:59,  5.44s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [01:34<02:22,  6.80s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [01:40<02:07,  6.37s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:45<01:56,  6.14s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:53<02:00,  6.69s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [02:02<02:01,  7.17s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [02:11<02:05,  7.82s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [02:18<01:52,  7.47s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [02:21<01:27,  6.22s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [02:24<01:08,  5.30s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [02:29<01:01,  5.14s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [02:33<00:54,  4.93s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [02:37<00:46,  4.64s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [02:41<00:37,  4.22s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [02:44<00:31,  3.88s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [02:47<00:26,  3.73s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [02:51<00:22,  3.67s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [02:55<00:18,  3.76s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [02:58<00:14,  3.57s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [03:01<00:10,  3.45s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [03:04<00:06,  3.33s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [03:07<00:03,  3.39s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [03:11<00:00,  5.16s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:04<02:38,  4.40s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:07<02:07,  3.65s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:10<01:57,  3.47s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:14<01:53,  3.44s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:17<01:46,  3.34s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:22<01:58,  3.81s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:27<02:05,  4.19s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:32<02:15,  4.67s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:39<02:25,  5.19s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:45<02:26,  5.43s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:52<02:33,  5.91s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [00:58<02:31,  6.07s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [01:04<02:23,  5.96s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [01:10<02:21,  6.16s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [01:15<02:07,  5.79s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [01:21<02:00,  5.75s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [01:25<01:46,  5.34s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:30<01:38,  5.19s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:35<01:30,  5.00s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:39<01:22,  4.86s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:44<01:18,  4.90s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [01:49<01:15,  5.01s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [01:56<01:16,  5.47s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [02:12<01:51,  8.55s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [02:17<01:31,  7.67s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [02:22<01:13,  6.65s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [02:26<01:00,  6.01s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [02:33<00:55,  6.21s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [02:37<00:43,  5.46s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [02:41<00:35,  5.05s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [02:44<00:27,  4.61s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [02:48<00:22,  4.50s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [02:52<00:17,  4.26s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [02:57<00:13,  4.35s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [03:00<00:08,  4.15s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [03:04<00:04,  4.10s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [03:08<00:00,  5.10s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:04<02:25,  4.03s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:07<02:14,  3.84s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:12<02:24,  4.26s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:16<02:15,  4.09s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:22<02:31,  4.72s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:28<02:46,  5.38s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:36<03:08,  6.27s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:46<03:35,  7.44s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:51<03:03,  6.56s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:55<02:33,  5.68s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:59<02:13,  5.15s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [01:02<01:57,  4.71s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [01:06<01:48,  4.50s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [01:10<01:40,  4.36s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [01:16<01:43,  4.73s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [01:20<01:32,  4.42s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [01:24<01:25,  4.28s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:28<01:19,  4.18s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:31<01:12,  4.01s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:35<01:07,  3.99s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:39<01:03,  3.94s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [01:43<00:59,  3.96s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [01:48<01:00,  4.34s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [01:54<01:03,  4.88s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [02:06<01:23,  6.93s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [02:12<01:12,  6.58s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [02:17<01:01,  6.19s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [02:22<00:51,  5.73s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [02:27<00:44,  5.55s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [02:35<00:43,  6.21s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [02:39<00:34,  5.71s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [02:43<00:25,  5.14s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [02:47<00:19,  4.80s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [02:51<00:13,  4.48s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [02:55<00:08,  4.30s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [02:59<00:04,  4.40s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [03:06<00:00,  5.03s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:05<03:05,  5.14s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:11<03:26,  5.91s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:22<04:36,  8.15s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:26<03:34,  6.51s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:30<03:03,  5.72s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:34<02:37,  5.09s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:38<02:16,  4.55s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:41<02:02,  4.22s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:44<01:49,  3.89s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:47<01:38,  3.65s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:51<01:31,  3.52s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [00:54<01:29,  3.57s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [00:57<01:23,  3.46s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [01:10<02:19,  6.07s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [01:18<02:26,  6.68s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [01:24<02:18,  6.58s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [01:31<02:13,  6.66s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:39<02:13,  7.01s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:46<02:07,  7.10s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:50<01:43,  6.08s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:54<01:27,  5.44s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [01:58<01:15,  5.00s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [02:04<01:15,  5.36s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [02:08<01:03,  4.87s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [02:11<00:54,  4.58s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [02:15<00:46,  4.25s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [02:19<00:40,  4.06s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [02:22<00:34,  3.84s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [02:25<00:29,  3.73s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [02:29<00:25,  3.70s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [02:32<00:21,  3.64s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [02:36<00:18,  3.74s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [02:40<00:14,  3.62s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [02:43<00:10,  3.48s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [02:47<00:07,  3.68s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [02:51<00:03,  3.72s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [02:54<00:00,  4.72s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:03<02:00,  3.35s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:06<02:02,  3.50s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:11<02:07,  3.76s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:14<01:57,  3.55s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:17<01:53,  3.54s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:21<01:55,  3.73s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:25<01:46,  3.57s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:28<01:39,  3.43s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:32<01:42,  3.65s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:35<01:36,  3.58s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:38<01:29,  3.45s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [00:42<01:27,  3.50s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [00:45<01:22,  3.43s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [00:49<01:18,  3.42s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [00:52<01:13,  3.35s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [00:55<01:11,  3.39s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [00:59<01:06,  3.31s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:02<01:06,  3.49s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:06<01:03,  3.55s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:10<01:01,  3.63s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:13<00:56,  3.54s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [01:17<00:53,  3.54s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [01:21<00:50,  3.63s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [01:24<00:45,  3.50s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [01:27<00:41,  3.46s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [01:31<00:39,  3.58s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [01:34<00:34,  3.47s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [01:37<00:30,  3.39s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [01:41<00:26,  3.32s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [01:44<00:23,  3.32s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [01:47<00:20,  3.36s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [01:51<00:16,  3.31s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [01:54<00:12,  3.25s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [01:57<00:09,  3.29s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [02:01<00:06,  3.39s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [02:06<00:03,  3.82s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [02:10<00:00,  3.53s/it]


Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])
Epoch 10, Shape of node_emb: torch.Size([10000, 1024]), Shape of pos_pred: torch.Size([2, 1])


  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:04<02:51,  4.76s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:07<02:14,  3.85s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:11<02:00,  3.56s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:14<01:54,  3.47s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:17<01:46,  3.33s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:20<01:43,  3.33s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:24<01:41,  3.40s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:27<01:36,  3.32s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:30<01:32,  3.30s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:34<01:31,  3.38s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:37<01:25,  3.30s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [00:40<01:21,  3.27s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [00:44<01:21,  3.41s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [00:47<01:16,  3.33s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [00:50<01:12,  3.30s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [00:54<01:10,  3.34s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [00:58<01:09,  3.46s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:01<01:05,  3.44s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:04<01:00,  3.34s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:07<00:55,  3.28s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:11<00:53,  3.33s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [01:14<00:49,  3.29s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [01:17<00:45,  3.25s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [01:21<00:43,  3.36s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [01:24<00:39,  3.31s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [01:27<00:36,  3.36s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [01:31<00:33,  3.35s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [01:34<00:30,  3.41s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [01:38<00:28,  3.51s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [01:41<00:23,  3.39s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [01:45<00:20,  3.44s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [01:48<00:16,  3.35s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [01:51<00:13,  3.39s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [01:54<00:09,  3.33s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [01:58<00:06,  3.41s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [02:01<00:03,  3.41s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [02:04<00:00,  3.38s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:04<02:44,  4.58s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:07<02:13,  3.80s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:11<01:59,  3.52s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:14<01:52,  3.41s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:17<01:50,  3.46s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:21<01:47,  3.46s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:25<01:55,  3.84s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:29<01:46,  3.68s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:33<01:43,  3.71s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:36<01:36,  3.56s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:39<01:33,  3.59s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [00:42<01:26,  3.44s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [00:46<01:22,  3.44s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [00:49<01:16,  3.35s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [00:53<01:14,  3.38s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [00:56<01:12,  3.43s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [00:59<01:06,  3.34s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:03<01:04,  3.37s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:06<01:00,  3.34s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:09<00:55,  3.26s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:12<00:51,  3.21s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [01:16<00:52,  3.49s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [01:20<00:48,  3.43s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [01:23<00:43,  3.38s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [01:26<00:41,  3.47s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [01:30<00:37,  3.40s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [01:33<00:34,  3.42s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [01:36<00:30,  3.35s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [01:40<00:27,  3.44s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [01:44<00:24,  3.54s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [01:47<00:20,  3.48s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [01:51<00:17,  3.48s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [01:54<00:13,  3.49s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [01:58<00:10,  3.49s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [02:01<00:07,  3.52s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [02:05<00:03,  3.62s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [02:08<00:00,  3.48s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:03<02:09,  3.58s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:07<02:03,  3.54s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:10<02:02,  3.61s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:14<01:57,  3.55s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:17<01:55,  3.62s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:21<01:53,  3.65s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:24<01:45,  3.51s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:28<01:38,  3.40s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:31<01:35,  3.41s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:34<01:31,  3.39s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:38<01:28,  3.39s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [00:41<01:23,  3.34s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [00:45<01:24,  3.51s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [00:50<01:31,  4.00s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [00:54<01:27,  3.99s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [00:59<01:30,  4.31s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [01:04<01:29,  4.47s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:09<01:30,  4.78s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:13<01:19,  4.41s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:18<01:17,  4.56s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:24<01:21,  5.10s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [01:31<01:24,  5.65s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [01:37<01:21,  5.83s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [01:45<01:22,  6.32s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [01:49<01:09,  5.79s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [01:54<00:59,  5.38s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [02:03<01:06,  6.66s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [02:08<00:53,  5.96s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [02:12<00:43,  5.39s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [02:16<00:34,  5.00s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [02:23<00:34,  5.68s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [02:29<00:28,  5.65s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [02:33<00:21,  5.37s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [02:37<00:14,  4.95s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [02:42<00:09,  4.84s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [02:47<00:04,  4.73s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [02:51<00:00,  4.63s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:06<04:09,  6.94s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:15<04:37,  7.93s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:19<03:22,  5.95s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:23<02:52,  5.23s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:27<02:30,  4.70s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:30<02:10,  4.21s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:33<01:56,  3.87s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:38<02:01,  4.18s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:46<02:32,  5.43s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:51<02:19,  5.18s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:55<02:11,  5.06s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [01:00<02:04,  4.96s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [01:07<02:15,  5.63s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [01:11<01:53,  4.93s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [01:14<01:40,  4.57s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [01:19<01:35,  4.52s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [01:23<01:31,  4.55s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:27<01:24,  4.42s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:33<01:24,  4.72s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:37<01:18,  4.63s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:41<01:11,  4.47s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [01:46<01:08,  4.55s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [01:51<01:06,  4.77s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [01:57<01:06,  5.15s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [02:03<01:03,  5.30s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [02:07<00:53,  4.90s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [02:11<00:45,  4.58s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [02:16<00:42,  4.75s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [02:20<00:37,  4.63s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [02:26<00:35,  5.03s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [02:31<00:29,  4.83s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [02:36<00:24,  4.82s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [02:41<00:19,  4.98s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [02:45<00:14,  4.84s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [02:50<00:09,  4.68s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [02:54<00:04,  4.70s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [02:59<00:00,  4.86s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:04<02:31,  4.22s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:08<02:39,  4.55s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:13<02:28,  4.38s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:18<02:38,  4.80s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:22<02:25,  4.56s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:27<02:18,  4.46s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:30<02:08,  4.28s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:34<01:55,  3.99s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:37<01:45,  3.76s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:40<01:36,  3.58s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:44<01:32,  3.57s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [00:47<01:26,  3.45s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [00:51<01:26,  3.59s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [00:55<01:28,  3.83s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [00:59<01:21,  3.72s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [01:02<01:18,  3.72s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [01:06<01:14,  3.72s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:20<02:10,  6.85s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:25<01:50,  6.15s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:30<01:38,  5.77s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:37<01:40,  6.30s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [01:42<01:27,  5.82s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [01:47<01:17,  5.52s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [01:52<01:10,  5.40s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [01:56<01:01,  5.09s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [02:01<00:53,  4.86s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [02:06<00:49,  4.91s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [02:10<00:43,  4.78s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [02:16<00:42,  5.25s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [02:22<00:38,  5.50s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [02:29<00:34,  5.73s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [02:34<00:27,  5.54s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [02:39<00:21,  5.39s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [02:56<00:26,  8.77s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [03:03<00:16,  8.41s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [03:09<00:07,  7.66s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [03:14<00:00,  5.25s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:06<03:45,  6.26s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:12<03:38,  6.25s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:17<03:12,  5.66s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:21<02:52,  5.21s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:27<02:46,  5.20s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:32<02:37,  5.09s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:45<03:52,  7.74s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:51<03:26,  7.12s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:54<02:49,  6.06s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:59<02:35,  5.78s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [01:05<02:24,  5.58s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [01:09<02:12,  5.29s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [01:13<01:54,  4.76s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [01:20<02:08,  5.57s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [01:26<02:04,  5.67s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [01:44<03:17,  9.38s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [01:50<02:48,  8.42s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:56<02:26,  7.72s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [02:02<02:07,  7.06s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [02:10<02:03,  7.25s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [02:16<01:53,  7.06s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [02:21<01:37,  6.52s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [02:26<01:24,  6.06s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [02:31<01:13,  5.68s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [02:50<01:55,  9.61s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [02:57<01:36,  8.77s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [03:02<01:18,  7.80s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [03:09<01:07,  7.53s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [03:15<00:56,  7.06s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [03:20<00:44,  6.32s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [03:27<00:39,  6.52s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [03:31<00:29,  5.91s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [03:36<00:22,  5.67s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [03:41<00:16,  5.51s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [03:46<00:10,  5.34s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [03:51<00:05,  5.16s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [03:56<00:00,  6.39s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:04<02:57,  4.94s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:12<03:39,  6.27s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:17<03:20,  5.90s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:24<03:27,  6.30s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:29<03:08,  5.89s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:34<02:53,  5.61s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:39<02:37,  5.26s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:43<02:26,  5.06s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:49<02:22,  5.08s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:52<02:04,  4.60s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:56<01:52,  4.33s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [01:00<01:46,  4.25s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [01:03<01:36,  4.04s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [01:07<01:28,  3.86s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [01:12<01:31,  4.17s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [01:15<01:23,  3.96s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [01:19<01:17,  3.86s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:22<01:11,  3.75s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:26<01:05,  3.66s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:29<01:01,  3.64s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:33<00:57,  3.58s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [01:36<00:53,  3.60s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [01:40<00:51,  3.68s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [01:44<00:48,  3.74s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [01:48<00:44,  3.73s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [01:51<00:40,  3.64s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [01:56<00:38,  3.88s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [02:05<00:48,  5.37s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [02:12<00:47,  5.96s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [02:18<00:41,  5.95s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [02:23<00:35,  5.83s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [02:31<00:32,  6.46s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [02:37<00:24,  6.14s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [02:42<00:17,  5.85s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [02:48<00:11,  5.83s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [02:53<00:05,  5.52s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [02:57<00:00,  4.80s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:04<02:56,  4.90s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:09<02:45,  4.72s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:14<02:44,  4.83s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:18<02:22,  4.33s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:21<02:11,  4.11s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:26<02:16,  4.42s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:34<02:44,  5.49s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:39<02:32,  5.25s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:44<02:23,  5.14s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:48<02:14,  4.97s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:54<02:16,  5.26s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [00:59<02:10,  5.22s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [01:04<02:05,  5.22s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [01:09<01:53,  4.93s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [01:13<01:46,  4.83s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [01:18<01:38,  4.71s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [01:25<01:50,  5.50s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:29<01:35,  5.04s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:33<01:25,  4.76s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:37<01:16,  4.53s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:41<01:07,  4.23s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [01:45<01:02,  4.19s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [01:49<00:59,  4.28s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [01:53<00:53,  4.14s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [01:57<00:47,  3.95s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [02:00<00:42,  3.88s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [02:04<00:37,  3.78s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [02:08<00:34,  3.88s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [02:12<00:30,  3.84s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [02:17<00:29,  4.27s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [02:21<00:24,  4.11s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [02:24<00:19,  3.99s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [02:28<00:16,  4.01s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [02:32<00:11,  3.92s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [02:36<00:07,  3.80s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [02:40<00:03,  3.81s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [02:44<00:00,  4.45s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:04<02:28,  4.11s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:07<02:09,  3.69s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:11<02:05,  3.70s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:14<01:58,  3.60s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:18<01:54,  3.58s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:21<01:49,  3.54s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:25<01:45,  3.51s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:30<01:55,  3.99s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:34<01:56,  4.15s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:39<01:57,  4.35s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:45<02:02,  4.73s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [00:49<01:56,  4.68s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [00:54<01:52,  4.67s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [00:58<01:41,  4.43s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [01:01<01:33,  4.23s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [01:05<01:25,  4.09s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [01:09<01:18,  3.90s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:12<01:11,  3.77s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:19<01:22,  4.60s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:23<01:16,  4.48s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:27<01:12,  4.52s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [01:32<01:07,  4.47s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [01:37<01:03,  4.55s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [01:41<00:59,  4.60s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [01:45<00:53,  4.44s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [01:50<00:49,  4.54s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [01:55<00:46,  4.66s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [02:01<00:45,  5.05s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [02:05<00:38,  4.81s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [02:10<00:33,  4.76s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [02:14<00:28,  4.68s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [02:19<00:23,  4.63s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [02:23<00:18,  4.61s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [02:28<00:13,  4.54s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [02:32<00:09,  4.57s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [02:37<00:04,  4.70s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [02:42<00:00,  4.38s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:04<02:49,  4.70s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:08<02:31,  4.32s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:13<02:29,  4.40s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:17<02:28,  4.51s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:22<02:20,  4.40s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:26<02:18,  4.47s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:31<02:14,  4.49s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:36<02:14,  4.65s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:40<02:08,  4.57s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:45<02:07,  4.74s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:50<02:04,  4.78s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [00:55<02:00,  4.82s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [00:59<01:52,  4.69s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [01:04<01:44,  4.53s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [01:09<01:45,  4.79s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [01:14<01:39,  4.74s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [01:18<01:32,  4.65s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:24<01:34,  5.00s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:29<01:29,  4.98s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:33<01:22,  4.87s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:38<01:18,  4.93s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [01:46<01:26,  5.78s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [01:51<01:17,  5.55s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [01:56<01:07,  5.16s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [02:00<01:00,  5.05s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [02:05<00:52,  4.81s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [02:09<00:47,  4.74s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [02:14<00:42,  4.67s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [02:19<00:38,  4.77s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [02:23<00:32,  4.68s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [02:28<00:27,  4.59s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [02:31<00:21,  4.27s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [02:34<00:15,  3.95s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [02:39<00:12,  4.07s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [02:43<00:08,  4.16s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [02:48<00:04,  4.30s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [02:52<00:00,  4.67s/it]


Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])
Epoch 20, Shape of node_emb: torch.Size([10000, 1024]), Shape of pos_pred: torch.Size([2, 1])


  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:05<03:03,  5.11s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:09<02:47,  4.78s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:13<02:33,  4.52s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:18<02:35,  4.72s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:23<02:34,  4.84s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:28<02:29,  4.82s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:33<02:22,  4.76s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:37<02:10,  4.52s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:42<02:08,  4.58s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:46<02:02,  4.54s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:51<01:57,  4.53s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [00:56<01:56,  4.67s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [01:00<01:53,  4.72s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [01:06<01:51,  4.86s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [01:09<01:39,  4.51s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [01:13<01:31,  4.36s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [01:17<01:24,  4.24s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:22<01:21,  4.27s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:26<01:18,  4.39s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:30<01:08,  4.06s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:34<01:08,  4.29s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [01:39<01:04,  4.28s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [01:44<01:04,  4.58s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [01:49<01:00,  4.64s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [01:53<00:56,  4.68s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [01:58<00:51,  4.68s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [02:02<00:45,  4.56s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [02:08<00:43,  4.79s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [02:12<00:38,  4.78s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [02:16<00:31,  4.55s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [02:21<00:27,  4.56s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [02:26<00:23,  4.61s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [02:31<00:18,  4.69s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [02:37<00:15,  5.07s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [02:41<00:09,  4.95s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [02:46<00:04,  4.85s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [02:51<00:00,  4.62s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:05<03:05,  5.15s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:11<03:31,  6.04s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:17<03:14,  5.73s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:22<03:00,  5.47s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:27<02:50,  5.33s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:31<02:35,  5.02s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:36<02:27,  4.91s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:40<02:17,  4.74s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:45<02:09,  4.63s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:49<02:01,  4.51s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:59<02:44,  6.34s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [01:05<02:30,  6.03s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [01:10<02:16,  5.69s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [01:15<02:05,  5.46s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [01:20<02:01,  5.53s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [01:25<01:49,  5.20s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [01:29<01:40,  5.04s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:35<01:40,  5.30s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:39<01:28,  4.90s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:44<01:22,  4.88s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:48<01:15,  4.69s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [01:53<01:12,  4.83s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [01:59<01:12,  5.14s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [02:04<01:05,  5.02s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [02:08<00:56,  4.68s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [02:12<00:48,  4.42s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [02:15<00:41,  4.16s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [02:22<00:43,  4.78s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [02:27<00:38,  4.83s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [02:31<00:32,  4.64s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [02:36<00:29,  4.90s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [02:40<00:23,  4.68s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [02:44<00:17,  4.31s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [02:47<00:12,  4.05s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [02:51<00:07,  3.95s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [02:54<00:03,  3.81s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [02:58<00:00,  4.82s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:04<02:35,  4.31s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:07<02:12,  3.78s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:11<02:01,  3.58s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:14<01:51,  3.39s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:17<01:52,  3.51s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:23<02:11,  4.26s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:29<02:23,  4.78s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:35<02:34,  5.34s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:40<02:20,  5.01s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:44<02:05,  4.63s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:47<01:51,  4.27s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [00:50<01:40,  4.03s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [00:54<01:30,  3.79s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [00:57<01:24,  3.68s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [01:00<01:18,  3.57s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [01:04<01:15,  3.57s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [01:07<01:09,  3.49s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:10<01:04,  3.38s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:14<00:59,  3.30s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:17<00:58,  3.44s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:20<00:53,  3.34s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [01:24<00:50,  3.39s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [01:27<00:46,  3.35s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [01:31<00:44,  3.41s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [01:34<00:41,  3.48s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [01:38<00:37,  3.37s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [01:41<00:33,  3.32s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [01:44<00:30,  3.37s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [01:47<00:26,  3.31s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [01:51<00:22,  3.26s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [01:54<00:19,  3.30s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [01:57<00:16,  3.33s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [02:01<00:13,  3.47s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [02:05<00:10,  3.54s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [02:08<00:07,  3.54s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [02:11<00:03,  3.41s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [02:15<00:00,  3.65s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:03<02:03,  3.44s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:06<01:58,  3.38s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:10<02:01,  3.56s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:14<01:56,  3.52s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:17<01:51,  3.48s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:20<01:47,  3.46s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:24<01:43,  3.45s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:27<01:40,  3.47s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:31<01:36,  3.45s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:34<01:35,  3.55s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:38<01:30,  3.49s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [00:41<01:27,  3.51s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [00:45<01:21,  3.40s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [00:48<01:20,  3.48s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [00:51<01:14,  3.39s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [00:55<01:12,  3.45s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [00:58<01:08,  3.43s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:02<01:03,  3.35s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:05<01:00,  3.34s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:09<00:58,  3.46s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:15<01:07,  4.21s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [01:18<01:00,  4.07s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [01:21<00:53,  3.79s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [01:25<00:46,  3.59s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [01:28<00:42,  3.53s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [01:31<00:37,  3.42s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [01:34<00:33,  3.39s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [01:38<00:29,  3.32s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [01:41<00:26,  3.34s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [01:44<00:23,  3.29s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [01:48<00:19,  3.32s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [01:51<00:16,  3.28s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [01:54<00:13,  3.31s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [01:58<00:10,  3.35s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [02:01<00:06,  3.41s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [02:04<00:03,  3.32s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [02:07<00:00,  3.45s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:03<02:13,  3.70s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:07<02:01,  3.47s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:10<01:51,  3.29s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:13<01:52,  3.42s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:17<01:52,  3.53s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:21<01:58,  3.81s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:29<02:30,  5.02s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:33<02:21,  4.88s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:38<02:13,  4.75s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:42<02:04,  4.60s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:46<01:57,  4.51s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [00:51<01:50,  4.43s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [00:55<01:43,  4.32s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [00:59<01:39,  4.31s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [01:03<01:34,  4.29s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [01:08<01:32,  4.40s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [01:12<01:28,  4.42s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:17<01:23,  4.38s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:21<01:18,  4.35s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:25<01:14,  4.41s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:30<01:10,  4.44s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [01:35<01:08,  4.53s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [01:39<01:03,  4.56s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [01:44<00:58,  4.49s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [01:48<00:54,  4.56s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [01:54<00:52,  4.80s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [01:58<00:47,  4.76s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [02:03<00:42,  4.67s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [02:07<00:36,  4.60s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [02:12<00:31,  4.55s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [02:16<00:26,  4.49s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [02:20<00:22,  4.44s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [02:25<00:17,  4.48s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [02:29<00:13,  4.47s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [02:34<00:08,  4.45s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [02:38<00:04,  4.41s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [02:43<00:00,  4.41s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:04<02:45,  4.60s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:09<02:50,  4.89s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:14<02:45,  4.88s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:19<02:37,  4.78s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:24<02:41,  5.03s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:29<02:39,  5.13s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:34<02:27,  4.93s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:38<02:17,  4.74s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:43<02:10,  4.66s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:55<03:12,  7.13s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [01:00<02:41,  6.22s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [01:03<02:14,  5.38s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [01:06<01:53,  4.74s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [01:10<01:41,  4.40s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [01:14<01:31,  4.17s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [01:19<01:35,  4.54s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [01:24<01:32,  4.64s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:31<01:39,  5.24s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:36<01:34,  5.26s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:40<01:24,  4.96s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:46<01:26,  5.39s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [01:55<01:35,  6.40s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [02:00<01:22,  5.92s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [02:09<01:30,  6.96s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [02:19<01:33,  7.79s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [02:25<01:18,  7.15s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [02:31<01:09,  6.90s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [02:35<00:55,  6.11s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [02:40<00:44,  5.56s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [02:45<00:37,  5.39s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [02:51<00:33,  5.63s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [02:56<00:27,  5.44s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [03:02<00:22,  5.52s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [03:06<00:15,  5.19s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [03:10<00:09,  4.95s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [03:15<00:04,  4.96s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [03:21<00:00,  5.45s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:06<04:01,  6.71s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:11<03:07,  5.35s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:15<02:52,  5.09s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:21<03:00,  5.48s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:26<02:48,  5.27s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:31<02:41,  5.22s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:36<02:34,  5.14s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:41<02:24,  4.98s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:45<02:09,  4.62s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:49<01:58,  4.39s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:53<01:49,  4.23s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [00:57<01:45,  4.21s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [01:01<01:38,  4.12s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [01:05<01:34,  4.11s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [01:09<01:29,  4.09s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [01:14<01:35,  4.53s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [01:18<01:27,  4.38s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:24<01:29,  4.68s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:30<01:32,  5.15s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:35<01:25,  5.06s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [01:39<01:16,  4.76s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [01:43<01:07,  4.47s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [01:47<01:00,  4.29s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [01:50<00:53,  4.13s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [01:54<00:48,  4.02s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [01:59<00:45,  4.13s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [02:02<00:40,  4.05s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [02:08<00:40,  4.49s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [02:12<00:35,  4.50s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [02:16<00:30,  4.30s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [02:20<00:25,  4.19s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [02:27<00:25,  5.00s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [02:32<00:20,  5.09s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [02:37<00:14,  4.89s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [02:41<00:09,  4.76s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [02:46<00:04,  4.70s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [02:50<00:00,  4.61s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:06<03:45,  6.27s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:09<02:43,  4.68s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:12<02:11,  3.86s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:17<02:12,  4.02s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:20<02:07,  3.98s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:24<02:01,  3.91s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:28<01:56,  3.89s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:32<01:52,  3.88s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:36<01:48,  3.86s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:40<01:46,  3.95s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:44<01:42,  3.95s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 32%|███▏      | 12/37 [00:48<01:38,  3.95s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 35%|███▌      | 13/37 [00:52<01:33,  3.91s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 38%|███▊      | 14/37 [00:58<01:45,  4.60s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 41%|████      | 15/37 [01:01<01:33,  4.27s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 43%|████▎     | 16/37 [01:04<01:20,  3.84s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 46%|████▌     | 17/37 [01:07<01:10,  3.52s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 49%|████▊     | 18/37 [01:10<01:05,  3.44s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 51%|█████▏    | 19/37 [01:13<00:59,  3.30s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 54%|█████▍    | 20/37 [01:17<01:00,  3.55s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 57%|█████▋    | 21/37 [1:10:27<5:32:48, 1248.06s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 59%|█████▉    | 22/37 [1:10:53<3:40:18, 881.26s/it] 

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 62%|██████▏   | 23/37 [1:11:02<2:24:34, 619.62s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 65%|██████▍   | 24/37 [1:11:08<1:34:22, 435.59s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 68%|██████▊   | 25/37 [1:11:14<1:01:18, 306.55s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 70%|███████   | 26/37 [1:17:26<59:48, 326.24s/it]  

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 73%|███████▎  | 27/37 [1:17:33<38:23, 230.32s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 76%|███████▌  | 28/37 [1:17:54<25:08, 167.61s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 78%|███████▊  | 29/37 [1:17:59<15:49, 118.72s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 81%|████████  | 30/37 [1:18:03<09:50, 84.33s/it] 

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 84%|████████▍ | 31/37 [1:18:06<05:59, 59.99s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 86%|████████▋ | 32/37 [1:18:09<03:34, 42.91s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 89%|████████▉ | 33/37 [1:18:12<02:04, 31.09s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 92%|█████████▏| 34/37 [1:18:16<01:08, 22.97s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 95%|█████████▍| 35/37 [1:18:20<00:34, 17.18s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 97%|█████████▋| 36/37 [1:18:24<00:13, 13.16s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


100%|██████████| 37/37 [1:34:58<00:00, 154.01s/it]
  0%|          | 0/37 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  3%|▎         | 1/37 [00:05<03:04,  5.13s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  5%|▌         | 2/37 [00:09<02:53,  4.97s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  8%|▊         | 3/37 [00:14<02:44,  4.84s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 11%|█         | 4/37 [00:21<03:11,  5.81s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 14%|█▎        | 5/37 [00:25<02:37,  4.91s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 16%|█▌        | 6/37 [00:28<02:15,  4.36s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 19%|█▉        | 7/37 [00:32<02:04,  4.15s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 22%|██▏       | 8/37 [00:35<01:51,  3.85s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 24%|██▍       | 9/37 [00:38<01:41,  3.62s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 27%|██▋       | 10/37 [00:41<01:34,  3.49s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


 30%|██▉       | 11/37 [00:44<01:26,  3.32s/it]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


In [1]:
# Plotting the training losses
plt.plot(train_losses, label='Training Loss')
plt.xlabel('Training Step')
plt.ylabel('Loss')
plt.legend()
plt.show()

NameError: name 'plt' is not defined

### accuracy and prediction in validation set

In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

val_graph = torch.load('/content/drive/MyDrive/Graph /val.pt')
val_graph = val_graph.to(device)

In [16]:
def validate(model, link_predictor, x, edge_index, pos_val_edge, batch_size, max_steps=None):
    model.eval()
    link_predictor.eval()

    val_losses = []
    predictions = []
    ground_truth_labels = []

    # Determine the number of steps for validation
    num_steps = min(len(pos_val_edge), max_steps) if max_steps else len(pos_val_edge)

    if num_steps == 0:
        print("No steps for validation.")
        return val_losses, predictions, ground_truth_labels

    for step in tqdm(range(num_steps), leave=True):
        edge_id = step % len(pos_val_edge)

        # No need to zero_grad or backpropagate during validation
        with torch.no_grad():
            # Run message passing on the initial node features to get updated embeddings
            node_emb = model(x, edge_index)

            # Positive edges
            pos_edge = pos_val_edge[edge_id].T
            pos_pred = link_predictor(node_emb[pos_edge[0]], node_emb[pos_edge[1]])

            # Negative edges sampling (for validation, you might want to use all possible negative edges)
            neg_edge = negative_sampling(edge_index, num_nodes=x.shape[0], num_neg_samples=x.shape[0] - len(pos_val_edge), method='dense')
            neg_pred = link_predictor(node_emb[neg_edge[0]], node_emb[neg_edge[1]])

            # Compute the corresponding negative log likelihood loss on the positive and negative edges
            loss = -torch.log(pos_pred + 1e-15).mean() - torch.log(1 - neg_pred + 1e-15).mean()

            val_losses.append(loss.item())

            # Store predictions and ground truth labels for later evaluation
            predictions.extend(torch.sigmoid(pos_pred).cpu().numpy().flatten().tolist())
            ground_truth_labels.extend([1] * len(pos_pred))
            predictions.extend(torch.sigmoid(neg_pred).cpu().numpy().flatten().tolist())
            ground_truth_labels.extend([0] * len(neg_pred))

    return val_losses, predictions, ground_truth_labels

In [17]:
max_validation_steps = 50
val_losses, val_predictions, val_ground_truth = validate(
    model,
    link_predictor,
    torch.tensor(val_graph.x).to(torch.float32).to(device),
    val_graph.edge_index,
    val_graph.pos_edge_label_index.T,
    batch_size,
    max_steps=max_validation_steps
)

  0%|          | 0/50 [00:00<?, ?it/s]

Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])
Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])
Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])
Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])


  pos_edge = pos_val_edge[edge_id].T


Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])
Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])
Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])
Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])
Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])
Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])
Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])
Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])
Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])
Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])
Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])
Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])
Input Shape: torch.Size([10000, 1])
Output Shape: torch.Size([10000, 1024])
Input Shape:

In [18]:
# Calculate accuracy
val_accuracy = accuracy_score(val_ground_truth, np.round(val_predictions))

# Output validation loss and accuracy
print(f"val Loss: {np.mean(val_losses)}")
print(f"val Accuracy: {val_accuracy}")

val Loss: 1.282436707019806
val Accuracy: 0.0031622317596566524


### using a small graph to show the prediction

In [19]:
#using a small graph
graph = torch.load('/content/drive/MyDrive/Graph /amazon0302.pt')
mask = np.zeros(graph.x.shape[0])
mask[10000:10500] = 1
mask = torch.tensor(mask == 1)

small_graph= Data(x=graph.x[mask], edge_index=utils.subgraph(mask, graph.edge_index)[0])
small_graph.edge_index = small_graph.edge_index - 10000

In [20]:
#adding metadata
metadata= {}
product_data={}
with open('/content/drive/MyDrive/Graph /amazon-meta.txt', 'r', encoding="utf8") as file:
    for _ in range(2):
        next(file)
    for line in file:
        line = line.strip()
        #i=i+1
        #if i>100:
        #    break
        if line:
            try:
                key, value = map(str.strip, line.split(':', 1))
                product_data[key] = value
            except Exception:  #when only value, no key, then pass
                key=''
                value=''
                pass
            #print(key,',', value)
        else:                                              #An empty line indicates the end of one product's data
            if product_data:
                product_id = product_data.get('Id')
                if product_id:
                    metadata[product_id] = product_data
                product_data = {}                          #Reset the current_product dictionary for the next product

In [21]:

# Initialize PyVis network
net = Network(height="750px", width="100%", bgcolor="#222222", font_color="white")
k=0
for e in tqdm(small_graph.edge_index.T):
    k=k+1
    src = e[0].item()
    dst = e[1].item()

    # Get predicted labels for the nodes
    src_predicted_label = val_predictions[src]
    dst_predicted_label = val_predictions[dst]

    # Extract metadata for each node
    src_metadata = metadata.get(str(src), {})
    dst_metadata = metadata.get(str(dst), {})
    if k==1:
        print(src_metadata['group'])
        print(dst_metadata)

    # Add nodes with predicted labels and metadata
    #src_title = f"Node {src}\nPredicted Label: {src_predicted_label:.4f}\n\nMetadata:\n{src_metadata.get('group', 'N/A')}"
    #dst_title = f"Node {dst}\nPredicted Label: {dst_predicted_label:.4f}\n\nMetadata:\n{dst_metadata.get('group', 'N/A')}"
    src_title = f"Node {src}\nMetadata:\n{src_metadata.get('group', 'N/A')}"
    dst_title = f"Node {dst}\nMetadata:\n{dst_metadata.get('group', 'N/A')}"

    net.add_node(src, label=src_title, title=src_title)
    net.add_node(dst, label=dst_title, title=dst_title)

    net.add_edge(src, dst, value=0.1)

# Save the PyVis visualization to an HTML file
net.show("/content/drive/MyDrive/Graph /smaller_graph_with_predictions_and_metadata.html", notebook=False)

  0%|          | 0/470 [00:00<?, ?it/s]

Book
{'Id': '3', 'ASIN': '0486287785', 'title': 'World War II Allied Fighter Planes Trading Cards', 'group': 'Book', 'salesrank': '1270652', 'similar': '0', 'categories': '1', 'reviews': 'total: 1  downloaded: 1  avg rating: 5', '2003-7-10  cutomer': 'A3IDGASRQAW8B2  rating: 5  votes:   2  helpful:   2'}
/content/drive/MyDrive/Graph /smaller_graph_with_predictions_and_metadata.html
