In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
import numpy as np
import pandas as pd
import datetime
import seaborn as sns
import ogb
from tqdm import tqdm

In [3]:
import torch
from torch_geometric.data import Data, DataLoader

In [4]:
cwd = os.getcwd()
print(cwd)

/cluster/home/skyriakos/chemprop_run/git/notebooks


In [5]:
os.chdir('..')
import deepadr
from deepadr.dataset import *
from deepadr.utilities import *
from deepadr.run_workflow import *
from deepadr.chemfeatures import *
# from deepadr.model_gnn import GCN as testGCN
from deepadr.model_gnn_ogb import GNN, DeepAdr_SiameseTrf
# from deepadr.model_attn_siamese import *
from ogb.graphproppred import Evaluator
os.chdir(cwd)



In [6]:
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem import Draw

In [7]:
# from tdc.single_pred import Tox
from tdc.multi_pred import DDI

In [8]:
rawdata_dir = '../data/raw/'
processed_dir = '../data/processed/'
up_dir = '..'

In [9]:
report_available_cuda_devices()

number of GPUs available: 1
cuda:0, name:GeForce GTX 1080 Ti
total memory available: 10.91650390625 GB
total memory allocated on device: 0.0 GB
max memory allocated on device: 0.0 GB
total memory cached on device: 0.0 GB
max memory cached  on device: 0.0 GB



In [10]:
n_gpu = torch.cuda.device_count()
n_gpu

1

In [11]:
device_cpu = get_device(to_gpu=False)
device_gpu = get_device(True, index=0)

### Preparing dataset 

In [12]:
# TDC Tox
DSdataset_name = 'TWOSIDES' #'DrugBank'

#fname_suffix = ds_config["fname_suffix"]
similarity_types = ['chem']
kernel_option = 'sqeuclidean'
data_fname = 'data_v1'
# interact_matfname = ds_config["interact_matfname"]
# exp_iden = 'simtypeall'
# ddi_interaction_labels_pth = ds_config["ddi_interaction_labels_pth"]

# up_dir, processed_dir, DSdataset_name, data_fname

In [13]:
targetdata_dir = create_directory(os.path.join(processed_dir, DSdataset_name, data_fname))
targetdata_dir_raw = create_directory(os.path.join(targetdata_dir, "raw"))
targetdata_dir_processed = create_directory(os.path.join(targetdata_dir, "processed"))
# # ReaderWriter.dump_data(dpartitions, os.path.join(targetdata_dir, 'data_partitions.pkl'))
print(targetdata_dir)

path_current_dir /cluster/home/skyriakos/chemprop_run/git/deepadr
path_current_dir /cluster/home/skyriakos/chemprop_run/git/deepadr
path_current_dir /cluster/home/skyriakos/chemprop_run/git/deepadr
/cluster/home/skyriakos/chemprop_run/git/data/processed/TWOSIDES/data_v1


In [14]:
%%time
dataset = MoleculeDataset(root=targetdata_dir)

CPU times: user 8.3 ms, sys: 652 ms, total: 660 ms
Wall time: 650 ms


In [15]:
print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data0 = dataset[0]  # Get the first graph object.

# print()
# print(data)
# print('=============================================================')

# # Gather some statistics about the first graph.
# print(f'Number of nodes: {data.num_nodes}')
# print(f'Number of edges: {data.num_edges}')
# print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
# print(f'Contains isolated nodes: {data.contains_isolated_nodes()}')
# print(f'Contains self-loops: {data.contains_self_loops()}')
# print(f'Is undirected: {data.is_undirected()}')


Dataset: MoleculeDataset(149878):
Number of graphs: 149878
Number of features: 9
Number of classes: 2


In [16]:
data0

PairData(edge_attr_a=[52, 3], edge_attr_b=[22, 3], edge_index_a=[2, 52], edge_index_b=[2, 22], id=[1], x_a=[24, 9], x_b=[12, 9], y=[1])

