In [1]:
import pandas as pd
import os
import torch
from model.database_util import *

In [2]:
import torch
from model.util import Normalizer

# cost_norm = Normalizer(1, 100)
# cost_norm = Normalizer(-3.61192, 12.290855)
#cost_norm = Normalizer(5, 2611)
cost_norm = Normalizer(8.26, 11.12)

## Define Model

In [3]:
class Args:
    # bs = 1024
    # SQ: smaller batch size
    bs = 64
    # bs = 10
    #lr = 0.001
    lr = 0.001
    # epochs = 200
    epochs = 50
    clip_size = 50
    embed_size = 64
    pred_hid = 128
    ffn_dim = 128
    head_size = 12
    n_layers = 8
    dropout = 0.1
    sch_decay = 0.6
    # device = 'cuda:0'
    device = 'cpu'
    newpath = 'job_queries_training'
    to_predict = 'cost'
args = Args()

import os
if not os.path.exists(args.newpath):
    os.makedirs(args.newpath)

In [4]:
from model.model import QueryFormer

model = QueryFormer(emb_size = args.embed_size ,ffn_dim = args.ffn_dim, head_size = args.head_size, \
                 dropout = args.dropout, n_layers = args.n_layers, \
                 use_sample = False, use_hist = False, \
                 pred_hid = args.pred_hid
                )

In [5]:
from model.dataset import PlanTreeDataset

## TEST - Loading 10 tensors

In [6]:
import os
import torch
import numpy as np
from sklearn.model_selection import train_test_split
from model.dataset import PlanTreeDataset  # Assuming PlanTreeDataset is defined elsewhere
import json

# Path to the tensors folder
tensors_dir = "./job_queries/tensors"

# Validate if the tensors directory exists
if not os.path.exists(tensors_dir):
    raise FileNotFoundError(f"Tensors directory '{tensors_dir}' not found.")

# Get all tensor file paths from the directory
tensor_files = sorted(os.listdir(tensors_dir))

# Validate if tensor files exist in the directory
if not tensor_files:
    raise FileNotFoundError(f"No tensor files found in '{tensors_dir}'.")

# Initialize lists to store tensor components
x_list = []
rel_pos_list = []
attn_bias_list = []
heights_list = []
cost_labels_list = []
raw_costs_list = []

# Load tensors dynamically
for tensor_file in tensor_files:
    tensor_path = os.path.join(tensors_dir, tensor_file)
    try:
        loaded_tensors = torch.load(tensor_path)
        # Append components to respective lists
        x_list.append(loaded_tensors["x"])
        rel_pos_list.append(loaded_tensors["rel_pos"])
        attn_bias_list.append(loaded_tensors["attn_bias"])
        heights_list.append(loaded_tensors["heights"])
        cost_labels_list.append(loaded_tensors["cost_labels"])
        raw_costs_list.append(loaded_tensors["raw_costs"])
    except Exception as e:
        print(f"Error loading tensor file '{tensor_file}': {e}")

# Generate indices for splitting
num_examples = len(x_list)
if num_examples == 0:
    raise ValueError("No valid tensors loaded for dataset creation.")
    
all_indices = np.arange(num_examples)

# Perform train-validation split with fixed seed for reproducibility
train_indices, val_indices = train_test_split(all_indices, test_size=0.2, random_state=0)

# Create training and validation datasets
train_dataset = PlanTreeDataset(
    len(train_indices),
    [x_list[i] for i in train_indices],
    [attn_bias_list[i] for i in train_indices],
    [rel_pos_list[i] for i in train_indices],
    [heights_list[i] for i in train_indices],
    [cost_labels_list[i] for i in train_indices],
    [raw_costs_list[i] for i in train_indices]
)

val_dataset = PlanTreeDataset(
    len(val_indices),
    [x_list[i] for i in val_indices],
    [attn_bias_list[i] for i in val_indices],
    [rel_pos_list[i] for i in val_indices],
    [heights_list[i] for i in val_indices],
    [cost_labels_list[i] for i in val_indices],
    [raw_costs_list[i] for i in val_indices]
)

# Save validation indices and file names
val_file_names = [tensor_files[i] for i in val_indices]
val_data = {
    "val_indices": val_indices.tolist(),
    "file_names": val_file_names,
}

