# Config

### Imports

In [503]:
import itertools
import os
import random
import shutil
import time
from tqdm.notebook import tqdm
import uuid
import numpy as np
import pandas as pd
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
import matplotlib.pyplot as plt

In [504]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE

'cuda'

### Constants

In [554]:
TRAIN_ON_TF = 'TF_31'
TRAIN_PROBE = 'A'
TEST_PROBE = 'B'
BATCH_SIZE = 64

In [555]:
DNA_BASES = 'ACGT'
RNA_BASES = 'ACGU'

In [556]:
PBM_DATA = "./data/dream5/pbm"

# Data preparation

### Read DREAM5 sequence data with binding scores

In [557]:
df_seq = pd.read_csv(f"{PBM_DATA}/sequences.tsv", sep='\t')
df_seq.head()

Unnamed: 0,Fold ID,Event ID,seq
0,B,MEreverse14075,TAAAACTATGAGGAAGGATTCAGGGTCGGACAGTGCCTGT
1,B,MEforward19438,CTTATGATCAGAAGCGGCTAGGTGTATTACATGTCCCTGT
2,B,MEforward19439,CCGCCGTAGGCCCCGAAACAGTACCAGACATGTAACCTGT
3,B,MEforward19436,GACCAAACGAGTCCTAGGATTCCAAGCGTTACGACCCTGT
4,B,MEforward19437,CGTTACGACGGAGTTTGGATCCCGAACTTATGATCCCTGT


In [558]:
df_targets = pd.read_csv(f"{PBM_DATA}/targets.tsv", sep='\t')
df_targets.head()

Unnamed: 0,TF_40,TF_41,TF_42,TF_43,TF_44,TF_45,TF_46,TF_47,TF_48,TF_49,...,C_19,C_18,C_15,C_14,C_17,C_16,C_11,C_10,C_13,C_12
0,823.914118,12702.625538,2124.023125,2314.305782,1474.888697,1131.785521,4597.003319,14589.890994,1556.951404,34180.775942,...,1651.254953,1242.303363,724.10585,3184.883349,8935.394363,12689.558779,4102.312624,505.126184,12946.381724,1313.790253
1,1307.840222,4316.426121,2554.658908,3415.320661,3408.586803,1697.342725,5272.763446,22903.130555,2181.551097,10000.297243,...,3505.604759,2516.00012,1640.114829,3463.713253,19535.468264,18006.72169,6890.427794,1402.597,38309.856355,3024.107809
2,1188.353499,3436.803941,2088.909658,3708.324021,2219.741833,1571.646567,6225.376501,13858.014077,1971.053716,18800.025304,...,3270.572883,1693.419147,997.792996,3196.992198,16695.027604,14486.992627,13517.968701,10680.866586,25648.825592,2675.530918
3,1806.103795,6531.268855,2406.186212,3601.204703,2828.415329,2746.861783,5810.10465,25701.749693,2191.273065,19213.880658,...,2701.555739,2059.614815,1432.163042,4927.163643,18896.765835,18784.043322,8608.167421,4624.044391,23651.726053,3679.449867
4,1417.411525,3951.243575,2581.309532,3375.884699,2764.716964,1806.919566,5033.976283,26364.859152,2311.790793,16139.097553,...,2457.214141,1901.709222,1672.531034,3877.787322,14699.253953,17119.871513,8995.328144,12641.425965,27999.405431,3128.844808


### Build a dataframe for single transcription factor

In [559]:
def build_df(tf, df_seq, df_targets):
    df = df_seq.copy()
    df["Target"] = df_targets[tf].values
    return df

In [560]:
df = build_df(TRAIN_ON_TF, df_seq, df_targets)
df.head()

Unnamed: 0,Fold ID,Event ID,seq,Target
0,B,MEreverse14075,TAAAACTATGAGGAAGGATTCAGGGTCGGACAGTGCCTGT,24652.209903
1,B,MEforward19438,CTTATGATCAGAAGCGGCTAGGTGTATTACATGTCCCTGT,23603.65978
2,B,MEforward19439,CCGCCGTAGGCCCCGAAACAGTACCAGACATGTAACCTGT,58487.438924
3,B,MEforward19436,GACCAAACGAGTCCTAGGATTCCAAGCGTTACGACCCTGT,28969.180606
4,B,MEforward19437,CGTTACGACGGAGTTTGGATCCCGAACTTATGATCCCTGT,25171.915989


### Remove probe specific biases for each sequence

In [561]:
biases = df_targets.median(axis=1).values
biases

array([2641.988164 , 4170.2475   , 3699.8877625, ..., 2287.3149815,
       1884.915039 , 1231.9427485], shape=(80856,))

In [562]:
# Note this won't cause data leak because then normalization is sequence specific and does not use other sequences
df['TargetNorm'] = df['Target'].values - biases
df.head()

Unnamed: 0,Fold ID,Event ID,seq,Target,TargetNorm
0,B,MEreverse14075,TAAAACTATGAGGAAGGATTCAGGGTCGGACAGTGCCTGT,24652.209903,22010.221739
1,B,MEforward19438,CTTATGATCAGAAGCGGCTAGGTGTATTACATGTCCCTGT,23603.65978,19433.41228
2,B,MEforward19439,CCGCCGTAGGCCCCGAAACAGTACCAGACATGTAACCTGT,58487.438924,54787.551161
3,B,MEforward19436,GACCAAACGAGTCCTAGGATTCCAAGCGTTACGACCCTGT,28969.180606,24346.930618
4,B,MEforward19437,CGTTACGACGGAGTTTGGATCCCGAACTTATGATCCCTGT,25171.915989,21294.128667


### Calculate NA content and drop rows

In [563]:
# Only 4% NA content, it is safe to drop the rows
df['Target'].isna().mean()

np.float64(0.012850004947066389)

In [564]:
df.dropna(subset=['Target'], inplace=True)
df['Target'].isna().mean()

np.float64(0.0)

### Train/Test data split

In [565]:
df_train = df[df['Fold ID'] == TRAIN_PROBE]
df_train.head()

Unnamed: 0,Fold ID,Event ID,seq,Target,TargetNorm
40526,A,HK26479,ACGCGCGTCAGCTTTTTGGAATATTGCGGAGAGTTCCTGT,3690.235644,1035.297721
40527,A,HK26478,GACTATGGGATGGGCCGCCTTTGATTACGCGCGTCCCTGT,10197.241727,7155.318333
40528,A,HK22998,CCTTCGTGAGCCATGTGTTTCAGGCTGTGCGTGTCCCTGT,4336.389547,1962.749611
40529,A,HK22999,GGCGGGTGGTAAAGGCCCCGGAAGCGGACACGCACCCTGT,11108.69666,9423.523044
40530,A,HK26473,GTTGTGGTTTGTCCTTTTGTATTAACAGTGTATGGCCTGT,10835.459353,7094.844587


In [566]:
df_test = df[df['Fold ID'] == TEST_PROBE]
df_test.head()

Unnamed: 0,Fold ID,Event ID,seq,Target,TargetNorm
0,B,MEreverse14075,TAAAACTATGAGGAAGGATTCAGGGTCGGACAGTGCCTGT,24652.209903,22010.221739
1,B,MEforward19438,CTTATGATCAGAAGCGGCTAGGTGTATTACATGTCCCTGT,23603.65978,19433.41228
2,B,MEforward19439,CCGCCGTAGGCCCCGAAACAGTACCAGACATGTAACCTGT,58487.438924,54787.551161
3,B,MEforward19436,GACCAAACGAGTCCTAGGATTCCAAGCGTTACGACCCTGT,28969.180606,24346.930618
4,B,MEforward19437,CGTTACGACGGAGTTTGGATCCCGAACTTATGATCCCTGT,25171.915989,21294.128667


In [567]:
df_train.shape, df_test.shape

((40063, 5), (39754, 5))

### Add labels for ROC and AUC

In [568]:
# DREAM5 https://pmc.ncbi.nlm.nih.gov/articles/PMC3687085/
def add_label(d):
    # Add positive label to only rows with binding score higher than mean + 4*std
    mean = d['Target'].mean()
    std = d['Target'].std()
    lower_limit = mean + 4*std

    # Limits MAX: 1300 rows MIN: 50 rows
    top = d[d['Target'] > lower_limit].copy()
    if len(top) >= 50:
        top = top.sort_values(by='Target', ascending=False).head(1300)
    else:
        top = d.sort_values(by='Target', ascending=False, inplace=False).head(50)

    d['Label'] = 0
    d.loc[top.index, 'Label'] = 1

In [569]:
add_label(df_train)
df_train.head()

Unnamed: 0,Fold ID,Event ID,seq,Target,TargetNorm,Label
40526,A,HK26479,ACGCGCGTCAGCTTTTTGGAATATTGCGGAGAGTTCCTGT,3690.235644,1035.297721,0
40527,A,HK26478,GACTATGGGATGGGCCGCCTTTGATTACGCGCGTCCCTGT,10197.241727,7155.318333,0
40528,A,HK22998,CCTTCGTGAGCCATGTGTTTCAGGCTGTGCGTGTCCCTGT,4336.389547,1962.749611,0
40529,A,HK22999,GGCGGGTGGTAAAGGCCCCGGAAGCGGACACGCACCCTGT,11108.69666,9423.523044,0
40530,A,HK26473,GTTGTGGTTTGTCCTTTTGTATTAACAGTGTATGGCCTGT,10835.459353,7094.844587,0


In [570]:
df_train[df_train['Label'] == 1].head()

Unnamed: 0,Fold ID,Event ID,seq,Target,TargetNorm,Label
40674,A,HK36930,ACCAGATGGGCCCAATTCTCCTGCCTTAAGGTTGGCCTGT,17016.087534,15106.7999,1
41024,A,HK05502,TCCGCGCGGGAACCTAGGCCGCTATTTAACCTAAACCTGT,15340.928216,12716.468335,1
41030,A,HK05508,AGGGCTGTGGATTTTGGGCCTGTCGCGTTCCGATACCTGT,14621.229434,11308.391338,1
41723,A,HK24333,CGGAATTAGGCCTACAGCTGCGCACGTACTCTCTGCCTGT,15065.605272,12237.795894,1
41964,A,HK12672,CCGGATAGGCCCCATTCATCTGCGTGAGCAACGTACCTGT,14946.657754,12438.001401,1


In [571]:
df_train[df_train['Label'] == 1].shape

(125, 6)

In [572]:
add_label(df_test)
df_test.head()

Unnamed: 0,Fold ID,Event ID,seq,Target,TargetNorm,Label
0,B,MEreverse14075,TAAAACTATGAGGAAGGATTCAGGGTCGGACAGTGCCTGT,24652.209903,22010.221739,0
1,B,MEforward19438,CTTATGATCAGAAGCGGCTAGGTGTATTACATGTCCCTGT,23603.65978,19433.41228,0
2,B,MEforward19439,CCGCCGTAGGCCCCGAAACAGTACCAGACATGTAACCTGT,58487.438924,54787.551161,0
3,B,MEforward19436,GACCAAACGAGTCCTAGGATTCCAAGCGTTACGACCCTGT,28969.180606,24346.930618,0
4,B,MEforward19437,CGTTACGACGGAGTTTGGATCCCGAACTTATGATCCCTGT,25171.915989,21294.128667,0


In [573]:
df_test[df_test['Label'] == 1].head()

Unnamed: 0,Fold ID,Event ID,seq,Target,TargetNorm,Label
341,B,MEforward20030,ATCTGAGGCCCATATGGTCGGAGACAGTACTATAGCCTGT,68671.537017,64612.614275,1
1487,B,MEforward11688,ATGAAGAGGCCTCGCTTTAGTGTACTGCGGAGGTGCCTGT,68032.254725,64478.573684,1
2016,B,MEforward06388,TTATCGTAGTACTCCCAGGCCCGTATATAGGAGGCCCTGT,69187.251617,65724.472753,1
2987,B,MEforward17176,CTAAGGATCTTAGATTAGTCGTAGGCCTACGTAGGCCTGT,69369.70538,65430.144099,1
3022,B,MEforward01404,GCTAGGCCATGTTTGACACAGGGCCCCTATTCCGCCCTGT,72123.869553,68392.039092,1


In [574]:
df_test[df_test['Label'] == 1].shape

(50, 6)

### Mathmatical target normalization to N(0, 1)

In [575]:
target_norm = df_train['TargetNorm'].values
df_train['TargetNormFinal'] = (target_norm - target_norm.mean()) / target_norm.std()
df_train.head()

Unnamed: 0,Fold ID,Event ID,seq,Target,TargetNorm,Label,TargetNormFinal
40526,A,HK26479,ACGCGCGTCAGCTTTTTGGAATATTGCGGAGAGTTCCTGT,3690.235644,1035.297721,0,-0.570562
40527,A,HK26478,GACTATGGGATGGGCCGCCTTTGATTACGCGCGTCCCTGT,10197.241727,7155.318333,0,1.934831
40528,A,HK22998,CCTTCGTGAGCCATGTGTTTCAGGCTGTGCGTGTCCCTGT,4336.389547,1962.749611,0,-0.190885
40529,A,HK22999,GGCGGGTGGTAAAGGCCCCGGAAGCGGACACGCACCCTGT,11108.69666,9423.523044,0,2.863381
40530,A,HK26473,GTTGTGGTTTGTCCTTTTGTATTAACAGTGTATGGCCTGT,10835.459353,7094.844587,0,1.910074


In [576]:
df_test['TargetNormFinal'] = (df_test['TargetNorm'].values - target_norm.mean()) / target_norm.std()
df_test.head()

Unnamed: 0,Fold ID,Event ID,seq,Target,TargetNorm,Label,TargetNormFinal
0,B,MEreverse14075,TAAAACTATGAGGAAGGATTCAGGGTCGGACAGTGCCTGT,24652.209903,22010.221739,0,8.016079
1,B,MEforward19438,CTTATGATCAGAAGCGGCTAGGTGTATTACATGTCCCTGT,23603.65978,19433.41228,0,6.961194
2,B,MEforward19439,CCGCCGTAGGCCCCGAAACAGTACCAGACATGTAACCTGT,58487.438924,54787.551161,0,21.434348
3,B,MEforward19436,GACCAAACGAGTCCTAGGATTCCAAGCGTTACGACCCTGT,28969.180606,24346.930618,0,8.972673
4,B,MEforward19437,CGTTACGACGGAGTTTGGATCCCGAACTTATGATCCCTGT,25171.915989,21294.128667,0,7.722928


### DNA/RNA sequence to Matrix logic

In [577]:
def fill_cell(motif_len, row, col, bases, seq):
    num_rows = len(seq) + 2 * motif_len - 2
    
    # First M-1 rows are filled with 0.25
    if row < motif_len-1:
        return 0.25

    # Last M-1 rows are filled with 0.25
    if num_rows-1-row < motif_len-1:
        return 0.25

    idx = row - motif_len + 1
    if seq[idx] == bases[col]:
        return 1.0

    return 0.0

def seq2matrix(seq, motif_len, typ='DNA'):
    bases = DNA_BASES if typ == 'DNA' else RNA_BASES
    num_rows = len(seq) + 2 * motif_len - 2
    result = np.empty([num_rows, 4])
    for row in range(num_rows):
        for col in range(4):
            result[row, col] = fill_cell(motif_len, row, col, bases, seq)
    return np.transpose(result)

In [578]:
# Test the function
S = seq2matrix("ATGG", 3, 'DNA')
S

array([[0.25, 0.25, 1.  , 0.  , 0.  , 0.  , 0.25, 0.25],
       [0.25, 0.25, 0.  , 0.  , 0.  , 0.  , 0.25, 0.25],
       [0.25, 0.25, 0.  , 0.  , 1.  , 1.  , 0.25, 0.25],
       [0.25, 0.25, 0.  , 1.  , 0.  , 0.  , 0.25, 0.25]])

In [579]:
S.shape

(4, 8)

### Train Data Augmentation (Reverse Complement)

In [580]:
def reverse_complement_batch(x):
    # x: (B, 4, L)
    # reverse the sequence
    x = torch.flip(x, dims=[2])
    
    # swap A-T, C-G
    # A C G T → T G C A
    x = x[:, [3, 2, 1, 0], :]

    return x

def reverse_complement_M(x):
    # x: (4, L)
    # reverse the sequence
    x = torch.flip(x, dims=[1])
    
    # swap A-T, C-G
    # A C G T → T G C A
    x = x[[3, 2, 1, 0], :]

    return x

### Sequence Dataset and Loader

In [581]:
class SeqDataset(Dataset):
    def __init__(self, df, augment=False):
        self.sequences = df['seq'].values
        self.targets = df['TargetNormFinal'].values
        self.labels = df['Label'].values
        self.aug = augment

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        
        M = seq2matrix(seq, MOTIF_LEN, 'DNA')
        
        x = torch.tensor(M, dtype=torch.float32)
        y = torch.tensor(self.targets[idx], dtype=torch.float32)
        label = self.labels[idx].copy()

        # Augment the data
        if self.aug and random.random() > 0.5:
            x = reverse_complement_M(x)
        
        return x, y, label

In [582]:
train_dataset = SeqDataset(df_train, augment=True)
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True
)

In [583]:
xs, ys, labels = [], [], []

for x, y, label in train_dataset:
    xs.append(x)
    ys.append(y)
    labels.append(label)

x_train = torch.stack(xs)
y_train = torch.tensor(ys, dtype=torch.float32)
label_train = torch.tensor(labels)

In [584]:
x, target, label = next(iter(train_loader))
x.shape, target.shape, label.shape

(torch.Size([64, 4, 102]), torch.Size([64]), torch.Size([64]))