In [17]:
len(dataset)

149878

In [18]:
used_dataset = dataset

smaller_dataset_len = int(len(dataset)/5)
used_dataset = dataset[:smaller_dataset_len]

In [19]:
train_val_test_frac = [0.7, 0.1, 0.2]
assert sum(train_val_test_frac) == 1

torch.manual_seed(42)
used_dataset = used_dataset.shuffle()

num_train = round(train_val_test_frac[0] * len(used_dataset)) 
num_trainval = round((train_val_test_frac[0]+train_val_test_frac[1]) * len(used_dataset))

train_dataset = used_dataset[:num_train]
val_dataset = used_dataset[num_train:num_trainval]
test_dataset = used_dataset[num_trainval:]

In [20]:
print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of val graphs: {len(val_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

Number of training graphs: 20982
Number of val graphs: 2998
Number of test graphs: 5995


In [21]:
train_dataset[3]

PairData(edge_attr_a=[66, 3], edge_attr_b=[66, 3], edge_index_a=[2, 66], edge_index_b=[2, 66], id=[1], x_a=[29, 9], x_b=[30, 9], y=[1])

In [22]:
batch_size=32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, follow_batch=['x_a', 'x_b'])
valid_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, follow_batch=['x_a', 'x_b'])
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, follow_batch=['x_a', 'x_b'])

In [23]:
print(dataset.num_classes)

2


In [24]:
# model_intermediate_dimension = 64
emb_dim = 300

In [25]:
# model_name = "testGCN"

# deepadr_model = testGCN(num_node_features=dataset.num_node_features, 
#             hidden_channels=64, 
#             num_classes=model_intermediate_dimension)
# print(deepadr_model)

In [26]:
model_name = "ogb"

gnn_model = GNN(gnn_type = 'gcn', 
                    num_tasks = dataset.num_classes, 
                    num_layer = 5, 
                    emb_dim = emb_dim, 
                    drop_ratio = 0.5, 
                    virtual_node = False,
                    with_edge_attr=False,
                    return_multilayer = True).to(device_gpu)
#print(deepadr_model)

In [27]:
input_embed_dim = None
num_attn_heads = 2
num_transformer_units = 1
p_dropout = 0.3
nonlin_func = nn.ReLU()
mlp_embed_factor = 2
pooling_mode = 'attn'

dist_opt = 'cosine'
l2_reg = 1e-6
batch_size = 1000
num_epochs = 100
loss_w = 0.05
margin_v = 1.

transformer_model = DeepAdr_Transformer(input_size=emb_dim,
                                        input_embed_dim=input_embed_dim,
                                        num_attn_heads=num_attn_heads,
                                        mlp_embed_factor=mlp_embed_factor,
                                        nonlin_func=nonlin_func,
                                        pdropout=p_dropout,
                                        num_transformer_units=num_transformer_units,
                                        pooling_mode=pooling_mode).to(device_gpu)

In [28]:
deepadr_siamese = DeepAdr_SiameseTrf(input_dim=emb_dim,
                                   dist=dist_opt,
                                   num_classes=dataset.num_classes).to(device_gpu)

updated


In [29]:
models_param = list(gnn_model.parameters()) + list(transformer_model.parameters()) + list(deepadr_siamese.parameters())

models = [(gnn_model, f'{model_name}_GNN'),
          (transformer_model, f'{model_name}_Transformer'),
          (deepadr_siamese, f'{model_name}_Siamese')]
#models

In [30]:
# from IPython.display import Javascript
# display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))

optimizer = torch.optim.Adam(models_param, lr=0.001)
# criterion = torch.nn.CrossEntropyLoss()

# loss_nlll = torch.nn.NLLLoss(weight=class_weights, reduction='mean')  # negative log likelihood loss
loss_nlll = torch.nn.NLLLoss(reduction='mean')  # negative log likelihood loss
loss_contrastive = ContrastiveLoss(0.5, reduction='mean')
fdtype = torch.float32
loss_w = 0.1