val_data_file = "./val_data.json"
with open(val_data_file, "w") as f:
    json.dump(val_data, f)

# Save training indices and file names
train_file_names = [tensor_files[i] for i in train_indices]
train_data = {
    "train_indices": train_indices.tolist(),
    "file_names": train_file_names,
}

train_data_file = "./train_data.json"
with open(train_data_file, "w") as f:
    json.dump(train_data, f)

# Print information
print("Training Dataset length:", len(train_dataset))
print("Validation Dataset length:", len(val_dataset))
print(f"Validation data saved to {val_data_file}")
print(f"Training data saved to {train_data_file}")

Training Dataset length: 1865
Validation Dataset length: 467
Validation data saved to ./val_data.json
Training data saved to ./train_data.json


In [7]:
import pandas as pd
import json

# File paths
val_data_file = "./val_data.json"
train_data_file = "./train_data.json"

# Load validation data from JSON file
with open(val_data_file, "r") as f:
    loaded_val_data = json.load(f)

loaded_val_indices = loaded_val_data["val_indices"]
loaded_val_file_names = loaded_val_data["file_names"]

# Create a DataFrame for validation data
val_data_df = pd.DataFrame({
    "val_index": loaded_val_indices,
    "filename": loaded_val_file_names
})

# Display the Validation DataFrame
print("Validation DataFrame:")
print(val_data_df.head())

# Save validation data to CSV
val_data_df.to_csv("./val_data.csv", index=False)
print("Validation data saved to ./val_data.csv")

# Load training data from JSON file
with open(train_data_file, "r") as f:
    loaded_train_data = json.load(f)

loaded_train_indices = loaded_train_data["train_indices"]
loaded_train_file_names = loaded_train_data["file_names"]

# Create a DataFrame for training data
train_data_df = pd.DataFrame({
    "train_index": loaded_train_indices,
    "filename": loaded_train_file_names
})

# Display the Training DataFrame
print("Training DataFrame:")
print(train_data_df.head())

# Save training data to CSV
train_data_df.to_csv("./train_data.csv", index=False)
print("Training data saved to ./train_data.csv")


Validation DataFrame:
   val_index                                  filename
0       1056  query_1952_2024-12-03-21.34.02.429279.pt
1       2323   query_993_2024-12-03-21.10.08.618148.pt
2        655  query_1591_2024-12-03-21.23.17.861948.pt
3       2101   query_793_2024-12-03-21.05.49.462784.pt
4       1401  query_2262_2024-12-03-21.44.00.539251.pt
Validation data saved to ./val_data.csv
Training DataFrame:
   train_index                                  filename
0          812  query_1732_2024-12-03-21.27.45.652122.pt
1          233  query_1210_2024-12-03-21.14.50.599373.pt
2         1380  query_2243_2024-12-03-21.43.30.935782.pt
3         1752   query_479_2024-12-03-20.59.07.992844.pt
4         1334  query_2201_2024-12-03-21.42.07.400700.pt
Training data saved to ./train_data.csv


In [8]:
len(train_dataset)

1865

In [9]:
len(train_dataset[:10])

10

In [10]:
# Example numpy label
import numpy as np
import torch.nn as nn
import importlib

from model import trainer
importlib.reload(trainer)
from  model.trainer import train_single, train


crit = nn.MSELoss()

# Train the model with the numpy label
# trained_model = train_single(model, dataset, dataset, crit, cost_norm, args)
model, best_model_path, train_embeddings, val_embeddings = train(model, train_dataset, val_dataset, crit, cost_norm, args)


running epoch: 0


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


Epoch: 0  Avg Loss: 0.0016779799743408172, Time: 123.62960648536682
Median: 1.9086097045622883
Mean: 2.5780137706422686
running epoch: 1


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 2


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 3


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 4


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 5


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 6


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 7


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 8


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 9


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 10


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 11


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 12


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 13


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 14


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 15


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 16


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 17


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 18


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 19


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 20


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


Epoch: 20  Avg Loss: 0.00010610695865289894, Time: 5404.156476259232
Median: 1.072034101979081
Mean: 1.1689281099608375
running epoch: 21


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 22


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 23


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 24


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 25


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 26


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 27


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 28


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 29


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 30


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 31


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 32


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 33


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 34


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 35


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 36


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 37


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 38


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 39


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


running epoch: 40


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