In [585]:
# Do NOT augment test dataset
test_dataset = SeqDataset(df_test, augment=False)

xs, ys, labels = [], [], []

for x, y, label in test_dataset:
    xs.append(x)
    ys.append(y)
    labels.append(label)

x_test = torch.stack(xs)
y_test = torch.tensor(ys, dtype=torch.float32)
label_test = torch.tensor(labels)

In [586]:
x_test.shape, y_test.shape, label_test.shape

(torch.Size([39754, 4, 102]), torch.Size([39754]), torch.Size([39754]))

# Model

### DeepBind Model

In [587]:
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

class DeepBindShallow(nn.Module):
    def __init__(self, 
                 num_motif_detectors, 
                 motif_len,
                 dropout,
                 conv_init_w,
                 fc_init_w):
        super().__init__()

        self.conv = nn.Conv1d(in_channels=4, out_channels=num_motif_detectors, kernel_size=motif_len)
        self.dropout = nn.Dropout(p=dropout)
        self.fc1 = nn.Linear(num_motif_detectors, 1)
        
        self.init_weights(conv_init_w, fc_init_w)

    def init_weights(self, conv_init_w, fc_init_w):
        init.normal_(self.conv.weight, mean=0.0, std=conv_init_w)
        init.constant_(self.conv.bias, 0.0)

        init.normal_(self.fc1.weight, mean=0.0, std=fc_init_w)
        init.constant_(self.fc1.bias, 0.0)

    def forward_pass(self, x):
        
        x = self.conv(x)
        x = F.relu(x)
        x, _ = torch.max(x, dim=2)
        x = self.dropout(x)
        x = self.fc1(x)

        return x

    def forward(self, x):
        r = self.forward_pass(x)
        
        x_comp = reverse_complement_batch(x)
        r_comp = self.forward_pass(x_comp)
        
        r_final = torch.max(r, r_comp)
        
        return r_final

In [588]:
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

class DeepBind(nn.Module):
    def __init__(self, 
                 num_motif_detectors, 
                 motif_len,
                 dropout,
                 conv_init_w,
                 fc_init_w):
        super().__init__()

        self.conv = nn.Conv1d(in_channels=4, out_channels=num_motif_detectors, kernel_size=motif_len)
        self.dropout1 = nn.Dropout(p=dropout)
        self.fc1 = nn.Linear(num_motif_detectors, 32)
        self.dropout2 = nn.Dropout(p=dropout)
        self.fc2 = nn.Linear(32, 1)

        self.init_weights(conv_init_w, fc_init_w)

    def init_weights(self, conv_init_w, fc_init_w):
        init.normal_(self.conv.weight, mean=0.0, std=conv_init_w)
        init.constant_(self.conv.bias, 0.0)

        init.normal_(self.fc1.weight, mean=0.0, std=fc_init_w)
        init.constant_(self.fc1.bias, 0.0)
        
        init.normal_(self.fc2.weight, mean=0.0, std=fc_init_w)
        init.constant_(self.fc2.bias, 0.0)

    def forward_pass(self, x):
        
        x = self.conv(x)
        x = F.relu(x)
        x, _ = torch.max(x, dim=2)
        x = self.dropout1(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)

        return x

    def forward(self, x):
        r = self.forward_pass(x)
        
        x_comp = reverse_complement_batch(x)
        r_comp = self.forward_pass(x_comp)
        
        r_final = torch.max(r, r_comp)
        
        return r_final

### Model Wrapper for training

In [589]:
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

class ModelWrapper:
    def __init__(self, m, lr=1e-3, w_decay=1e-3):
        self.model = m.to(DEVICE)
        #self.opt = torch.optim.Adam(self.model.parameters(), lr=0.0005)
        self.opt = optim.AdamW(self.model.parameters(), lr=lr, weight_decay=w_decay)
        
        self.criterion = torch.nn.MSELoss()

    def train_step(self, x, target):
        self.model.train()
        
        x = x.to(DEVICE)
        target = target.to(DEVICE)

        self.opt.zero_grad()

        pred = self.model(x)
        loss = self.criterion(pred.squeeze(), target)

        loss.backward()
        self.opt.step()
        
        return loss.item()

    def predict(self, x):
        self.model.eval()
        x = x.to(DEVICE)
        pred = self.model(x).squeeze()
        pred = pred.detach().cpu().numpy()
        return pred

    def evaluate(self, x, y_true, label_true):
        y_pred = self.predict(x)

        fpr, tpr, thresholds = roc_curve(label_true, y_pred)
        roc_auc = auc(fpr, tpr)
        
        pearson_corr, _ = pearsonr(y_true, y_pred)
        spearman_corr, _ = spearmanr(y_true, y_pred)

        # plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.3f}")
        # plt.plot([0, 1], [0, 1], linestyle='--')
        # plt.xlabel("False Positive Rate")
        # plt.ylabel("True Positive Rate")
        # plt.legend()
        # plt.show()

        return {
            'pearson': pearson_corr,
            'spearman': spearman_corr,
            'auc': roc_auc,
        }

    def train_one_epoch(self, loader):
        epoch_loss = 0
        for i, (x, y, label) in enumerate(tqdm(loader)):
            epoch_loss += self.train_step(x, y)
        epoch_loss = epoch_loss / len(loader)
        return epoch_loss

### Sanity Check: Overfit on single mini-batch

In [590]:
x, y, label = next(iter(train_loader))
x.shape, y.shape, label.shape

(torch.Size([64, 4, 102]), torch.Size([64]), torch.Size([64]))

In [591]:
m_sanity = DeepBindShallow(num_motif_detectors=16, motif_len=24, dropout=0.95, conv_init_w=0.001, fc_init_w=0.001)

In [592]:
mw_sanity = ModelWrapper(m_sanity)

In [544]:
for i in range(10000):
    loss = mw_sanity.train_step(x, y)
    if i % 1000 == 0:
        print(loss)

1.5963393449783325
1.333681344985962
0.8059917688369751
1.4180238246917725
1.2546128034591675
1.347851037979126
0.6801857948303223
0.7549934983253479
0.7918535470962524
1.4197611808776855


In [545]:
pred_sanity = mw_sanity.predict(x)
pred_sanity[:10]

array([-1.3913959e-03,  3.9298367e-02,  5.8030114e-02,  1.6137331e+00,
        1.7538747e-01, -2.2062328e-02, -1.1360131e-02, -1.5314184e-02,
       -3.7912540e-03, -2.8610535e-02], dtype=float32)

In [546]:
target_sanity = y.cpu().numpy()
target_sanity[:10]

array([-0.27307722,  0.08115514,  0.2736761 ,  6.3071117 ,  0.6529401 ,
       -0.5904817 , -0.66178095, -0.55244404, -0.6513517 , -0.62807626],
      dtype=float32)

In [547]:
for name, p in mw_sanity.model.named_parameters():
    print(name, p.data.abs().mean().item())

conv.weight 0.06308583170175552
conv.bias 0.017959177494049072
fc1.weight 0.10577191412448883
fc1.bias 0.03392793983221054


### Sanity Check: Reverse compliment

In [548]:
np.allclose(mw_sanity.predict(x), mw_sanity.predict(reverse_complement_batch(reverse_complement_batch(x))))

True

# Training

In [593]:
import os

def save_snapshot(mw, epoch, params):
    snap_path = "/tmp/snaps"
    os.makedirs(snap_path, exist_ok=True)
    name = f"{snap_path}/model_epoch_{epoch}.pt"
    torch.save({
        "model": mw.model.state_dict(),
        "params": params,
    }, name)
    return name

def load_snapshot(path):
    ckpt = torch.load(path, map_location="cpu")
    params = ckpt["params"]
    print(params)

    m = DeepBind(
        num_motif_detectors=params['num_motif_detectors'],
        motif_len=params['motif_len'],
        dropout=params['dropout'],
        conv_init_w=params['conv_init_weight_std'],
        fc_init_w=params['fc_init_weight_std'],
    )
    mw = ModelWrapper(m, lr=params['lr'], w_decay=params['weight_decay'])
    
    mw.model.load_state_dict(ckpt["model"])

    return mw

def load_snapshot_shallow(path):
    ckpt = torch.load(path, map_location="cpu")
    params = ckpt["params"]
    print(params)

    m = DeepBindShallow(
        num_motif_detectors=params['num_motif_detectors'],
        motif_len=params['motif_len'],
        dropout=params['dropout'],
        conv_init_w=params['conv_init_weight_std'],
        fc_init_w=params['fc_init_weight_std'],
    )
    mw = ModelWrapper(m, lr=params['lr'], w_decay=params['weight_decay'])
    
    mw.model.load_state_dict(ckpt["model"])

    return mw

In [594]:
import wandb

def train_wandb(run_name, mw, loader, epochs, params):
    run = wandb.init(
        entity="vvaza22-free-university-of-tbilisi",
        project="CompBioFinal",
        name=run_name,
        config={
            "model": "DeepBind",
            "dataset": "DREAM5",
            "epochs": epochs,
            **params
        },
    )

    for epoch in range(1, epochs+1):
        print(f"===== EPOCH {epoch} =====")
        epoch_loss = mw.train_one_epoch(loader)
        print(f"Loss: {epoch_loss}")

        metrics_train = mw.evaluate(x_train, y_train, label_train)
        metrics_test = mw.evaluate(x_test, y_test, label_test)
        
        wandb.log({
            'epoch_loss': epoch_loss,
            'train/pearson': metrics_train['pearson'],
            'train/spearman': metrics_train['spearman'],
            'train/auc': metrics_train['auc'],
            'test/pearson': metrics_test['pearson'],
            'test/spearman': metrics_test['spearman'],
            'test/auc': metrics_test['auc'],
        }, step=epoch)

    run.finish()

In [597]:
import wandb

MIN_DELTA = 1e-3

def early_stopping_train_wandb(run_name, mw, loader, epochs, params):
    run = wandb.init(
        entity="vvaza22-free-university-of-tbilisi",
        project="CompBioBaseline__TF_31",
        name=run_name,
        config={
            "model": "DeepBind",
            "dataset": "DREAM5",
            "epochs": epochs,
            "tf": TRAIN_ON_TF,
            **params
        },
    )

    best_test_pearson = 0

    for epoch in range(1, epochs+1):
        print(f"===== EPOCH {epoch} =====")
        epoch_loss = mw.train_one_epoch(loader)
        print(f"Loss: {epoch_loss}")

        metrics_train = mw.evaluate(x_train, y_train, label_train)
        metrics_test = mw.evaluate(x_test, y_test, label_test)

        results = {
            'epoch_loss': epoch_loss,
            'train/pearson': metrics_train['pearson'],
            'train/spearman': metrics_train['spearman'],
            'train/auc': metrics_train['auc'],
            'test/pearson': metrics_test['pearson'],
            'test/spearman': metrics_test['spearman'],
            'test/auc': metrics_test['auc'],
        }
        print(results)

        # Snapshot the model
        snap = save_snapshot(mw, epoch, params)
        artifact = wandb.Artifact(
            name=f"model_epoch_{epoch}",
            type="model"
        )
        artifact.add_file(snap)
        wandb.log_artifact(artifact)
        
        wandb.log(results, step=epoch)

        if np.isnan(metrics_test['pearson']):
            # Vanishing Gradient? something went very wrong
            break

        if metrics_test['pearson'] > best_test_pearson:
            best_test_pearson = metrics_test['pearson']
        elif metrics_test['pearson'] + MIN_DELTA < best_test_pearson:
            # Overfitting
            print(f"Early stopping on epoch {epoch}")
            break

    run.finish()

In [598]:
from sklearn.model_selection import ParameterGrid
from scipy.stats import loguniform

param_grid = {
    'lr': [0.0005],
    'dropout': [0.5],
    'num_motif_detectors': [16],
    'motif_len': [12, 16, 24],
    'weight_decay': [1e-3, 5e-3],
    'conv_init_weight_std': [1e-3, 1e-4],
    'fc_init_weight_std': [1e-2, 1e-3],
}
print(param_grid)

grid = ParameterGrid(param_grid)

print(f"Total: {len(grid)}")

{'lr': [0.0005], 'dropout': [0.5], 'num_motif_detectors': [16], 'motif_len': [12, 16, 24], 'weight_decay': [0.001, 0.005], 'conv_init_weight_std': [0.001, 0.0001], 'fc_init_weight_std': [0.01, 0.001]}
Total: 24


In [599]:
run_id = 1
for params in grid:
    print(f"Calibration with: {params}")

    run_name = f"{TRAIN_ON_TF}_Best_Calib_{run_id}"
    
    m = DeepBind(
        num_motif_detectors=params['num_motif_detectors'],
        motif_len=params['motif_len'],
        dropout=params['dropout'],
        conv_init_w=params['conv_init_weight_std'],
        fc_init_w=params['fc_init_weight_std'],
    )
    mw = ModelWrapper(m, lr=params['lr'], w_decay=params['weight_decay'])

    early_stopping_train_wandb(run_name, mw, train_loader, 50, params=params)

    run_id = run_id + 1

Calibration with: {'conv_init_weight_std': 0.001, 'dropout': 0.5, 'fc_init_weight_std': 0.01, 'lr': 0.0005, 'motif_len': 12, 'num_motif_detectors': 16, 'weight_decay': 0.001}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.7069444726831235
{'epoch_loss': 0.7069444726831235, 'train/pearson': np.float32(0.86206955), 'train/spearman': np.float64(0.6732880220166023), 'train/auc': 0.9563555511042116, 'test/pearson': np.float32(0.8023156), 'test/spearman': np.float64(0.5213805637378037), 'test/auc': 0.9616617973000202}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3557169448477201
{'epoch_loss': 0.3557169448477201, 'train/pearson': np.float32(0.8962597), 'train/spearman': np.float64(0.6959509930942255), 'train/auc': 0.9622992638589813, 'test/pearson': np.float32(0.83666104), 'test/spearman': np.float64(0.5576800377010575), 'test/auc': 0.9644887165021158}
===== EPOCH 3 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.32288602221840484
{'epoch_loss': 0.32288602221840484, 'train/pearson': np.float32(0.9034595), 'train/spearman': np.float64(0.7010477537557538), 'train/auc': 0.9639316941258952, 'test/pearson': np.float32(0.84177107), 'test/spearman': np.float64(0.560304370161485), 'test/auc': 0.9665655853314528}
===== EPOCH 4 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3132284680803934
{'epoch_loss': 0.3132284680803934, 'train/pearson': np.float32(0.9054462), 'train/spearman': np.float64(0.7115397555287163), 'train/auc': 0.9640893384746356, 'test/pearson': np.float32(0.84105027), 'test/spearman': np.float64(0.5561710650827525), 'test/auc': 0.9649939552689906}
===== EPOCH 5 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3001353345192469
{'epoch_loss': 0.3001353345192469, 'train/pearson': np.float32(0.9061997), 'train/spearman': np.float64(0.7197577623197301), 'train/auc': 0.9661675597175623, 'test/pearson': np.float32(0.84405303), 'test/spearman': np.float64(0.573365624270739), 'test/auc': 0.9668103969373363}
===== EPOCH 6 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.298591814352015
{'epoch_loss': 0.298591814352015, 'train/pearson': np.float32(0.90715224), 'train/spearman': np.float64(0.718903087641918), 'train/auc': 0.9673800390605439, 'test/pearson': np.float32(0.84494054), 'test/spearman': np.float64(0.5743102684105946), 'test/auc': 0.9676687487406811}
===== EPOCH 7 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2959747375152743
{'epoch_loss': 0.2959747375152743, 'train/pearson': np.float32(0.9060044), 'train/spearman': np.float64(0.7181167081622334), 'train/auc': 0.9677157594271121, 'test/pearson': np.float32(0.844078), 'test/spearman': np.float64(0.5775600716281036), 'test/auc': 0.9663953254080193}
===== EPOCH 8 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.29061015783407435
{'epoch_loss': 0.29061015783407435, 'train/pearson': np.float32(0.906088), 'train/spearman': np.float64(0.7188913141406407), 'train/auc': 0.9675997796584707, 'test/pearson': np.float32(0.84476477), 'test/spearman': np.float64(0.5781174668591544), 'test/auc': 0.9674639834777352}
===== EPOCH 9 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.29081103571068745
{'epoch_loss': 0.29081103571068745, 'train/pearson': np.float32(0.9067642), 'train/spearman': np.float64(0.7241882183436261), 'train/auc': 0.9675390855826531, 'test/pearson': np.float32(0.8437962), 'test/spearman': np.float64(0.5751812598491058), 'test/auc': 0.9662132278863591}
Early stopping on epoch 9


0,1
epoch_loss,█▂▂▁▁▁▁▁▁
test/auc,▁▄▇▅▇█▇█▆
test/pearson,▁▇▇▇█████
test/spearman,▁▅▆▅▇████
train/auc,▁▅▆▆▇████
train/pearson,▁▆▇██████
train/spearman,▁▄▅▆▇▇▇▇█

0,1
epoch_loss,0.29081
test/auc,0.96621
test/pearson,0.8438
test/spearman,0.57518
train/auc,0.96754
train/pearson,0.90676
train/spearman,0.72419