# evaluator = Evaluator(DSdataset_name)

In [31]:
def train():
    for m, m_name in models:
        m.train()

        #            for i_batch, samples_batch in enumerate(data_loader):

#     for data in train_loader:  # Iterate in batches over the training dataset.
    for i_batch, batch in enumerate(tqdm(train_loader, desc="Iteration")):
        batch = batch.to(device_gpu)
#         print("running batch:", i_batch)
        # x, edge_index, edge_attr, batch
        h_a = gnn_model(batch.x_a, batch.edge_index_a, batch.edge_attr_a, batch.x_a_batch)
        h_b = gnn_model(batch.x_b, batch.edge_index_b, batch.edge_attr_b, batch.x_b_batch)
        
        z_a, fattn_w_scores_a = transformer_model(h_a)
        z_b, fattn_w_scores_b = transformer_model(h_b)
        
#         print("h_a shape:", h_a.shape)
#         print("z_a shape:", z_a.shape)
        
        logsoftmax_scores, dist = deepadr_siamese(z_a, z_b)
#         out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
#         loss = criterion(out, samples_batch.y)  # Compute the loss.
        cl = loss_nlll(logsoftmax_scores, batch.y)            
        dl = loss_contrastive(dist.reshape(-1), batch.y.type(fdtype))          
        loss = loss_w*cl + (1-loss_w)*dl
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.

def eval(loader, dsettype):
    for m, m_name in models:
        m.eval()
        
    pred_class = []
    ref_class = []
    prob_scores = []
    
#     for data in loader:  # Iterate in batches over the training/test dataset.
    for i_batch, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device_gpu)
#         out = model(data.x, data.edge_index, data.batch)  
        h_a = gnn_model(batch.x_a, batch.edge_index_a, batch.edge_attr_a, batch.x_a_batch)
        h_b = gnn_model(batch.x_b, batch.edge_index_b, batch.edge_attr_b, batch.x_b_batch)
        
        z_a, fattn_w_scores_a = transformer_model(h_a)
        z_b, fattn_w_scores_b = transformer_model(h_b)
        
#         logsoftmax_scores, dist = deepadr_siamese(z_a, z_b)
        
#         __, y_pred_clss = torch.max(logsoftmax_scores, -1)
#         y_pred_prob  = torch.exp(logsoftmax_scores.detach().cpu()).numpy()

#         pred_class.extend(y_pred_clss.view(-1).tolist())
#         ref_class.extend(batch.y.view(-1).tolist())
# #         prob_scores.append(y_pred_prob)
#         prob_scores.extend(y_pred_prob.view(-1).tolist())
    
        logsoftmax_scores, dist = deepadr_siamese(z_a, z_b)

        __, y_pred_clss = torch.max(logsoftmax_scores, -1)

        y_pred_prob  = torch.exp(logsoftmax_scores.detach().cpu()).numpy()

        pred_class.extend(y_pred_clss.view(-1).tolist())
        ref_class.extend(batch.y.view(-1).tolist())
        prob_scores.append(y_pred_prob)

    prob_scores_arr = np.concatenate(prob_scores, axis=0)
    modelscore = perfmetric_report(pred_class, ref_class, prob_scores_arr[:,1], epoch,
                                  outlog = os.path.join(targetdata_dir_processed, dsettype + ".log"))
#     modelscore = perfmetric_report(pred_class, ref_class, prob_scores, epoch, "")        
    
#         pred = out.argmax(dim=1)  # Use the class with highest probability.
#         correct += int((pred == samples_batch.y).sum())  # Check against ground-truth labels.
#     return correct / len(loader.dataset)  # Derive ratio of correct predictions.
    return modelscore

In [32]:
%%time

valid_curve = []
test_curve = []
train_curve = []

for epoch in range(100):
    print("=====Epoch {}".format(epoch))
    print('Training...')