Epoch: 40  Avg Loss: 9.93397390225218e-05, Time: 10057.846321821213
Median: 1.0537982382824767
Mean: 1.1580006845732362
running epoch: 41


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([19, 1])
ps: [41566.223 12691.133 41567.69  11989.226 67215.79  12957.274 67199.
 60396.934 41562.816 11987.314 67207.14  11987.235 60396.824 60398.895
 67216.12  60398.434 41563.883 67216.69  60398.32  23163.676 60397.516
 41564.363 15488.422 60395.266 67215.85  67226.625 11987.875 67215.67
 12669.886 67216.18  60397.168 60397.688 41564.32  60397.688 67242.016
 16652.63  67217.336 41565.348 11986.915 60397.688 60397.75  67199.
 67318.12  60397.23  60394.52  41562.258 67215.54  41562.652 60395.555
 14346.281 11987.853 11987.063 12954.101 60396.934 41563.926 41564.277
 60398.38  67216.375 67230.09  67215.67  67214.9   11987.979 41563.96
 11987.818 60396.59  67215.02  60398.32  67215.6   11982.264 11986.366
 67203.3   67215.85  41566.418 11986.572 41566.855 60397.4   11987.27
 67216.88  12951.246 67211.95  67199.516 12955.36  11988.893 12957.547
 11983.155 14348.349 60398.84  67200.086 60397.63  11987.807 67215.92
 11987.784 11987.59  67199.77  23163.543 119

pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([19, 1])
running epoch: 42


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([19, 1])
ps: [38184.656  11775.601  38185.344  10803.972  67068.01   11977.773
 67069.734  58417.816  38181.85   10803.157  67066.47   10802.745
 58418.2    58420.71   67067.43   58420.434  38182.8    67066.47
 58420.04   23802.004  58420.266  38184.11   14682.657  58414.523
 67066.02   67065.07   10803.302  67070.375  11768.393  67068.33
 58418.09   58420.375  38182.434  58420.152  67062.83   14881.453
 67070.88   38184.07   10802.734  58418.7    58418.145  67067.88
 67061.74   58418.312  58414.188  38181.418  67067.18   38181.49
 58415.977  13110.1875 10803.209  10802.55   11970.488  58417.977
 38181.996  38183.344  58418.594  67069.67   67070.56   67066.47
 67062.19   10803.528  38183.2    10803.385  58416.645  67068.52
 58419.87   67062.89   10797.802  10801.892  67065.26   67064.94
 38185.492  10802.138  38185.02   58418.04   10802.828  67068.01
 11956.603  67067.56   67068.97   11976.312  10804.291  11979.167
 10799.254  13111.461  58420.49   67062.8

ls: [tensor([35153.], dtype=torch.float64), tensor([9951.], dtype=torch.float64), tensor([38815.], dtype=torch.float64), tensor([14175.], dtype=torch.float64), tensor([66821.], dtype=torch.float64), tensor([9951.], dtype=torch.float64), tensor([66821.], dtype=torch.float64), tensor([58513.], dtype=torch.float64), tensor([39421.], dtype=torch.float64), tensor([9951.], dtype=torch.float64), tensor([66821.], dtype=torch.float64), tensor([14187.], dtype=torch.float64), tensor([58512.], dtype=torch.float64), tensor([58512.], dtype=torch.float64), tensor([66822.], dtype=torch.float64), tensor([58512.], dtype=torch.float64), tensor([35343.], dtype=torch.float64), tensor([66823.], dtype=torch.float64), tensor([58512.], dtype=torch.float64), tensor([19798.], dtype=torch.float64), tensor([58512.], dtype=torch.float64), tensor([38590.], dtype=torch.float64), tensor([9951.], dtype=torch.float64), tensor([58513.], dtype=torch.float64), tensor([66821.], dtype=torch.float64), tensor([66821.], dtype=t

pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([19, 1])
running epoch: 43


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([19, 1])
ps: [37093.477  14344.764  37094.605  11113.522  61833.34   12195.88
 61832.395  53061.688  37094.434  11113.448  61834.516  11113.173
 53060.88   53063.51   61834.457  53063.465  37093.97   61834.043
 53063.258  28704.484  53061.945  37095.03   16811.672  53058.652
 61834.867  61834.754  11113.236  61834.402  14310.438  61833.16
 53061.742  53061.742  37096.234  53061.84   61835.695  23676.918
 61832.63   37094.04   11113.81   53062.35   53062.6    61834.64
 61836.055  53061.234  53056.883  37094.395  61833.52   37094.15
 53058.297  13070.275  11113.639  11113.184  12192.786  53060.88
 37094.82   37095.     53063.562  61835.58   61833.58   61834.695
 61836.055  11113.734  37094.15   11113.872  53061.08   61833.812
 53062.65   61833.93   11113.12   11113.502  61833.457  61835.344
 37093.83   11113.4795 37094.22   53062.043  11113.311  61833.395
 12189.636  61834.164  61832.332  12193.413  11113.713  12196.4375
 11113.66   13073.865  53063.965  618

pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([19, 1])
ps: [34884.266  11365.968  34886.695  11031.749  63772.418  13032.749
 63771.684  60796.547  34879.676  11030.876  63774.184  11030.686
 60821.836  60731.07   63775.574  60796.26   34881.273  63775.945
 60801.945  13863.437  60787.273  34883.1    13520.197  60754.01
 63775.21   63770.594  11031.044  63775.21   11353.746  63772.777
 60792.027  60790.46   34878.676  60775.45   63772.777  14648.475
 63774.184  34886.527  11030.655  60755.344  60720.355  63772.113
 63769.68   60748.797  60776.492  34879.81   63772.055  34882.402
 60794.293  12742.847  11031.181  11030.613  13030.4375 60783.62
 34885.2    34881.273  60701.71   63775.273  63776.977  63772.84
 63773.633  11031.212  34881.97   11031.065  60713.41   63773.387
 60780.38   63771.75   11028.435  11030.424  63771.867  63774.6
 34884.     11030.486  34888.19   60775.97   11030.644  63773.996
 13026.672  63773.51   63771.32   13031.096  11031.581  13034.763
 11029.004  12745.702  60761.535  6377

pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([19, 1])
ps: [35825.457  11315.46   35828.77   11125.1875 66954.51   12912.768
 66950.74   66059.266  35824.6    11122.906  66948.125  11122.97
 66060.22   66058.83   66949.66   66058.64   35825.215  66948.766
 66058.13   22457.984  66058.516  35824.465  13055.04   66057.51
 66949.28   66950.42   11123.373  66949.914  11311.402  66954.695
 66060.65   66058.695  35823.164  66058.26   66947.805  22223.92
 66954.26   35827.914  11123.224  66057.945  66060.91   66950.1
 66946.914  66060.91   66058.45   35824.223  66953.875  35826.312
 66057.     13116.014  11124.254  11122.747  12910.65   66060.02
 35828.324  35824.53   66059.96   66951.89   66954.58   66952.85
 66950.1    11124.231  35826.445  11124.254  66059.45   66949.28
 66059.45   66953.875  11120.499  11123.203  66949.28   66951.06
 35824.33   11122.535  35829.285  66060.086  11122.567  66954.83
 12907.363  66949.15   66950.94   12910.083  11124.656  12913.963
 11121.411  13116.753  66059.58   66948.95 

pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([19, 1])
ps: [41637.992  12412.546  41633.113  12353.464  66927.89   12406.652
 66927.7    65590.47   41635.45   12351.392  66929.03   12351.474
 65592.16   65593.836  66928.59   65591.9    41633.984  66928.72
 65591.84   22970.398  65593.336  41643.355  12459.833  65586.64
 66928.66   66928.59   12352.109  66928.85   12409.528  66927.83
 65590.27   65593.52   41646.215  65593.02   66929.36   18243.799
 66926.734  41632.04   12351.508  65591.53   65592.72   66928.85
 66929.164  65593.65   65589.15   41639.504  66927.57   41628.383
 65590.02   12666.503  12352.392  12351.273  12405.766  65591.78
 41618.465  41638.707  65591.16   66928.98   66927.7    66926.99
 66929.164  12352.215  41635.688  12351.991  65591.21   66928.59
 65592.41   66927.375  12349.212  12351.309  66928.14   66928.914
 41645.46   12351.249  41633.348  65591.47   12351.309  66927.95
 12403.813  66928.46   66927.51   12405.6    12352.616  12407.173
 12349.989  12669.39   65592.84   66928.6

pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([19, 1])
ps: [37084.773  13454.765  37066.11   10973.869  66604.94   12846.894
 66602.08   65717.125  37108.375  10972.163  66602.97   10972.54
 65719.375  65717.06   66600.69   65716.125  37107.91   66600.305
 65716.125  26759.908  65718.375  37101.4    14276.797  65712.62
 66601.83   66603.86   10972.78   66602.46   13446.452  66605.195
 65717.06   65718.76   37084.492  65718.195  66604.49   23600.607
 66596.43   37082.902  10972.121  65715.44   65717.88   66603.54
 66605.     65720.016  65716.75   37108.375  66605.     37093.086
 65715.44   12378.619  10973.188  10972.184  12844.639  65718.76
 37071.902  37100.62   65715.44   66603.09   66584.99   66603.35
 66604.37   10973.723  37106.883  10973.502  65717.06   66602.78
 65717.     66605.13   10968.407  10972.833  66603.35   66603.984
 37099.77   10971.106  37070.207  65717.88   10970.897  66604.3
 12840.449  66602.97   66602.27   12845.067  10973.974  12849.406
 10970.081  12380.603  65716.25   66603.9

ls: [tensor([35153.], dtype=torch.float64), tensor([9951.], dtype=torch.float64), tensor([38815.], dtype=torch.float64), tensor([14175.], dtype=torch.float64), tensor([66821.], dtype=torch.float64), tensor([9951.], dtype=torch.float64), tensor([66821.], dtype=torch.float64), tensor([58513.], dtype=torch.float64), tensor([39421.], dtype=torch.float64), tensor([9951.], dtype=torch.float64), tensor([66821.], dtype=torch.float64), tensor([14187.], dtype=torch.float64), tensor([58512.], dtype=torch.float64), tensor([58512.], dtype=torch.float64), tensor([66822.], dtype=torch.float64), tensor([58512.], dtype=torch.float64), tensor([35343.], dtype=torch.float64), tensor([66823.], dtype=torch.float64), tensor([58512.], dtype=torch.float64), tensor([19798.], dtype=torch.float64), tensor([58512.], dtype=torch.float64), tensor([38590.], dtype=torch.float64), tensor([9951.], dtype=torch.float64), tensor([58513.], dtype=torch.float64), tensor([66821.], dtype=torch.float64), tensor([66821.], dtype=t

pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([19, 1])
ps: [39787.37   13295.704  39790.938  11684.441  65382.18   12352.286
 65383.617  57709.695  39787.145  11680.765  65386.74   11680.141
 57704.36   57703.758  65387.234  57718.066  39788.13   65387.74
 57720.49   20292.387  57703.64   39784.902  13461.067  57678.17
 65387.234  65385.047  11682.559  65387.05   13282.231  65381.75
 57703.15   57710.137  39784.64   57708.82   65386.176  16756.627
 65385.047  39788.582  11679.449  57707.773  57682.906  65385.43
 65385.047  57685.816  57659.145  39785.85   65381.934  39788.43
 57687.086  12640.125  11681.845  11678.715  12350.708  57700.29
 39790.33   39786.004  57697.203  65386.297  65387.797  65383.43
 65386.113  11680.508  39787.332  11680.597  57688.79   65386.92
 57707.824  65382.18   11668.917  11676.064  65386.668  65384.676
 39783.574  11679.772  39789.113  57709.918  11680.707  65382.367
 12348.423  65386.547  65383.56   12350.86   11682.602  12353.6045
 11671.244  12645.574  57705.242  65385.

pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([9, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([64, 1])


pred_output shape: torch.Size([19, 1])
ps: [35268.89   13455.445  35224.688  12253.904  67124.     13057.194
 67126.11   64274.676  35120.266  12242.937  67125.98   12240.671
 64274.12   64281.293  67125.98   64277.37   35124.42   67125.98
 64279.516  35361.035  64277.492  35242.492  14827.863  64269.77
 67125.22   67125.28   12244.523  67125.15   13451.096  67123.74
 64271.855  64278.1    35179.465  64277.242  67125.28   26790.066
 67125.86   35195.105  12242.434  64278.715  64277.125  67125.28
 67125.15   64275.836  64266.89   35163.535  67123.94   35120.6
 64270.266  12706.599  12242.725  12238.371  13052.152  64273.33
 35124.855  35229.96   64278.598  67124.9    67125.73   67123.55
 67124.96   12242.341  35124.29   12242.469  64275.594  67125.92
 64273.754  67124.     12226.81   12235.861  67125.86   67122.84
 35265.63   12241.43   35267.547  64275.652  12242.294  67123.74
 13045.471  67125.92   67126.11   13052.912  12246.007  13059.747
 12231.078  12708.962  64278.348  67125.86  

ls: [tensor([35153.], dtype=torch.float64), tensor([9951.], dtype=torch.float64), tensor([38815.], dtype=torch.float64), tensor([14175.], dtype=torch.float64), tensor([66821.], dtype=torch.float64), tensor([9951.], dtype=torch.float64), tensor([66821.], dtype=torch.float64), tensor([58513.], dtype=torch.float64), tensor([39421.], dtype=torch.float64), tensor([9951.], dtype=torch.float64), tensor([66821.], dtype=torch.float64), tensor([14187.], dtype=torch.float64), tensor([58512.], dtype=torch.float64), tensor([58512.], dtype=torch.float64), tensor([66822.], dtype=torch.float64), tensor([58512.], dtype=torch.float64), tensor([35343.], dtype=torch.float64), tensor([66823.], dtype=torch.float64), tensor([58512.], dtype=torch.float64), tensor([19798.], dtype=torch.float64), tensor([58512.], dtype=torch.float64), tensor([38590.], dtype=torch.float64), tensor([9951.], dtype=torch.float64), tensor([58513.], dtype=torch.float64), tensor([66821.], dtype=torch.float64), tensor([66821.], dtype=t

In [11]:
best_model_path

'-3799469494494950040.pt'

In [12]:
len(loaded_val_indices[:10])

10

In [13]:
len(val_embeddings)

467

In [14]:
import pandas as pd
import numpy as np

# Assume these are loaded or computed
# val_embeddings: List or numpy array of embeddings (e.g., shape [num_examples, embedding_dim])
# loaded_val_indices: List of validation indices
# loaded_file_names: List of validation file names

# Convert val_embeddings to a numpy array if it's not already
val_embeddings = np.array(val_embeddings)

# Ensure the size of all inputs matches
if len(loaded_val_indices) != len(loaded_val_file_names) or len(loaded_val_indices) != len(val_embeddings):
    raise ValueError("The sizes of validation indices, filenames, and embeddings must match.")

# Create a DataFrame
val_data_df = pd.DataFrame({
    "val_index": loaded_val_indices,
    "filename": loaded_val_file_names,
    "embedding": list(val_embeddings)  # Store embeddings as a list of numpy arrays
})

# Save the DataFrame to a CSV file
csv_file_path = "validation_embeddings.csv"
val_data_df.to_csv(csv_file_path, index=False)

# Optionally save embeddings in a separate binary file for efficient storage
embeddings_file_path = "validation_embeddings.npy"
np.save(embeddings_file_path, val_embeddings)

print(f"Validation DataFrame created and saved to {csv_file_path}")
print(f"Embeddings saved to {embeddings_file_path}")

# Print the first few rows for verification
print(val_data_df.head())


Validation DataFrame created and saved to validation_embeddings.csv
Embeddings saved to validation_embeddings.npy
   val_index                                  filename  \
0       1056  query_1952_2024-12-03-21.34.02.429279.pt   
1       2323   query_993_2024-12-03-21.10.08.618148.pt   
2        655  query_1591_2024-12-03-21.23.17.861948.pt   
3       2101   query_793_2024-12-03-21.05.49.462784.pt   
4       1401  query_2262_2024-12-03-21.44.00.539251.pt   

                                           embedding  
0  [-114.55693, -101.882355, -14.314699, -1.86743...  
1  [-104.10589, -100.892784, -15.305394, 19.10723...  
2  [-114.54547, -101.893394, -14.308007, -1.88108...  
3  [-89.35007, -103.53369, -24.114931, 38.514816,...  
4  [-99.47161, -60.7916, -35.936123, -1.4182578, ...  


In [15]:
val_data_df.head()

Unnamed: 0,val_index,filename,embedding
0,1056,query_1952_2024-12-03-21.34.02.429279.pt,"[-114.55693, -101.882355, -14.314699, -1.86743..."
1,2323,query_993_2024-12-03-21.10.08.618148.pt,"[-104.10589, -100.892784, -15.305394, 19.10723..."
2,655,query_1591_2024-12-03-21.23.17.861948.pt,"[-114.54547, -101.893394, -14.308007, -1.88108..."
3,2101,query_793_2024-12-03-21.05.49.462784.pt,"[-89.35007, -103.53369, -24.114931, 38.514816,..."
4,1401,query_2262_2024-12-03-21.44.00.539251.pt,"[-99.47161, -60.7916, -35.936123, -1.4182578, ..."


In [16]:
print(f"Number of training embeddings (best epoch): {len(train_embeddings)}")
print(f"First training embedding shape: {train_embeddings[0].shape}")
print(f"Number of validation embeddings (best epoch): {len(val_embeddings)}")
print(f"First validation embedding shape: {val_embeddings[0].shape}")


Number of training embeddings (best epoch): 1865
First training embedding shape: (1417,)
Number of validation embeddings (best epoch): 467
First validation embedding shape: (1417,)


In [17]:
print(f"First validation embedding shape: {val_embeddings[0]}")

First validation embedding shape: [-114.55693  -101.882355  -14.314699 ...   51.675053   26.12046
   56.93147 ]


In [18]:
type(val_embeddings[0])

numpy.ndarray

In [19]:
print(f"First validation embedding shape: {val_embeddings[1]}")

First validation embedding shape: [-104.10589  -100.892784  -15.305394 ...   48.47303    35.593414
   74.74174 ]


In [20]:
len(val_embeddings[0])

1417

In [21]:
# import numpy as np

# # Disable truncation
# np.set_printoptions(threshold=np.inf, linewidth=1000)

# for i, emb in enumerate(val_embeddings):
#     print(f"Full Embedding {i}:\n{emb}")


In [22]:
import json
import numpy as np

# Ensure embeddings are numpy arrays
train_embeddings = np.array(train_embeddings)
val_embeddings = np.array(val_embeddings)

# Paths for saving the JSON files
train_embeddings_file = "./train_embeddings.json"
val_embeddings_file = "./val_embeddings.json"

# Create dictionaries with file names as keys and embeddings as values
train_embedding_dict = {train_file_name: embedding.tolist() for train_file_name, embedding in zip(loaded_train_file_names, train_embeddings)}
val_embedding_dict = {val_file_name: embedding.tolist() for val_file_name, embedding in zip(loaded_val_file_names, val_embeddings)}

# Save training embeddings to JSON
with open(train_embeddings_file, "w") as f:
    json.dump(train_embedding_dict, f)

# Save validation embeddings to JSON
with open(val_embeddings_file, "w") as f:
    json.dump(val_embedding_dict, f)

print(f"Train embeddings saved to {train_embeddings_file}")
print(f"Validation embeddings saved to {val_embeddings_file}")

Train embeddings saved to ./train_embeddings.json
Validation embeddings saved to ./val_embeddings.json


In [23]:
import pandas as pd
import json

# Path to the JSON file
val_embeddings_file = "./val_embeddings.json"  # Replace with train_embeddings.json for training data

# Load the JSON file
with open(val_embeddings_file, "r") as f:
    embeddings_data = json.load(f)

# Convert JSON data to a DataFrame
val_embeddings_df = pd.DataFrame({
    "file_name": list(embeddings_data.keys()),
    "embedding": list(embeddings_data.values())
})

# Display the DataFrame
print(val_embeddings_df.head())

# Optionally save it to a CSV
val_embeddings_df.to_csv("val_embeddings.csv", index=False)
print("Validation embeddings saved to val_embeddings.csv")


                                  file_name  \
0  query_1952_2024-12-03-21.34.02.429279.pt   
1   query_993_2024-12-03-21.10.08.618148.pt   
2  query_1591_2024-12-03-21.23.17.861948.pt   
3   query_793_2024-12-03-21.05.49.462784.pt   
4  query_2262_2024-12-03-21.44.00.539251.pt   

                                           embedding  
0  [-114.55693054199219, -101.88235473632812, -14...  
1  [-104.10588836669922, -100.89278411865234, -15...  
2  [-114.54547119140625, -101.89339447021484, -14...  
3  [-89.35006713867188, -103.53369140625, -24.114...  
4  [-99.47161102294922, -60.79159927368164, -35.9...  


Validation embeddings saved to val_embeddings.csv