Calibration with: {'conv_init_weight_std': 0.001, 'dropout': 0.5, 'fc_init_weight_std': 0.01, 'lr': 0.0005, 'motif_len': 12, 'num_motif_detectors': 16, 'weight_decay': 0.005}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.6237202054038405
{'epoch_loss': 0.6237202054038405, 'train/pearson': np.float32(0.8941626), 'train/spearman': np.float64(0.6586413775646419), 'train/auc': 0.9640905403375231, 'test/pearson': np.float32(0.8343228), 'test/spearman': np.float64(0.5269611737804929), 'test/auc': 0.9661326314728995}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.31156695066406703
{'epoch_loss': 0.31156695066406703, 'train/pearson': np.float32(0.90225655), 'train/spearman': np.float64(0.6760059839169884), 'train/auc': 0.9672520406630276, 'test/pearson': np.float32(0.8415128), 'test/spearman': np.float64(0.5404893498672849), 'test/auc': 0.967808281281483}
===== EPOCH 3 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3055881473679131
{'epoch_loss': 0.3055881473679131, 'train/pearson': np.float32(0.9042347), 'train/spearman': np.float64(0.6806166322518449), 'train/auc': 0.9684777404977716, 'test/pearson': np.float32(0.8431283), 'test/spearman': np.float64(0.5454329736697534), 'test/auc': 0.9680032238565384}
===== EPOCH 4 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.30158856167151527
{'epoch_loss': 0.30158856167151527, 'train/pearson': np.float32(0.90515214), 'train/spearman': np.float64(0.686447381241922), 'train/auc': 0.9690588412038659, 'test/pearson': np.float32(0.84334093), 'test/spearman': np.float64(0.5438495942238257), 'test/auc': 0.9685054402579085}
===== EPOCH 5 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.30476016732164846
{'epoch_loss': 0.30476016732164846, 'train/pearson': np.float32(0.9050811), 'train/spearman': np.float64(0.6806931965233752), 'train/auc': 0.9704423856978317, 'test/pearson': np.float32(0.843176), 'test/spearman': np.float64(0.5413725874555214), 'test/auc': 0.9696494056014506}
===== EPOCH 6 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.298391242818044
{'epoch_loss': 0.298391242818044, 'train/pearson': np.float32(0.90565914), 'train/spearman': np.float64(0.6888817648200447), 'train/auc': 0.9702759276879162, 'test/pearson': np.float32(0.84384924), 'test/spearman': np.float64(0.5489672068982702), 'test/auc': 0.9688555309288737}
===== EPOCH 7 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.29934322958984694
{'epoch_loss': 0.29934322958984694, 'train/pearson': np.float32(0.90622205), 'train/spearman': np.float64(0.6871461521159569), 'train/auc': 0.9702012118784116, 'test/pearson': np.float32(0.8446055), 'test/spearman': np.float64(0.547134220480654), 'test/auc': 0.9693204714890188}
===== EPOCH 8 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2939879371763799
{'epoch_loss': 0.2939879371763799, 'train/pearson': np.float32(0.905995), 'train/spearman': np.float64(0.6849638524016838), 'train/auc': 0.970236967299314, 'test/pearson': np.float32(0.8428654), 'test/spearman': np.float64(0.5399598310814886), 'test/auc': 0.970288132178118}
Early stopping on epoch 8


0,1
epoch_loss,█▁▁▁▁▁▁▁
test/auc,▁▄▄▅▇▆▆█
test/pearson,▁▆▇▇▇▇█▇
test/spearman,▁▅▇▆▆█▇▅
train/auc,▁▄▆▆████
train/pearson,▁▆▇▇▇███
train/spearman,▁▅▆▇▆██▇

0,1
epoch_loss,0.29399
test/auc,0.97029
test/pearson,0.84287
test/spearman,0.53996
train/auc,0.97024
train/pearson,0.906
train/spearman,0.68496


Calibration with: {'conv_init_weight_std': 0.001, 'dropout': 0.5, 'fc_init_weight_std': 0.01, 'lr': 0.0005, 'motif_len': 16, 'num_motif_detectors': 16, 'weight_decay': 0.001}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.5904302431371646
{'epoch_loss': 0.5904302431371646, 'train/pearson': np.float32(0.89442766), 'train/spearman': np.float64(0.6630895270093345), 'train/auc': 0.9627322349641945, 'test/pearson': np.float32(0.8339365), 'test/spearman': np.float64(0.5159274812199184), 'test/auc': 0.966459298811203}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2912931977131496
{'epoch_loss': 0.2912931977131496, 'train/pearson': np.float32(0.90352803), 'train/spearman': np.float64(0.6860312840915495), 'train/auc': 0.9648735540087136, 'test/pearson': np.float32(0.8419306), 'test/spearman': np.float64(0.5413227827933309), 'test/auc': 0.9658502921619988}
===== EPOCH 3 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.28255214287450137
{'epoch_loss': 0.28255214287450137, 'train/pearson': np.float32(0.9056101), 'train/spearman': np.float64(0.6957108969867484), 'train/auc': 0.9685420401622514, 'test/pearson': np.float32(0.8437332), 'test/spearman': np.float64(0.557739233503035), 'test/auc': 0.9700770703203708}
===== EPOCH 4 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.28276825467713723
{'epoch_loss': 0.28276825467713723, 'train/pearson': np.float32(0.90779096), 'train/spearman': np.float64(0.7015648311858508), 'train/auc': 0.9690550353047223, 'test/pearson': np.float32(0.845716), 'test/spearman': np.float64(0.5607771647413836), 'test/auc': 0.9701012492444088}
===== EPOCH 5 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.27565677684193224
{'epoch_loss': 0.27565677684193224, 'train/pearson': np.float32(0.90823686), 'train/spearman': np.float64(0.7083013994091216), 'train/auc': 0.9706897691421703, 'test/pearson': np.float32(0.84627396), 'test/spearman': np.float64(0.5690842936811775), 'test/auc': 0.9691799314930487}
===== EPOCH 6 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.27841766878438834
{'epoch_loss': 0.27841766878438834, 'train/pearson': np.float32(0.90925926), 'train/spearman': np.float64(0.707038286089143), 'train/auc': 0.9710814762882468, 'test/pearson': np.float32(0.8460288), 'test/spearman': np.float64(0.5603881166836686), 'test/auc': 0.9677493451541406}
===== EPOCH 7 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2729105829192808
{'epoch_loss': 0.2729105829192808, 'train/pearson': np.float32(0.9090433), 'train/spearman': np.float64(0.7013467093050341), 'train/auc': 0.9725183033702238, 'test/pearson': np.float32(0.8454899), 'test/spearman': np.float64(0.5507644260144602), 'test/auc': 0.9714587950836188}
===== EPOCH 8 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2767204248819488
{'epoch_loss': 0.2767204248819488, 'train/pearson': np.float32(0.9104313), 'train/spearman': np.float64(0.711852533066745), 'train/auc': 0.9725526566177576, 'test/pearson': np.float32(0.8464794), 'test/spearman': np.float64(0.5589578619293788), 'test/auc': 0.9709575861374169}
===== EPOCH 9 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2787044242191048
{'epoch_loss': 0.2787044242191048, 'train/pearson': np.float32(0.9106809), 'train/spearman': np.float64(0.7124247407842085), 'train/auc': 0.9741927988381991, 'test/pearson': np.float32(0.84682184), 'test/spearman': np.float64(0.562601487689809), 'test/auc': 0.9708664114446908}
===== EPOCH 10 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2728218349000326
{'epoch_loss': 0.2728218349000326, 'train/pearson': np.float32(0.91042525), 'train/spearman': np.float64(0.7126624511442538), 'train/auc': 0.9734300165256148, 'test/pearson': np.float32(0.8457646), 'test/spearman': np.float64(0.5651670879562221), 'test/auc': 0.9694519443884748}
Early stopping on epoch 10


0,1
epoch_loss,█▁▁▁▁▁▁▁▁▁
test/auc,▂▁▆▆▅▃█▇▇▅
test/pearson,▁▅▆▇██▇██▇
test/spearman,▁▄▇▇█▇▆▇▇▇
train/auc,▁▂▅▅▆▆▇▇██
train/pearson,▁▅▆▇▇▇▇███
train/spearman,▁▄▆▆▇▇▆███

0,1
epoch_loss,0.27282
test/auc,0.96945
test/pearson,0.84576
test/spearman,0.56517
train/auc,0.97343
train/pearson,0.91043
train/spearman,0.71266


Calibration with: {'conv_init_weight_std': 0.001, 'dropout': 0.5, 'fc_init_weight_std': 0.01, 'lr': 0.0005, 'motif_len': 16, 'num_motif_detectors': 16, 'weight_decay': 0.005}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.6548637632078256
{'epoch_loss': 0.6548637632078256, 'train/pearson': np.float32(0.8845161), 'train/spearman': np.float64(0.6529399509206998), 'train/auc': 0.9630513295608193, 'test/pearson': np.float32(0.8245834), 'test/spearman': np.float64(0.4977873039186501), 'test/auc': 0.9645587346363087}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.33558491956644926
{'epoch_loss': 0.33558491956644926, 'train/pearson': np.float32(0.9002519), 'train/spearman': np.float64(0.663166643039165), 'train/auc': 0.9682988632380191, 'test/pearson': np.float32(0.8392078), 'test/spearman': np.float64(0.5131778062003944), 'test/auc': 0.967574047954866}
===== EPOCH 3 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3263102039790001
{'epoch_loss': 0.3263102039790001, 'train/pearson': np.float32(0.902915), 'train/spearman': np.float64(0.6819679368672739), 'train/auc': 0.9667859181731684, 'test/pearson': np.float32(0.8409681), 'test/spearman': np.float64(0.5386789828591924), 'test/auc': 0.965620592383639}
===== EPOCH 4 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3195073300371536
{'epoch_loss': 0.3195073300371536, 'train/pearson': np.float32(0.9044894), 'train/spearman': np.float64(0.682315569060212), 'train/auc': 0.9691690119685512, 'test/pearson': np.float32(0.8409692), 'test/spearman': np.float64(0.5378799335772809), 'test/auc': 0.9672350392907516}
===== EPOCH 5 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.31436067645827803
{'epoch_loss': 0.31436067645827803, 'train/pearson': np.float32(0.90551746), 'train/spearman': np.float64(0.6809114813783537), 'train/auc': 0.9705445440432671, 'test/pearson': np.float32(0.8428738), 'test/spearman': np.float64(0.5407220153199234), 'test/auc': 0.9679347169050977}
===== EPOCH 6 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3156135838966781
{'epoch_loss': 0.3156135838966781, 'train/pearson': np.float32(0.90560037), 'train/spearman': np.float64(0.6902947330207626), 'train/auc': 0.9705213080274426, 'test/pearson': np.float32(0.8424588), 'test/spearman': np.float64(0.5465670526333489), 'test/auc': 0.9694710860366713}
===== EPOCH 7 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3188752854546419
{'epoch_loss': 0.3188752854546419, 'train/pearson': np.float32(0.9053961), 'train/spearman': np.float64(0.6771654405060427), 'train/auc': 0.9706364865541589, 'test/pearson': np.float32(0.84176517), 'test/spearman': np.float64(0.5295888668530006), 'test/auc': 0.9680833165424139}
Early stopping on epoch 7


0,1
epoch_loss,█▁▁▁▁▁▁
test/auc,▁▅▃▅▆█▆
test/pearson,▁▇▇▇███
test/spearman,▁▃▇▇▇█▆
train/auc,▁▆▄▇███
train/pearson,▁▆▇████
train/spearman,▁▃▆▇▆█▆

0,1
epoch_loss,0.31888
test/auc,0.96808
test/pearson,0.84177
test/spearman,0.52959
train/auc,0.97064
train/pearson,0.9054
train/spearman,0.67717


Calibration with: {'conv_init_weight_std': 0.001, 'dropout': 0.5, 'fc_init_weight_std': 0.01, 'lr': 0.0005, 'motif_len': 24, 'num_motif_detectors': 16, 'weight_decay': 0.001}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.6313909600479916
{'epoch_loss': 0.6313909600479916, 'train/pearson': np.float32(0.8929843), 'train/spearman': np.float64(0.6481538494482312), 'train/auc': 0.9607321348089539, 'test/pearson': np.float32(0.829518), 'test/spearman': np.float64(0.4896186546224634), 'test/auc': 0.965653334676607}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2952783931011019
{'epoch_loss': 0.2952783931011019, 'train/pearson': np.float32(0.9021916), 'train/spearman': np.float64(0.685686116327535), 'train/auc': 0.965959637438029, 'test/pearson': np.float32(0.83876055), 'test/spearman': np.float64(0.5212803927237063), 'test/auc': 0.9685714285714286}
===== EPOCH 3 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.28338255102451615
{'epoch_loss': 0.28338255102451615, 'train/pearson': np.float32(0.90574074), 'train/spearman': np.float64(0.6931903410401336), 'train/auc': 0.9678681957033403, 'test/pearson': np.float32(0.8412477), 'test/spearman': np.float64(0.534417626620802), 'test/auc': 0.969350695144066}
===== EPOCH 4 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.28198038013027116
{'epoch_loss': 0.28198038013027116, 'train/pearson': np.float32(0.9070639), 'train/spearman': np.float64(0.7076138309987471), 'train/auc': 0.9681095698332415, 'test/pearson': np.float32(0.84254295), 'test/spearman': np.float64(0.5539242913899863), 'test/auc': 0.9698453556316744}
===== EPOCH 5 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.27865047306536483
{'epoch_loss': 0.27865047306536483, 'train/pearson': np.float32(0.9070672), 'train/spearman': np.float64(0.7046130077917364), 'train/auc': 0.9700057088487155, 'test/pearson': np.float32(0.84373474), 'test/spearman': np.float64(0.5576860508863759), 'test/auc': 0.970378299415676}
===== EPOCH 6 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2734802339713985
{'epoch_loss': 0.2734802339713985, 'train/pearson': np.float32(0.908125), 'train/spearman': np.float64(0.7114420504412154), 'train/auc': 0.970686163553508, 'test/pearson': np.float32(0.8451911), 'test/spearman': np.float64(0.5663551933566754), 'test/auc': 0.9696927261736854}
===== EPOCH 7 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.27800924658251647
{'epoch_loss': 0.27800924658251647, 'train/pearson': np.float32(0.9085877), 'train/spearman': np.float64(0.7045324588459957), 'train/auc': 0.9696329310431167, 'test/pearson': np.float32(0.8446128), 'test/spearman': np.float64(0.5547275573540712), 'test/auc': 0.9699103364900262}
===== EPOCH 8 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2816025101886199
{'epoch_loss': 0.2816025101886199, 'train/pearson': np.float32(0.9064831), 'train/spearman': np.float64(0.7140282669830229), 'train/auc': 0.9705052831889429, 'test/pearson': np.float32(0.8422444), 'test/spearman': np.float64(0.5616852724349453), 'test/auc': 0.9686117267781584}
Early stopping on epoch 8


0,1
epoch_loss,█▁▁▁▁▁▁▁
test/auc,▁▅▆▇█▇▇▅
test/pearson,▁▅▆▇▇██▇
test/spearman,▁▄▅▇▇█▇█
train/auc,▁▅▆▆██▇█
train/pearson,▁▅▇▇▇██▇
train/spearman,▁▅▆▇▇█▇█

0,1
epoch_loss,0.2816
test/auc,0.96861
test/pearson,0.84224
test/spearman,0.56169
train/auc,0.97051
train/pearson,0.90648
train/spearman,0.71403