#     train(model, device, train_loader, optimizer, dataset.task_type)
    train()

    print('Evaluating...')
#     train_perf = eval(model, device, train_loader, evaluator)
#     valid_perf = eval(model, device, valid_loader, evaluator)
#     test_perf = eval(model, device, test_loader, evaluator)
    train_perf = eval(train_loader, dsettype="train")
    valid_perf = eval(valid_loader, dsettype="valid")
    test_perf = eval(test_loader, dsettype="test")

    print({'Train': train_perf, 'Validation': valid_perf, 'Test': test_perf})

    train_curve.append(train_perf.s_aupr)
    valid_curve.append(valid_perf.s_aupr)
    test_curve.append(test_perf.s_aupr)

# if 'classification' in dataset.task_type:
best_val_epoch = np.argmax(np.array(valid_curve))
best_train = max(train_curve)
# else:
#     best_val_epoch = np.argmin(np.array(valid_curve))
#     best_train = min(train_curve)

print('Finished training!')
print('Best validation score: {}'.format(valid_curve[best_val_epoch]))
print('Test score: {}'.format(test_curve[best_val_epoch]))

Iteration:   0%|          | 0/656 [00:00<?, ?it/s]

=====Epoch 0
Training...


Iteration:   0%|          | 3/656 [00:00<01:45,  6.18it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:   1%|          | 7/656 [00:00<00:54, 11.96it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:   2%|▏         | 11/656 [00:00<00:41, 15.38it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:   2%|▏         | 15/656 [00:01<00:39, 16.42it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:   3%|▎         | 19/656 [00:01<00:37, 17.03it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:   4%|▎         | 23/656 [00:01<00:36, 17.48it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:   4%|▍         | 27/656 [00:01<00:34, 18.15it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:   5%|▍         | 31/656 [00:02<00:34, 18.18it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:   5%|▌         | 35/656 [00:02<00:33, 18.34it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:   6%|▌         | 39/656 [00:02<00:33, 18.19it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:   7%|▋         | 43/656 [00:02<00:34, 18.00it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:   7%|▋         | 45/656 [00:02<00:36, 16.74it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:   7%|▋         | 49/656 [00:03<00:37, 16.22it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:   8%|▊         | 53/656 [00:03<00:35, 17.22it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:   9%|▊         | 57/656 [00:03<00:34, 17.60it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:   9%|▉         | 61/656 [00:03<00:35, 16.81it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  10%|▉         | 65/656 [00:04<00:33, 17.43it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  11%|█         | 69/656 [00:04<00:37, 15.47it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  11%|█         | 73/656 [00:04<00:36, 16.07it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  12%|█▏        | 77/656 [00:04<00:34, 17.02it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  12%|█▏        | 81/656 [00:05<00:33, 17.36it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  13%|█▎        | 85/656 [00:05<00:32, 17.59it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  14%|█▎        | 89/656 [00:05<00:31, 17.74it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  14%|█▍        | 93/656 [00:05<00:31, 17.73it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  14%|█▍        | 95/656 [00:05<00:32, 17.32it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  15%|█▌        | 99/656 [00:06<00:35, 15.87it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  16%|█▌        | 103/656 [00:06<00:32, 16.84it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  16%|█▋        | 107/656 [00:06<00:31, 17.60it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  17%|█▋        | 109/656 [00:06<00:31, 17.64it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  17%|█▋        | 113/656 [00:07<00:37, 14.36it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  18%|█▊        | 117/656 [00:07<00:34, 15.84it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  18%|█▊        | 121/656 [00:07<00:31, 16.78it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  19%|█▉        | 125/656 [00:07<00:30, 17.23it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  20%|█▉        | 129/656 [00:07<00:30, 17.48it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  20%|██        | 133/656 [00:08<00:29, 17.52it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  21%|██        | 137/656 [00:08<00:29, 17.41it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  21%|██▏       | 141/656 [00:08<00:29, 17.71it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  22%|██▏       | 145/656 [00:08<00:29, 17.60it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  23%|██▎       | 149/656 [00:09<00:33, 15.17it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  23%|██▎       | 153/656 [00:09<00:30, 16.49it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  24%|██▍       | 157/656 [00:09<00:28, 17.37it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  25%|██▍       | 161/656 [00:09<00:28, 17.53it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  25%|██▌       | 165/656 [00:10<00:27, 18.07it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  26%|██▌       | 169/656 [00:10<00:26, 18.09it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  26%|██▌       | 171/656 [00:10<00:27, 17.84it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  27%|██▋       | 177/656 [00:10<00:26, 17.93it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  28%|██▊       | 181/656 [00:10<00:26, 17.70it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  28%|██▊       | 185/656 [00:11<00:26, 17.77it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  29%|██▉       | 189/656 [00:11<00:26, 17.70it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  29%|██▉       | 193/656 [00:11<00:26, 17.66it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  30%|███       | 197/656 [00:11<00:25, 17.82it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  31%|███       | 201/656 [00:12<00:25, 17.79it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  31%|███▏      | 205/656 [00:12<00:25, 17.70it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  32%|███▏      | 209/656 [00:12<00:25, 17.84it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  32%|███▏      | 213/656 [00:12<00:25, 17.68it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  33%|███▎      | 217/656 [00:12<00:24, 17.67it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  34%|███▎      | 221/656 [00:13<00:24, 17.67it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  34%|███▍      | 225/656 [00:13<00:24, 17.81it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  35%|███▍      | 229/656 [00:13<00:23, 17.82it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  36%|███▌      | 233/656 [00:13<00:23, 17.66it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  36%|███▌      | 235/656 [00:13<00:24, 16.95it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  37%|███▋      | 241/656 [00:14<00:23, 17.60it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  37%|███▋      | 245/656 [00:14<00:23, 17.80it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  38%|███▊      | 249/656 [00:14<00:22, 17.76it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  38%|███▊      | 251/656 [00:14<00:23, 17.56it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  39%|███▉      | 255/656 [00:15<00:23, 16.92it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  40%|███▉      | 261/656 [00:15<00:22, 17.86it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  40%|████      | 265/656 [00:15<00:21, 18.08it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  41%|████      | 269/656 [00:15<00:21, 17.93it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  42%|████▏     | 273/656 [00:16<00:21, 17.79it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  42%|████▏     | 277/656 [00:16<00:21, 17.85it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  43%|████▎     | 281/656 [00:16<00:21, 17.59it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  43%|████▎     | 285/656 [00:16<00:20, 17.68it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  44%|████▍     | 289/656 [00:17<00:20, 17.63it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  45%|████▍     | 293/656 [00:17<00:20, 17.65it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  45%|████▌     | 297/656 [00:17<00:20, 17.63it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  46%|████▌     | 301/656 [00:17<00:19, 17.77it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  46%|████▋     | 305/656 [00:17<00:19, 17.97it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  47%|████▋     | 309/656 [00:18<00:19, 18.12it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  48%|████▊     | 313/656 [00:18<00:18, 18.09it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  48%|████▊     | 315/656 [00:18<00:20, 16.78it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  49%|████▊     | 319/656 [00:18<00:20, 16.28it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  49%|████▉     | 323/656 [00:18<00:19, 17.10it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  50%|████▉     | 327/656 [00:19<00:18, 17.36it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  50%|█████     | 331/656 [00:19<00:18, 17.72it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  51%|█████     | 335/656 [00:19<00:18, 17.68it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  52%|█████▏    | 339/656 [00:19<00:17, 17.76it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  52%|█████▏    | 343/656 [00:20<00:17, 17.66it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  53%|█████▎    | 347/656 [00:20<00:17, 17.94it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  54%|█████▎    | 351/656 [00:20<00:17, 17.93it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  54%|█████▍    | 355/656 [00:20<00:16, 17.89it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  55%|█████▍    | 359/656 [00:21<00:16, 17.92it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  55%|█████▌    | 363/656 [00:21<00:16, 17.93it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  56%|█████▌    | 367/656 [00:21<00:15, 18.07it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  57%|█████▋    | 371/656 [00:21<00:16, 17.81it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  57%|█████▋    | 375/656 [00:21<00:15, 17.78it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])


Iteration:  58%|█████▊    | 378/656 [00:22<00:16, 16.97it/s]

h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])
h_a shape: torch.Size([32, 6, 300])
z_a shape: torch.Size([32, 300])





KeyboardInterrupt: 

In [None]:
valid_curve

In [None]:
df_curves = pd.DataFrame(np.array([train_curve, valid_curve, test_curve]).T)
df_curves.columns = ['train', 'valid', 'test']
df_curves.index.name = "epoch"

In [None]:
df_curves

In [None]:
sns.lineplot(data=df_curves)

In [37]:
b = torch.tensor([[1, 1,1,1,2,2,2,2], [3,3,3,3, 4,4,4,4], [5,5,5,5, 6,6,6,6]])

In [38]:
b

tensor([[1, 1, 1, 1, 2, 2, 2, 2],
        [3, 3, 3, 3, 4, 4, 4, 4],
        [5, 5, 5, 5, 6, 6, 6, 6]])

2.0

In [45]:
b.reshape(b.shape[0],4,b.shape[1] //4)

tensor([[[1, 1],
         [1, 1],
         [2, 2],
         [2, 2]],

        [[3, 3],
         [3, 3],
         [4, 4],
         [4, 4]],

        [[5, 5],
         [5, 5],
         [6, 6],
         [6, 6]]])

In [None]:
print("h_graph_cat shape:", h_graph_cat.shape)

### Generate datapartitions (i.e. train/val, test indices)

In [None]:
dpartitions = get_stratified_partitions(y, num_folds=5, valid_set_portion=0.1, random_state=42)

In [None]:
# dump data on disk
# targetdata_dir = create_directory(os.path.join(processed_dir, DSdataset_name, data_fname))
ReaderWriter.dump_data(dpartitions, os.path.join(targetdata_dir, 'data_partitions.pkl'))

### Create Tensors

In [None]:
device_cpu = get_device(to_gpu=False)
device_gpu = get_device(True, index=0)

### Using masking and inference with gip computation

In [None]:
gip_perfold = {}
for fold_id in dpartitions:
    masked_intermat = interaction_mat.copy()
    masked_intermat = masked_intermat.astype(np.float)
    for dsettype in ('validation', 'test'):
        # get validation/test ddi pair indices
        sids = dpartitions[fold_id][dsettype]
        a = [sid_ddipairs_map[sid][0] for sid in sids]
        b = [sid_ddipairs_map[sid][1] for sid in sids]
        # set to nan
        masked_intermat[tuple([a,b])] = np.nan
        masked_intermat[tuple([b,a])] = np.nan
        
    intermat_infer_lst = []
    nanw_mat_lst = []
    for similarity_type in similarity_types:
        print('similarity_type', similarity_type)
        siminput_feat_pth = os.path.join(up_dir, rawdata_dir, DSdataset_name, '{}{}'.format(similarity_type, fname_suffix))
        sim_mat = get_similarity_matrix(siminput_feat_pth, DSdataset_name)
        imat_infer, nanw_m = impute_nan(masked_intermat, sim_mat, k=15)
        intermat_infer_lst.append(imat_infer)
        nanw_mat_lst.append(nanw_m)
        
    infer_mat_fus = weight_inferred_mat(nanw_mat_lst, intermat_infer_lst)

    print('norm(infer_mat-interaction_mat)', np.linalg.norm(infer_mat_fus - interaction_mat))

    # compute GIP here
    gip_kernel = compute_gip_kernel(infer_mat_fus, 1., kernel_option)
    print('norm(gip_kernel-interaction_mat)',np.linalg.norm(gip_kernel - interaction_mat))
    t = gip_kernel-interaction_mat
    print(np.sum(np.abs(t) > 0.5)/(t.size - t.shape[0]))
    gip_perfold[fold_id] = gip_kernel

### Compute features from similarity matrices

#### check if similarity matrix is symmetric

In [None]:
num_sim_types = len(similarity_types)
for similarity_type in similarity_types:
    siminput_feat_pth = os.path.join(up_dir, rawdata_dir, DSdataset_name, '{}{}'.format(similarity_type, fname_suffix))
    sim_mat = get_similarity_matrix(siminput_feat_pth, DSdataset_name)   
    print(np.allclose(sim_mat, np.transpose(sim_mat)))

In [None]:
num_sim_types = len(similarity_types)
X_feats = []
for similarity_type in similarity_types:
    siminput_feat_pth = os.path.join(up_dir, rawdata_dir, DSdataset_name, '{}{}'.format(similarity_type, fname_suffix))
    X_feat = preprocess_features(siminput_feat_pth, DSdataset_name, fill_diag=None)   
    X_feats.append(X_feat)
X_feat_cat = np.concatenate(X_feats, axis=1)
print("X_feat_cat", X_feat_cat.shape)

In [None]:
X = create_setvector_features(X_feat_cat, 2*num_sim_types)
X.shape

In [None]:
X_a = X[:,list(range(0,2*num_sim_types,2))].copy()
X_b = X[:,list(range(1,2*num_sim_types,2))].copy()

In [None]:
from ddi.utilities import format_bytes
print(format_bytes(X_feat_cat.size * X_feat_cat.itemsize))
print(format_bytes(y.size * y.itemsize))

In [None]:
# clear unused objects
del X_feats
del X_feat_cat
del X_feat

In [None]:
device_cpu = get_device(to_gpu=False)
device_gpu = get_device(True, index=0)

In [None]:
# dtype is float32 since we will use sigmoid (binary outcome)
y_tensor = torch.tensor(y, dtype = torch.int64, device = device_cpu) 
X_a = torch.tensor(X_a, dtype = torch.float32, device = device_cpu)
X_b = torch.tensor(X_b, dtype = torch.float32, device = device_cpu)
ddi_datatensor = DDIDataTensor(X_a, X_b, y_tensor)

In [None]:
targetdata_dir

In [None]:
# dump data on disk
ReaderWriter.dump_tensor(X_a, os.path.join(targetdata_dir, 'X_a.torch'))
ReaderWriter.dump_tensor(X_b, os.path.join(targetdata_dir, 'X_b.torch'))
ReaderWriter.dump_tensor(y_tensor, os.path.join(targetdata_dir, 'y_tensor.torch'))

### Construct GIP datatensor for each fold

In [None]:
gip_dtensor_perfold = {}
for fold_id in gip_perfold:
    print('fold_id:', fold_id)
    gip_mat = gip_perfold[fold_id]
    print('gip_mat:', gip_mat.shape)
    gip_feat = get_features_from_simmatrix(gip_mat)
    gip_all = create_setvector_features(gip_feat, 2)
    print('gip_all:', gip_all.shape)
    X_a_gip = gip_all[:,list(range(0,2*1,2))].copy()
    X_b_gip = gip_all[:,list(range(1,2*1,2))].copy()
    print('X_a_gip:', X_a_gip.shape)
    X_a_gip = torch.tensor(X_a_gip, dtype = torch.float32, device = device_cpu)
    X_b_gip = torch.tensor(X_b_gip, dtype = torch.float32, device = device_cpu)
    gip_datatensor = GIPDataTensor(X_a_gip, X_b_gip)
    gip_dtensor_perfold[fold_id] = gip_datatensor

In [None]:
# dump data on disk
ReaderWriter.dump_tensor(gip_dtensor_perfold, os.path.join(targetdata_dir, 'gip_dtensor_perfold.torch'))