Calibration with: {'conv_init_weight_std': 0.001, 'dropout': 0.5, 'fc_init_weight_std': 0.01, 'lr': 0.0005, 'motif_len': 24, 'num_motif_detectors': 16, 'weight_decay': 0.005}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.6265877441476329
{'epoch_loss': 0.6265877441476329, 'train/pearson': np.float32(0.89276737), 'train/spearman': np.float64(0.649719354065599), 'train/auc': 0.9606355851569933, 'test/pearson': np.float32(0.8298173), 'test/spearman': np.float64(0.4831059569202215), 'test/auc': 0.9647964940560145}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.301744599561817
{'epoch_loss': 0.301744599561817, 'train/pearson': np.float32(0.90154564), 'train/spearman': np.float64(0.6728748321845953), 'train/auc': 0.9659123641644549, 'test/pearson': np.float32(0.8377995), 'test/spearman': np.float64(0.5054186664531528), 'test/auc': 0.9701868829337095}
===== EPOCH 3 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.29206459113536554
{'epoch_loss': 0.29206459113536554, 'train/pearson': np.float32(0.9050582), 'train/spearman': np.float64(0.6961564931678331), 'train/auc': 0.966498873253543, 'test/pearson': np.float32(0.8422114), 'test/spearman': np.float64(0.5435836479137401), 'test/auc': 0.9698403183558331}
===== EPOCH 4 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2868330380953729
{'epoch_loss': 0.2868330380953729, 'train/pearson': np.float32(0.9072079), 'train/spearman': np.float64(0.7097549261679392), 'train/auc': 0.9691381641544394, 'test/pearson': np.float32(0.8432337), 'test/spearman': np.float64(0.5523042912094153), 'test/auc': 0.9687628450533952}
===== EPOCH 5 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2829922928620642
{'epoch_loss': 0.2829922928620642, 'train/pearson': np.float32(0.9070119), 'train/spearman': np.float64(0.7091470967010848), 'train/auc': 0.969585056838099, 'test/pearson': np.float32(0.8434015), 'test/spearman': np.float64(0.5565347485563282), 'test/auc': 0.9704422728188595}
===== EPOCH 6 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2835522722988464
{'epoch_loss': 0.2835522722988464, 'train/pearson': np.float32(0.91037935), 'train/spearman': np.float64(0.7129638703117115), 'train/auc': 0.9718096048875757, 'test/pearson': np.float32(0.8450455), 'test/spearman': np.float64(0.5557272509143665), 'test/auc': 0.9720839210155149}
===== EPOCH 7 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2818186386538961
{'epoch_loss': 0.2818186386538961, 'train/pearson': np.float32(0.90872663), 'train/spearman': np.float64(0.7147920471870625), 'train/auc': 0.9709303420301467, 'test/pearson': np.float32(0.84493655), 'test/spearman': np.float64(0.5659588276304968), 'test/auc': 0.9723025387870239}
===== EPOCH 8 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2850214526200066
{'epoch_loss': 0.2850214526200066, 'train/pearson': np.float32(0.90855616), 'train/spearman': np.float64(0.7161694625228361), 'train/auc': 0.9719403074765888, 'test/pearson': np.float32(0.84521097), 'test/spearman': np.float64(0.5657667373422468), 'test/auc': 0.9727196252266774}
===== EPOCH 9 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.28286702907123507
{'epoch_loss': 0.28286702907123507, 'train/pearson': np.float32(0.90997857), 'train/spearman': np.float64(0.7172837827724353), 'train/auc': 0.9732154840002003, 'test/pearson': np.float32(0.8450931), 'test/spearman': np.float64(0.5619890942682731), 'test/auc': 0.9711691517227483}
===== EPOCH 10 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.27950325643959134
{'epoch_loss': 0.27950325643959134, 'train/pearson': np.float32(0.9101178), 'train/spearman': np.float64(0.7171408176912109), 'train/auc': 0.9750240873353698, 'test/pearson': np.float32(0.84548396), 'test/spearman': np.float64(0.5653574518028158), 'test/auc': 0.9732963933104977}
===== EPOCH 11 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.28270660803769343
{'epoch_loss': 0.28270660803769343, 'train/pearson': np.float32(0.9100444), 'train/spearman': np.float64(0.7112800291232118), 'train/auc': 0.9738360458711002, 'test/pearson': np.float32(0.84553933), 'test/spearman': np.float64(0.5578477791519652), 'test/auc': 0.9725720330445294}
===== EPOCH 12 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2762236961493858
{'epoch_loss': 0.2762236961493858, 'train/pearson': np.float32(0.91051644), 'train/spearman': np.float64(0.712491628718639), 'train/auc': 0.9734929140167259, 'test/pearson': np.float32(0.8466012), 'test/spearman': np.float64(0.5612296122861651), 'test/auc': 0.9722622405802942}
===== EPOCH 13 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.27812910838106186
{'epoch_loss': 0.27812910838106186, 'train/pearson': np.float32(0.91117775), 'train/spearman': np.float64(0.7197714969650285), 'train/auc': 0.9737086484050277, 'test/pearson': np.float32(0.84695596), 'test/spearman': np.float64(0.5686894125336959), 'test/auc': 0.9720990328430386}
===== EPOCH 14 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.27516205474353445
{'epoch_loss': 0.27516205474353445, 'train/pearson': np.float32(0.9101796), 'train/spearman': np.float64(0.7139993591025326), 'train/auc': 0.97501447243227, 'test/pearson': np.float32(0.8466614), 'test/spearman': np.float64(0.5640385593377168), 'test/auc': 0.9758664114446907}
===== EPOCH 15 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2738988045900584
{'epoch_loss': 0.2738988045900584, 'train/pearson': np.float32(0.9113878), 'train/spearman': np.float64(0.7085256916199724), 'train/auc': 0.9745120937453052, 'test/pearson': np.float32(0.84715724), 'test/spearman': np.float64(0.5593915433774499), 'test/auc': 0.9757445093693331}
===== EPOCH 16 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2720470554983844
{'epoch_loss': 0.2720470554983844, 'train/pearson': np.float32(0.9119786), 'train/spearman': np.float64(0.7170673855886409), 'train/auc': 0.9756268215734388, 'test/pearson': np.float32(0.8472773), 'test/spearman': np.float64(0.5611577320339682), 'test/auc': 0.9739975821075961}
===== EPOCH 17 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.27441439641931187
{'epoch_loss': 0.27441439641931187, 'train/pearson': np.float32(0.9118186), 'train/spearman': np.float64(0.719347225813609), 'train/auc': 0.9753331664079323, 'test/pearson': np.float32(0.8460173), 'test/spearman': np.float64(0.5552911523256869), 'test/auc': 0.97354473100947}
Early stopping on epoch 17


0,1
epoch_loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test/auc,▁▄▄▄▅▆▆▆▅▆▆▆▆██▇▇
test/pearson,▁▄▆▆▆▇▇▇▇▇▇█████▇
test/spearman,▁▃▆▇▇▇██▇█▇▇██▇▇▇
train/auc,▁▃▄▅▅▆▆▆▇█▇▇▇█▇██
train/pearson,▁▄▅▆▆▇▇▇▇▇▇▇█▇███
train/spearman,▁▃▆▇▇▇████▇▇█▇▇██

0,1
epoch_loss,0.27441
test/auc,0.97354
test/pearson,0.84602
test/spearman,0.55529
train/auc,0.97533
train/pearson,0.91182
train/spearman,0.71935


Calibration with: {'conv_init_weight_std': 0.001, 'dropout': 0.5, 'fc_init_weight_std': 0.001, 'lr': 0.0005, 'motif_len': 12, 'num_motif_detectors': 16, 'weight_decay': 0.001}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.9521554043403448
{'epoch_loss': 0.9521554043403448, 'train/pearson': np.float32(0.8636645), 'train/spearman': np.float64(0.6719974193015238), 'train/auc': 0.9499103610596424, 'test/pearson': np.float32(0.7931971), 'test/spearman': np.float64(0.5058683129265431), 'test/auc': 0.944798508966351}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.7144088164305153
{'epoch_loss': 0.7144088164305153, 'train/pearson': np.float32(0.8789106), 'train/spearman': np.float64(0.6842755880058392), 'train/auc': 0.9587454554559568, 'test/pearson': np.float32(0.8087811), 'test/spearman': np.float64(0.5166617591960789), 'test/auc': 0.9544977835986299}
===== EPOCH 3 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.575961925017948
{'epoch_loss': 0.575961925017948, 'train/pearson': np.float32(0.898989), 'train/spearman': np.float64(0.6762369461544252), 'train/auc': 0.9618953377735491, 'test/pearson': np.float32(0.82918155), 'test/spearman': np.float64(0.5269542633239255), 'test/auc': 0.9604473100947006}
===== EPOCH 4 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.5421367121723514
{'epoch_loss': 0.5421367121723514, 'train/pearson': np.float32(0.8985771), 'train/spearman': np.float64(0.6642400482693944), 'train/auc': 0.9635729380539837, 'test/pearson': np.float32(0.8318489), 'test/spearman': np.float64(0.5311346718142593), 'test/auc': 0.9616945395929881}
===== EPOCH 5 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.5431907751594489
{'epoch_loss': 0.5431907751594489, 'train/pearson': np.float32(0.89851934), 'train/spearman': np.float64(0.6686751160458436), 'train/auc': 0.963356101958035, 'test/pearson': np.float32(0.83349466), 'test/spearman': np.float64(0.5419195907330872), 'test/auc': 0.9627296997783599}
===== EPOCH 6 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.5322934077999082
{'epoch_loss': 0.5322934077999082, 'train/pearson': np.float32(0.90008795), 'train/spearman': np.float64(0.6754423093601097), 'train/auc': 0.9642233461865892, 'test/pearson': np.float32(0.8350921), 'test/spearman': np.float64(0.5489973877823695), 'test/auc': 0.9629382429981865}
===== EPOCH 7 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.5251215165034651
{'epoch_loss': 0.5251215165034651, 'train/pearson': np.float32(0.89792675), 'train/spearman': np.float64(0.6726614390360108), 'train/auc': 0.9634979217787571, 'test/pearson': np.float32(0.83199394), 'test/spearman': np.float64(0.5438039594920343), 'test/auc': 0.9614618174491234}
Early stopping on epoch 7


0,1
epoch_loss,█▄▂▁▁▁▁
test/auc,▁▅▇███▇
test/pearson,▁▄▇▇██▇
test/spearman,▁▃▄▅▇█▇
train/auc,▁▅▇████
train/pearson,▁▄█████
train/spearman,▄█▅▁▃▅▄

0,1
epoch_loss,0.52512
test/auc,0.96146
test/pearson,0.83199
test/spearman,0.5438
train/auc,0.9635
train/pearson,0.89793
train/spearman,0.67266


Calibration with: {'conv_init_weight_std': 0.001, 'dropout': 0.5, 'fc_init_weight_std': 0.001, 'lr': 0.0005, 'motif_len': 12, 'num_motif_detectors': 16, 'weight_decay': 0.005}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.6062808351966139
{'epoch_loss': 0.6062808351966139, 'train/pearson': np.float32(0.89487404), 'train/spearman': np.float64(0.6579250260259286), 'train/auc': 0.9620363563523462, 'test/pearson': np.float32(0.83754313), 'test/spearman': np.float64(0.5152812412309651), 'test/auc': 0.9640670965142051}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.30373386202005154
{'epoch_loss': 0.30373386202005154, 'train/pearson': np.float32(0.90002286), 'train/spearman': np.float64(0.6747292757277255), 'train/auc': 0.9656657819620411, 'test/pearson': np.float32(0.8418994), 'test/spearman': np.float64(0.535557631869342), 'test/auc': 0.9667046141446706}
===== EPOCH 3 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2948408271557988
{'epoch_loss': 0.2948408271557988, 'train/pearson': np.float32(0.9039635), 'train/spearman': np.float64(0.691128830342045), 'train/auc': 0.9671444739345986, 'test/pearson': np.float32(0.84594715), 'test/spearman': np.float64(0.5590252672855092), 'test/auc': 0.9662502518637921}
===== EPOCH 4 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.28654952454395566
{'epoch_loss': 0.28654952454395566, 'train/pearson': np.float32(0.9057441), 'train/spearman': np.float64(0.7065759502323093), 'train/auc': 0.9676622765286194, 'test/pearson': np.float32(0.84647655), 'test/spearman': np.float64(0.570204376549743), 'test/auc': 0.9677594197058231}
===== EPOCH 5 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.28404515791244017
{'epoch_loss': 0.28404515791244017, 'train/pearson': np.float32(0.9070218), 'train/spearman': np.float64(0.7100358184558311), 'train/auc': 0.9689432620561871, 'test/pearson': np.float32(0.8472857), 'test/spearman': np.float64(0.576859069330671), 'test/auc': 0.9692126737860165}
===== EPOCH 6 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.27970349107877895
{'epoch_loss': 0.27970349107877895, 'train/pearson': np.float32(0.9079053), 'train/spearman': np.float64(0.7105612435894444), 'train/auc': 0.9704267614802945, 'test/pearson': np.float32(0.8478832), 'test/spearman': np.float64(0.5791637531816721), 'test/auc': 0.9690985794882128}
===== EPOCH 7 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.28515785582625447
{'epoch_loss': 0.28515785582625447, 'train/pearson': np.float32(0.90766233), 'train/spearman': np.float64(0.7072065809558643), 'train/auc': 0.9708713505934199, 'test/pearson': np.float32(0.8482115), 'test/spearman': np.float64(0.5788162179752073), 'test/auc': 0.9684686681442676}
===== EPOCH 8 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2860405396182126
{'epoch_loss': 0.2860405396182126, 'train/pearson': np.float32(0.9082923), 'train/spearman': np.float64(0.7092590022880919), 'train/auc': 0.9720865341278982, 'test/pearson': np.float32(0.8480011), 'test/spearman': np.float64(0.5757694418099358), 'test/auc': 0.9698841426556518}
===== EPOCH 9 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2803200026765799
{'epoch_loss': 0.2803200026765799, 'train/pearson': np.float32(0.90860045), 'train/spearman': np.float64(0.7045140188201034), 'train/auc': 0.9729105112925034, 'test/pearson': np.float32(0.8492971), 'test/spearman': np.float64(0.5803272527658279), 'test/auc': 0.971751460809994}
===== EPOCH 10 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.28742919784908094
{'epoch_loss': 0.28742919784908094, 'train/pearson': np.float32(0.9088926), 'train/spearman': np.float64(0.7098504516518823), 'train/auc': 0.9727373428814663, 'test/pearson': np.float32(0.8474592), 'test/spearman': np.float64(0.572923183211543), 'test/auc': 0.9704621700584324}
Early stopping on epoch 10


0,1
epoch_loss,█▂▁▁▁▁▁▁▁▁
test/auc,▁▃▃▄▆▆▅▆█▇
test/pearson,▁▄▆▆▇▇▇▇█▇
test/spearman,▁▃▆▇█████▇
train/auc,▁▃▄▅▅▆▇▇██
train/pearson,▁▄▆▆▇█▇███
train/spearman,▁▃▅▇████▇█

0,1
epoch_loss,0.28743
test/auc,0.97046
test/pearson,0.84746
test/spearman,0.57292
train/auc,0.97274
train/pearson,0.90889
train/spearman,0.70985


Calibration with: {'conv_init_weight_std': 0.001, 'dropout': 0.5, 'fc_init_weight_std': 0.001, 'lr': 0.0005, 'motif_len': 16, 'num_motif_detectors': 16, 'weight_decay': 0.001}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.9694220900249938
{'epoch_loss': 0.9694220900249938, 'train/pearson': np.float32(0.84405535), 'train/spearman': np.float64(0.6475968798591731), 'train/auc': 0.9442614051780259, 'test/pearson': np.float32(0.74407846), 'test/spearman': np.float64(0.4247090158877921), 'test/auc': 0.9355979246423535}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.8580240074533243
{'epoch_loss': 0.8580240074533243, 'train/pearson': np.float32(0.838325), 'train/spearman': np.float64(0.6449564356397889), 'train/auc': 0.9427104011217388, 'test/pearson': np.float32(0.74190545), 'test/spearman': np.float64(0.4425600982868595), 'test/auc': 0.9263379004634295}
Early stopping on epoch 2


0,1
epoch_loss,█▁
test/auc,█▁
test/pearson,█▁
test/spearman,▁█
train/auc,█▁
train/pearson,█▁
train/spearman,█▁

0,1
epoch_loss,0.85802
test/auc,0.92634
test/pearson,0.74191
test/spearman,0.44256
train/auc,0.94271
train/pearson,0.83833
train/spearman,0.64496


Calibration with: {'conv_init_weight_std': 0.001, 'dropout': 0.5, 'fc_init_weight_std': 0.001, 'lr': 0.0005, 'motif_len': 16, 'num_motif_detectors': 16, 'weight_decay': 0.005}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.6177041218541681
{'epoch_loss': 0.6177041218541681, 'train/pearson': np.float32(0.89390594), 'train/spearman': np.float64(0.6355181443716926), 'train/auc': 0.961267364414843, 'test/pearson': np.float32(0.8348392), 'test/spearman': np.float64(0.5055478940662022), 'test/auc': 0.964790449325005}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3015872577008919
{'epoch_loss': 0.3015872577008919, 'train/pearson': np.float32(0.9010331), 'train/spearman': np.float64(0.6621746240404398), 'train/auc': 0.9663792878962392, 'test/pearson': np.float32(0.841665), 'test/spearman': np.float64(0.5393673472325002), 'test/auc': 0.9676385250856336}
===== EPOCH 3 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.28827308914579525
{'epoch_loss': 0.28827308914579525, 'train/pearson': np.float32(0.9026616), 'train/spearman': np.float64(0.6802288382811033), 'train/auc': 0.9684849516750964, 'test/pearson': np.float32(0.8427513), 'test/spearman': np.float64(0.5443633519386254), 'test/auc': 0.969276143461616}
===== EPOCH 4 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.27993471962146865
{'epoch_loss': 0.27993471962146865, 'train/pearson': np.float32(0.90446943), 'train/spearman': np.float64(0.6859400994152635), 'train/auc': 0.9700297461064651, 'test/pearson': np.float32(0.8435468), 'test/spearman': np.float64(0.5501493067449719), 'test/auc': 0.9688152327221439}
===== EPOCH 5 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.29196314871930085
{'epoch_loss': 0.29196314871930085, 'train/pearson': np.float32(0.9050831), 'train/spearman': np.float64(0.6939145994721172), 'train/auc': 0.9718958385497521, 'test/pearson': np.float32(0.8442195), 'test/spearman': np.float64(0.5577932030306577), 'test/auc': 0.9712288434414669}
===== EPOCH 6 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2806099258458462
{'epoch_loss': 0.2806099258458462, 'train/pearson': np.float32(0.9069579), 'train/spearman': np.float64(0.698127966239544), 'train/auc': 0.9717566227652862, 'test/pearson': np.float32(0.84550697), 'test/spearman': np.float64(0.5611870681434153), 'test/auc': 0.9728324602055209}
===== EPOCH 7 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2845496398667558
{'epoch_loss': 0.2845496398667558, 'train/pearson': np.float32(0.9072176), 'train/spearman': np.float64(0.7005570107546751), 'train/auc': 0.9725675797486103, 'test/pearson': np.float32(0.84512544), 'test/spearman': np.float64(0.5614211124247992), 'test/auc': 0.9703687285915776}
===== EPOCH 8 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2773458767479982
{'epoch_loss': 0.2773458767479982, 'train/pearson': np.float32(0.90804845), 'train/spearman': np.float64(0.7025279161623968), 'train/auc': 0.9719471180329511, 'test/pearson': np.float32(0.84506375), 'test/spearman': np.float64(0.562355971862709), 'test/auc': 0.9722446101148499}
===== EPOCH 9 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2811163980216264
{'epoch_loss': 0.2811163980216264, 'train/pearson': np.float32(0.9076546), 'train/spearman': np.float64(0.701133495561382), 'train/auc': 0.9724282638089038, 'test/pearson': np.float32(0.8459814), 'test/spearman': np.float64(0.5672404629049196), 'test/auc': 0.9706734837799718}
===== EPOCH 10 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.27648678241065516
{'epoch_loss': 0.27648678241065516, 'train/pearson': np.float32(0.908125), 'train/spearman': np.float64(0.6993263096333312), 'train/auc': 0.9732044669237319, 'test/pearson': np.float32(0.84615266), 'test/spearman': np.float64(0.5681368679315911), 'test/auc': 0.971517227483377}
===== EPOCH 11 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.28024757161926916
{'epoch_loss': 0.28024757161926916, 'train/pearson': np.float32(0.9097615), 'train/spearman': np.float64(0.7044443428680484), 'train/auc': 0.9725903149882318, 'test/pearson': np.float32(0.8474375), 'test/spearman': np.float64(0.5706801993178077), 'test/auc': 0.9713127140842233}
===== EPOCH 12 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2826595055623748
{'epoch_loss': 0.2826595055623748, 'train/pearson': np.float32(0.90972716), 'train/spearman': np.float64(0.7047657914999995), 'train/auc': 0.9742471831338575, 'test/pearson': np.float32(0.84795606), 'test/spearman': np.float64(0.5741833369655831), 'test/auc': 0.9726803344751158}
===== EPOCH 13 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2779036602035117
{'epoch_loss': 0.2779036602035117, 'train/pearson': np.float32(0.9098117), 'train/spearman': np.float64(0.705335220554094), 'train/auc': 0.974817366918724, 'test/pearson': np.float32(0.84678555), 'test/spearman': np.float64(0.5669477458273059), 'test/auc': 0.9718522063268185}
Early stopping on epoch 13


0,1
epoch_loss,█▂▁▁▁▁▁▁▁▁▁▁▁
test/auc,▁▃▅▅▇█▆▇▆▇▇█▇
test/pearson,▁▅▅▆▆▇▆▆▇▇██▇
test/spearman,▁▄▅▆▆▇▇▇▇▇██▇
train/auc,▁▄▅▆▆▆▇▇▇▇▇██
train/pearson,▁▄▅▆▆▇▇▇▇▇███
train/spearman,▁▄▅▆▇▇███▇███

0,1
epoch_loss,0.2779
test/auc,0.97185
test/pearson,0.84679
test/spearman,0.56695
train/auc,0.97482
train/pearson,0.90981
train/spearman,0.70534


Calibration with: {'conv_init_weight_std': 0.001, 'dropout': 0.5, 'fc_init_weight_std': 0.001, 'lr': 0.0005, 'motif_len': 24, 'num_motif_detectors': 16, 'weight_decay': 0.001}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.7321664836650459
{'epoch_loss': 0.7321664836650459, 'train/pearson': np.float32(0.8842634), 'train/spearman': np.float64(0.6207032740727649), 'train/auc': 0.960966397916771, 'test/pearson': np.float32(0.8221113), 'test/spearman': np.float64(0.44079448084639217), 'test/auc': 0.9663479750151118}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.34698457153032003
{'epoch_loss': 0.34698457153032003, 'train/pearson': np.float32(0.89607847), 'train/spearman': np.float64(0.6349452270317436), 'train/auc': 0.9641890930942961, 'test/pearson': np.float32(0.83583224), 'test/spearman': np.float64(0.47013806792002316), 'test/auc': 0.9662316139431795}
===== EPOCH 3 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.33183249495566464
{'epoch_loss': 0.33183249495566464, 'train/pearson': np.float32(0.899649), 'train/spearman': np.float64(0.6554501329652348), 'train/auc': 0.9654272121788772, 'test/pearson': np.float32(0.83865184), 'test/spearman': np.float64(0.5175933716750867), 'test/auc': 0.96609359258513}
===== EPOCH 4 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.33235682839878833
{'epoch_loss': 0.33235682839878833, 'train/pearson': np.float32(0.9017915), 'train/spearman': np.float64(0.6684101914820361), 'train/auc': 0.9681116730932946, 'test/pearson': np.float32(0.84010345), 'test/spearman': np.float64(0.5262279633350835), 'test/auc': 0.9682157968970381}
===== EPOCH 5 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.32968348214706295
{'epoch_loss': 0.32968348214706295, 'train/pearson': np.float32(0.902577), 'train/spearman': np.float64(0.6624810004498842), 'train/auc': 0.9700912414242076, 'test/pearson': np.float32(0.84102446), 'test/spearman': np.float64(0.5264653865952037), 'test/auc': 0.971200886560548}
===== EPOCH 6 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3307821743856794
{'epoch_loss': 0.3307821743856794, 'train/pearson': np.float32(0.90300304), 'train/spearman': np.float64(0.6724707699888022), 'train/auc': 0.9707256247183134, 'test/pearson': np.float32(0.8410009), 'test/spearman': np.float64(0.535497165202316), 'test/auc': 0.9702438041507153}
===== EPOCH 7 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3348675545055074
{'epoch_loss': 0.3348675545055074, 'train/pearson': np.float32(0.9040748), 'train/spearman': np.float64(0.6786165920395463), 'train/auc': 0.9694642696179078, 'test/pearson': np.float32(0.8421456), 'test/spearman': np.float64(0.5437558504149196), 'test/auc': 0.968482772516623}
===== EPOCH 8 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.33439435681310326
{'epoch_loss': 0.33439435681310326, 'train/pearson': np.float32(0.90380025), 'train/spearman': np.float64(0.6760958613877797), 'train/auc': 0.9693567028894787, 'test/pearson': np.float32(0.841655), 'test/spearman': np.float64(0.5370011143023892), 'test/auc': 0.9691673383034455}
===== EPOCH 9 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3395461185743062
{'epoch_loss': 0.3395461185743062, 'train/pearson': np.float32(0.90405065), 'train/spearman': np.float64(0.676047496513148), 'train/auc': 0.9701058640893385, 'test/pearson': np.float32(0.84303784), 'test/spearman': np.float64(0.5452640186579275), 'test/auc': 0.9708941164618174}
===== EPOCH 10 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.33619323750130664
{'epoch_loss': 0.33619323750130664, 'train/pearson': np.float32(0.9040268), 'train/spearman': np.float64(0.6684359648664778), 'train/auc': 0.970175572136812, 'test/pearson': np.float32(0.8427153), 'test/spearman': np.float64(0.5446692023999069), 'test/auc': 0.9695496675397944}
===== EPOCH 11 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.33048427949983855
{'epoch_loss': 0.33048427949983855, 'train/pearson': np.float32(0.9049505), 'train/spearman': np.float64(0.674494570763396), 'train/auc': 0.971039711552907, 'test/pearson': np.float32(0.8430001), 'test/spearman': np.float64(0.5452558952602894), 'test/auc': 0.9720123916985695}
===== EPOCH 12 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3298761214430149
{'epoch_loss': 0.3298761214430149, 'train/pearson': np.float32(0.90366447), 'train/spearman': np.float64(0.6738479015620442), 'train/auc': 0.9705173018178177, 'test/pearson': np.float32(0.84217304), 'test/spearman': np.float64(0.5441514090088704), 'test/auc': 0.9709268587547855}
===== EPOCH 13 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.33288401591415984
{'epoch_loss': 0.33288401591415984, 'train/pearson': np.float32(0.9047807), 'train/spearman': np.float64(0.6773147938603351), 'train/auc': 0.9704658220241373, 'test/pearson': np.float32(0.84311825), 'test/spearman': np.float64(0.5483478015451942), 'test/auc': 0.9709555712270804}
===== EPOCH 14 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3314931239135349
{'epoch_loss': 0.3314931239135349, 'train/pearson': np.float32(0.9058464), 'train/spearman': np.float64(0.6840777883622124), 'train/auc': 0.9708542240472733, 'test/pearson': np.float32(0.8435174), 'test/spearman': np.float64(0.5480334115171459), 'test/auc': 0.9694640338504936}
===== EPOCH 15 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3322025432992286
{'epoch_loss': 0.3322025432992286, 'train/pearson': np.float32(0.90505666), 'train/spearman': np.float64(0.6823559310358681), 'train/auc': 0.9707925284190496, 'test/pearson': np.float32(0.8430645), 'test/spearman': np.float64(0.5512968313030224), 'test/auc': 0.9696584726979649}
===== EPOCH 16 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.33228271192731185
{'epoch_loss': 0.33228271192731185, 'train/pearson': np.float32(0.90533996), 'train/spearman': np.float64(0.6800880261247738), 'train/auc': 0.9716547648855727, 'test/pearson': np.float32(0.8440486), 'test/spearman': np.float64(0.5593727884801739), 'test/auc': 0.9707419907314124}
===== EPOCH 17 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.32821213323087356
{'epoch_loss': 0.32821213323087356, 'train/pearson': np.float32(0.90593547), 'train/spearman': np.float64(0.6842886948191539), 'train/auc': 0.9704411838349443, 'test/pearson': np.float32(0.843912), 'test/spearman': np.float64(0.5543439897699128), 'test/auc': 0.9704221237154946}
===== EPOCH 18 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3247341706682318
{'epoch_loss': 0.3247341706682318, 'train/pearson': np.float32(0.90638053), 'train/spearman': np.float64(0.6877556685697861), 'train/auc': 0.9698885272171867, 'test/pearson': np.float32(0.8431593), 'test/spearman': np.float64(0.5526480864648446), 'test/auc': 0.9687507555913761}
===== EPOCH 19 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.32882921989209735
{'epoch_loss': 0.32882921989209735, 'train/pearson': np.float32(0.90595406), 'train/spearman': np.float64(0.6882859474580882), 'train/auc': 0.9699764635184538, 'test/pearson': np.float32(0.84419495), 'test/spearman': np.float64(0.5639749312441327), 'test/auc': 0.9680117872254684}
===== EPOCH 20 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3302686103045369
{'epoch_loss': 0.3302686103045369, 'train/pearson': np.float32(0.9055783), 'train/spearman': np.float64(0.6838551772074773), 'train/auc': 0.9711568931844359, 'test/pearson': np.float32(0.8429053), 'test/spearman': np.float64(0.5520142697045908), 'test/auc': 0.9709666532339312}
Early stopping on epoch 20


0,1
epoch_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test/auc,▁▁▁▄▇▆▄▅▇▅█▇▇▅▅▆▆▄▃▇
test/pearson,▁▅▆▇▇▇▇▇███▇████████
test/spearman,▁▃▅▆▆▆▇▆▇▇▇▇▇▇▇█▇▇█▇
train/auc,▁▃▄▆▇▇▇▆▇▇█▇▇▇▇█▇▇▇█
train/pearson,▁▅▆▇▇▇▇▇▇▇█▇▇███████
train/spearman,▁▂▅▆▅▆▇▇▇▆▇▇▇█▇▇████

0,1
epoch_loss,0.33027
test/auc,0.97097
test/pearson,0.84291
test/spearman,0.55201
train/auc,0.97116
train/pearson,0.90558
train/spearman,0.68386


Calibration with: {'conv_init_weight_std': 0.001, 'dropout': 0.5, 'fc_init_weight_std': 0.001, 'lr': 0.0005, 'motif_len': 24, 'num_motif_detectors': 16, 'weight_decay': 0.005}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.6498174946338605
{'epoch_loss': 0.6498174946338605, 'train/pearson': np.float32(0.89112073), 'train/spearman': np.float64(0.6321743944320661), 'train/auc': 0.9606706394912113, 'test/pearson': np.float32(0.8302341), 'test/spearman': np.float64(0.4741051923845194), 'test/auc': 0.963869131573645}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.31953123949777584
{'epoch_loss': 0.31953123949777584, 'train/pearson': np.float32(0.8978897), 'train/spearman': np.float64(0.640242084949214), 'train/auc': 0.9659430116680855, 'test/pearson': np.float32(0.8368043), 'test/spearman': np.float64(0.4753729925653378), 'test/auc': 0.9689129558734636}
===== EPOCH 3 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.32634810477304765
{'epoch_loss': 0.32634810477304765, 'train/pearson': np.float32(0.90098035), 'train/spearman': np.float64(0.6653821049864376), 'train/auc': 0.9674427362411738, 'test/pearson': np.float32(0.8405184), 'test/spearman': np.float64(0.5187047580246755), 'test/auc': 0.9690726375176304}
===== EPOCH 4 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.315333921819354
{'epoch_loss': 0.315333921819354, 'train/pearson': np.float32(0.9031872), 'train/spearman': np.float64(0.6787833095076653), 'train/auc': 0.9680425659772647, 'test/pearson': np.float32(0.84106076), 'test/spearman': np.float64(0.5349761537843126), 'test/auc': 0.9716814426758009}
===== EPOCH 5 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.31350471012698955
{'epoch_loss': 0.31350471012698955, 'train/pearson': np.float32(0.905079), 'train/spearman': np.float64(0.6852049275207641), 'train/auc': 0.9700171265461466, 'test/pearson': np.float32(0.843213), 'test/spearman': np.float64(0.5411693863041308), 'test/auc': 0.9728198670159177}
===== EPOCH 6 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.31942896446147667
{'epoch_loss': 0.31942896446147667, 'train/pearson': np.float32(0.90548205), 'train/spearman': np.float64(0.6887733496634725), 'train/auc': 0.9684857529170213, 'test/pearson': np.float32(0.8442514), 'test/spearman': np.float64(0.5594457312361063), 'test/auc': 0.9700669957686883}
===== EPOCH 7 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3149428591965296
{'epoch_loss': 0.3149428591965296, 'train/pearson': np.float32(0.90690976), 'train/spearman': np.float64(0.6900186336319031), 'train/auc': 0.9710653512945064, 'test/pearson': np.float32(0.8445485), 'test/spearman': np.float64(0.5551214448992672), 'test/auc': 0.9707722143864599}
===== EPOCH 8 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3112968062511839
{'epoch_loss': 0.3112968062511839, 'train/pearson': np.float32(0.9058886), 'train/spearman': np.float64(0.6827479273319964), 'train/auc': 0.9721786769492714, 'test/pearson': np.float32(0.8434062), 'test/spearman': np.float64(0.553340808132532), 'test/auc': 0.9706447713076769}
Early stopping on epoch 8


0,1
epoch_loss,█▁▁▁▁▁▁▁
test/auc,▁▅▅▇█▆▆▆
test/pearson,▁▄▆▆▇██▇
test/spearman,▁▁▅▆▇██▇
train/auc,▁▄▅▅▇▆▇█
train/pearson,▁▄▅▆▇▇██
train/spearman,▁▂▅▇▇██▇

0,1
epoch_loss,0.3113
test/auc,0.97064
test/pearson,0.84341
test/spearman,0.55334
train/auc,0.97218
train/pearson,0.90589
train/spearman,0.68275


Calibration with: {'conv_init_weight_std': 0.0001, 'dropout': 0.5, 'fc_init_weight_std': 0.01, 'lr': 0.0005, 'motif_len': 12, 'num_motif_detectors': 16, 'weight_decay': 0.001}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.5614421922035111
{'epoch_loss': 0.5614421922035111, 'train/pearson': np.float32(0.897409), 'train/spearman': np.float64(0.6626596568588927), 'train/auc': 0.9642211427712954, 'test/pearson': np.float32(0.83805895), 'test/spearman': np.float64(0.5228240971947916), 'test/auc': 0.9663232923634898}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.27536581531643106
{'epoch_loss': 0.27536581531643106, 'train/pearson': np.float32(0.9029499), 'train/spearman': np.float64(0.6875944680850969), 'train/auc': 0.9663544493965647, 'test/pearson': np.float32(0.8438212), 'test/spearman': np.float64(0.5457317285315433), 'test/auc': 0.9689411646181745}
===== EPOCH 3 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.27211693223244465
{'epoch_loss': 0.27211693223244465, 'train/pearson': np.float32(0.906037), 'train/spearman': np.float64(0.6991794158729545), 'train/auc': 0.9675683309129149, 'test/pearson': np.float32(0.8471338), 'test/spearman': np.float64(0.5704018657980366), 'test/auc': 0.9673166431593794}
===== EPOCH 4 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2697606302750187
{'epoch_loss': 0.2697606302750187, 'train/pearson': np.float32(0.9053751), 'train/spearman': np.float64(0.6993424049596241), 'train/auc': 0.9676524613150383, 'test/pearson': np.float32(0.84639484), 'test/spearman': np.float64(0.5717179192314444), 'test/auc': 0.9686998791053798}
===== EPOCH 5 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2669800615872438
{'epoch_loss': 0.2669800615872438, 'train/pearson': np.float32(0.90701073), 'train/spearman': np.float64(0.6991703277619793), 'train/auc': 0.9694005708848715, 'test/pearson': np.float32(0.8473055), 'test/spearman': np.float64(0.5664524722173053), 'test/auc': 0.9690076566592787}
===== EPOCH 6 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2669495276797313
{'epoch_loss': 0.2669495276797313, 'train/pearson': np.float32(0.9063881), 'train/spearman': np.float64(0.7038637533770109), 'train/auc': 0.9694732835895639, 'test/pearson': np.float32(0.84637004), 'test/spearman': np.float64(0.5685295032359877), 'test/auc': 0.9691308180535966}
===== EPOCH 7 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.26643299115018343
{'epoch_loss': 0.26643299115018343, 'train/pearson': np.float32(0.9073334), 'train/spearman': np.float64(0.7050133744430683), 'train/auc': 0.9706350843807903, 'test/pearson': np.float32(0.846158), 'test/spearman': np.float64(0.5658886799934699), 'test/auc': 0.9702019947612331}
Early stopping on epoch 7


0,1
epoch_loss,█▁▁▁▁▁▁
test/auc,▁▆▃▅▆▆█
test/pearson,▁▅█▇█▇▇
test/spearman,▁▄██▇█▇
train/auc,▁▃▅▅▇▇█
train/pearson,▁▅▇▇█▇█
train/spearman,▁▅▇▇▇██

0,1
epoch_loss,0.26643
test/auc,0.9702
test/pearson,0.84616
test/spearman,0.56589
train/auc,0.97064
train/pearson,0.90733
train/spearman,0.70501


Calibration with: {'conv_init_weight_std': 0.0001, 'dropout': 0.5, 'fc_init_weight_std': 0.01, 'lr': 0.0005, 'motif_len': 12, 'num_motif_detectors': 16, 'weight_decay': 0.005}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.563649331621183
{'epoch_loss': 0.563649331621183, 'train/pearson': np.float32(0.89690113), 'train/spearman': np.float64(0.6486040087801059), 'train/auc': 0.9642163353197456, 'test/pearson': np.float32(0.8391323), 'test/spearman': np.float64(0.5332052948451), 'test/auc': 0.965924843844449}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2835022570273747
{'epoch_loss': 0.2835022570273747, 'train/pearson': np.float32(0.90311223), 'train/spearman': np.float64(0.6848097095966379), 'train/auc': 0.9694538534728829, 'test/pearson': np.float32(0.8435368), 'test/spearman': np.float64(0.5479017462230051), 'test/auc': 0.9686462321176708}
===== EPOCH 3 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2762254678165189
{'epoch_loss': 0.2762254678165189, 'train/pearson': np.float32(0.9059123), 'train/spearman': np.float64(0.6981971041277144), 'train/auc': 0.9719473183434323, 'test/pearson': np.float32(0.8466084), 'test/spearman': np.float64(0.568139819440737), 'test/auc': 0.9706719725972195}
===== EPOCH 4 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2774179525935231
{'epoch_loss': 0.2774179525935231, 'train/pearson': np.float32(0.9074171), 'train/spearman': np.float64(0.6988583709470467), 'train/auc': 0.9719266863638639, 'test/pearson': np.float32(0.8470399), 'test/spearman': np.float64(0.5661609409724575), 'test/auc': 0.9704211162603265}
===== EPOCH 5 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2754568917016252
{'epoch_loss': 0.2754568917016252, 'train/pearson': np.float32(0.9078989), 'train/spearman': np.float64(0.7032524851426303), 'train/auc': 0.9707037908758576, 'test/pearson': np.float32(0.84674656), 'test/spearman': np.float64(0.5684923243520793), 'test/auc': 0.9699037880314326}
===== EPOCH 6 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.26923613690625364
{'epoch_loss': 0.26923613690625364, 'train/pearson': np.float32(0.9089428), 'train/spearman': np.float64(0.7060665206633111), 'train/auc': 0.9728905803996194, 'test/pearson': np.float32(0.8478883), 'test/spearman': np.float64(0.5754011579664948), 'test/auc': 0.9705455369736047}
===== EPOCH 7 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2711119269862914
{'epoch_loss': 0.2711119269862914, 'train/pearson': np.float32(0.9093154), 'train/spearman': np.float64(0.7061209267766687), 'train/auc': 0.9730400120186289, 'test/pearson': np.float32(0.84788215), 'test/spearman': np.float64(0.5714577331319416), 'test/auc': 0.9699027805762643}
===== EPOCH 8 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.27424528114140606
{'epoch_loss': 0.27424528114140606, 'train/pearson': np.float32(0.90983856), 'train/spearman': np.float64(0.7064837396612595), 'train/auc': 0.9737665381341079, 'test/pearson': np.float32(0.8485482), 'test/spearman': np.float64(0.5729587804331355), 'test/auc': 0.9713832359460004}
===== EPOCH 9 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2728372487468651
{'epoch_loss': 0.2728372487468651, 'train/pearson': np.float32(0.9094186), 'train/spearman': np.float64(0.7027099632585567), 'train/auc': 0.975363813911563, 'test/pearson': np.float32(0.8483651), 'test/spearman': np.float64(0.5697783533589204), 'test/auc': 0.9723811202901471}
===== EPOCH 10 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2730590144809062
{'epoch_loss': 0.2730590144809062, 'train/pearson': np.float32(0.91060895), 'train/spearman': np.float64(0.7062080882319639), 'train/auc': 0.9759497220692073, 'test/pearson': np.float32(0.84876806), 'test/spearman': np.float64(0.5702855148391012), 'test/auc': 0.9709714386459802}
===== EPOCH 11 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2713406188681293
{'epoch_loss': 0.2713406188681293, 'train/pearson': np.float32(0.9123668), 'train/spearman': np.float64(0.7080669056842586), 'train/auc': 0.9760124192498373, 'test/pearson': np.float32(0.8516551), 'test/spearman': np.float64(0.5835657868717155), 'test/auc': 0.972364245416079}
===== EPOCH 12 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.27458508394825193
{'epoch_loss': 0.27458508394825193, 'train/pearson': np.float32(0.9128725), 'train/spearman': np.float64(0.6995530759351272), 'train/auc': 0.9775902649106115, 'test/pearson': np.float32(0.85093707), 'test/spearman': np.float64(0.5699371642039642), 'test/auc': 0.9711359057021962}
===== EPOCH 13 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2674622076054731
{'epoch_loss': 0.2674622076054731, 'train/pearson': np.float32(0.91190505), 'train/spearman': np.float64(0.703286157677081), 'train/auc': 0.9765792979117632, 'test/pearson': np.float32(0.8502404), 'test/spearman': np.float64(0.5740351847384598), 'test/auc': 0.9693471690509773}
Early stopping on epoch 13


0,1
epoch_loss,█▁▁▁▁▁▁▁▁▁▁▁▁
test/auc,▁▄▆▆▅▆▅▇█▆█▇▅
test/pearson,▁▃▅▅▅▆▆▆▆▆██▇
test/spearman,▁▃▆▆▆▇▆▇▆▆█▆▇
train/auc,▁▄▅▅▄▆▆▆▇▇▇█▇
train/pearson,▁▄▅▆▆▆▆▇▆▇███
train/spearman,▁▅▇▇▇███▇██▇▇

0,1
epoch_loss,0.26746
test/auc,0.96935
test/pearson,0.85024
test/spearman,0.57404
train/auc,0.97658
train/pearson,0.91191
train/spearman,0.70329


Calibration with: {'conv_init_weight_std': 0.0001, 'dropout': 0.5, 'fc_init_weight_std': 0.01, 'lr': 0.0005, 'motif_len': 16, 'num_motif_detectors': 16, 'weight_decay': 0.001}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.6022643364085176
{'epoch_loss': 0.6022643364085176, 'train/pearson': np.float32(0.8914993), 'train/spearman': np.float64(0.6663586531845915), 'train/auc': 0.9608951875406881, 'test/pearson': np.float32(0.82986164), 'test/spearman': np.float64(0.5093388045637991), 'test/auc': 0.9641502115655853}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2975893340671596
{'epoch_loss': 0.2975893340671596, 'train/pearson': np.float32(0.90177405), 'train/spearman': np.float64(0.6876936750690062), 'train/auc': 0.9655742400721117, 'test/pearson': np.float32(0.84118605), 'test/spearman': np.float64(0.544523690561021), 'test/auc': 0.9665484585935926}
===== EPOCH 3 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.29086324342856773
{'epoch_loss': 0.29086324342856773, 'train/pearson': np.float32(0.90629), 'train/spearman': np.float64(0.7029468181645251), 'train/auc': 0.9677974861034603, 'test/pearson': np.float32(0.8441998), 'test/spearman': np.float64(0.5607002933133495), 'test/auc': 0.9665439250453355}
===== EPOCH 4 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.28803957032319455
{'epoch_loss': 0.28803957032319455, 'train/pearson': np.float32(0.90625376), 'train/spearman': np.float64(0.7065918525970316), 'train/auc': 0.9675735389854275, 'test/pearson': np.float32(0.8449853), 'test/spearman': np.float64(0.572155758133861), 'test/auc': 0.9671489018738666}
===== EPOCH 5 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2830219493625453
{'epoch_loss': 0.2830219493625453, 'train/pearson': np.float32(0.90768325), 'train/spearman': np.float64(0.7010498014030285), 'train/auc': 0.9691438730031549, 'test/pearson': np.float32(0.84638566), 'test/spearman': np.float64(0.5699709501278072), 'test/auc': 0.9688993552286924}
===== EPOCH 6 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2871567439823486
{'epoch_loss': 0.2871567439823486, 'train/pearson': np.float32(0.90735745), 'train/spearman': np.float64(0.7026580342623228), 'train/auc': 0.969992288046472, 'test/pearson': np.float32(0.8442891), 'test/spearman': np.float64(0.5580044743597107), 'test/auc': 0.9685578279266573}
Early stopping on epoch 6


0,1
epoch_loss,█▁▁▁▁▁
test/auc,▁▅▅▅█▇
test/pearson,▁▆▇▇█▇
test/spearman,▁▅▇██▆
train/auc,▁▅▆▆▇█
train/pearson,▁▅▇▇██
train/spearman,▁▅▇█▇▇

0,1
epoch_loss,0.28716
test/auc,0.96856
test/pearson,0.84429
test/spearman,0.558
train/auc,0.96999
train/pearson,0.90736
train/spearman,0.70266


Calibration with: {'conv_init_weight_std': 0.0001, 'dropout': 0.5, 'fc_init_weight_std': 0.01, 'lr': 0.0005, 'motif_len': 16, 'num_motif_detectors': 16, 'weight_decay': 0.005}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.6414757753951481
{'epoch_loss': 0.6414757753951481, 'train/pearson': np.float32(0.88562256), 'train/spearman': np.float64(0.6704935924380346), 'train/auc': 0.9620042065201061, 'test/pearson': np.float32(0.82489127), 'test/spearman': np.float64(0.519756726494347), 'test/auc': 0.9646922224461011}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3155457113807003
{'epoch_loss': 0.3155457113807003, 'train/pearson': np.float32(0.9031846), 'train/spearman': np.float64(0.6922809420175854), 'train/auc': 0.9641634533526966, 'test/pearson': np.float32(0.84152025), 'test/spearman': np.float64(0.5455377670617432), 'test/auc': 0.9667051178722548}
===== EPOCH 3 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.30669557476957765
{'epoch_loss': 0.30669557476957765, 'train/pearson': np.float32(0.90495753), 'train/spearman': np.float64(0.6970011563363391), 'train/auc': 0.9661347087986378, 'test/pearson': np.float32(0.8422942), 'test/spearman': np.float64(0.5476468563225694), 'test/auc': 0.967306568607697}
===== EPOCH 4 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.29392813186580763
{'epoch_loss': 0.29392813186580763, 'train/pearson': np.float32(0.9068613), 'train/spearman': np.float64(0.7102445666950176), 'train/auc': 0.9672545445440432, 'test/pearson': np.float32(0.8438718), 'test/spearman': np.float64(0.5673839817401602), 'test/auc': 0.9663540197461213}
===== EPOCH 5 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.29878551904528666
{'epoch_loss': 0.29878551904528666, 'train/pearson': np.float32(0.90670544), 'train/spearman': np.float64(0.7094172810352704), 'train/auc': 0.9691527868195704, 'test/pearson': np.float32(0.84274656), 'test/spearman': np.float64(0.5609996336659011), 'test/auc': 0.9672325206528309}
Early stopping on epoch 5


0,1
epoch_loss,█▁▁▁▁
test/auc,▁▆█▅█
test/pearson,▁▇▇██
test/spearman,▁▅▅█▇
train/auc,▁▃▅▆█
train/pearson,▁▇▇██
train/spearman,▁▅▆██

0,1
epoch_loss,0.29879
test/auc,0.96723
test/pearson,0.84275
test/spearman,0.561
train/auc,0.96915
train/pearson,0.90671
train/spearman,0.70942


Calibration with: {'conv_init_weight_std': 0.0001, 'dropout': 0.5, 'fc_init_weight_std': 0.01, 'lr': 0.0005, 'motif_len': 24, 'num_motif_detectors': 16, 'weight_decay': 0.001}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.6927230547006709
{'epoch_loss': 0.6927230547006709, 'train/pearson': np.float32(0.8885034), 'train/spearman': np.float64(0.6339244899053741), 'train/auc': 0.9614528519204768, 'test/pearson': np.float32(0.8217704), 'test/spearman': np.float64(0.451730395578329), 'test/auc': 0.9668365907717107}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.33585620210907713
{'epoch_loss': 0.33585620210907713, 'train/pearson': np.float32(0.9011686), 'train/spearman': np.float64(0.6663027338665053), 'train/auc': 0.9652777805598678, 'test/pearson': np.float32(0.83483714), 'test/spearman': np.float64(0.48250547580135816), 'test/auc': 0.9691179730002014}
===== EPOCH 3 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3246880883630663
{'epoch_loss': 0.3246880883630663, 'train/pearson': np.float32(0.903278), 'train/spearman': np.float64(0.6851323077486277), 'train/auc': 0.9658702989633933, 'test/pearson': np.float32(0.8391825), 'test/spearman': np.float64(0.5286675032854716), 'test/auc': 0.9686777150916783}
===== EPOCH 4 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3188330846639296
{'epoch_loss': 0.3188330846639296, 'train/pearson': np.float32(0.9047448), 'train/spearman': np.float64(0.6759077216734375), 'train/auc': 0.9675611197355901, 'test/pearson': np.float32(0.8392614), 'test/spearman': np.float64(0.5159756590387576), 'test/auc': 0.9686197864195043}
===== EPOCH 5 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.32653868341217407
{'epoch_loss': 0.32653868341217407, 'train/pearson': np.float32(0.9048366), 'train/spearman': np.float64(0.6918079899738928), 'train/auc': 0.9678920326506085, 'test/pearson': np.float32(0.8397784), 'test/spearman': np.float64(0.5299550298945082), 'test/auc': 0.9688222849083216}
===== EPOCH 6 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.32138664668170025
{'epoch_loss': 0.32138664668170025, 'train/pearson': np.float32(0.90549), 'train/spearman': np.float64(0.6853623005729412), 'train/auc': 0.9687541689618909, 'test/pearson': np.float32(0.8398155), 'test/spearman': np.float64(0.5260338272684599), 'test/auc': 0.9693144267580093}
===== EPOCH 7 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3167493145067852
{'epoch_loss': 0.3167493145067852, 'train/pearson': np.float32(0.9062469), 'train/spearman': np.float64(0.6959563863983634), 'train/auc': 0.9677048425058842, 'test/pearson': np.float32(0.84201115), 'test/spearman': np.float64(0.5420903040583934), 'test/auc': 0.9682913560346564}
===== EPOCH 8 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3209816706709016
{'epoch_loss': 0.3209816706709016, 'train/pearson': np.float32(0.9060225), 'train/spearman': np.float64(0.6893913242205616), 'train/auc': 0.9681780760178276, 'test/pearson': np.float32(0.8410422), 'test/spearman': np.float64(0.5356764275283515), 'test/auc': 0.9674012693935119}
===== EPOCH 9 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3170737978606559
{'epoch_loss': 0.3170737978606559, 'train/pearson': np.float32(0.9045145), 'train/spearman': np.float64(0.6804557654119303), 'train/auc': 0.9685147979368021, 'test/pearson': np.float32(0.8388356), 'test/spearman': np.float64(0.5221776790628161), 'test/auc': 0.9680394922425951}
Early stopping on epoch 9


0,1
epoch_loss,█▁▁▁▁▁▁▁▁
test/auc,▁▇▆▆▇█▅▃▄
test/pearson,▁▆▇▇▇▇██▇
test/spearman,▁▃▇▆▇▇██▆
train/auc,▁▅▅▇▇█▇▇█
train/pearson,▁▆▇▇▇███▇
train/spearman,▁▅▇▆█▇█▇▆

0,1
epoch_loss,0.31707
test/auc,0.96804
test/pearson,0.83884
test/spearman,0.52218
train/auc,0.96851
train/pearson,0.90451
train/spearman,0.68046


Calibration with: {'conv_init_weight_std': 0.0001, 'dropout': 0.5, 'fc_init_weight_std': 0.01, 'lr': 0.0005, 'motif_len': 24, 'num_motif_detectors': 16, 'weight_decay': 0.005}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.6218661912047444
{'epoch_loss': 0.6218661912047444, 'train/pearson': np.float32(0.89405346), 'train/spearman': np.float64(0.6464414514559302), 'train/auc': 0.9609228303871, 'test/pearson': np.float32(0.8310836), 'test/spearman': np.float64(0.49245627214373694), 'test/auc': 0.9651015011082007}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.29663495479014734
{'epoch_loss': 0.29663495479014734, 'train/pearson': np.float32(0.90369296), 'train/spearman': np.float64(0.6884676615599294), 'train/auc': 0.9670523311132254, 'test/pearson': np.float32(0.8395614), 'test/spearman': np.float64(0.5179258447205686), 'test/auc': 0.9691731311706628}
===== EPOCH 3 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.28828528985238305
{'epoch_loss': 0.28828528985238305, 'train/pearson': np.float32(0.9069982), 'train/spearman': np.float64(0.7016069224273758), 'train/auc': 0.9665812008613351, 'test/pearson': np.float32(0.8412654), 'test/spearman': np.float64(0.5293754440008579), 'test/auc': 0.9688933104976829}
===== EPOCH 4 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2860886491954136
{'epoch_loss': 0.2860886491954136, 'train/pearson': np.float32(0.9079145), 'train/spearman': np.float64(0.7107745762222794), 'train/auc': 0.969190345034804, 'test/pearson': np.float32(0.8426458), 'test/spearman': np.float64(0.5510427506084424), 'test/auc': 0.9682621398347773}
===== EPOCH 5 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2832647771619189
{'epoch_loss': 0.2832647771619189, 'train/pearson': np.float32(0.9072465), 'train/spearman': np.float64(0.7103213296981941), 'train/auc': 0.9695147478591817, 'test/pearson': np.float32(0.8439543), 'test/spearman': np.float64(0.5562308737455556), 'test/auc': 0.9691774128551279}
===== EPOCH 6 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.27737312373547507
{'epoch_loss': 0.27737312373547507, 'train/pearson': np.float32(0.9077659), 'train/spearman': np.float64(0.7028810378728729), 'train/auc': 0.9694438379488206, 'test/pearson': np.float32(0.84317726), 'test/spearman': np.float64(0.5406496465086706), 'test/auc': 0.9705349586943381}
===== EPOCH 7 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2773151626340307
{'epoch_loss': 0.2773151626340307, 'train/pearson': np.float32(0.90750873), 'train/spearman': np.float64(0.7086867469101371), 'train/auc': 0.970546146527117, 'test/pearson': np.float32(0.8450578), 'test/spearman': np.float64(0.562291291548557), 'test/auc': 0.971394317952851}
===== EPOCH 8 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.27172076815185836
{'epoch_loss': 0.27172076815185836, 'train/pearson': np.float32(0.90730834), 'train/spearman': np.float64(0.7121114301176111), 'train/auc': 0.9693430817767539, 'test/pearson': np.float32(0.84353507), 'test/spearman': np.float64(0.5565725259262039), 'test/auc': 0.9700982268789039}
Early stopping on epoch 8


0,1
epoch_loss,█▁▁▁▁▁▁▁
test/auc,▁▆▅▅▆▇█▇
test/pearson,▁▅▆▇▇▇█▇
test/spearman,▁▄▅▇▇▆█▇
train/auc,▁▅▅▇▇▇█▇
train/pearson,▁▆██████
train/spearman,▁▅▇██▇██

0,1
epoch_loss,0.27172
test/auc,0.9701
test/pearson,0.84354
test/spearman,0.55657
train/auc,0.96934
train/pearson,0.90731
train/spearman,0.71211


Calibration with: {'conv_init_weight_std': 0.0001, 'dropout': 0.5, 'fc_init_weight_std': 0.001, 'lr': 0.0005, 'motif_len': 12, 'num_motif_detectors': 16, 'weight_decay': 0.001}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.6656348740283292
{'epoch_loss': 0.6656348740283292, 'train/pearson': np.float32(0.8871596), 'train/spearman': np.float64(0.6300131899119907), 'train/auc': 0.9598691972557464, 'test/pearson': np.float32(0.82968533), 'test/spearman': np.float64(0.5184228139816578), 'test/auc': 0.9620159177916582}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3281005968967566
{'epoch_loss': 0.3281005968967566, 'train/pearson': np.float32(0.897291), 'train/spearman': np.float64(0.6505524557807296), 'train/auc': 0.9627514647703941, 'test/pearson': np.float32(0.84058803), 'test/spearman': np.float64(0.5367329711583506), 'test/auc': 0.9643126637114648}
===== EPOCH 3 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.33009095017473916
{'epoch_loss': 0.33009095017473916, 'train/pearson': np.float32(0.9003993), 'train/spearman': np.float64(0.6684987856955473), 'train/auc': 0.9662196404426862, 'test/pearson': np.float32(0.84362525), 'test/spearman': np.float64(0.5601397752041188), 'test/auc': 0.9671881926254281}
===== EPOCH 4 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3227136903010999
{'epoch_loss': 0.3227136903010999, 'train/pearson': np.float32(0.90133464), 'train/spearman': np.float64(0.6738952903391796), 'train/auc': 0.9659474184986728, 'test/pearson': np.float32(0.8444448), 'test/spearman': np.float64(0.5682735028759516), 'test/auc': 0.9654639331049768}
===== EPOCH 5 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.32181741231784655
{'epoch_loss': 0.32181741231784655, 'train/pearson': np.float32(0.9027623), 'train/spearman': np.float64(0.6809249735546885), 'train/auc': 0.9679583354199008, 'test/pearson': np.float32(0.8446242), 'test/spearman': np.float64(0.5698826042560132), 'test/auc': 0.9668945194438848}
===== EPOCH 6 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.32012861702864925
{'epoch_loss': 0.32012861702864925, 'train/pearson': np.float32(0.9024301), 'train/spearman': np.float64(0.6798964813729677), 'train/auc': 0.9696325304221544, 'test/pearson': np.float32(0.84481645), 'test/spearman': np.float64(0.57910601260796), 'test/auc': 0.9691884948619787}
===== EPOCH 7 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.32515850013342146
{'epoch_loss': 0.32515850013342146, 'train/pearson': np.float32(0.90384895), 'train/spearman': np.float64(0.6820551954878508), 'train/auc': 0.9688282838399519, 'test/pearson': np.float32(0.84620154), 'test/spearman': np.float64(0.5795940516819001), 'test/auc': 0.9671232117670763}
===== EPOCH 8 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3254723457482676
{'epoch_loss': 0.3254723457482676, 'train/pearson': np.float32(0.90461266), 'train/spearman': np.float64(0.6855017785817819), 'train/auc': 0.9693002153337674, 'test/pearson': np.float32(0.8473962), 'test/spearman': np.float64(0.5839460514675443), 'test/auc': 0.9678213781986702}
===== EPOCH 9 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.32409395038272243
{'epoch_loss': 0.32409395038272243, 'train/pearson': np.float32(0.903951), 'train/spearman': np.float64(0.6760970432241992), 'train/auc': 0.968978616856127, 'test/pearson': np.float32(0.845967), 'test/spearman': np.float64(0.5734638791178698), 'test/auc': 0.9664326012492443}
Early stopping on epoch 9


0,1
epoch_loss,█▁▁▁▁▁▁▁▁
test/auc,▁▃▆▄▆█▆▇▅
test/pearson,▁▅▇▇▇▇██▇
test/spearman,▁▃▅▆▆▇██▇
train/auc,▁▃▆▅▇█▇██
train/pearson,▁▅▆▇▇▇███
train/spearman,▁▄▆▇▇▇██▇

0,1
epoch_loss,0.32409
test/auc,0.96643
test/pearson,0.84597
test/spearman,0.57346
train/auc,0.96898
train/pearson,0.90395
train/spearman,0.6761


Calibration with: {'conv_init_weight_std': 0.0001, 'dropout': 0.5, 'fc_init_weight_std': 0.001, 'lr': 0.0005, 'motif_len': 12, 'num_motif_detectors': 16, 'weight_decay': 0.005}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.6761966651430527
{'epoch_loss': 0.6761966651430527, 'train/pearson': np.float32(0.88396907), 'train/spearman': np.float64(0.659008624046477), 'train/auc': 0.9583782863438329, 'test/pearson': np.float32(0.82417375), 'test/spearman': np.float64(0.5121425581170783), 'test/auc': 0.9614567801732822}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3385629093351836
{'epoch_loss': 0.3385629093351836, 'train/pearson': np.float32(0.89995164), 'train/spearman': np.float64(0.678550626724448), 'train/auc': 0.9634050778706996, 'test/pearson': np.float32(0.8402819), 'test/spearman': np.float64(0.536371977115646), 'test/auc': 0.9661298609711867}
===== EPOCH 3 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3272112980127906
{'epoch_loss': 0.3272112980127906, 'train/pearson': np.float32(0.9026953), 'train/spearman': np.float64(0.6909253427434332), 'train/auc': 0.9652397215684311, 'test/pearson': np.float32(0.8406535), 'test/spearman': np.float64(0.5432027980354422), 'test/auc': 0.9665595406004434}
===== EPOCH 4 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.32935784550044483
{'epoch_loss': 0.32935784550044483, 'train/pearson': np.float32(0.9033723), 'train/spearman': np.float64(0.6939949270024077), 'train/auc': 0.9665070859832741, 'test/pearson': np.float32(0.84190255), 'test/spearman': np.float64(0.5503418336686928), 'test/auc': 0.9684794982873262}
===== EPOCH 5 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3214207190080001
{'epoch_loss': 0.3214207190080001, 'train/pearson': np.float32(0.9049376), 'train/spearman': np.float64(0.6975489461579208), 'train/auc': 0.9661204867544695, 'test/pearson': np.float32(0.8439103), 'test/spearman': np.float64(0.5590852007370657), 'test/auc': 0.967384646383236}
===== EPOCH 6 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.32247477344954356
{'epoch_loss': 0.32247477344954356, 'train/pearson': np.float32(0.9050213), 'train/spearman': np.float64(0.694232972942105), 'train/auc': 0.9671028093544994, 'test/pearson': np.float32(0.84236246), 'test/spearman': np.float64(0.5483357356627886), 'test/auc': 0.9674231815434213}
Early stopping on epoch 6


0,1
epoch_loss,█▁▁▁▁▁
test/auc,▁▆▆█▇▇
test/pearson,▁▇▇▇█▇
test/spearman,▁▅▆▇█▆
train/auc,▁▅▇█▇█
train/pearson,▁▆▇▇██
train/spearman,▁▅▇▇█▇

0,1
epoch_loss,0.32247
test/auc,0.96742
test/pearson,0.84236
test/spearman,0.54834
train/auc,0.9671
train/pearson,0.90502
train/spearman,0.69423


Calibration with: {'conv_init_weight_std': 0.0001, 'dropout': 0.5, 'fc_init_weight_std': 0.001, 'lr': 0.0005, 'motif_len': 16, 'num_motif_detectors': 16, 'weight_decay': 0.001}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.5971460558545475
{'epoch_loss': 0.5971460558545475, 'train/pearson': np.float32(0.89145434), 'train/spearman': np.float64(0.6130116367942802), 'train/auc': 0.9631726175572137, 'test/pearson': np.float32(0.834214), 'test/spearman': np.float64(0.571078527164374), 'test/auc': 0.9657515615555108}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.29543241274099763
{'epoch_loss': 0.29543241274099763, 'train/pearson': np.float32(0.89669806), 'train/spearman': np.float64(0.6561967057101717), 'train/auc': 0.9670000500776204, 'test/pearson': np.float32(0.8390458), 'test/spearman': np.float64(0.5684324642959175), 'test/auc': 0.9704024783397138}
===== EPOCH 3 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.28699080247324876
{'epoch_loss': 0.28699080247324876, 'train/pearson': np.float32(0.8996319), 'train/spearman': np.float64(0.6684624541592238), 'train/auc': 0.9667282287545697, 'test/pearson': np.float32(0.8420381), 'test/spearman': np.float64(0.5720259295083593), 'test/auc': 0.9689169856941364}
===== EPOCH 4 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.27888249940337084
{'epoch_loss': 0.27888249940337084, 'train/pearson': np.float32(0.90151733), 'train/spearman': np.float64(0.6699206898897672), 'train/auc': 0.9688645400370575, 'test/pearson': np.float32(0.8437788), 'test/spearman': np.float64(0.5683679009935826), 'test/auc': 0.9719464033850493}
===== EPOCH 5 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2741763786719249
{'epoch_loss': 0.2741763786719249, 'train/pearson': np.float32(0.9029325), 'train/spearman': np.float64(0.6690527209218888), 'train/auc': 0.9681277980870349, 'test/pearson': np.float32(0.84482825), 'test/spearman': np.float64(0.5607057751365476), 'test/auc': 0.9696771106185775}
===== EPOCH 6 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2769183446043216
{'epoch_loss': 0.2769183446043216, 'train/pearson': np.float32(0.90348214), 'train/spearman': np.float64(0.6737094306457969), 'train/auc': 0.969256747959337, 'test/pearson': np.float32(0.844126), 'test/spearman': np.float64(0.5579690160426407), 'test/auc': 0.9709767277856136}
===== EPOCH 7 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.27699987539134846
{'epoch_loss': 0.27699987539134846, 'train/pearson': np.float32(0.9039944), 'train/spearman': np.float64(0.6750684072200139), 'train/auc': 0.9714393309629927, 'test/pearson': np.float32(0.8451079), 'test/spearman': np.float64(0.5679970108438605), 'test/auc': 0.9714779367318154}
===== EPOCH 8 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.28446344742045615
{'epoch_loss': 0.28446344742045615, 'train/pearson': np.float32(0.90434146), 'train/spearman': np.float64(0.6788367060617155), 'train/auc': 0.9708069507736993, 'test/pearson': np.float32(0.8449167), 'test/spearman': np.float64(0.56427010585535), 'test/auc': 0.9720894620189402}
===== EPOCH 9 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.27720004254684283
{'epoch_loss': 0.27720004254684283, 'train/pearson': np.float32(0.904986), 'train/spearman': np.float64(0.6757243779306136), 'train/auc': 0.9722310581401172, 'test/pearson': np.float32(0.84492576), 'test/spearman': np.float64(0.5549042979617482), 'test/auc': 0.9723136207938747}
===== EPOCH 10 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2760031698301387
{'epoch_loss': 0.2760031698301387, 'train/pearson': np.float32(0.90451133), 'train/spearman': np.float64(0.6802979953924149), 'train/auc': 0.9728693474886072, 'test/pearson': np.float32(0.8450246), 'test/spearman': np.float64(0.5661205794154608), 'test/auc': 0.9724957183155349}
===== EPOCH 11 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.27254241239577054
{'epoch_loss': 0.27254241239577054, 'train/pearson': np.float32(0.90549845), 'train/spearman': np.float64(0.67283890039926), 'train/auc': 0.9733779358004908, 'test/pearson': np.float32(0.8456644), 'test/spearman': np.float64(0.5625925175063349), 'test/auc': 0.970904694741084}
===== EPOCH 12 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2765250204684445
{'epoch_loss': 0.2765250204684445, 'train/pearson': np.float32(0.9061713), 'train/spearman': np.float64(0.6747468414528239), 'train/auc': 0.973988582302569, 'test/pearson': np.float32(0.8458596), 'test/spearman': np.float64(0.5636033408550187), 'test/auc': 0.970839210155148}
===== EPOCH 13 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.26927529502743347
{'epoch_loss': 0.26927529502743347, 'train/pearson': np.float32(0.9068386), 'train/spearman': np.float64(0.6768524725164083), 'train/auc': 0.973251139265862, 'test/pearson': np.float32(0.84645486), 'test/spearman': np.float64(0.5616681401093662), 'test/auc': 0.9716592786620994}
===== EPOCH 14 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2733293469435872
{'epoch_loss': 0.2733293469435872, 'train/pearson': np.float32(0.9078486), 'train/spearman': np.float64(0.6784334082709783), 'train/auc': 0.9760028043467375, 'test/pearson': np.float32(0.8471723), 'test/spearman': np.float64(0.5609142995222498), 'test/auc': 0.9734631271408422}
===== EPOCH 15 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2712958425926134
{'epoch_loss': 0.2712958425926134, 'train/pearson': np.float32(0.90816975), 'train/spearman': np.float64(0.6797044199579841), 'train/auc': 0.9762177374931144, 'test/pearson': np.float32(0.84793305), 'test/spearman': np.float64(0.5673611136054851), 'test/auc': 0.9731890993350796}
===== EPOCH 16 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2657534429154838
{'epoch_loss': 0.2657534429154838, 'train/pearson': np.float32(0.90859056), 'train/spearman': np.float64(0.6791757379953327), 'train/auc': 0.9769933396764985, 'test/pearson': np.float32(0.84751177), 'test/spearman': np.float64(0.5622106706566049), 'test/auc': 0.972113640942978}
===== EPOCH 17 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.26939377291038774
{'epoch_loss': 0.26939377291038774, 'train/pearson': np.float32(0.90991706), 'train/spearman': np.float64(0.683224009775188), 'train/auc': 0.9774607641844859, 'test/pearson': np.float32(0.8475874), 'test/spearman': np.float64(0.5602086236733373), 'test/auc': 0.9734485190409027}
===== EPOCH 18 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.26397507844831997
{'epoch_loss': 0.26397507844831997, 'train/pearson': np.float32(0.9102391), 'train/spearman': np.float64(0.6795684096689772), 'train/auc': 0.9777565226100455, 'test/pearson': np.float32(0.84937656), 'test/spearman': np.float64(0.5662027920717508), 'test/auc': 0.9723483779971791}
===== EPOCH 19 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.26581895797730637
{'epoch_loss': 0.26581895797730637, 'train/pearson': np.float32(0.9109106), 'train/spearman': np.float64(0.6783710479328765), 'train/auc': 0.9795786469026992, 'test/pearson': np.float32(0.85031575), 'test/spearman': np.float64(0.5610077743817299), 'test/auc': 0.9738192625428168}
===== EPOCH 20 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2678364893070425
{'epoch_loss': 0.2678364893070425, 'train/pearson': np.float32(0.91074336), 'train/spearman': np.float64(0.6780279783382819), 'train/auc': 0.9790511292503381, 'test/pearson': np.float32(0.8503076), 'test/spearman': np.float64(0.5684915630033925), 'test/auc': 0.9730767680838204}
===== EPOCH 21 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.27130262967877494
{'epoch_loss': 0.27130262967877494, 'train/pearson': np.float32(0.9113045), 'train/spearman': np.float64(0.678811797122209), 'train/auc': 0.9796534628674445, 'test/pearson': np.float32(0.8498423), 'test/spearman': np.float64(0.5592155246865382), 'test/auc': 0.971772113640943}
===== EPOCH 22 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.26726457079092913
{'epoch_loss': 0.26726457079092913, 'train/pearson': np.float32(0.9117268), 'train/spearman': np.float64(0.6805412052613931), 'train/auc': 0.9796955280685062, 'test/pearson': np.float32(0.84957814), 'test/spearman': np.float64(0.5584152563838717), 'test/auc': 0.9724365303244007}
===== EPOCH 23 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2633617053421351
{'epoch_loss': 0.2633617053421351, 'train/pearson': np.float32(0.9117291), 'train/spearman': np.float64(0.6787783330034389), 'train/auc': 0.9816429465671791, 'test/pearson': np.float32(0.8505726), 'test/spearman': np.float64(0.5612119574645478), 'test/auc': 0.9756795285109813}
===== EPOCH 24 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2660055537693226
{'epoch_loss': 0.2660055537693226, 'train/pearson': np.float32(0.9120693), 'train/spearman': np.float64(0.6745838947635675), 'train/auc': 0.9812741749712054, 'test/pearson': np.float32(0.8508736), 'test/spearman': np.float64(0.5593325062337966), 'test/auc': 0.9746292564980858}
===== EPOCH 25 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2713119246946356
{'epoch_loss': 0.2713119246946356, 'train/pearson': np.float32(0.9121311), 'train/spearman': np.float64(0.6795071859338762), 'train/auc': 0.9811225399369022, 'test/pearson': np.float32(0.8506078), 'test/spearman': np.float64(0.5638195505273588), 'test/auc': 0.9744922425952046}
===== EPOCH 26 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2684171307391633
{'epoch_loss': 0.2684171307391633, 'train/pearson': np.float32(0.9124301), 'train/spearman': np.float64(0.6804459972181282), 'train/auc': 0.9812100756172067, 'test/pearson': np.float32(0.85117096), 'test/spearman': np.float64(0.5680159458238376), 'test/auc': 0.975279568809188}
===== EPOCH 27 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.26545631666533864
{'epoch_loss': 0.26545631666533864, 'train/pearson': np.float32(0.9114306), 'train/spearman': np.float64(0.6790462062669678), 'train/auc': 0.9810157744503981, 'test/pearson': np.float32(0.8499304), 'test/spearman': np.float64(0.5645946870199788), 'test/auc': 0.9746156558533146}
Early stopping on epoch 27


0,1
epoch_loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test/auc,▁▄▃▅▄▅▅▅▆▆▅▅▅▆▆▅▆▆▇▆▅▆█▇▇█▇
test/pearson,▁▃▄▅▅▅▅▅▅▅▆▆▆▆▇▆▇▇██▇▇████▇
test/spearman,█▇█▇▃▂▆▅▁▆▄▅▄▃▆▄▃▆▃▇▃▂▄▃▅▆▅
train/auc,▁▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇▇█████
train/pearson,▁▃▄▄▅▅▅▅▆▅▆▆▆▆▇▇▇▇▇▇███████
train/spearman,▁▅▇▇▇▇▇█▇█▇▇▇██████▇███▇███

0,1
epoch_loss,0.26546
test/auc,0.97462
test/pearson,0.84993
test/spearman,0.56459
train/auc,0.98102
train/pearson,0.91143
train/spearman,0.67905


Calibration with: {'conv_init_weight_std': 0.0001, 'dropout': 0.5, 'fc_init_weight_std': 0.001, 'lr': 0.0005, 'motif_len': 16, 'num_motif_detectors': 16, 'weight_decay': 0.005}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.6062489719793629
{'epoch_loss': 0.6062489719793629, 'train/pearson': np.float32(0.89241856), 'train/spearman': np.float64(0.6474932994163616), 'train/auc': 0.9606728429065051, 'test/pearson': np.float32(0.8335815), 'test/spearman': np.float64(0.5025623402647148), 'test/auc': 0.9645763651017529}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.30798294915320773
{'epoch_loss': 0.30798294915320773, 'train/pearson': np.float32(0.9010001), 'train/spearman': np.float64(0.6771286799941374), 'train/auc': 0.9652305072862937, 'test/pearson': np.float32(0.8421031), 'test/spearman': np.float64(0.5461047930259072), 'test/auc': 0.9676460809993955}
===== EPOCH 3 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3018439589574124
{'epoch_loss': 0.3018439589574124, 'train/pearson': np.float32(0.9044434), 'train/spearman': np.float64(0.6923840216205547), 'train/auc': 0.9680274425359306, 'test/pearson': np.float32(0.8437598), 'test/spearman': np.float64(0.5599746070381335), 'test/auc': 0.9680284102357445}
===== EPOCH 4 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.29735624828277685
{'epoch_loss': 0.29735624828277685, 'train/pearson': np.float32(0.9068232), 'train/spearman': np.float64(0.7006607839428022), 'train/auc': 0.9680809254344234, 'test/pearson': np.float32(0.8456257), 'test/spearman': np.float64(0.5668550240713007), 'test/auc': 0.9684329034857948}
===== EPOCH 5 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2983893880162376
{'epoch_loss': 0.2983893880162376, 'train/pearson': np.float32(0.9074981), 'train/spearman': np.float64(0.7017526970427305), 'train/auc': 0.96911252441284, 'test/pearson': np.float32(0.84675294), 'test/spearman': np.float64(0.5720204608991895), 'test/auc': 0.9684747128752771}
===== EPOCH 6 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2915784969878273
{'epoch_loss': 0.2915784969878273, 'train/pearson': np.float32(0.9087777), 'train/spearman': np.float64(0.7040543387677155), 'train/auc': 0.9703163904051278, 'test/pearson': np.float32(0.84698045), 'test/spearman': np.float64(0.5759186472504901), 'test/auc': 0.9691764053999599}
===== EPOCH 7 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.29415902613212885
{'epoch_loss': 0.29415902613212885, 'train/pearson': np.float32(0.9092964), 'train/spearman': np.float64(0.7089815364479397), 'train/auc': 0.9710994040763183, 'test/pearson': np.float32(0.8474774), 'test/spearman': np.float64(0.5758228484663749), 'test/auc': 0.9701934313923031}
===== EPOCH 8 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2941242610446561
{'epoch_loss': 0.2941242610446561, 'train/pearson': np.float32(0.9090152), 'train/spearman': np.float64(0.7031145520350925), 'train/auc': 0.9718151134258101, 'test/pearson': np.float32(0.84675455), 'test/spearman': np.float64(0.5729337854068208), 'test/auc': 0.9691854724964739}
===== EPOCH 9 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2897703921951996
{'epoch_loss': 0.2897703921951996, 'train/pearson': np.float32(0.9088585), 'train/spearman': np.float64(0.7089409292516776), 'train/auc': 0.9698827182132306, 'test/pearson': np.float32(0.8469579), 'test/spearman': np.float64(0.5760767820398214), 'test/auc': 0.9677871247229498}
===== EPOCH 10 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2981740154159336
{'epoch_loss': 0.2981740154159336, 'train/pearson': np.float32(0.90997016), 'train/spearman': np.float64(0.7024123275289678), 'train/auc': 0.9733715258650909, 'test/pearson': np.float32(0.8482745), 'test/spearman': np.float64(0.5763382825442119), 'test/auc': 0.9718315534958695}
===== EPOCH 11 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.29135661855482825
{'epoch_loss': 0.29135661855482825, 'train/pearson': np.float32(0.910014), 'train/spearman': np.float64(0.7035671680582769), 'train/auc': 0.9734041764735339, 'test/pearson': np.float32(0.8478336), 'test/spearman': np.float64(0.5736189979794628), 'test/auc': 0.9715615555107798}
===== EPOCH 12 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.29305084434132606
{'epoch_loss': 0.29305084434132606, 'train/pearson': np.float32(0.91003495), 'train/spearman': np.float64(0.7103325676147876), 'train/auc': 0.9722706194601632, 'test/pearson': np.float32(0.8467384), 'test/spearman': np.float64(0.5723527749285633), 'test/auc': 0.9705853314527504}
Early stopping on epoch 12


0,1
epoch_loss,█▁▁▁▁▁▁▁▁▁▁▁
test/auc,▁▄▄▅▅▅▆▅▄██▇
test/pearson,▁▅▆▇▇▇█▇▇██▇
test/spearman,▁▅▆▇████████
train/auc,▁▄▅▅▆▆▇▇▆██▇
train/pearson,▁▄▆▇▇███████
train/spearman,▁▄▆▇▇▇█▇█▇▇█

0,1
epoch_loss,0.29305
test/auc,0.97059
test/pearson,0.84674
test/spearman,0.57235
train/auc,0.97227
train/pearson,0.91003
train/spearman,0.71033


Calibration with: {'conv_init_weight_std': 0.0001, 'dropout': 0.5, 'fc_init_weight_std': 0.001, 'lr': 0.0005, 'motif_len': 24, 'num_motif_detectors': 16, 'weight_decay': 0.001}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.8858649511211596
{'epoch_loss': 0.8858649511211596, 'train/pearson': np.float32(0.8602752), 'train/spearman': np.float64(0.6437972557119803), 'train/auc': 0.9512933046221642, 'test/pearson': np.float32(0.7748276), 'test/spearman': np.float64(0.4524364382305801), 'test/auc': 0.9508044529518437}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.4759619036040748
{'epoch_loss': 0.4759619036040748, 'train/pearson': np.float32(0.891678), 'train/spearman': np.float64(0.6569232036266895), 'train/auc': 0.9614434373278582, 'test/pearson': np.float32(0.8262202), 'test/spearman': np.float64(0.4972216444414766), 'test/auc': 0.9669620189401571}
===== EPOCH 3 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.41592702643273355
{'epoch_loss': 0.41592702643273355, 'train/pearson': np.float32(0.89959437), 'train/spearman': np.float64(0.6673034026523025), 'train/auc': 0.964235364815464, 'test/pearson': np.float32(0.83718795), 'test/spearman': np.float64(0.5248175763481092), 'test/auc': 0.9672410840217611}
===== EPOCH 4 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.41950875706375595
{'epoch_loss': 0.41950875706375595, 'train/pearson': np.float32(0.9005486), 'train/spearman': np.float64(0.6619633332916222), 'train/auc': 0.9652317091491812, 'test/pearson': np.float32(0.8354554), 'test/spearman': np.float64(0.5093623791586466), 'test/auc': 0.9651445698166432}
Early stopping on epoch 4


0,1
epoch_loss,█▂▁▁
test/auc,▁██▇
test/pearson,▁▇██
test/spearman,▁▅█▇
train/auc,▁▆▇█
train/pearson,▁▆██
train/spearman,▁▅█▆

0,1
epoch_loss,0.41951
test/auc,0.96514
test/pearson,0.83546
test/spearman,0.50936
train/auc,0.96523
train/pearson,0.90055
train/spearman,0.66196


Calibration with: {'conv_init_weight_std': 0.0001, 'dropout': 0.5, 'fc_init_weight_std': 0.001, 'lr': 0.0005, 'motif_len': 24, 'num_motif_detectors': 16, 'weight_decay': 0.005}


===== EPOCH 1 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.6061417737041419
{'epoch_loss': 0.6061417737041419, 'train/pearson': np.float32(0.8943576), 'train/spearman': np.float64(0.6119071551775338), 'train/auc': 0.9641207872201912, 'test/pearson': np.float32(0.83472437), 'test/spearman': np.float64(0.4892624524968567), 'test/auc': 0.9660311303646988}
===== EPOCH 2 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3138051157989822
{'epoch_loss': 0.3138051157989822, 'train/pearson': np.float32(0.89791316), 'train/spearman': np.float64(0.6326439886305523), 'train/auc': 0.9680809254344234, 'test/pearson': np.float32(0.838981), 'test/spearman': np.float64(0.5086084449516942), 'test/auc': 0.969148700382833}
===== EPOCH 3 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.304343284961705
{'epoch_loss': 0.304343284961705, 'train/pearson': np.float32(0.90144867), 'train/spearman': np.float64(0.6567147946975039), 'train/auc': 0.9680608943862988, 'test/pearson': np.float32(0.8410294), 'test/spearman': np.float64(0.5297138040853995), 'test/auc': 0.9720400967156961}
===== EPOCH 4 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.30237669646501925
{'epoch_loss': 0.30237669646501925, 'train/pearson': np.float32(0.9016481), 'train/spearman': np.float64(0.6615661595682417), 'train/auc': 0.9687417497120536, 'test/pearson': np.float32(0.8414036), 'test/spearman': np.float64(0.5466254202867272), 'test/auc': 0.9697466250251864}
===== EPOCH 5 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.29775918194185047
{'epoch_loss': 0.29775918194185047, 'train/pearson': np.float32(0.90249807), 'train/spearman': np.float64(0.6686460945916577), 'train/auc': 0.9701092693675195, 'test/pearson': np.float32(0.8417609), 'test/spearman': np.float64(0.5508323397752936), 'test/auc': 0.9695390892605279}
===== EPOCH 6 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2982098839391535
{'epoch_loss': 0.2982098839391535, 'train/pearson': np.float32(0.9039859), 'train/spearman': np.float64(0.6725063069421072), 'train/auc': 0.971840352546447, 'test/pearson': np.float32(0.8440846), 'test/spearman': np.float64(0.5610749722506462), 'test/auc': 0.9721645174289744}
===== EPOCH 7 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.29512254977092955
{'epoch_loss': 0.29512254977092955, 'train/pearson': np.float32(0.9050421), 'train/spearman': np.float64(0.6745662392244314), 'train/auc': 0.9710094646702387, 'test/pearson': np.float32(0.8435872), 'test/spearman': np.float64(0.5614744690553705), 'test/auc': 0.9702095506749949}
===== EPOCH 8 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.3044384857074331
{'epoch_loss': 0.3044384857074331, 'train/pearson': np.float32(0.90521437), 'train/spearman': np.float64(0.6755874515715988), 'train/auc': 0.971183734788923, 'test/pearson': np.float32(0.84401804), 'test/spearman': np.float64(0.5617944522609123), 'test/auc': 0.9696786218013298}
===== EPOCH 9 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.30017081679056246
{'epoch_loss': 0.30017081679056246, 'train/pearson': np.float32(0.90542424), 'train/spearman': np.float64(0.677599730420691), 'train/auc': 0.9719226801542391, 'test/pearson': np.float32(0.84377974), 'test/spearman': np.float64(0.5632639506517343), 'test/auc': 0.9707299012693935}
===== EPOCH 10 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.29264657313640885
{'epoch_loss': 0.29264657313640885, 'train/pearson': np.float32(0.90690154), 'train/spearman': np.float64(0.6852060478851444), 'train/auc': 0.9728915819520256, 'test/pearson': np.float32(0.8452103), 'test/spearman': np.float64(0.5689729612687193), 'test/auc': 0.9706553495869434}
===== EPOCH 11 =====


  0%|          | 0/626 [00:00<?, ?it/s]

Loss: 0.2992774634220349
{'epoch_loss': 0.2992774634220349, 'train/pearson': np.float32(0.90569854), 'train/spearman': np.float64(0.6811050498166905), 'train/auc': 0.9731638038960388, 'test/pearson': np.float32(0.8436426), 'test/spearman': np.float64(0.5677288432475684), 'test/auc': 0.9731649204110417}
Early stopping on epoch 11


0,1
epoch_loss,█▁▁▁▁▁▁▁▁▁▁
test/auc,▁▄▇▅▄▇▅▅▆▆█
test/pearson,▁▄▅▅▆▇▇▇▇█▇
test/spearman,▁▃▅▆▆▇▇▇▇██
train/auc,▁▄▄▅▆▇▆▆▇██
train/pearson,▁▃▅▅▆▆▇▇▇█▇
train/spearman,▁▃▅▆▆▇▇▇▇██

0,1
epoch_loss,0.29928
test/auc,0.97316
test/pearson,0.84364
test/spearman,0.56773
train/auc,0.97316
train/pearson,0.9057
train/spearman,0.68111


In [None]:
import wandb
run = wandb.init()
artifact = run.use_artifact('vvaza22-free-university-of-tbilisi/CompBioFinal/model_epoch_25:v72', type='model')
artifact_dir = artifact.download()

In [None]:
path = f"{artifact_dir}/model_epoch_25.pt"
path

In [None]:
mw = load_snapshot(path)

In [None]:
mw.evaluate(x_test, y_test, label_test)

### Build PWM matrix from Conv Layer

In [None]:
conv_weights = mw.model.conv.weight.detach().cpu().numpy()
conv_weights.shape

In [None]:
def filter_to_pwm(conv_filter):
    pwm = np.exp(conv_filter)
    pwm = pwm / pwm.sum(axis=0, keepdims=True)
    return pwm

def filter_to_pwm_df(conv_filter):
    pwm = filter_to_pwm(conv_filter)
    return pd.DataFrame(pwm.T, columns=['A','C','G','T'], dtype='float64')

In [None]:
pwm_df = filter_to_pwm_df(conv_weights[15])
pwm_df

In [None]:
info_df = logomaker.transform_matrix(pwm_df, from_type='probability', to_type='information')
info_df

In [None]:
fig, ax = plt.subplots(figsize=(10, 3))
    
logo = logomaker.Logo(info_df,
                      ax=ax,
                      color_scheme='classic',
                      vpad=.1,
                      width=.8)

logo.style_spines(visible=False)
logo.style_spines(spines=['left', 'bottom'], visible=True)
ax.set_ylabel("Information (Bits)")
ax.set_xlabel("Position in Motif")
#ax.set_ylim([0, 2])

plt.show()