# Basic settings and Importing required packages

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install nilearn
!pip install einops
!pip install beartype
!pip install numpy==1.23.0
import sys
sys.path.append('/content/drive/MyDrive/GraduationStudy/HOT')



In [3]:
import os
import torch
import math
import numpy as np
import pandas as pd
import torch.nn as nn
import networkx as nx
from einops import *
from tqdm import tqdm
from google.colab import drive
import random
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import MultiheadAttention
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score
import torch.utils.checkpoint
import zipfile
import gzip
import shutil


from models.MemoryModel import MemoryModel, compute_src_dst_node_time_shifts
from models.modules import TimeEncoder
from utils.utils import NeighborSampler
from utils.utils import get_neighbor_sampler, NegativeEdgeSampler
from utils.DataLoader import get_idx_data_loader
from models.block_recurrent_transformer_pytorch import BlockRecurrentTransformer
from nilearn import image, maskers, datasets

import warnings
warnings.simplefilter(action='ignore', category=DeprecationWarning)
%cd drive
%cd MyDrive
%cd GraduationStudy


/content/drive
/content/drive/MyDrive
/content/drive/MyDrive/GraduationStudy


# Saving data & Data preprocessing

The helper functions are defined in the cell below.

In [None]:
# Normalization of the fMRI timeseries
# Shape of timeseries : (# of timestamps, # of ROIs)
def normalize(ts):
  return (ts-torch.mean(ts, dim=0, keepdims=True))/(torch.std(ts, dim=0, keepdims=True)+1e-10)

# corrcoef based on
# https://github.com/pytorch/pytorch/issues/1254
def corrcoef(x):
    mean_x = torch.mean(x, 1, keepdim=True)
    xm = x.sub(mean_x.expand_as(x))
    c = xm.mm(xm.t())
    c = c / (x.size(1) - 1)
    d = torch.diag(c)
    stddev = torch.pow(d, 0.5)
    c = c.div(stddev.expand_as(c))
    c = c.div(stddev.expand_as(c).t())
    c = torch.clamp(c, -1.0, 1.0)
    return c

# process_dynamic_fc
# Source : https://github.com/egyptdj/stagin/blob/main/util/bold.py
def get_fc(timeseries, sampling_point, window_size, self_loop):
    fc = corrcoef(timeseries[sampling_point:sampling_point+window_size].T)
    if not self_loop: fc-= torch.eye(fc.shape[0])
    return fc

def get_minibatch_fc(minibatch_timeseries, sampling_point, window_size, self_loop):
    fc_list = []
    for timeseries in minibatch_timeseries:
        fc = get_fc(timeseries, sampling_point, window_size, self_loop)
        fc_list.append(fc)
    return torch.stack(fc_list)


def process_dynamic_fc(minibatch_timeseries, window_size, window_stride, dynamic_length=None, sampling_init=None, self_loop=True):
    # assumes input shape [minibatch x time x node]
    # output shape [minibatch x time x node x node]
    if dynamic_length is None:
        dynamic_length = minibatch_timeseries.shape[1]
        sampling_init = 0
    else:
        if isinstance(sampling_init, int):
            assert minibatch_timeseries.shape[1] > sampling_init + dynamic_length
    assert sampling_init is None or isinstance(sampling_init, int)
    assert minibatch_timeseries.ndim==3
    assert dynamic_length > window_size

    if sampling_init is None:
        sampling_init = randrange(minibatch_timeseries.shape[1]-dynamic_length+1)
    sampling_points = list(range(sampling_init, sampling_init+dynamic_length-window_size, window_stride))

    minibatch_fc_list = [get_minibatch_fc(minibatch_timeseries, sampling_point, window_size, self_loop) for sampling_point in sampling_points]
    dynamic_fc = torch.stack(minibatch_fc_list, dim=1)

    return dynamic_fc, sampling_points

def readtxt(filePath):
  # Read the .txt file
  data = []
  with open(filePath, 'r') as file:
      for line in file:
          # Split the line by whitespace or commas and convert to float
          row = [float(value) for value in line.split()]
          data.append(row)

  # Convert list to a PyTorch tensor
  tensor = torch.tensor(data)

  return tensor



The fMRI data should be downloaded from ConnectomeDB website.
(https://db.humanconnectome.org/app/template/Login.vm)

The directory tree should be constructed like the following diagram.

MyDrive\
$\quad$├─GraduationStudy\
$\quad$$\quad$$\quad$      ├─Emotion Task Original Files\
$\quad$$\quad$$\quad$$\quad$$\quad$            ├─######_3T_tfMRI_EMOTION_preproc.zip\
$\quad$$\quad$$\quad$$\quad$$\quad$            ├─@@@@@@_3T_tfMRI_EMOTION_preproc.zip\
$\quad$$\quad$$\quad$$\quad$$\quad$            ├─......\
$\quad$$\quad$$\quad$$\quad$$\quad$            ├─&&&&&&_3T_tfMRI_EMOTION_preproc.zip\
$\quad$$\quad$$\quad$├─DatasetSplit


The cell below constructs two 4D tensors of shape (100, 15, 116, 116)\
"tripletTensor40.pth" holds the 40th Quantile Dataset.\
"tripletTensor80.pth" holds the 80th Quantile Dataset.

In [None]:
baseDir='Emotion Task Original Files'
prefix='MNINonLinear/Results/tfMRI_EMOTION_LR'
subjectIDs=[]
for zipFile in os.listdir(baseDir):
  if zipFile[-1]!='p':
    continue
  subjectID=zipFile[0:6]
  print(subjectID)
  subjectIDs.append(subjectID)
  files_to_extract = [os.path.join(subjectID, prefix,'EVs/fear.txt'), os.path.join(subjectID, prefix, 'tfMRI_EMOTION_LR.nii.gz')]
  with zipfile.ZipFile(os.path.join(baseDir,zipFile), 'r') as zip_ref:
    for file in files_to_extract:
        zip_ref.extract(file, path=os.path.join(baseDir, subjectID))

for ID in subjectIDs:
  gzFile = os.path.join(baseDir, ID, ID, prefix,'tfMRI_EMOTION_LR.nii.gz')
  niiFile = os.path.join(baseDir, ID, 'tfMRI_EMOTION_LR.nii')
  txtFile = os.path.join(baseDir, ID, ID, prefix, 'EVs/fear.txt')
  with gzip.open(gzFile, 'rb') as f_in:
    with open(niiFile, 'wb') as f_out:
        shutil.copyfileobj(f_in, f_out)
  shutil.copyfile(txtFile, os.path.join(baseDir,ID,'fear.txt'))

atlasData=datasets.fetch_atlas_aal(data_dir=baseDir)
for ID in subjectIDs:
  img=image.load_img(os.path.join(baseDir, ID, 'tfMRI_EMOTION_LR.nii'))
  masker = maskers.NiftiLabelsMasker(image.load_img(atlasData['maps']))
  timeseries = masker.fit_transform(img) # (numTP, N=116)
  torch.save(timeseries, os.path.join(baseDir, ID, 'Emotion_Task.pth'))

windowSize=20
triplet40=[]
triplet80=[]
for ID in subjectIDs:
  timeSeries=torch.from_numpy(torch.load(os.path.join(baseDir, ID, 'Emotion_Task.pth')))
  timeSeries=normalize(timeSeries)
  print(timeSeries.shape)
  onsetFile=os.path.join(baseDir, ID, 'fear.txt')
  onsetTensor=readtxt(onsetFile)
  onsetStamp1=math.floor(onsetTensor[0,0]/0.72)
  onsetStamp2=math.floor(onsetTensor[1,0]/0.72)
  onsetStamp3=math.floor(onsetTensor[2,0]/0.72)

  fcList=[]
  onsets=[onsetStamp1, onsetStamp2, onsetStamp3]
  for onset in onsets:
    for stride in range(-3, 12, 3):
      fc=get_fc(timeSeries, onset+stride, windowSize, True)
      fcList.append(fc)
  fcSubject=torch.stack(fcList)

  samplePoints=[]
  for onset in onsets:
    for stride in range(-3, 12, 3):
      samplePoints.append(onset+stride)

  triplet40S=[]
  triplet80S=[]
  for sampleidx, fc2D in enumerate(fcSubject):
    triplet80List=[]
    fc80=fc2D[40].clone()
    fcSort80=fc80.sort(descending=True)
    fc80[fcSort80.indices[math.ceil(116*0.8):]]=0
    idx80=torch.nonzero(fc80)[:,0]
    idx1D=samplePoints[sampleidx]*torch.ones_like(idx80)
    for i in range(idx80.shape[0]):
      triplet80LIst.append([40, idx80[i].item(), idx1D[i].item()])

    fcTemp=fc2D.flatten()
    fcSort=fcTemp.sort(descending=True)
    fcTemp[fcSort.indices[math.ceil(116*116*0.4):]]=0
    fcMasked=fcTemp.reshape(fc2D.shape)
    idx2D, idx3D=torch.nonzero(fcMasked, as_tuple=True)
    idx1D=samplePoints[sampleidx]*torch.ones_like(idx2D)
    assert idx1D.shape[0]==math.ceil(116*116*0.4)
    for j in range(idx1D.shape[0]):
      triplet40S.append([idx2D[j].item(), idx3D[j].item(), idx1D[j].item()])
    for j in range(idx1D.shape[0]):
      if idx2D[j]!=40:
        triplet80List.append([idx2D[j].item(), idx3D[j].item(), idx1D[j].item()])

    if len(triplet80S)>math.ceil(116*116*0.4):
      subVal=len(triplet80S)-math.ceil(116*116*0.4)
      for k in range(subVal):
        valRemove=random.randint(0, len(triplet80S)-1)
        triplet80S.pop(valRemove)
      while not (torch.tensor(triplet80S)[:,0].unique().shape[0]==116 and torch.tensor(triplet80S)[:,1].unique().shape[0]==116):
        if torch.tensor(triplet80S)[:,0].unique().shape[0]!=116:
          triplet80S.pop(random.randint(0, len(triplet80S)-1))
          missing=[i for i in range(116) if i not in (torch.tensor(triplet80S))[:,0].unique()]
          for k in range(len(missing)):
            triplet80S.append([missing[k], random.randint(0, 115), samplePoints[sampleidx]])
        if torch.tensor(triplet80S)[:,0].unique().shape[0]==116 and torch.tensor(triplet80S)[:,1].unique().shape[0]==116:
          break
        if torch.tensor(triplet80S)[:,1].unique().shape[0]!=116:
          triplet80S.pop(random.randint(0, len(triplet80S)-1))
          missing=[i for i in range(116) if i not in (torch.tensor(triplet80S))[:,1].unique()]
          for k in range(len(missing)):
            triplet80S.append([random.randint(0, 115), missing[k], samplePoints[sampleidx]])
    elif len(triplet80S)<math.ceil(116*116*0.4):
      addVal=math.ceil(116*116*0.4)-len(triplet80S)
      for k in range(addVal):
        val1=random.randint(0, 115)
        while val1==40:
          val1=random.randint(0, 115)
        valZero=torch.nonzero(fc2D[fc2D==val1])[:,0]
        cand=[i for i in range(116) if i not in valZero]
        val2=random.choice(cand)
        triplet80S.append([val1, val2, samplePoints[sampleidx]])

    triplet403D=torch.tensor(triplet40S).to('cuda')
    triplet40S.append(triplet403D)

    assert len(triplet80S)==math.ceil(116*116*0.4)
    triplet803D=torch.tensor(triplet80S).to('cuda')
    assert torch.unique(triplet803D[:,0]).shape[0]==116
    assert torch.unique(triplet803D[:,1]).shape[0]==116

    triplet80S.append(triplet803D)

  triplet40S=torch.stack(triplet40S).to('cuda')
  triplet40.append(triplet40S)

  triplet80S=torch.tensor(triplet80S).to('cuda')
  triplet80.append(triplet80S)

tripletTensor=torch.stack(triplet40).to('cuda')
torch.save(tripletTensor, os.path.join('DatasetSplit','tripletTensor40.pth'))
triplet80Tensor=torch.stack(triplet80).to('cuda')
torch.save(triplet80Tensor, os.path.join('DatasetSplit','tripletTensor80.pth'))

The cell below divides the whole dataset tensor into training, validation, and test data.\
"train@.pth", "valid@.pth", and "test@.pth" holds the 40th Quantile  train, validation, and test data of @-th subject, respectively.\
"train80@.pth", "valid80@.pth", and "test80@.pth" holds the 80th Quantile  train, validation, and test data of @-th subject, respectively.

In [None]:
numList=[i for i in range(100)]
random.shuffle(numList)
trainIdx=numList[0:16]
validIdx=numList[16:24]
testIdx=numList[24:34]


triplet40s=torch.load(os.path.join('DatasetSplit','tripletTensor40.pth')).to('cuda')
triplet80s=torch.load(os.path.join('DatasetSplit','tripletTensor80.pth')).to('cuda')
trainTriplet=triplet40s[trainIdx]
testTriplet=triplet40s[testIdx]
validTriplet=triplet40s[validIdx]
train80=triplet80s[trainIdx]
test80=triplet80s[testIdx]
valid80=triplet80s[validIdx]
for i in range(16):
  torch.save(trainTriplet[i].squeeze(0), os.path.join('DatasetSplit','train'+str(i)+'.pth'))
  torch.save(train80[i].squeeze(0), os.path.join('DatasetSplit','train80'+str(i)+'.pth'))
for i in range(8):
  torch.save(validTriplet[i].squeeze(0), os.path.join('DatasetSplit','valid'+str(i)+'.pth'))
  torch.save(valid80[i].squeeze(0), os.path.join('DatasetSplit','valid80'+str(i)+'.pth'))
for i in range(10):
  torch.save(testTriplet[i].squeeze(0), os.path.join('DatasetSplit','test'+str(i)+'.pth'))
  torch.save(test80[i].squeeze(0), os.path.join('DatasetSplit','test80'+str(i)+'.pth'))

  triplet4D=torch.load(os.path.join('DatasetSplit','tripletTensor.pth')).to('cuda')
  triplet40s=torch.load(os.path.join('DatasetSplit','triplet40Tensor.pth')).to('cuda')


# Models - HOT and Triadic Decoder

The cell below is a definition of the HOT encoder, introduced in [Besta et al., 2024](https://arxiv.org/abs/2311.18526).

In [None]:
class HOT(nn.Module):

    def __init__(self, node_raw_features: np.ndarray, edge_raw_features: np.ndarray, neighbor_sampler: NeighborSampler,
                 time_feat_dim: int, channel_embedding_dim: int, patch_size: int = 1, num_layers: int = 2, num_heads: int = 4,
                 block_size: int = 16, num_state_vectors: int = 32, segment_size: int = 32, num2hop: int = 0,
                 dropout: float = 0.1, max_input_sequence_length: int = 65536, device: str = 'cpu'):
        """
        DyGFormer model.
        :param node_raw_features: ndarray, shape (num_nodes + 1, node_feat_dim)
        :param edge_raw_features: ndarray, shape (num_edges + 1, edge_feat_dim)
        :param neighbor_sampler: neighbor sampler
        :param time_feat_dim: int, dimension of time features (encodings)
        :param channel_embedding_dim: int, dimension of each channel embedding
        :param patch_size: int, patch size

        :param num_layers: int, number of transformer layers
        :param num_heads: int, number of attention heads
        :param dropout: float, dropout rate
        :param max_input_sequence_length: int, maximal length of the input sequence for each node
        :param device: str, device
        """
        super(HOT, self).__init__()

        self.node_raw_features = torch.from_numpy(node_raw_features.astype(np.float32)).to(device)
        self.edge_raw_features = torch.from_numpy(edge_raw_features.astype(np.float32)).to(device)

        self.neighbor_sampler = neighbor_sampler
        self.node_feat_dim = self.node_raw_features.shape[1] + 2
        self.edge_feat_dim = self.edge_raw_features.shape[1]
        self.time_feat_dim = time_feat_dim
        self.channel_embedding_dim = channel_embedding_dim
        self.patch_size = patch_size
        self.num_layers = num_layers
        self.block_size = block_size
        self.segment_size = segment_size
        self.num_state_vectors = num_state_vectors
        self.num_heads = num_heads
        self.dropout = dropout
        self.max_input_sequence_length = max_input_sequence_length
        self.device = device

        self.num2hop = num2hop

        self.time_encoder = TimeEncoder(time_dim=time_feat_dim).to(self.device)

        self.neighbor_co_occurrence_feat_dim = self.channel_embedding_dim
        self.neighbor_co_occurrence_encoder = NeighborCooccurrenceEncoder(neighbor_co_occurrence_feat_dim=self.neighbor_co_occurrence_feat_dim, device=self.device)

        self.projection_layer = nn.ModuleDict({
            'node': nn.Linear(in_features=self.patch_size * self.node_feat_dim, out_features=self.channel_embedding_dim, bias=True),
            'edge': nn.Linear(in_features=self.patch_size * self.edge_feat_dim, out_features=self.channel_embedding_dim, bias=True),
            'time': nn.Linear(in_features=self.patch_size * self.time_feat_dim, out_features=self.channel_embedding_dim, bias=True),
            'neighbor_co_occurrence': nn.Linear(in_features=self.patch_size * self.neighbor_co_occurrence_feat_dim, out_features=self.channel_embedding_dim, bias=True)
        }).to(self.device)

        self.num_channels = 4

        self.brt = BlockRecurrentTransformer(
            dim = 2 * self.num_channels * self.channel_embedding_dim,
            depth = self.num_layers,
            dim_head = self.num_channels * self.channel_embedding_dim // (self.num_heads // 2),
            heads = self.num_heads,
            max_seq_len = self.segment_size,
            block_width = self.block_size,
            num_state_vectors = self.num_state_vectors,
            recurrent_layers = (1, ),
            use_compressed_mem = False,
            use_flash_attn = False
        ).to(self.device)

        self.output_layer = nn.Linear(in_features=self.num_channels * self.channel_embedding_dim, out_features=self.node_raw_features.shape[1], bias=True).to(self.device)

    def compute_src_dst_node_temporal_embeddings(self, src_node_ids: np.ndarray, dst_node_ids: np.ndarray, node_interact_times: np.ndarray):
        """
        compute source and destination node temporal embeddings
        :param src_node_ids: ndarray, shape (batch_size, )
        :param dst_node_ids: ndarray, shape (batch_size, )
        :param node_interact_times: ndarray, shape (batch_size, )
        :return:
        """

        # Extract 1-hop and 2-hop interactions
        src_firt_hop_lengths, src_nodes_neighbor_ids_list, src_nodes_edge_ids_list, src_nodes_neighbor_times_list = \
            self.neighbor_sampler.get_all_second_hop_neighbors(node_ids=src_node_ids, node_interact_times=node_interact_times,
                                                                max_seq_len=self.max_input_sequence_length, max_2hop=self.num2hop)

        dst_firt_hop_lengths, dst_nodes_neighbor_ids_list, dst_nodes_edge_ids_list, dst_nodes_neighbor_times_list = \
            self.neighbor_sampler.get_all_second_hop_neighbors(node_ids=dst_node_ids, node_interact_times=node_interact_times,
                                                                max_seq_len=self.max_input_sequence_length, max_2hop=self.num2hop)

        # pad the sequences
        src_padded_nodes_neighbor_ids, src_padded_nodes_edge_ids, src_padded_nodes_neighbor_times = \
            self.pad_sequences(node_ids=src_node_ids, node_interact_times=node_interact_times, nodes_neighbor_ids_list=src_nodes_neighbor_ids_list,
                               nodes_edge_ids_list=src_nodes_edge_ids_list, nodes_neighbor_times_list=src_nodes_neighbor_times_list,
                               patch_size=self.patch_size, max_input_sequence_length=1048576)

        dst_padded_nodes_neighbor_ids, dst_padded_nodes_edge_ids, dst_padded_nodes_neighbor_times = \
            self.pad_sequences(node_ids=dst_node_ids, node_interact_times=node_interact_times, nodes_neighbor_ids_list=dst_nodes_neighbor_ids_list,
                               nodes_edge_ids_list=dst_nodes_edge_ids_list, nodes_neighbor_times_list=dst_nodes_neighbor_times_list,
                               patch_size=self.patch_size, max_input_sequence_length=1048576)

        # get HO encoding
        src_padded_nodes_neighbor_co_occurrence_features, dst_padded_nodes_neighbor_co_occurrence_features = \
            self.neighbor_co_occurrence_encoder(src_padded_nodes_neighbor_ids=src_padded_nodes_neighbor_ids,
                                                dst_padded_nodes_neighbor_ids=dst_padded_nodes_neighbor_ids)

        # get features
        src_padded_nodes_neighbor_node_raw_features, src_padded_nodes_edge_raw_features, src_padded_nodes_neighbor_time_features = \
            self.get_features(node_interact_times=node_interact_times, padded_nodes_neighbor_ids=src_padded_nodes_neighbor_ids,
                              padded_nodes_edge_ids=src_padded_nodes_edge_ids, padded_nodes_neighbor_times=src_padded_nodes_neighbor_times, time_encoder=self.time_encoder)

        dst_padded_nodes_neighbor_node_raw_features, dst_padded_nodes_edge_raw_features, dst_padded_nodes_neighbor_time_features = \
            self.get_features(node_interact_times=node_interact_times, padded_nodes_neighbor_ids=dst_padded_nodes_neighbor_ids,
                              padded_nodes_edge_ids=dst_padded_nodes_edge_ids, padded_nodes_neighbor_times=dst_padded_nodes_neighbor_times, time_encoder=self.time_encoder)


        # Add one-hot encoding
        batch_size = len(src_padded_nodes_neighbor_node_raw_features)

        new_src_padded_nodes_neighbor_node_raw_features = torch.empty((src_padded_nodes_neighbor_node_raw_features.shape[0],
                                                                       src_padded_nodes_neighbor_node_raw_features.shape[1],
                                                                       src_padded_nodes_neighbor_node_raw_features.shape[2] + 2)).to(self.device)

        new_dst_padded_nodes_neighbor_node_raw_features = torch.empty((dst_padded_nodes_neighbor_node_raw_features.shape[0],
                                                                       dst_padded_nodes_neighbor_node_raw_features.shape[1],
                                                                       dst_padded_nodes_neighbor_node_raw_features.shape[2] + 2)).to(self.device)

        for i in range(batch_size):
            if (src_firt_hop_lengths[i] != 0):
                new_src_padded_nodes_neighbor_node_raw_features[i, -src_firt_hop_lengths[i]:, :] = \
                    torch.cat([src_padded_nodes_neighbor_node_raw_features[i, -src_firt_hop_lengths[i]:, :],
                            torch.ones((src_firt_hop_lengths[i], 1)).float().to(self.device),
                            torch.zeros((src_firt_hop_lengths[i], 1)).float().to(self.device),], dim=1)
                new_src_padded_nodes_neighbor_node_raw_features[i, :-src_firt_hop_lengths[i], :] = \
                    torch.cat([src_padded_nodes_neighbor_node_raw_features[i, :-src_firt_hop_lengths[i], :],
                            torch.zeros((src_padded_nodes_neighbor_node_raw_features.shape[1] - src_firt_hop_lengths[i], 1)).float().to(self.device),
                            torch.ones((src_padded_nodes_neighbor_node_raw_features.shape[1] - src_firt_hop_lengths[i], 1)).float().to(self.device)], dim=1)

            else:
                new_src_padded_nodes_neighbor_node_raw_features[i, :, :] = \
                    torch.cat([src_padded_nodes_neighbor_node_raw_features[i, :, :],
                            torch.zeros((src_padded_nodes_neighbor_node_raw_features.shape[1], 1)).float().to(self.device),
                            torch.ones((src_padded_nodes_neighbor_node_raw_features.shape[1], 1)).float().to(self.device)], dim=1)

            if (dst_firt_hop_lengths[i] != 0):
                new_dst_padded_nodes_neighbor_node_raw_features[i, -dst_firt_hop_lengths[i]:, :] = \
                    torch.cat([dst_padded_nodes_neighbor_node_raw_features[i, -dst_firt_hop_lengths[i]:, :],
                            torch.ones((dst_firt_hop_lengths[i], 1)).float().to(self.device),
                            torch.zeros((dst_firt_hop_lengths[i], 1)).float().to(self.device)], dim=1)
                new_dst_padded_nodes_neighbor_node_raw_features[i, :-dst_firt_hop_lengths[i], :] = \
                    torch.cat([dst_padded_nodes_neighbor_node_raw_features[i, :-dst_firt_hop_lengths[i], :],
                            torch.zeros((dst_padded_nodes_neighbor_node_raw_features.shape[1] - dst_firt_hop_lengths[i], 1)).float().to(self.device),
                            torch.ones((dst_padded_nodes_neighbor_node_raw_features.shape[1] - dst_firt_hop_lengths[i], 1)).float().to(self.device)], dim=1)
            else:
                new_dst_padded_nodes_neighbor_node_raw_features[i, :, :] = \
                    torch.cat([dst_padded_nodes_neighbor_node_raw_features[i, :, :],
                            torch.zeros((dst_padded_nodes_neighbor_node_raw_features.shape[1], 1)).float().to(self.device),
                            torch.ones((dst_padded_nodes_neighbor_node_raw_features.shape[1], 1)).float().to(self.device)], dim=1)


        # Patching
        src_patches_nodes_neighbor_node_raw_features, src_patches_nodes_edge_raw_features, \
        src_patches_nodes_neighbor_time_features, src_patches_nodes_neighbor_co_occurrence_features = \
            self.get_patches(padded_nodes_neighbor_node_raw_features=new_src_padded_nodes_neighbor_node_raw_features,
                             padded_nodes_edge_raw_features=src_padded_nodes_edge_raw_features,
                             padded_nodes_neighbor_time_features=src_padded_nodes_neighbor_time_features,
                             padded_nodes_neighbor_co_occurrence_features=src_padded_nodes_neighbor_co_occurrence_features,
                             patch_size=self.patch_size)

        dst_patches_nodes_neighbor_node_raw_features, dst_patches_nodes_edge_raw_features, \
        dst_patches_nodes_neighbor_time_features, dst_patches_nodes_neighbor_co_occurrence_features = \
            self.get_patches(padded_nodes_neighbor_node_raw_features=new_dst_padded_nodes_neighbor_node_raw_features,
                             padded_nodes_edge_raw_features=dst_padded_nodes_edge_raw_features,
                             padded_nodes_neighbor_time_features=dst_padded_nodes_neighbor_time_features,
                             padded_nodes_neighbor_co_occurrence_features=dst_padded_nodes_neighbor_co_occurrence_features,
                             patch_size=self.patch_size)

        # Alignment
        src_patches_nodes_neighbor_node_raw_features = self.projection_layer['node'](src_patches_nodes_neighbor_node_raw_features)
        src_patches_nodes_edge_raw_features = self.projection_layer['edge'](src_patches_nodes_edge_raw_features)
        src_patches_nodes_neighbor_time_features = self.projection_layer['time'](src_patches_nodes_neighbor_time_features)
        src_patches_nodes_neighbor_co_occurrence_features = self.projection_layer['neighbor_co_occurrence'](src_patches_nodes_neighbor_co_occurrence_features)

        dst_patches_nodes_neighbor_node_raw_features = self.projection_layer['node'](dst_patches_nodes_neighbor_node_raw_features)
        dst_patches_nodes_edge_raw_features = self.projection_layer['edge'](dst_patches_nodes_edge_raw_features)
        dst_patches_nodes_neighbor_time_features = self.projection_layer['time'](dst_patches_nodes_neighbor_time_features)
        dst_patches_nodes_neighbor_co_occurrence_features = self.projection_layer['neighbor_co_occurrence'](dst_patches_nodes_neighbor_co_occurrence_features)


        # Concatenation
        src_num_patches = src_patches_nodes_neighbor_node_raw_features.shape[1]
        dst_num_patches = dst_patches_nodes_neighbor_node_raw_features.shape[1]

        big_patch_size = src_num_patches

        if src_num_patches > dst_num_patches:
            dst_patches_nodes_neighbor_node_raw_features = F.pad(dst_patches_nodes_neighbor_node_raw_features, (0, 0, src_num_patches - dst_num_patches, 0), value=0)
            dst_patches_nodes_edge_raw_features = F.pad(dst_patches_nodes_edge_raw_features, (0, 0, src_num_patches - dst_num_patches, 0), value=0)
            dst_patches_nodes_neighbor_time_features = F.pad(dst_patches_nodes_neighbor_time_features, (0, 0, src_num_patches - dst_num_patches, 0), value=0)
            dst_patches_nodes_neighbor_co_occurrence_features = F.pad(dst_patches_nodes_neighbor_co_occurrence_features, (0, 0, src_num_patches - dst_num_patches, 0), value=0)
        elif src_num_patches < dst_num_patches:
            big_patch_size = dst_num_patches
            src_patches_nodes_neighbor_node_raw_features = F.pad(src_patches_nodes_neighbor_node_raw_features, (0, 0, dst_num_patches - src_num_patches, 0), value=0)
            src_patches_nodes_edge_raw_features = F.pad(src_patches_nodes_edge_raw_features, (0, 0, dst_num_patches - src_num_patches, 0), value=0)
            src_patches_nodes_neighbor_time_features = F.pad(src_patches_nodes_neighbor_time_features, (0, 0, dst_num_patches - src_num_patches, 0), value=0)
            src_patches_nodes_neighbor_co_occurrence_features = F.pad(src_patches_nodes_neighbor_co_occurrence_features, (0, 0, dst_num_patches - src_num_patches, 0), value=0)

        n_segments = big_patch_size // self.segment_size \
            + (1 if big_patch_size % self.segment_size != 0 else 0)
        if (big_patch_size % self.segment_size == 0):
            first_segment_size = self.segment_size
        else:
            first_segment_size = big_patch_size % self.segment_size

        if (n_segments == 1):
            last_segment_size = big_patch_size
        else:
            last_segment_size = self.segment_size

        src_out = torch.empty(batch_size, big_patch_size, self.num_channels * self.channel_embedding_dim,
                              dtype=src_patches_nodes_edge_raw_features.dtype, device=self.device)
        dst_out = torch.empty(batch_size, big_patch_size, self.num_channels * self.channel_embedding_dim,
                              dtype=src_patches_nodes_edge_raw_features.dtype, device=self.device)
        for i in range(0, n_segments):

            if (i == 0):
                start_segment = 0
                end_segment = start_segment + first_segment_size
            else:
                start_segment = (i - 1) * self.segment_size + first_segment_size
                end_segment = start_segment + self.segment_size

            patches_data = [src_patches_nodes_neighbor_node_raw_features[:, start_segment:end_segment, :],
                            src_patches_nodes_edge_raw_features[:, start_segment:end_segment, :],
                            src_patches_nodes_neighbor_time_features[:, start_segment:end_segment, :],
                            src_patches_nodes_neighbor_co_occurrence_features[:, start_segment:end_segment, :],
                            dst_patches_nodes_neighbor_node_raw_features[:, start_segment:end_segment, :],
                            dst_patches_nodes_edge_raw_features[:, start_segment:end_segment, :],
                            dst_patches_nodes_neighbor_time_features[:, start_segment:end_segment, :],
                            dst_patches_nodes_neighbor_co_occurrence_features[:, start_segment:end_segment, :]]

            # Tensor, shape (batch_size, src_num_patches + dst_num_patches, num_channels, channel_embedding_dim)
            patches_data = torch.stack(patches_data, dim=2)
            # Tensor, shape (batch_size, src_num_patches + dst_num_patches, num_channels * channel_embedding_dim)
            patches_data = patches_data.reshape(batch_size, (end_segment - start_segment), 2 * self.num_channels * self.channel_embedding_dim)

            # Tensor, shape (batch_size, src_num_patches + dst_num_patches, num_channels * channel_embedding_dim)
            if (i > 0 and mem is not None):
                patches_data, mem, stat = self.brt(patches_data, xl_memories=mem, states=stat)
            else:
                patches_data, mem, stat = self.brt(patches_data)


            # src_patches_data, Tensor, shape (batch_size, src_num_patches, num_channels * channel_embedding_dim)
            src_out[:, start_segment:end_segment, :] = patches_data[:, :, :self.num_channels * self.channel_embedding_dim]
            # dst_patches_data, Tensor, shape (batch_size, dst_num_patches, num_channels * channel_embedding_dim)
            dst_out[:, start_segment:end_segment, :] = patches_data[:, :, self.num_channels * self.channel_embedding_dim:]


        # src_patches_data, Tensor, shape (batch_size, num_channels * channel_embedding_dim)
        src_out = torch.mean(src_out, dim=1)
        # dst_patches_data, Tensor, shape (batch_size, num_channels * channel_embedding_dim)
        dst_out = torch.mean(dst_out, dim=1)

        # Tensor, shape (batch_size, node_feat_dim)
        src_node_embeddings = self.output_layer(src_out)
        # Tensor, shape (batch_size, node_feat_dim)
        dst_node_embeddings = self.output_layer(dst_out)

        return src_node_embeddings, dst_node_embeddings

    def pad_sequences(self, node_ids: np.ndarray, node_interact_times: np.ndarray, nodes_neighbor_ids_list: list, nodes_edge_ids_list: list,
                      nodes_neighbor_times_list: list, patch_size: int = 1, max_input_sequence_length: int = 256):
        """
        pad the sequences for nodes in node_ids
        :param node_ids: ndarray, shape (batch_size, )
        :param node_interact_times: ndarray, shape (batch_size, )
        :param nodes_neighbor_ids_list: list of ndarrays, each ndarray contains neighbor ids for nodes in node_ids
        :param nodes_edge_ids_list: list of ndarrays, each ndarray contains edge ids for nodes in node_ids
        :param nodes_neighbor_times_list: list of ndarrays, each ndarray contains neighbor interaction timestamp for nodes in node_ids
        :param patch_size: int, patch size
        :param max_input_sequence_length: int, maximal number of neighbors for each node
        :return:
        """
        assert max_input_sequence_length - 1 > 0, 'Maximal number of neighbors for each node should be greater than 1!'
        max_seq_length = 0
        # first cut the sequence of nodes whose number of neighbors is more than max_input_sequence_length - 1 (we need to include the target node in the sequence)
        for idx in range(len(nodes_neighbor_ids_list)):
            assert len(nodes_neighbor_ids_list[idx]) == len(nodes_edge_ids_list[idx]) == len(nodes_neighbor_times_list[idx])
            if len(nodes_neighbor_ids_list[idx]) > max_input_sequence_length - 1:
                # cut the sequence by taking the most recent max_input_sequence_length interactions
                nodes_neighbor_ids_list[idx] = nodes_neighbor_ids_list[idx][-(max_input_sequence_length - 1):]
                nodes_edge_ids_list[idx] = nodes_edge_ids_list[idx][-(max_input_sequence_length - 1):]
                nodes_neighbor_times_list[idx] = nodes_neighbor_times_list[idx][-(max_input_sequence_length - 1):]
            if len(nodes_neighbor_ids_list[idx]) > max_seq_length:
                max_seq_length = len(nodes_neighbor_ids_list[idx])

        # include the target node itself
        max_seq_length += 1
        if max_seq_length % patch_size != 0:
            max_seq_length += (patch_size - max_seq_length % patch_size)
        assert max_seq_length % patch_size  == 0

        # pad the sequences
        # three ndarrays with shape (batch_size, max_seq_length)
        padded_nodes_neighbor_ids = np.zeros((len(node_ids), max_seq_length)).astype(np.long)
        padded_nodes_edge_ids = np.zeros((len(node_ids), max_seq_length)).astype(np.long)
        padded_nodes_neighbor_times = np.zeros((len(node_ids), max_seq_length)).astype(np.float32)

        for idx in range(len(node_ids)):
            padded_nodes_neighbor_ids[idx, 0] = node_ids[idx]
            padded_nodes_edge_ids[idx, 0] = 0
            padded_nodes_neighbor_times[idx, 0] = node_interact_times[idx]

            if len(nodes_neighbor_ids_list[idx]) > 0:
                padded_nodes_neighbor_ids[idx, 1: len(nodes_neighbor_ids_list[idx]) + 1] = nodes_neighbor_ids_list[idx]
                padded_nodes_edge_ids[idx, 1: len(nodes_edge_ids_list[idx]) + 1] = nodes_edge_ids_list[idx]
                padded_nodes_neighbor_times[idx, 1: len(nodes_neighbor_times_list[idx]) + 1] = nodes_neighbor_times_list[idx]

        # three ndarrays with shape (batch_size, max_seq_length)
        return padded_nodes_neighbor_ids, padded_nodes_edge_ids, padded_nodes_neighbor_times

    def get_features(self, node_interact_times: np.ndarray, padded_nodes_neighbor_ids: np.ndarray, padded_nodes_edge_ids: np.ndarray,
                     padded_nodes_neighbor_times: np.ndarray, time_encoder: TimeEncoder):
        """
        get node, edge and time features
        :param node_interact_times: ndarray, shape (batch_size, )
        :param padded_nodes_neighbor_ids: ndarray, shape (batch_size, max_seq_length)
        :param padded_nodes_edge_ids: ndarray, shape (batch_size, max_seq_length)
        :param padded_nodes_neighbor_times: ndarray, shape (batch_size, max_seq_length)
        :param time_encoder: TimeEncoder, time encoder
        :return:
        """
        # Tensor, shape (batch_size, max_seq_length, node_feat_dim)
        padded_nodes_neighbor_node_raw_features = self.node_raw_features[torch.from_numpy(padded_nodes_neighbor_ids)]
        # Tensor, shape (batch_size, max_seq_length, edge_feat_dim)
        padded_nodes_edge_raw_features = self.edge_raw_features[torch.from_numpy(padded_nodes_edge_ids)]
        # Tensor, shape (batch_size, max_seq_length, time_feat_dim)
        padded_nodes_neighbor_time_features = time_encoder(timestamps=torch.from_numpy(node_interact_times[:, np.newaxis] - padded_nodes_neighbor_times).float().to(self.device))

        # ndarray, set the time features to all zeros for the padded timestamp
        padded_nodes_neighbor_time_features[torch.from_numpy(padded_nodes_neighbor_ids == 0)] = 0.0

        return padded_nodes_neighbor_node_raw_features, padded_nodes_edge_raw_features, padded_nodes_neighbor_time_features

    def get_patches(self, padded_nodes_neighbor_node_raw_features: torch.Tensor, padded_nodes_edge_raw_features: torch.Tensor,
                    padded_nodes_neighbor_time_features: torch.Tensor, padded_nodes_neighbor_co_occurrence_features: torch.Tensor = None, patch_size: int = 1):
        """
        get the sequence of patches for nodes
        :param padded_nodes_neighbor_node_raw_features: Tensor, shape (batch_size, max_seq_length, node_feat_dim)
        :param padded_nodes_edge_raw_features: Tensor, shape (batch_size, max_seq_length, edge_feat_dim)
        :param padded_nodes_neighbor_time_features: Tensor, shape (batch_size, max_seq_length, time_feat_dim)
        :param padded_nodes_neighbor_co_occurrence_features: Tensor, shape (batch_size, max_seq_length, neighbor_co_occurrence_feat_dim)
        :param patch_size: int, patch size
        :return:
        """
        assert padded_nodes_neighbor_node_raw_features.shape[1] % patch_size == 0
        num_patches = padded_nodes_neighbor_node_raw_features.shape[1] // patch_size

        # list of Tensors with shape (num_patches, ), each Tensor with shape (batch_size, patch_size, node_feat_dim)
        patches_nodes_neighbor_node_raw_features, patches_nodes_edge_raw_features, \
        patches_nodes_neighbor_time_features, patches_nodes_neighbor_co_occurrence_features = [], [], [], []

        for patch_id in range(num_patches):
            start_idx = patch_id * patch_size
            end_idx = patch_id * patch_size + patch_size
            patches_nodes_neighbor_node_raw_features.append(padded_nodes_neighbor_node_raw_features[:, start_idx: end_idx, :])
            patches_nodes_edge_raw_features.append(padded_nodes_edge_raw_features[:, start_idx: end_idx, :])
            patches_nodes_neighbor_time_features.append(padded_nodes_neighbor_time_features[:, start_idx: end_idx, :])
            patches_nodes_neighbor_co_occurrence_features.append(padded_nodes_neighbor_co_occurrence_features[:, start_idx: end_idx, :])

        batch_size = len(padded_nodes_neighbor_node_raw_features)
        # Tensor, shape (batch_size, num_patches, patch_size * node_feat_dim)
        patches_nodes_neighbor_node_raw_features = torch.stack(patches_nodes_neighbor_node_raw_features, dim=1).reshape(batch_size, num_patches, patch_size * self.node_feat_dim)
        # Tensor, shape (batch_size, num_patches, patch_size * edge_feat_dim)
        patches_nodes_edge_raw_features = torch.stack(patches_nodes_edge_raw_features, dim=1).reshape(batch_size, num_patches, patch_size * self.edge_feat_dim)
        # Tensor, shape (batch_size, num_patches, patch_size * time_feat_dim)
        patches_nodes_neighbor_time_features = torch.stack(patches_nodes_neighbor_time_features, dim=1).reshape(batch_size, num_patches, patch_size * self.time_feat_dim)

        patches_nodes_neighbor_co_occurrence_features = torch.stack(patches_nodes_neighbor_co_occurrence_features, dim=1).reshape(batch_size, num_patches, patch_size * self.neighbor_co_occurrence_feat_dim)

        return patches_nodes_neighbor_node_raw_features, patches_nodes_edge_raw_features, patches_nodes_neighbor_time_features, patches_nodes_neighbor_co_occurrence_features

    def set_neighbor_sampler(self, neighbor_sampler: NeighborSampler):
        """
        set neighbor sampler to neighbor_sampler and reset the random state (for reproducing the results for uniform and time_interval_aware sampling)
        :param neighbor_sampler: NeighborSampler, neighbor sampler
        :return:
        """
        self.neighbor_sampler = neighbor_sampler
        if self.neighbor_sampler.sample_neighbor_strategy in ['uniform', 'time_interval_aware']:
            assert self.neighbor_sampler.seed is not None
            self.neighbor_sampler.reset_random_state()


class NeighborCooccurrenceEncoder(nn.Module):

    def __init__(self, neighbor_co_occurrence_feat_dim: int, device: str = 'cpu'):
        """
        Neighbor co-occurrence encoder.
        :param neighbor_co_occurrence_feat_dim: int, dimension of neighbor co-occurrence features (encodings)
        :param device: str, device
        """
        super(NeighborCooccurrenceEncoder, self).__init__()
        self.neighbor_co_occurrence_feat_dim = neighbor_co_occurrence_feat_dim
        self.device = device

        self.neighbor_co_occurrence_encode_layer = nn.Sequential(
            nn.Linear(in_features=1, out_features=self.neighbor_co_occurrence_feat_dim),
            nn.ReLU(),
            nn.Linear(in_features=self.neighbor_co_occurrence_feat_dim, out_features=self.neighbor_co_occurrence_feat_dim)).to(self.device)

    def count_nodes_appearances(self, src_padded_nodes_neighbor_ids: np.ndarray, dst_padded_nodes_neighbor_ids: np.ndarray):
        """
        count the appearances of nodes in the sequences of source and destination nodes
        :param src_padded_nodes_neighbor_ids: ndarray, shape (batch_size, src_max_seq_length)
        :param dst_padded_nodes_neighbor_ids:: ndarray, shape (batch_size, dst_max_seq_length)
        :return:
        """
        # two lists to store the appearances of source and destination nodes
        src_padded_nodes_appearances, dst_padded_nodes_appearances = [], []
        # src_padded_node_neighbor_ids, ndarray, shape (src_max_seq_length, )
        # dst_padded_node_neighbor_ids, ndarray, shape (dst_max_seq_length, )
        for src_padded_node_neighbor_ids, dst_padded_node_neighbor_ids in zip(src_padded_nodes_neighbor_ids, dst_padded_nodes_neighbor_ids):

            # src_unique_keys, ndarray, shape (num_src_unique_keys, )
            # src_inverse_indices, ndarray, shape (src_max_seq_length, )
            # src_counts, ndarray, shape (num_src_unique_keys, )
            # we can use src_unique_keys[src_inverse_indices] to reconstruct the original input, and use src_counts[src_inverse_indices] to get counts of the original input
            src_unique_keys, src_inverse_indices, src_counts = np.unique(src_padded_node_neighbor_ids, return_inverse=True, return_counts=True)
            # Tensor, shape (src_max_seq_length, )
            src_padded_node_neighbor_counts_in_src = torch.from_numpy(src_counts[src_inverse_indices]).float().to(self.device)
            # dictionary, store the mapping relation from unique neighbor id to its appearances for the source node
            src_mapping_dict = dict(zip(src_unique_keys, src_counts))

            # dst_unique_keys, ndarray, shape (num_dst_unique_keys, )
            # dst_inverse_indices, ndarray, shape (dst_max_seq_length, )
            # dst_counts, ndarray, shape (num_dst_unique_keys, )
            # we can use dst_unique_keys[dst_inverse_indices] to reconstruct the original input, and use dst_counts[dst_inverse_indices] to get counts of the original input
            dst_unique_keys, dst_inverse_indices, dst_counts = np.unique(dst_padded_node_neighbor_ids, return_inverse=True, return_counts=True)
            # Tensor, shape (dst_max_seq_length, )
            dst_padded_node_neighbor_counts_in_dst = torch.from_numpy(dst_counts[dst_inverse_indices]).float().to(self.device)
            # dictionary, store the mapping relation from unique neighbor id to its appearances for the destination node
            dst_mapping_dict = dict(zip(dst_unique_keys, dst_counts))

            # we need to use copy() to avoid the modification of src_padded_node_neighbor_ids
            # Tensor, shape (src_max_seq_length, )
            src_padded_node_neighbor_counts_in_dst = torch.from_numpy(src_padded_node_neighbor_ids.copy()).apply_(lambda neighbor_id: dst_mapping_dict.get(neighbor_id, 0.0)).float().to(self.device)
            # Tensor, shape (src_max_seq_length, 2)
            src_padded_nodes_appearances.append(torch.stack([src_padded_node_neighbor_counts_in_src, src_padded_node_neighbor_counts_in_dst], dim=1))

            # we need to use copy() to avoid the modification of dst_padded_node_neighbor_ids
            # Tensor, shape (dst_max_seq_length, )
            dst_padded_node_neighbor_counts_in_src = torch.from_numpy(dst_padded_node_neighbor_ids.copy()).apply_(lambda neighbor_id: src_mapping_dict.get(neighbor_id, 0.0)).float().to(self.device)
            # Tensor, shape (dst_max_seq_length, 2)
            dst_padded_nodes_appearances.append(torch.stack([dst_padded_node_neighbor_counts_in_src, dst_padded_node_neighbor_counts_in_dst], dim=1))

        # Tensor, shape (batch_size, src_max_seq_length, 2)
        src_padded_nodes_appearances = torch.stack(src_padded_nodes_appearances, dim=0)
        # Tensor, shape (batch_size, dst_max_seq_length, 2)
        dst_padded_nodes_appearances = torch.stack(dst_padded_nodes_appearances, dim=0)

        # set the appearances of the padded node (with zero index) to zeros
        # Tensor, shape (batch_size, src_max_seq_length, 2)
        src_padded_nodes_appearances[torch.from_numpy(src_padded_nodes_neighbor_ids == 0)] = 0.0
        # Tensor, shape (batch_size, dst_max_seq_length, 2)
        dst_padded_nodes_appearances[torch.from_numpy(dst_padded_nodes_neighbor_ids == 0)] = 0.0

        return src_padded_nodes_appearances, dst_padded_nodes_appearances

    def forward(self, src_padded_nodes_neighbor_ids: np.ndarray, dst_padded_nodes_neighbor_ids: np.ndarray):
        """
        compute the neighbor co-occurrence features of nodes in src_padded_nodes_neighbor_ids and dst_padded_nodes_neighbor_ids
        :param src_padded_nodes_neighbor_ids: ndarray, shape (batch_size, src_max_seq_length)
        :param dst_padded_nodes_neighbor_ids:: ndarray, shape (batch_size, dst_max_seq_length)
        :return:
        """
        # src_padded_nodes_appearances, Tensor, shape (batch_size, src_max_seq_length, 2)
        # dst_padded_nodes_appearances, Tensor, shape (batch_size, dst_max_seq_length, 2)
        src_padded_nodes_appearances, dst_padded_nodes_appearances = self.count_nodes_appearances(src_padded_nodes_neighbor_ids=src_padded_nodes_neighbor_ids,
                                                                                                  dst_padded_nodes_neighbor_ids=dst_padded_nodes_neighbor_ids)

        # sum the neighbor co-occurrence features in the sequence of source and destination nodes
        # Tensor, shape (batch_size, src_max_seq_length, neighbor_co_occurrence_feat_dim)
        src_padded_nodes_neighbor_co_occurrence_features = self.neighbor_co_occurrence_encode_layer(src_padded_nodes_appearances.unsqueeze(dim=-1)).sum(dim=2)
        # Tensor, shape (batch_size, dst_max_seq_length, neighbor_co_occurrence_feat_dim)
        dst_padded_nodes_neighbor_co_occurrence_features = self.neighbor_co_occurrence_encode_layer(dst_padded_nodes_appearances.unsqueeze(dim=-1)).sum(dim=2)

        # src_padded_nodes_neighbor_co_occurrence_features, Tensor, shape (batch_size, src_max_seq_length, neighbor_co_occurrence_feat_dim)
        # dst_padded_nodes_neighbor_co_occurrence_features, Tensor, shape (batch_size, dst_max_seq_length, neighbor_co_occurrence_feat_dim)
        return src_padded_nodes_neighbor_co_occurrence_features, dst_padded_nodes_neighbor_co_occurrence_features

This cell holds the definition of get_neighbor_sampler method in HOT encoder.\
The original method utilizes pandas DataFrame type arguments, but I modified it to be fed with numpy ndarray type ones.

In [None]:
def get_neighbor_sampler(data, sample_neighbor_strategy: str, time_scaling_factor: float = 0.0, seed: int = None):
    """
    get neighbor sampler
    :param data: Data
    :param sample_neighbor_strategy: str, how to sample historical neighbors, 'uniform', 'recent', or 'time_interval_aware''
    :param time_scaling_factor: float, a hyper-parameter that controls the sampling preference with time interval,
    a large time_scaling_factor tends to sample more on recent links, this parameter works when sample_neighbor_strategy == 'time_interval_aware'
    :param seed: int, random seed
    :return:
    """
    max_node_id = max(data[:,0].max(), data[:,1].max())
    # the adjacency vector stores edges for each node (source or destination), undirected
    # adj_list, list of list, where each element is a list of triple tuple (node_id, edge_id, timestamp)
    # the list at the first position in adj_list is empty
    adj_list = [[] for _ in range(max_node_id + 1)]
    edge_ids=torch.as_tensor([0 for i in range(data.shape[0])])
    for src_node_id, dst_node_id, edge_id, node_interact_time in zip(data[:,0], data[:,1], edge_ids, data[2]):
        adj_list[src_node_id].append((dst_node_id, edge_id, node_interact_time))
        adj_list[dst_node_id].append((src_node_id, edge_id, node_interact_time))

    return NeighborSampler(adj_list=adj_list, sample_neighbor_strategy=sample_neighbor_strategy, time_scaling_factor=time_scaling_factor, seed=seed)


This cell holds my own code implementation of the Triadic Decoder, introduced in [Shi et al., 2020](https://arxiv.org/abs/1911.11322).

In [None]:
class TriadicDecoder(nn.Module):
  def __init__(self, latentDim, numFilter, device):
    super(TriadicDecoder,self).__init__()
    self.latentDim=latentDim
    self.numFilter=numFilter
    self.device=device
    self.conv1=nn.Conv1d(in_channels=1, out_channels=self.numFilter, kernel_size=1).to(device)
    self.conv2=nn.Conv1d(in_channels=1, out_channels=self.numFilter, kernel_size=1).to(device)
    self.conv3=nn.Conv1d(in_channels=1, out_channels=self.numFilter, kernel_size=1).to(device)
    self.relu=nn.ReLU().to(device)
    self.ss=nn.Softsign().to(device)
    self.BN=nn.BatchNorm1d(num_features=self.latentDim).to(device)
    self.LN=nn.LayerNorm(normalized_shape=self.latentDim).to(device)
    self.dropout=nn.Dropout(p=0.1).to(device)
    # FC layers - how they transformed (1, 4, 116) z_triplet into (1, 1, 3) is not specified
    # Therefore I applied 2 FC layers without placing nonlinearities in between
    self.fc1=nn.Linear(in_features=self.latentDim, out_features=3).to(device)
    self.convFinal=nn.Conv1d(in_channels=self.numFilter, out_channels=1, kernel_size=1).to(device)


  def forward(self, embed1, embed2, embed3):
    IP=torch.zeros((3,)).to(self.device)
    embed1=torch.squeeze(embed1, dim=0)
    embed2=torch.squeeze(embed2, dim=0)
    embed3=torch.squeeze(embed3, dim=0)
    norm1=torch.sqrt(torch.sum(embed1*embed1)+1e-10)
    norm2=torch.sqrt(torch.sum(embed2*embed2)+1e-10)
    norm3=torch.sqrt(torch.sum(embed3*embed3)+1e-10)
    embed1=embed1/norm1
    embed2=embed2/norm2
    embed3=embed3/norm3

    IP[0]=torch.dot(embed1, embed2)
    IP[1]=torch.dot(embed1, embed3)
    IP[2]=torch.dot(embed2, embed3)
    v1=self.conv1(embed1.unsqueeze(0))
    v2=self.conv2(embed2.unsqueeze(0))
    v3=self.conv3(embed3.unsqueeze(0))
    zTriplet=self.relu(v1+v2+v3) # (4, 116)
    w=self.fc1(zTriplet) # (4, 3)
    x=self.convFinal(w) # (1, 3)
    y=self.relu(x)
    z=torch.sigmoid(y+IP).squeeze(0)
    return z

'\nclass EmbeddingGenerator(nn.Module):\n  def __init__(self, inputFeat, hiddenFeat, embedFeat, device):\n    super(EmbeddingGenerator, self).__init__()\n    self.device=device\n    self.inputFeat=inputFeat\n    self.embedFeat=embedFeat\n    self.hiddenFeat=hiddenFeat\n    self.layer1=nn.Linear(in_features=self.inputFeat, out_features=self.hiddenFeat).to(device)\n    self.layer2=nn.Linear(in_features=self.hiddenFeat, out_features=self.embedFeat).to(device)\n    self.relu=nn.ReLU()\n    self.ss=nn.Softsign()\n    self.elu=nn.ELU()\n  def forward(self, embed):\n    return self.layer2(self.ss(self.layer1(embed)))\n'

# Training Helper functions

This cell contains the definitions of trianing helper functions.

In [None]:
def sampleFrom(from1, from2, p, device):
  sampleFrom1=torch.bernoulli(torch.tensor([p])).to(device)
  if sampleFrom1:
    idx = torch.randint(len(from1), (1,)).item()
    return from1[idx]
  else:
    idx = torch.randint(len(from2), (1,)).item()
    return from2[idx]

# Parameter B is equivalent to the number of triads being sampled
def triadSample(triplets, B, p, device):
  triadList=[]
  i=0
  nodes=torch.Tensor([i for i in range(116)]).to(device)
  while i<B:
    node1=torch.randperm(116)[0].to(device)
    idx1=((triplets[:,0]==node1).nonzero())[:,0].to(device)
    neighbors1=triplets[idx1,1]
    mask1 = torch.isin(nodes, neighbors1, invert=True).to(device)
    faraways = nodes[mask1].to(device)
    # Have to remove node1 from faraways1
    temp=(faraways==node1).nonzero()
    if len(temp)>0:
      continue
    node2=sampleFrom(neighbors1, faraways, p, device).to(device)

    idx2=((triplets[:,0]==node2).nonzero())[:,0].to(device)
    neighbors2=triplets[idx2,1]
    mask2=torch.isin(nodes, neighbors2, invert=True).to(device)
    faraways = nodes[mask2].to(device)
    # Have to remove node1 from faraways1
    temp=(faraways==node2).nonzero()
    if len(temp)>0:
      continue
    node3=sampleFrom(neighbors2, faraways, p, device).to(device)
    triadList.append(torch.Tensor([node1, node2, node3]).to(device))
    i+=1


  triads=torch.stack(triadList, dim=0).int().to(device)
  assert triads.shape==(B, 3)
  return triads

def calculateM(triads, pair):
  pair=pair.squeeze(0)
  assert pair.shape==(2,)
  assert triads.shape[1]==3
  occurrence=0
  if pair[0] in triads[:,0]:
    idx1=((triads[:,0]==pair[0]).nonzero())[:,0]
    if pair[1] in triads[idx1,1]:
      occurrence+=((pair[1]==triads[idx1,1]).nonzero())[:,0].shape[0]
    if pair[1] in triads[idx1,2]:
      occurrence+=((pair[1]==triads[idx1,2]).nonzero())[:,0].shape[0]
  if pair[0] in triads[:,1]:
    idx2=((triads[:,1]==pair[0]).nonzero())[:,0]
    if pair[1] in triads[idx2,2]:
      occurrence+=((pair[1]==triads[idx2,2]).nonzero())[:,0].shape[0]
  return occurrence

def check_submodules_have_attributes(model):
  dic={}
  for name, submodule in model.named_modules():
      # Skip the top-level model itself
      if name == "":
          continue

      has_parameters_attr = hasattr(submodule, 'parameters')
      has_weight_attr = hasattr(submodule, 'weight')
      if has_parameters_attr:
        has_gradient_attr = hasattr(submodule.parameters(), 'requires_grad')
      #print(f"Submodule '{name}' has 'parameters' attribute: {has_parameters_attr}")
      #print(f"Submodule '{name}' has 'weight' attribute: {has_weight_attr}")
      dic[name]=(has_parameters_attr, has_weight_attr, has_gradient_attr)
  return dic

def init_kaiming_normal(model):
  check=check_submodules_have_attributes(model)
  for name, submodule in model.named_modules():
    if name == "":
      continue
    if check[name]==(True, True, True):
      torch.nn.init.kaiming_normal_(submodule.weight)
      torch.nn.init.zeros_(submodule.bias)


def pairIn(pair, duets):
  if pair[0] in duets[:,0]:
    idx1=((duets[:,0]==pair[0]).nonzero())[:,0]
    if pair[1] in duets[idx1,1]:
      return True

# HOT Encoder + Triadic Decoder - Experiment



The cell below is the training and validation loop for 40th Quantile Dataset experiment. \
The model will be saved in "checkpoint8.pth" file.

In [None]:
# 40th Quantile Dataset Experiment

torch.autograd.set_detect_anomaly(False)
if os.path.isfile('checkpoint8.pth'):
        print('resuming checkpoint experiment')
        checkpoint = torch.load('checkpoint8.pth', map_location='cuda')
else:
  checkpoint = {
    'epoch': 0,
    'subject': 0,
    'timeIdx':0,
    'encoder': None,
    'decoder': None,
    'optimizerE': None,
    'lossSubject': None,
    'prevLoss' : np.Inf,
    'patience': 0,
    'lossSubjectV': np.Inf,
    'lr': 1e-4
    }

# Argv setting
epc=10
sampleStrategy='uniform'
negSampleStrategy='inductive'
N=116 # Number of ROIs
numEdges=math.ceil(116*116*0.4)
nodeDim=N # One-Hot encoding as a node embedding
edgeDim=4 # Arbitrarily set
timeDim=100 # Half of the value mentioned in the HOT paper
channelDim=50 # The value mentioned in the HOT paper
latentDim=N # Positive node embedding and negative node embedding concatenated across dim=1
numFilter=4 # Following Triadic Decoder paper
B=10000
prob=0.7398
nodeFeat=torch.cat((torch.eye(N),torch.zeros((1,N))),dim=0).numpy()
edgeFeat=torch.zeros((numEdges+1,edgeDim)).numpy()
learning_rateE=1e-4
pt=2
patience=0
prevLoss=np.inf
lossSubjectV=np.inf


encoder=HOT(nodeFeat, edgeFeat, None, timeDim, channelDim, patch_size=8, num_state_vectors=32, num_layers=2, num2hop=0, dropout=0.1, max_input_sequence_length=4096, device='cuda')
encoder.apply(init_kaiming_normal)
decoder=TriadicDecoder(latentDim, numFilter, device='cuda')
decoder.apply(init_kaiming_normal)
model=nn.Sequential(encoder, decoder).to('cuda')
lossFunc=nn.BCELoss().to('cuda')
optimizerE=torch.optim.Adam(model.parameters(), lr=checkpoint['lr'])


for epoch in range(checkpoint['epoch'], epc):
  print("Epoch : "+str(epoch))
  if os.path.isfile('checkpoint8.pth'):
    checkpoint=torch.load('checkpoint8.pth', map_location='cuda')
  for i in range(checkpoint['subject'], 16):
    if os.path.isfile('checkpoint8.pth'):
      checkpoint=torch.load('checkpoint8.pth', map_location='cuda')
    model.train()
    if checkpoint['lr']!=0: learning_rateE=checkpoint['lr']
    if checkpoint['optimizerE'] is not None: optimizerE.load_state_dict(checkpoint['optimizerE'])
    if checkpoint['prevLoss'] is not None: prevLoss=checkpoint['prevLoss']
    if checkpoint['patience'] != 0: patience=checkpoint['patience']
    if checkpoint['lossSubjectV'] is not None: lossSubjectV=checkpoint['lossSubjectV']
    if checkpoint['encoder'] is not None: model[0].load_state_dict(checkpoint['encoder'])
    if checkpoint['decoder'] is not None: model[1].load_state_dict(checkpoint['decoder'])
    print("\nSubject index : "+str(i)+" out of 16")

    train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))
    train3D=train3D.reshape(-1,3).detach().to('cpu').numpy()
    train3D=train3D.reshape(15, -1, 3)

    print("Sampling probability : "+str(prob))

    lossSubject=torch.zeros(1).to('cuda')
    lossSubject.requires_grad_()

    for t in range(checkpoint['timeIdx'], 15):
      print("Time index : "+str(t)+" out of 15")

      timeBatch=train3D[t]
      srcNodes=timeBatch[:,0]
      dstNodes=timeBatch[:,1]
      timepoints=timeBatch[:,2]

      adjMat=torch.zeros((N,N)).to('cuda')
      for row in timeBatch:
        adjMat[int(row[0]), int(row[1])]=1

      nodePosEmbedding=torch.zeros((N,N)).to('cuda')
      nodeNegEmbedding=torch.zeros((N,N)).to('cuda')
      nodePosTimes=torch.zeros(N).to('cuda')
      nodeNegTimes=torch.zeros(N).to('cuda')
      nodeEmbedding=torch.zeros((N,N)).to('cuda')

      # Sample neighbor edges among the dynamic graphs of a single subject
      train_neighbor_sampler = get_neighbor_sampler(data=timeBatch, sample_neighbor_strategy=sampleStrategy, seed=0)
      # Sample negative edges among the dynamic graphs of a single subject
      train_neg_edge_sampler = NegativeEdgeSampler(src_node_ids=srcNodes, dst_node_ids=dstNodes, interact_times=timepoints, last_observed_time=np.min(timepoints) , negative_sample_strategy=negSampleStrategy, seed=1)
      model[0].set_neighbor_sampler(train_neighbor_sampler)

      subject_src_node_embeddings, subject_dst_node_embeddings = \
        torch.utils.checkpoint.checkpoint(model[0].compute_src_dst_node_temporal_embeddings, srcNodes, dstNodes, timepoints, use_reentrant=False)
      negSrcNodes, negDstNodes = train_neg_edge_sampler.sample(len(srcNodes), srcNodes, dstNodes, np.min(timepoints), np.max(timepoints))
      subject_neg_src_node_embeddings, subject_neg_dst_node_embeddings = \
        torch.utils.checkpoint.checkpoint(model[0].compute_src_dst_node_temporal_embeddings, negSrcNodes, negDstNodes, timepoints, use_reentrant=False)

      for tb, row in enumerate(zip(srcNodes, dstNodes)):
        nodePosEmbedding[row[0]]=nodePosEmbedding[row[0]]+subject_src_node_embeddings[tb]
        nodePosTimes[row[0]]+=1
        nodePosEmbedding[row[1]]=nodePosEmbedding[row[1]]+subject_dst_node_embeddings[tb]
        nodePosTimes[row[1]]+=1
      for tb, row in enumerate(zip(negSrcNodes, negDstNodes)):
        nodeNegEmbedding[row[0]]=nodeNegEmbedding[row[0]]+subject_neg_src_node_embeddings[tb]
        nodeNegTimes[row[0]]+=1
        nodeNegEmbedding[row[1]]=nodeNegEmbedding[row[1]]+subject_neg_dst_node_embeddings[tb]
        nodeNegTimes[row[1]]+=1

      for node in range(N):
        if nodePosTimes[node]>0:
          nodePosEmbedding[node]=nodePosEmbedding[node]/nodePosTimes[node]
        if nodeNegTimes[node]>0:
          nodeNegEmbedding[node]=nodeNegEmbedding[node]/nodeNegTimes[node]


      nodePosEmbedding=(nodePosEmbedding-torch.mean(nodePosEmbedding, dim=0, keepdims=True))/(torch.std(nodePosEmbedding, dim=0, keepdims=True)+1e-10)
      nodeNegEmbedding=(nodeNegEmbedding-torch.mean(nodeNegEmbedding, dim=0, keepdims=True))/(torch.std(nodeNegEmbedding, dim=0, keepdims=True)+1e-10)

      graph2D=np.concatenate((np.expand_dims(srcNodes, axis=1), np.expand_dims(dstNodes, axis=1)), axis=1)
      graph2D=torch.tensor(graph2D).to('cuda')
      graph2DN=np.concatenate((np.expand_dims(negSrcNodes, axis=1), np.expand_dims(negDstNodes, axis=1)), axis=1)
      graph2DN=torch.tensor(graph2DN).to('cuda')
      triads=triadSample(graph2D, B, prob, device='cuda')

      mSquare=torch.zeros((N,N)).to('cuda')
      eSquare=torch.zeros((N,N)).to('cuda')

      for j, triad in enumerate(triads):
        triad0=triad[0].unsqueeze(0).unsqueeze(0)
        triad1=triad[1].unsqueeze(0).unsqueeze(0)
        triad2=triad[2].unsqueeze(0).unsqueeze(0)

        if mSquare[triad[0], triad[1]]==0:
          pair1=torch.cat((triad0, triad1), dim=1).to('cuda')
          M1=calculateM(triads, pair1)
          mSquare[triad[0], triad[1]]=M1
        if mSquare[triad[0], triad[2]]==0:
          pair2=torch.cat((triad0, triad2), dim=1).to('cuda')
          M2=calculateM(triads, pair2)
          mSquare[triad[0], triad[2]]=M2
        if mSquare[triad[1], triad[2]]==0:
          pair3=torch.cat((triad1, triad2), dim=1).to('cuda')
          M3=calculateM(triads, pair3)
          mSquare[triad[1], triad[2]]=M3

      for j, triad in enumerate(triads):
        triad0=triad[0].unsqueeze(0).unsqueeze(0)
        triad1=triad[1].unsqueeze(0).unsqueeze(0)
        triad2=triad[2].unsqueeze(0).unsqueeze(0)
        pair1=torch.cat((triad0, triad1), dim=1).squeeze(0).to('cuda')
        pair2=torch.cat((triad0, triad2), dim=1).squeeze(0).to('cuda')
        pair3=torch.cat((triad1, triad2), dim=1).squeeze(0).to('cuda')

        emb1=torch.zeros((1,N)).to('cuda')
        emb2=torch.zeros((1,N)).to('cuda')
        emb3=torch.zeros((1,N)).to('cuda')

        if pairIn(pair1,graph2D[:,0:2]):
          emb1=nodePosEmbedding[triad[0]]
        if pairIn(pair1,graph2DN):
          emb1=nodeNegEmbedding[triad[0]]
        if pairIn(pair2,graph2D[:,0:2]):
          emb2=nodePosEmbedding[triad[1]]
        if pairIn(pair2,graph2DN):
          emb2=nodeNegEmbedding[triad[1]]
        if pairIn(pair3,graph2D[:,0:2]):
          emb3=nodePosEmbedding[triad[2]]
        if pairIn(pair3,graph2DN):
          emb3=nodeNegEmbedding[triad[2]]

        eTriplet=model[1].forward(emb1, emb2, emb3)
        if pairIn(pair1,graph2D[:,0:2]):
          eSquare[triad[0], triad[1]]+=eTriplet[0]
        if pairIn(pair1,graph2DN):
          eSquare[triad[0], triad[1]]-=eTriplet[0]
        if pairIn(pair2,graph2D[:,0:2]):
          eSquare[triad[0], triad[2]]+=eTriplet[1]
        if pairIn(pair2,graph2DN):
          eSquare[triad[0], triad[2]]-=eTriplet[1]
        if pairIn(pair3,graph2D[:,0:2]):
          eSquare[triad[1], triad[2]]+=eTriplet[2]
        if pairIn(pair3,graph2DN):
          eSquare[triad[1], triad[2]]-=eTriplet[2]
      eSquare.clamp_(min=0.0)

      nonzeros=((mSquare>0).nonzero())
      adjRecon=[]
      adjPart=[]
      adjFull=torch.zeros((N,N)).to('cuda')
      for nz in nonzeros:
        nzPair=torch.tensor([int(nz[0]), int(nz[1])]).to('cuda')
        em1=eSquare[int(nz[0]), int(nz[1])]/mSquare[int(nz[0]), int(nz[1])]
        if pairIn(nzPair,graph2D[:,0:2]):
          adjRecon.append(em1)
          adjPart.append(torch.tensor([1.0]))
        if pairIn(nzPair, graph2DN):
          adjRecon.append(em1)
          adjPart.append(torch.tensor([0.0]))
        if em1>0.5:
          adjFull[int(nz[0]), int(nz[1])]=1
        else:
          adjFull[int(nz[0]), int(nz[1])]=0
      adjRecon=torch.stack(adjRecon).to('cuda')

      adjPart=torch.stack(adjPart).squeeze(1).to('cuda')


      loss=lossFunc(input=adjRecon, target=adjPart)
      lossSubject=lossSubject+loss

      adjReconC=adjRecon.clone().detach().cpu().numpy()
      adjPartC=adjPart.clone().detach().cpu().numpy()

      adjReconC[adjReconC>0.5]=1
      adjReconC[adjReconC<=0.5]=0
      F1=f1_score(adjPartC, adjReconC)
      AUROC=roc_auc_score(adjPartC, adjReconC)
      AP=average_precision_score(adjPartC, adjReconC)


      print("Epoch : "+str(epoch)+" / Subject : "+str(i)+" / Timepoint : "+str(t) +" / Loss : "+str(loss.item())+" / AUROC : "+str(AUROC)+" / F1 : "+str(F1)+" / AP : "+str(AP)+"\n")

      torch.cuda.empty_cache()

    optimizerE.zero_grad()
    lossSubject.backward()
    optimizerE.step()




    # Validation Step
    if (i+1)%2==0:
      model.eval()
      with torch.no_grad():
        patience=checkpoint['patience']
        prevLoss=checkpoint['prevLoss']
        valid3D=torch.load(os.path.join('DatasetSplit','valid'+str((i+1)//2-1)+'.pth'))
        valid3D=valid3D.reshape(-1, 3).detach().to('cpu').numpy()
        valid3D=valid3D.reshape(15, -1, 3)
        lossSubjectV=0

        print("Validation Subject index : "+str((i+1)//2-1)+" out of 8")
        for t in range(15):
          adjRecon=torch.zeros((N,N)).to('cuda')

          timeBatch=valid3D[t]
          timeBatch[:,2]=timeBatch[:,2]/3-1
          srcNodes=timeBatch[:,0]
          dstNodes=timeBatch[:,1]
          timepoints=timeBatch[:,2]

          adjMat=torch.zeros((N,N)).to('cuda')
          for row in timeBatch:
            adjMat[int(row[0]), int(row[1])]=1

          nodePosEmbedding=torch.zeros((N,N)).to('cuda')
          nodeNegEmbedding=torch.zeros((N,N)).to('cuda')
          nodePosTimes=torch.zeros(N).to('cuda')
          nodeNegTimes=torch.zeros(N).to('cuda')
          nodeEmbedding=torch.zeros((N,N)).to('cuda')

          # Sample neighbor edges among the dynamic graphs of a single subject
          valid_neighbor_sampler = get_neighbor_sampler(data=timeBatch, sample_neighbor_strategy=sampleStrategy, seed=2)
          # Sample negative edges among the dynamic graphs of a single subject
          valid_neg_edge_sampler = NegativeEdgeSampler(src_node_ids=srcNodes, dst_node_ids=dstNodes, interact_times=timepoints, last_observed_time=np.min(timepoints), negative_sample_strategy=negSampleStrategy, seed=3)
          model[0].set_neighbor_sampler(valid_neighbor_sampler)

          subject_src_node_embeddings, subject_dst_node_embeddings = \
            model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=srcNodes, dst_node_ids=dstNodes, node_interact_times=timepoints)
          negSrcNodes, negDstNodes = valid_neg_edge_sampler.sample(len(srcNodes), srcNodes, dstNodes, np.min(timepoints), np.max(timepoints))
          subject_neg_src_node_embeddings, subject_neg_dst_node_embeddings = \
            model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=negSrcNodes,
                                                                          dst_node_ids=negDstNodes,
                                                                          node_interact_times=timepoints)

          for tb, row in enumerate(zip(srcNodes, dstNodes)):
            nodePosEmbedding[row[0]]=nodePosEmbedding[row[0]]+subject_src_node_embeddings[tb]
            nodePosTimes[row[0]]+=1
            nodePosEmbedding[row[1]]=nodePosEmbedding[row[1]]+subject_dst_node_embeddings[tb]
            nodePosTimes[row[1]]+=1
          for tb, row in enumerate(zip(negSrcNodes, negDstNodes)):
            nodeNegEmbedding[row[0]]=nodeNegEmbedding[row[0]]+subject_neg_src_node_embeddings[tb]
            nodeNegTimes[row[0]]+=1
            nodeNegEmbedding[row[1]]=nodeNegEmbedding[row[1]]+subject_neg_dst_node_embeddings[tb]
            nodeNegTimes[row[1]]+=1

          for node in range(N):
            if nodePosTimes[node]>0:
              nodePosEmbedding[node]=nodePosEmbedding[node]/nodePosTimes[node]
            if nodeNegTimes[node]>0:
              nodeNegEmbedding[node]=nodeNegEmbedding[node]/nodeNegTimes[node]

          nodePosEmbedding=(nodePosEmbedding-torch.mean(nodePosEmbedding, dim=0, keepdims=True))/(torch.std(nodePosEmbedding, dim=0, keepdims=True)+1e-10)
          nodeNegEmbedding=(nodeNegEmbedding-torch.mean(nodeNegEmbedding, dim=0, keepdims=True))/(torch.std(nodeNegEmbedding, dim=0, keepdims=True)+1e-10)

          graph2D=np.concatenate((np.expand_dims(srcNodes, axis=1), np.expand_dims(dstNodes, axis=1)), axis=1)
          graph2D=torch.tensor(graph2D).to('cuda')
          graph2DN=np.concatenate((np.expand_dims(negSrcNodes, axis=1), np.expand_dims(negDstNodes, axis=1)), axis=1)
          graph2DN=torch.tensor(graph2DN).to('cuda')

          triads=triadSample(graph2D, B, prob, device='cuda')
          mSquare=torch.zeros((N,N)).to('cuda')
          eSquare=torch.zeros((N,N)).to('cuda')

          for j, triad in enumerate(triads):
            triad0=triad[0].unsqueeze(0).unsqueeze(0)
            triad1=triad[1].unsqueeze(0).unsqueeze(0)
            triad2=triad[2].unsqueeze(0).unsqueeze(0)

            if mSquare[triad[0], triad[1]]==0:
              pair1=torch.cat((triad0, triad1), dim=1).to('cuda')
              M1=calculateM(triads, pair1)
              mSquare[triad[0], triad[1]]=M1

            if mSquare[triad[0], triad[2]]==0:
              pair1=torch.cat((triad0, triad2), dim=1).to('cuda')
              M2=calculateM(triads, pair1)
              mSquare[triad[0], triad[2]]=M2

            if mSquare[triad[1], triad[2]]==0:
              pair1=torch.cat((triad1, triad2), dim=1).to('cuda')
              M3=calculateM(triads, pair1)
              mSquare[triad[1], triad[2]]=M3

          for j, triad in enumerate(triads):
            triad0=triad[0].unsqueeze(0).unsqueeze(0)
            triad1=triad[1].unsqueeze(0).unsqueeze(0)
            triad2=triad[2].unsqueeze(0).unsqueeze(0)
            pair1=torch.cat((triad0, triad1), dim=1).squeeze(0).to('cuda')
            pair2=torch.cat((triad0, triad2), dim=1).squeeze(0).to('cuda')
            pair3=torch.cat((triad1, triad2), dim=1).squeeze(0).to('cuda')

            if pair1 in graph2D[:,0:2]:
              emb1=nodePosEmbedding[triad[0]]
            if pair1 in graph2DN:
              emb1=nodeNegEmbedding[triad[0]]
            if pair2 in graph2D[:,0:2]:
              emb2=nodePosEmbedding[triad[1]]
            if pair2 in graph2DN:
              emb2=nodeNegEmbedding[triad[1]]
            if pair3 in graph2D[:,0:2]:
              emb3=nodePosEmbedding[triad[2]]
            if pair3 in graph2DN:
              emb3=nodeNegEmbedding[triad[2]]

            eTriplet=model[1].forward(emb1, emb2, emb3)
            if pairIn(pair1,graph2D[:,0:2]):
              eSquare[triad[0], triad[1]]+=eTriplet[0]
            if pairIn(pair1,graph2DN):
              eSquare[triad[0], triad[1]]-=eTriplet[0]
            if pairIn(pair2,graph2D[:,0:2]):
              eSquare[triad[0], triad[2]]+=eTriplet[1]
            if pairIn(pair2,graph2DN):
              eSquare[triad[0], triad[2]]-=eTriplet[1]
            if pairIn(pair3,graph2D[:,0:2]):
              eSquare[triad[1], triad[2]]+=eTriplet[2]
            if pairIn(pair3,graph2DN):
              eSquare[triad[1], triad[2]]-=eTriplet[2]

          eSquare.clamp_(min=0.0)
          nonzeros=((mSquare>0).nonzero())
          adjRecon=[]
          adjPart=[]
          adjFull=torch.zeros((N,N)).to('cuda')
          for nz in nonzeros:
            nzPair=torch.tensor([int(nz[0]), int(nz[1])]).to('cuda')
            em1=eSquare[int(nz[0]), int(nz[1])]/mSquare[int(nz[0]), int(nz[1])]
            if pairIn(nzPair, graph2D[:,0:2]):
              adjRecon.append(em1)
              adjPart.append(torch.tensor([1.0]))
            if pairIn(nzPair,graph2DN):
              adjRecon.append(em1)
              adjPart.append(torch.tensor([0.0]))
            if em1>0.5:
              adjFull[int(nz[0]), int(nz[1])]=1
            else:
              adjFull[int(nz[0]), int(nz[1])]=0
          adjRecon=torch.stack(adjRecon).to('cuda')
          adjPart=torch.stack(adjPart).squeeze(1).to('cuda')

          loss=lossFunc(input=adjRecon, target=adjPart)
          lossSubjectV=lossSubjectV+loss

          adjReconC=adjRecon.clone().detach().cpu().numpy()
          adjPartC=adjPart.clone().detach().cpu().numpy()

          adjReconC[adjReconC>0.5]=1
          adjReconC[adjReconC<=0.5]=0
          F1=f1_score(adjPartC, adjReconC)
          AUROC=roc_auc_score(adjPartC, adjReconC)
          AP=average_precision_score(adjPartC, adjReconC)

          print("Validation / Subject : "+str((i+1)//2-1)+" / Timepoint : "+str(t) +" / Loss : "+str(loss.item())+" / AP : "+str(AP)+" / F1 : "+str(F1)+" / AUROC : "+str(AUROC)+'\n')


        if lossSubjectV>prevLoss:
          patience+=1
        else:
          patience=0

        if patience==1 and learning_rateE<=2.5e-5:
          torch.save({
          'epoch': epoch,
          'encoder': model[0].state_dict(),
          'decoder': model[1].state_dict(),
          'subject': i,
          'timeIdx': t,
          'optimizerE': optimizerE.state_dict(),
          'lossSubject': lossSubject,
          'patience': patience,
          'prevLoss': prevLoss,
          'lossSubjectV': lossSubjectV,
          'lr': learning_rateE
          },
          'checkpoint8.pth')
          raise Exception('Early Stopping')
        if patience==pt and learning_rateE>2.5e-5:
          learning_rateE=learning_rateE/2
          patience=0

    if lossSubjectV is not None:
      prevLoss=lossSubjectV
    torch.save({
        'epoch': epoch,
        'encoder': model[0].state_dict(),
        'decoder': model[1].state_dict(),
        'subject': i+1,
        'timeIdx': 0,
        'optimizerE': optimizerE.state_dict(),
        'lossSubject': None,
        'patience': patience,
        'prevLoss': prevLoss,
        'lossSubjectV': lossSubjectV,
        'lr': learning_rateE
        },
        'checkpoint8.pth')

    torch.cuda.empty_cache()

  torch.save({
              'epoch': epoch+1,
              'encoder': model[0].state_dict(),
              'decoder': model[1].state_dict(),
              'subject': 0,
              'timeIdx': 0,
              'optimizerE': optimizerE.state_dict(),
              'lossSubject': None,
              'patience': patience,
              'prevLoss': prevLoss,
              'lossSubjectV': lossSubjectV,
              'lr': learning_rateE
              },
              'checkpoint8.pth')


resuming checkpoint experiment
Epoch : 1


  checkpoint = torch.load('checkpoint8.pth', map_location='cuda')
  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')
  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))



Subject index : 3 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3844, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 5, 37, 40, 41, 87, 88], device='cuda:0')
Epoch : 1 / Subject : 3 / Timepoint : 0 / Loss : 0.22708791494369507 / AUROC : 0.9818903738136742 / F1 : 0.9815563665055725 / AP : 0.9844232026444585

Time index : 1 out of 15
tensor(0.3866, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 35,  37,  40,  41,  52,  87, 106], device='cuda:0')
Epoch : 1 / Subject : 3 / Timepoint : 1 / Loss : 0.22421588003635406 / AUROC : 0.9831252438548577 / F1 : 0.9828355987697193 / AP : 0.9854983737800311

Time index : 2 out of 15
tensor(0.3892, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  5,  21,  37,  38,  40,  52, 106], device='cuda:0')
Epoch : 1 / Subject : 3 / Timepoint : 2 / Loss : 0.21966208517551422 / AUROC : 0.9799567695028493 / F1 : 0.979546821736515 / AP : 0.9827476558626496

Time index : 3 out of 15
tensor(0.3877, device='cuda:0', grad_fn=<MeanB

  valid3D=torch.load(os.path.join('DatasetSplit','valid'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 1 out of 8
Validation / Subject : 1 / Timepoint : 0 / Loss : 0.21057777106761932 / AP : 0.9906430730117177 / F1 : 0.9888866838658464 / AUROC : 0.9890088321884201

Validation / Subject : 1 / Timepoint : 1 / Loss : 0.21078670024871826 / AP : 0.9912093862304165 / F1 : 0.989517819706499 / AUROC : 0.9896265560165975

Validation / Subject : 1 / Timepoint : 2 / Loss : 0.20829203724861145 / AP : 0.9902741060565458 / F1 : 0.9884965489646894 / AUROC : 0.9886273734177216

Validation / Subject : 1 / Timepoint : 3 / Loss : 0.21202948689460754 / AP : 0.9888648981999949 / F1 : 0.9866640127388535 / AUROC : 0.9868395207228442

Validation / Subject : 1 / Timepoint : 4 / Loss : 0.2074330598115921 / AP : 0.9887732489858675 / F1 : 0.9867463876432486 / AUROC : 0.9869197482297404

Validation / Subject : 1 / Timepoint : 5 / Loss : 0.21078164875507355 / AP : 0.9880540754843059 / F1 : 0.9856411286273722 / AUROC : 0.9858443872500495

Validation / Subject : 1 / Timepoint : 6 / Loss : 0.

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))



Subject index : 4 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3876, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,   4,   6,   8,   9,  21,  23,  24,  28,  29,  30,  31,  32,  34,
         36,  40,  66,  73,  82, 101, 107], device='cuda:0')
Epoch : 1 / Subject : 4 / Timepoint : 0 / Loss : 0.22912509739398956 / AUROC : 0.9778120779471349 / F1 : 0.9773086029992107 / AP : 0.9810638196482793

Time index : 1 out of 15
tensor(0.3873, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,   8,  28,  29,  31,  34,  37,  40,  42,  43,  45,  46,  47,  48,
         55,  66,  82, 101, 107], device='cuda:0')
Epoch : 1 / Subject : 4 / Timepoint : 1 / Loss : 0.22636009752750397 / AUROC : 0.979864472410455 / F1 : 0.9794507014424026 / AP : 0.9827276179743819

Time index : 2 out of 15
tensor(0.3856, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  8,  28,  29,  31,  40,  42,  43,  44,  45,  46,  48,  49,  51,  55,
         59,  63,  66,  80,  82, 101, 109, 113, 

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')



Subject index : 5 out of 16


  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3883, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,   2,   3,   4,   7,  13,  20,  24,  27,  36,  39,  40,  41,  42,
         43,  44,  45,  46,  47,  48,  49,  50,  51,  53,  55,  58,  59,  60,
         61,  65,  81,  83,  85,  87,  89,  91,  93,  94,  95,  97,  99, 101,
        108, 111, 113], device='cuda:0')
Epoch : 1 / Subject : 5 / Timepoint : 0 / Loss : 0.2238374501466751 / AUROC : 0.9795327283259316 / F1 : 0.9791050660358762 / AP : 0.9824160459238095

Time index : 1 out of 15
tensor(0.3911, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,   2,   3,  13,  24,  27,  39,  40,  41,  42,  43,  44,  45,  46,
         47,  48,  51,  53,  55,  59,  61,  83,  85,  87,  89,  94,  95,  97,
         99, 101, 103, 108, 111, 113], device='cuda:0')
Epoch : 1 / Subject : 5 / Timepoint : 1 / Loss : 0.22573025524616241 / AUROC : 0.9795682343870471 / F1 : 0.9791420700511609 / AP : 0.9826267734832353

Time index : 2 out

  valid3D=torch.load(os.path.join('DatasetSplit','valid'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 2 out of 8
Validation / Subject : 2 / Timepoint : 0 / Loss : 0.2071903795003891 / AP : 0.9901161042195914 / F1 : 0.9882049757161264 / AUROC : 0.9883424764890283

Validation / Subject : 2 / Timepoint : 1 / Loss : 0.20336194336414337 / AP : 0.9918432178631496 / F1 : 0.9903713892709766 / AUROC : 0.9904632152588556

Validation / Subject : 2 / Timepoint : 2 / Loss : 0.20100165903568268 / AP : 0.9912161915690987 / F1 : 0.9897189856065799 / AUROC : 0.9898236092265943

Validation / Subject : 2 / Timepoint : 3 / Loss : 0.2029627412557602 / AP : 0.9919761931004077 / F1 : 0.9906103286384976 / AUROC : 0.9906976744186047

Validation / Subject : 2 / Timepoint : 4 / Loss : 0.2027023434638977 / AP : 0.993788283234571 / F1 : 0.9926420092220151 / AUROC : 0.992695753798208

Validation / Subject : 2 / Timepoint : 5 / Loss : 0.2007509022951126 / AP : 0.9943095417077783 / F1 : 0.9933133055528637 / AUROC : 0.9933577204466693

Validation / Subject : 2 / Timepoint : 6 / Loss : 0.2017

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))



Subject index : 6 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.4029, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([16, 21, 38, 40, 56, 62, 79, 88], device='cuda:0')
Epoch : 1 / Subject : 6 / Timepoint : 0 / Loss : 0.22597737610340118 / AUROC : 0.9741980474198048 / F1 : 0.9735146743020758 / AP : 0.9786425038084289

Time index : 1 out of 15
tensor(0.3972, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 28,  40,  50,  62,  64,  80,  96, 113], device='cuda:0')
Epoch : 1 / Subject : 6 / Timepoint : 1 / Loss : 0.22300845384597778 / AUROC : 0.9759167492566898 / F1 : 0.9753224332283944 / AP : 0.9796908870265059

Time index : 2 out of 15
tensor(0.3916, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 10,  21,  28,  40,  41,  42,  44,  45,  50,  56,  62,  64,  77,  80,
         90,  92,  96, 113], device='cuda:0')
Epoch : 1 / Subject : 6 / Timepoint : 2 / Loss : 0.22723546624183655 / AUROC : 0.9755836575875486 / F1 : 0.9749725795193938 / AP : 0.979268524310

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')



Subject index : 7 out of 16


  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3896, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,   2,   6,   7,   8,  10,  12,  13,  15,  17,  19,  30,  32,  33,
         38,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,
         53,  54,  55,  56,  57,  60,  61,  63,  65,  66,  67,  71,  84,  85,
         86,  88,  90,  91,  92,  93,  98,  99, 100, 101, 103, 105, 107, 111],
       device='cuda:0')
Epoch : 1 / Subject : 7 / Timepoint : 0 / Loss : 0.2266380935907364 / AUROC : 0.9770597422881686 / F1 : 0.9765211309821161 / AP : 0.9804234917399819

Time index : 1 out of 15
tensor(0.3881, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,   2,   6,   7,   8,  10,  11,  12,  13,  14,  19,  24,  26,  27,
         30,  34,  36,  38,  40,  41,  42,  44,  46,  48,  49,  50,  51,  52,
         53,  54,  55,  57,  61,  65,  67,  69,  70,  71,  77,  86,  90,  91,
         92,  93,  98,  99, 100, 101, 103, 104, 105, 107, 111],
       device='cuda:0

  valid3D=torch.load(os.path.join('DatasetSplit','valid'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 3 out of 8
Validation / Subject : 3 / Timepoint : 0 / Loss : 0.20382912456989288 / AP : 0.9904266891943506 / F1 : 0.9886777591808605 / AUROC : 0.9888045171339563

Validation / Subject : 3 / Timepoint : 1 / Loss : 0.20086340606212616 / AP : 0.9892503054923268 / F1 : 0.9873642645607108 / AUROC : 0.9875219341002145

Validation / Subject : 3 / Timepoint : 2 / Loss : 0.20127616822719574 / AP : 0.9896326722679721 / F1 : 0.9877965439812555 / AUROC : 0.9879436728395061

Validation / Subject : 3 / Timepoint : 3 / Loss : 0.20090089738368988 / AP : 0.9896057709297786 / F1 : 0.9877240841777085 / AUROC : 0.9878729547641963

Validation / Subject : 3 / Timepoint : 4 / Loss : 0.21210989356040955 / AP : 0.9976722508175512 / F1 : 0.9972266244057052 / AUROC : 0.99723429474516

Validation / Subject : 3 / Timepoint : 5 / Loss : 0.19854477047920227 / AP : 0.9935369328804656 / F1 : 0.9924094978590892 / AUROC : 0.9924666795441375

Validation / Subject : 3 / Timepoint : 6 / Loss : 0.

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))



Subject index : 8 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3931, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,   2,   3,   7,  11,  13,  17,  23,  38,  39,  40,  53,  63,  68,
         76,  81,  85,  93,  94,  95, 101, 109], device='cuda:0')
Epoch : 1 / Subject : 8 / Timepoint : 0 / Loss : 0.2177598923444748 / AUROC : 0.9817347138112913 / F1 : 0.981394886080987 / AP : 0.9843632823711859

Time index : 1 out of 15
tensor(0.3904, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  2,   3,   5,   7,   9,  17,  19,  21,  22,  23,  24,  34,  38,  39,
         40,  42,  47,  50,  51,  52,  53,  55,  56,  58,  60,  63,  64,  66,
         68,  69,  78,  80,  81,  84,  87,  88,  90,  91,  93,  94, 107],
       device='cuda:0')
Epoch : 1 / Subject : 8 / Timepoint : 1 / Loss : 0.22423212230205536 / AUROC : 0.9783867232728676 / F1 : 0.9779092702169625 / AP : 0.9815330353597067

Time index : 2 out of 15
tensor(0.3898, device='cuda:0', grad_fn=<MeanBackward0

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')



Subject index : 9 out of 16


  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3916, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 23,  32,  36,  39,  40,  52,  53,  54,  55,  68,  72,  74,  87,  88,
        114, 115], device='cuda:0')
Epoch : 1 / Subject : 9 / Timepoint : 0 / Loss : 0.22216688096523285 / AUROC : 0.9746277429467085 / F1 : 0.9739672328877275 / AP : 0.9782620935762482

Time index : 1 out of 15
tensor(0.3952, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  6,   7,  11,  17,  18,  19,  22,  25,  30,  32,  33,  36,  39,  40,
         52,  53,  54,  55,  56,  62,  63,  68,  72,  74,  76,  78,  79,  84,
         87,  88,  89,  99, 114], device='cuda:0')
Epoch : 1 / Subject : 9 / Timepoint : 1 / Loss : 0.22557367384433746 / AUROC : 0.9742790335151987 / F1 : 0.9736 / AP : 0.9782744110195916

Time index : 2 out of 15
tensor(0.3974, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([16, 18, 19, 36, 40, 56, 62, 72, 78, 79, 80], device='cuda:0')
Epoch : 1 / Subject : 9 / Timepoint : 2 / L

  valid3D=torch.load(os.path.join('DatasetSplit','valid'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 4 out of 8
Validation / Subject : 4 / Timepoint : 0 / Loss : 0.20166558027267456 / AP : 0.9902276300820525 / F1 : 0.9882191242882388 / AUROC : 0.9883562973025422

Validation / Subject : 4 / Timepoint : 1 / Loss : 0.2025403529405594 / AP : 0.9922004020028203 / F1 : 0.9904780411970463 / AUROC : 0.9905678537054861

Validation / Subject : 4 / Timepoint : 2 / Loss : 0.19814689457416534 / AP : 0.9923342164958967 / F1 : 0.9909108871764474 / AUROC : 0.9909927550420992

Validation / Subject : 4 / Timepoint : 3 / Loss : 0.19909712672233582 / AP : 0.9920640024139318 / F1 : 0.9905596107055961 / AUROC : 0.9906478981874277

Validation / Subject : 4 / Timepoint : 4 / Loss : 0.19617627561092377 / AP : 0.9939751762056553 / F1 : 0.992923137409082 / AUROC : 0.9929728674604723

Validation / Subject : 4 / Timepoint : 5 / Loss : 0.1976359486579895 / AP : 0.9926957259252286 / F1 : 0.9913793103448276 / AUROC : 0.9914529914529915

Validation / Subject : 4 / Timepoint : 6 / Loss : 0.1

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))



Subject index : 10 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3939, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 35,  39,  40,  41,  78,  87,  95,  96, 109], device='cuda:0')
Epoch : 1 / Subject : 10 / Timepoint : 0 / Loss : 0.22314892709255219 / AUROC : 0.9767101929448451 / F1 : 0.976154843859124 / AP : 0.980217111980059

Time index : 1 out of 15
tensor(0.3947, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([40, 41, 64, 74, 75, 78], device='cuda:0')
Epoch : 1 / Subject : 10 / Timepoint : 1 / Loss : 0.22097021341323853 / AUROC : 0.9773830324208892 / F1 : 0.9768596682888072 / AP : 0.980758925171988

Time index : 2 out of 15
tensor(0.3942, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([40, 75], device='cuda:0')
Epoch : 1 / Subject : 10 / Timepoint : 2 / Loss : 0.22163984179496765 / AUROC : 0.9769503546099291 / F1 : 0.9764065335753176 / AP : 0.9803742922043587

Time index : 3 out of 15
tensor(0.3915, device='cuda:0', grad_fn=<MeanBackward0>)
tens

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')



Subject index : 11 out of 16


  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3994, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 13,  29,  36,  37,  38,  39,  40,  42,  44,  45,  46,  48,  49,  53,
         54,  55,  58,  67,  72,  73,  74,  77,  79,  80,  81,  86,  89,  90,
         91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103, 104,
        105, 108, 109, 110, 111, 112, 113, 114, 115], device='cuda:0')
Epoch : 1 / Subject : 11 / Timepoint : 0 / Loss : 0.2189072072505951 / AUROC : 0.9808176709939934 / F1 : 0.9804425128407744 / AP : 0.9838326514050272

Time index : 1 out of 15
tensor(0.3991, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 13,  29,  36,  38,  39,  40,  42,  44,  45,  46,  48,  49,  53,  54,
         55,  58,  72,  73,  74,  77,  80,  81,  86,  89,  90,  91,  92,  93,
         94,  96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 108, 109, 110,
        111, 112, 113, 114, 115], device='cuda:0')
Epoch : 1 / Subject : 11 / Timepoint : 1 / Loss : 0.2181131839752

  valid3D=torch.load(os.path.join('DatasetSplit','valid'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 5 out of 8
Validation / Subject : 5 / Timepoint : 0 / Loss : 0.19679400324821472 / AP : 0.9912467988473291 / F1 : 0.9896625707041155 / AUROC : 0.9897683397683398

Validation / Subject : 5 / Timepoint : 1 / Loss : 0.19635999202728271 / AP : 0.9908041373870508 / F1 : 0.9891970802919708 / AUROC : 0.9893125361062969

Validation / Subject : 5 / Timepoint : 2 / Loss : 0.19814267754554749 / AP : 0.9904216331807693 / F1 : 0.9886531820424272 / AUROC : 0.988780487804878

Validation / Subject : 5 / Timepoint : 3 / Loss : 0.1998693197965622 / AP : 0.9901373764891213 / F1 : 0.9881791993642595 / AUROC : 0.9883172982525035

Validation / Subject : 5 / Timepoint : 4 / Loss : 0.20101967453956604 / AP : 0.9897055352143478 / F1 : 0.9876790543924672 / AUROC : 0.9878290124678408

Validation / Subject : 5 / Timepoint : 5 / Loss : 0.198898583650589 / AP : 0.9903754017078166 / F1 : 0.9885370118515641 / AUROC : 0.988666922781406

Validation / Subject : 5 / Timepoint : 6 / Loss : 0.198

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))



Subject index : 12 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3928, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,  13,  17,  21,  22,  23,  24,  26,  27,  30,  31,  32,  35,  36,
         40,  43,  44,  45,  46,  47,  57,  64,  65,  66,  71,  72,  78,  79,
         81,  82,  83,  84,  85,  86,  87,  95,  96,  97, 101, 104, 114, 115],
       device='cuda:0')
Epoch : 1 / Subject : 12 / Timepoint : 0 / Loss : 0.21716007590293884 / AUROC : 0.9797389440872784 / F1 : 0.979319944322927 / AP : 0.9826069885859879

Time index : 1 out of 15
tensor(0.3937, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,  16,  17,  18,  21,  22,  23,  24,  27,  31,  32,  35,  36,  40,
         43,  47,  57,  64,  65,  69,  71,  72,  77,  78,  79,  80,  81,  82,
         84,  87,  94,  95,  96, 114, 115], device='cuda:0')
Epoch : 1 / Subject : 12 / Timepoint : 1 / Loss : 0.21752046048641205 / AUROC : 0.9821981424148607 / F1 : 0.9818754925137904 / AP : 0.984750267130606

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')



Subject index : 13 out of 16


  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3984, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,   2,   3,   6,  16,  17,  18,  28,  30,  31,  32,  33,  34,  35,
         38,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,
         53,  54,  55,  56,  57,  58,  59,  62,  64,  66,  67,  73,  76,  79,
         80,  81,  85,  89,  92,  93,  98,  99, 100, 101, 102, 109, 110],
       device='cuda:0')
Epoch : 1 / Subject : 13 / Timepoint : 0 / Loss : 0.21666280925273895 / AUROC : 0.9803288314738696 / F1 : 0.9799341120095837 / AP : 0.9833413329624103

Time index : 1 out of 15
tensor(0.3988, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,   2,   6,  16,  17,  26,  28,  30,  31,  32,  33,  34,  35,  38,
         40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,
         54,  55,  56,  57,  58,  59,  62,  64,  66,  67,  72,  76,  80,  81,
         85,  93,  96,  98,  99, 100, 101, 109], device='cuda:0')
Epoch : 1 / Subject : 

  valid3D=torch.load(os.path.join('DatasetSplit','valid'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 6 out of 8
Validation / Subject : 6 / Timepoint : 0 / Loss : 0.2074742466211319 / AP : 0.9880241691326596 / F1 : 0.9850595598627094 / AUROC : 0.9852794907499502

Validation / Subject : 6 / Timepoint : 1 / Loss : 0.20313693583011627 / AP : 0.9880431372451115 / F1 : 0.9854545454545455 / AUROC : 0.985663082437276

Validation / Subject : 6 / Timepoint : 2 / Loss : 0.2020639032125473 / AP : 0.9883357461721894 / F1 : 0.9859525012632643 / AUROC : 0.9861470998604744

Validation / Subject : 6 / Timepoint : 3 / Loss : 0.19874262809753418 / AP : 0.9878227879406557 / F1 : 0.9856659366912204 / AUROC : 0.9858684985279687

Validation / Subject : 6 / Timepoint : 4 / Loss : 0.2003001570701599 / AP : 0.9880379365272826 / F1 : 0.9857029388403494 / AUROC : 0.9859044635865309

Validation / Subject : 6 / Timepoint : 5 / Loss : 0.19694110751152039 / AP : 0.9903743523201497 / F1 : 0.9885820240070264 / AUROC : 0.9887109224237747

Validation / Subject : 6 / Timepoint : 6 / Loss : 0.19

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))



Subject index : 14 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3918, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  4,  19,  21,  28,  29,  36,  40,  41,  46,  47,  49,  51,  53,  55,
         56,  57,  58,  59,  62,  66,  67,  68,  69,  72,  73,  74,  75,  76,
         79,  83,  95,  96,  98, 106, 110, 111, 113], device='cuda:0')
Epoch : 1 / Subject : 14 / Timepoint : 0 / Loss : 0.22013132274150848 / AUROC : 0.9776776972782455 / F1 : 0.9771680352493491 / AP : 0.9808416453948546

Time index : 1 out of 15
tensor(0.3920, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 17,  21,  25,  36,  40,  41,  46,  47,  53,  54,  55,  56,  57,  58,
         59,  66,  67,  68,  69,  72,  73,  74,  75,  76,  79,  80,  81,  83,
         95,  96,  97,  98, 110, 111, 113], device='cuda:0')
Epoch : 1 / Subject : 14 / Timepoint : 1 / Loss : 0.22059008479118347 / AUROC : 0.9759259259259259 / F1 : 0.9753320683111955 / AP : 0.9793451733122258

Time index : 2 out of 15
ten

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')



Subject index : 15 out of 16


  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.4071, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([36, 40, 86], device='cuda:0')
Epoch : 1 / Subject : 15 / Timepoint : 0 / Loss : 0.22376112639904022 / AUROC : 0.9736737136019147 / F1 : 0.9729619008603032 / AP : 0.9783661089098568

Time index : 1 out of 15
tensor(0.4067, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([40], device='cuda:0')
Epoch : 1 / Subject : 15 / Timepoint : 1 / Loss : 0.2223655879497528 / AUROC : 0.9741293532338309 / F1 : 0.9734422880490297 / AP : 0.9786751781569457

Time index : 2 out of 15
tensor(0.4069, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([40], device='cuda:0')
Epoch : 1 / Subject : 15 / Timepoint : 2 / Loss : 0.22338609397411346 / AUROC : 0.9760627731426301 / F1 : 0.975475730131271 / AP : 0.9803126807881842

Time index : 3 out of 15
tensor(0.4050, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([40], device='cuda:0')
Epoch : 1 / Subject : 15 / Timepoint : 3 / Loss : 0.22290568

  valid3D=torch.load(os.path.join('DatasetSplit','valid'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 7 out of 8
Validation / Subject : 7 / Timepoint : 0 / Loss : 0.20069259405136108 / AP : 0.9919767638513497 / F1 : 0.9901652242328874 / AUROC : 0.9902610050642774

Validation / Subject : 7 / Timepoint : 1 / Loss : 0.1943122148513794 / AP : 0.9953136897610178 / F1 : 0.9943670323154462 / AUROC : 0.9943985849056604

Validation / Subject : 7 / Timepoint : 2 / Loss : 0.1938323825597763 / AP : 0.9960751848401488 / F1 : 0.9952913478516775 / AUROC : 0.9953134153485648

Validation / Subject : 7 / Timepoint : 3 / Loss : 0.19741004705429077 / AP : 0.9925193797275649 / F1 : 0.9909820632246557 / AUROC : 0.9910626595953644

Validation / Subject : 7 / Timepoint : 4 / Loss : 0.19975729286670685 / AP : 0.9900229364934171 / F1 : 0.9879253567508233 / AUROC : 0.9880694143167028

Validation / Subject : 7 / Timepoint : 5 / Loss : 0.1993909478187561 / AP : 0.990185691860575 / F1 : 0.9881485907778109 / AUROC : 0.9882874015748031

Validation / Subject : 7 / Timepoint : 6 / Loss : 0.20

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')
  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')



Subject index : 0 out of 16


  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3953, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 3,  7, 23, 25, 26, 30, 31, 34, 40, 42, 49, 51, 64, 65, 67, 72, 94],
       device='cuda:0')
Epoch : 2 / Subject : 0 / Timepoint : 0 / Loss : 0.21996457874774933 / AUROC : 0.9775710088148873 / F1 : 0.9770564071736298 / AP : 0.9809099933246098

Time index : 1 out of 15
tensor(0.3987, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  3,   7,  13,  18,  22,  23,  28,  30,  31,  36,  40,  42,  48,  49,
         51,  54,  77,  80,  82,  83,  86,  94,  97,  99, 107],
       device='cuda:0')
Epoch : 2 / Subject : 0 / Timepoint : 1 / Loss : 0.21925915777683258 / AUROC : 0.9776251226692836 / F1 : 0.9771130295121462 / AP : 0.9810947250483836

Time index : 2 out of 15
tensor(0.3962, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 7, 13, 18, 28, 36, 38, 40, 42, 48, 49, 51, 53, 54, 77, 83, 86, 94],
       device='cuda:0')
Epoch : 2 / Subject : 0 / Timepoint : 2 / Loss : 0.222

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')



Subject index : 1 out of 16


  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.4061, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 7, 11, 13, 19, 22, 28, 29, 30, 40, 42, 43, 46, 49, 50, 51, 52, 53, 55,
        59, 61, 68, 69, 73, 78], device='cuda:0')
Epoch : 2 / Subject : 1 / Timepoint : 0 / Loss : 0.222495436668396 / AUROC : 0.9758016331408086 / F1 : 0.9752015511786917 / AP : 0.9800143501118522

Time index : 1 out of 15
tensor(0.3957, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  4,   5,   7,  11,  13,  22,  24,  27,  28,  29,  30,  40,  42,  43,
         48,  49,  50,  51,  52,  53,  55,  58,  59,  61,  65,  68,  69,  73,
         74,  78, 108, 115], device='cuda:0')
Epoch : 2 / Subject : 1 / Timepoint : 1 / Loss : 0.22151842713356018 / AUROC : 0.9771622934888241 / F1 : 0.9766285430134262 / AP : 0.980622706941838

Time index : 2 out of 15
tensor(0.3899, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,  14,  15,  18,
         19,  2

  valid3D=torch.load(os.path.join('DatasetSplit','valid'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 0 out of 8
Validation / Subject : 0 / Timepoint : 0 / Loss : 0.19115430116653442 / AP : 0.993305731502881 / F1 : 0.9921522464194624 / AUROC : 0.9922133540977225

Validation / Subject : 0 / Timepoint : 1 / Loss : 0.1923324316740036 / AP : 0.9933794309203073 / F1 : 0.9921614736429551 / AUROC : 0.9922224382656037

Validation / Subject : 0 / Timepoint : 2 / Loss : 0.1918681263923645 / AP : 0.9936024169154845 / F1 : 0.992444313610048 / AUROC : 0.9925009738994937

Validation / Subject : 0 / Timepoint : 3 / Loss : 0.19340823590755463 / AP : 0.992782100507876 / F1 : 0.9914378506052554 / AUROC : 0.9915105386416863

Validation / Subject : 0 / Timepoint : 4 / Loss : 0.19122196733951569 / AP : 0.993467764876766 / F1 : 0.9923167848699763 / AUROC : 0.992375366568915

Validation / Subject : 0 / Timepoint : 5 / Loss : 0.19472187757492065 / AP : 0.9908139590003616 / F1 : 0.9891421304900714 / AUROC : 0.9892587574995162

Validation / Subject : 0 / Timepoint : 6 / Loss : 0.19608

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))



Subject index : 2 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3946, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,   7,  11,  13,  19,  25,  27,  29,  36,  40,  42,  43,  44,  45,
         46,  47,  49,  51,  53,  55,  56,  57,  58,  59,  62,  63,  65,  67,
         69,  72,  77,  78,  81,  94,  97,  98,  99, 103, 104, 109, 111],
       device='cuda:0')
Epoch : 2 / Subject : 2 / Timepoint : 0 / Loss : 0.22022657096385956 / AUROC : 0.9770472670686636 / F1 : 0.9765080629106112 / AP : 0.9804250038461302

Time index : 1 out of 15
tensor(0.3916, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,   7,  11,  13,  19,  29,  36,  40,  42,  43,  44,  45,  46,  49,
         57,  58,  59,  61,  62,  63,  67,  69,  73,  77,  78,  81,  94,  98,
         99, 104, 109, 111], device='cuda:0')
Epoch : 2 / Subject : 2 / Timepoint : 1 / Loss : 0.21849694848060608 / AUROC : 0.975609756097561 / F1 : 0.975 / AP : 0.9789817421290747

Time index : 2 out of 15
tensor(0

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')



Subject index : 3 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15


  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))


tensor(0.3946, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  5,  21,  31,  35,  37,  40,  41,  52,  87,  88, 106],
       device='cuda:0')
Epoch : 2 / Subject : 3 / Timepoint : 0 / Loss : 0.2121940702199936 / AUROC : 0.9859957504346146 / F1 : 0.9857968459202664 / AP : 0.9879589053401026

Time index : 1 out of 15
tensor(0.3999, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  5,  12,  21,  25,  35,  37,  40,  41,  52,  53,  87, 106],
       device='cuda:0')
Epoch : 2 / Subject : 3 / Timepoint : 1 / Loss : 0.21318994462490082 / AUROC : 0.9851973684210527 / F1 : 0.9849749582637729 / AP : 0.9874611060557138

Time index : 2 out of 15
tensor(0.3979, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  5,  21,  37,  38,  40,  52, 106], device='cuda:0')
Epoch : 2 / Subject : 3 / Timepoint : 2 / Loss : 0.20851637423038483 / AUROC : 0.986364528153816 / F1 : 0.9861760318249627 / AP : 0.9882980527234025

Time index : 3 out of 15
tensor(0.3960, device='cuda:0', grad_fn=<MeanBackward0>)
tens

  valid3D=torch.load(os.path.join('DatasetSplit','valid'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 1 out of 8
Validation / Subject : 1 / Timepoint : 0 / Loss : 0.19600310921669006 / AP : 0.9902871999302749 / F1 : 0.9884467265725289 / AUROC : 0.9885786802030456

Validation / Subject : 1 / Timepoint : 1 / Loss : 0.19550782442092896 / AP : 0.9908138714044628 / F1 : 0.9890547263681592 / AUROC : 0.9891732283464567

Validation / Subject : 1 / Timepoint : 2 / Loss : 0.19564925134181976 / AP : 0.9903162306507456 / F1 : 0.988482922954726 / AUROC : 0.988614055751865

Validation / Subject : 1 / Timepoint : 3 / Loss : 0.1988307684659958 / AP : 0.9888181999866764 / F1 : 0.9865871833084948 / AUROC : 0.986764705882353

Validation / Subject : 1 / Timepoint : 4 / Loss : 0.197074294090271 / AP : 0.9891531841667807 / F1 : 0.9870780326555144 / AUROC : 0.9872428797468354

Validation / Subject : 1 / Timepoint : 5 / Loss : 0.198197141289711 / AP : 0.9874953660594398 / F1 : 0.9851703406813628 / AUROC : 0.985387045813586

Validation / Subject : 1 / Timepoint : 6 / Loss : 0.1958873

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))



Subject index : 4 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3920, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,   4,   6,   8,   9,  21,  23,  24,  28,  29,  30,  31,  32,  34,
         36,  40,  66,  73,  82, 101, 107], device='cuda:0')
Epoch : 2 / Subject : 4 / Timepoint : 0 / Loss : 0.2181396484375 / AUROC : 0.9785465790490916 / F1 : 0.9780762393837645 / AP : 0.9815371131741577

Time index : 1 out of 15
tensor(0.3927, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,   8,  28,  29,  31,  34,  37,  40,  42,  43,  45,  46,  47,  48,
         55,  66,  82, 101, 107], device='cuda:0')
Epoch : 2 / Subject : 4 / Timepoint : 1 / Loss : 0.22028520703315735 / AUROC : 0.9779752704791345 / F1 : 0.9774792572105887 / AP : 0.9811466952294591

Time index : 2 out of 15
tensor(0.3958, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  8,  28,  29,  31,  40,  42,  43,  44,  45,  46,  48,  49,  51,  55,
         59,  63,  66,  80,  82, 101, 109, 113, 114

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')



Subject index : 5 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15


  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))


tensor(0.3939, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,   2,   3,   4,   7,  13,  20,  24,  27,  36,  39,  40,  41,  42,
         43,  44,  45,  46,  47,  48,  49,  50,  51,  53,  55,  58,  59,  60,
         61,  81,  83,  85,  87,  89,  91,  93,  94,  95,  97,  99, 101, 103,
        108, 111, 113], device='cuda:0')
Epoch : 2 / Subject : 5 / Timepoint : 0 / Loss : 0.21830277144908905 / AUROC : 0.9813850308641976 / F1 : 0.981031941031941 / AP : 0.9840701389825612

Time index : 1 out of 15
tensor(0.3952, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,   2,   3,  13,  24,  27,  39,  40,  41,  42,  43,  44,  45,  46,
         47,  48,  51,  53,  55,  59,  61,  83,  85,  87,  89,  94,  95,  97,
         99, 101, 103, 108, 111, 113], device='cuda:0')
Epoch : 2 / Subject : 5 / Timepoint : 1 / Loss : 0.21866823732852936 / AUROC : 0.9792430971229967 / F1 : 0.9788031154490782 / AP : 0.9823009926509558

Time index : 2 out of 15
tensor(0.3908, device='cuda:0', grad_fn=<MeanBac

  valid3D=torch.load(os.path.join('DatasetSplit','valid'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 2 out of 8
Validation / Subject : 2 / Timepoint : 0 / Loss : 0.19400925934314728 / AP : 0.990089040741316 / F1 : 0.9883408071748879 / AUROC : 0.9884751773049645

Validation / Subject : 2 / Timepoint : 1 / Loss : 0.19138582050800323 / AP : 0.9918242027940488 / F1 : 0.9904386397240019 / AUROC : 0.9905291935168912

Validation / Subject : 2 / Timepoint : 2 / Loss : 0.1931968778371811 / AP : 0.9913357869862789 / F1 : 0.9897610921501706 / AUROC : 0.9898648648648649

Validation / Subject : 2 / Timepoint : 3 / Loss : 0.1912250518798828 / AP : 0.9923332776492011 / F1 : 0.9910121141070731 / AUROC : 0.9910921766072811

Validation / Subject : 2 / Timepoint : 4 / Loss : 0.18946267664432526 / AP : 0.9936062476055096 / F1 : 0.9925211572525093 / AUROC : 0.9925766751318617

Validation / Subject : 2 / Timepoint : 5 / Loss : 0.19011004269123077 / AP : 0.9941732952859069 / F1 : 0.9931261496756705 / AUROC : 0.9931730769230769

Validation / Subject : 2 / Timepoint : 6 / Loss : 0.1

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))



Subject index : 6 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.4052, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([16, 21, 38, 40, 56, 62, 79, 88], device='cuda:0')
Epoch : 2 / Subject : 6 / Timepoint : 0 / Loss : 0.22354845702648163 / AUROC : 0.9735897946980268 / F1 : 0.9728733749616133 / AP : 0.9781592480352865

Time index : 1 out of 15
tensor(0.4007, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 28,  40,  50,  62,  64,  80,  96, 113], device='cuda:0')
Epoch : 2 / Subject : 6 / Timepoint : 1 / Loss : 0.2221885770559311 / AUROC : 0.9730343737653102 / F1 : 0.9722870774540656 / AP : 0.9773976435599984

Time index : 2 out of 15
tensor(0.3969, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 10,  21,  28,  40,  41,  42,  44,  45,  50,  56,  62,  64,  77,  80,
         90,  92,  96, 113], device='cuda:0')
Epoch : 2 / Subject : 6 / Timepoint : 2 / Loss : 0.21994224190711975 / AUROC : 0.9764878048780488 / F1 : 0.9759216704965531 / AP : 0.9800573043285

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')



Subject index : 7 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15


  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))


tensor(0.3940, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,   2,   6,   7,   8,  11,  12,  13,  15,  17,  19,  30,  32,  33,
         38,  40,  41,  42,  43,  45,  46,  47,  49,  51,  53,  54,  55,  57,
         61,  63,  65,  66,  67,  71,  72,  85,  86,  88,  91,  92,  93,  98,
         99, 100, 101, 103, 104, 105, 107, 111], device='cuda:0')
Epoch : 2 / Subject : 7 / Timepoint : 0 / Loss : 0.21983455121517181 / AUROC : 0.9759129759129759 / F1 : 0.9753184713375797 / AP : 0.9794336552657826

Time index : 1 out of 15
tensor(0.3911, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,   2,   6,   7,   8,  10,  11,  12,  13,  14,  19,  24,  30,  34,
         38,  40,  41,  42,  43,  46,  49,  50,  51,  52,  53,  54,  55,  57,
         61,  65,  67,  69,  70,  71,  77,  90,  92,  93,  98,  99, 100, 101,
        103, 104, 105, 108, 111], device='cuda:0')
Epoch : 2 / Subject : 7 / Timepoint : 1 / Loss : 0.2189498394727707 / AUROC : 0.9777581814618852 / F1 : 0.9772522296823329 / 

  valid3D=torch.load(os.path.join('DatasetSplit','valid'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 3 out of 8
Validation / Subject : 3 / Timepoint : 0 / Loss : 0.19626280665397644 / AP : 0.990866247659949 / F1 : 0.9890196078431372 / AUROC : 0.9891388673390225

Validation / Subject : 3 / Timepoint : 1 / Loss : 0.19505129754543304 / AP : 0.9896470523326254 / F1 : 0.9877366820366918 / AUROC : 0.9878852490792789

Validation / Subject : 3 / Timepoint : 2 / Loss : 0.19461430609226227 / AP : 0.990225056525223 / F1 : 0.9884117246080436 / AUROC : 0.988544474393531

Validation / Subject : 3 / Timepoint : 3 / Loss : 0.19470196962356567 / AP : 0.9895751943584538 / F1 : 0.9876856919468334 / AUROC : 0.9878354894767329

Validation / Subject : 3 / Timepoint : 4 / Loss : 0.1896277219057083 / AP : 0.9975415213085058 / F1 : 0.9971008697390783 / AUROC : 0.997109250398724

Validation / Subject : 3 / Timepoint : 5 / Loss : 0.19052934646606445 / AP : 0.9935317326050704 / F1 : 0.9923723841189126 / AUROC : 0.9924301242236024

Validation / Subject : 3 / Timepoint : 6 / Loss : 0.191

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))



Subject index : 8 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3958, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,   2,   3,   7,  13,  17,  23,  38,  39,  40,  53,  63,  68,  76,
         81,  85,  93,  94,  95, 101, 109], device='cuda:0')
Epoch : 2 / Subject : 8 / Timepoint : 0 / Loss : 0.215034618973732 / AUROC : 0.9817120622568094 / F1 : 0.9813713832738803 / AP : 0.9843711997225323

Time index : 1 out of 15
tensor(0.3906, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  2,   3,   7,   9,  13,  16,  17,  19,  21,  22,  23,  24,  34,  38,
         39,  40,  42,  47,  50,  51,  52,  53,  55,  56,  58,  60,  63,  64,
         66,  68,  69,  78,  80,  81,  84,  87,  88,  89,  90,  91,  93,  94,
        107], device='cuda:0')
Epoch : 2 / Subject : 8 / Timepoint : 1 / Loss : 0.21629981696605682 / AUROC : 0.979469300794112 / F1 : 0.9790389559027091 / AP : 0.9822327450889381

Time index : 2 out of 15
tensor(0.3913, device='cuda:0', grad_fn=<MeanBack

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')



Subject index : 9 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15


  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))


tensor(0.3920, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 18,  23,  32,  36,  39,  40,  52,  53,  54,  55,  72,  74,  87,  88,
        114, 115], device='cuda:0')
Epoch : 2 / Subject : 9 / Timepoint : 0 / Loss : 0.21777470409870148 / AUROC : 0.975896531452087 / F1 : 0.9753012048192771 / AP : 0.9792178120680536

Time index : 1 out of 15
tensor(0.3961, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 6,  7, 11, 17, 18, 19, 22, 23, 25, 30, 32, 33, 36, 39, 40, 52, 53, 54,
        55, 56, 62, 63, 68, 72, 74, 76, 78, 79, 84, 87, 88, 89, 97, 99],
       device='cuda:0')
Epoch : 2 / Subject : 9 / Timepoint : 1 / Loss : 0.22165536880493164 / AUROC : 0.9764224473889322 / F1 : 0.975853123129116 / AP : 0.9799992933415986

Time index : 2 out of 15
tensor(0.4014, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([16, 18, 19, 36, 40, 56, 62, 72, 78, 79, 80], device='cuda:0')
Epoch : 2 / Subject : 9 / Timepoint : 2 / Loss : 0.2177683413028717 / AUROC : 0.9806042884990254 / F1 : 0.980220654010

  valid3D=torch.load(os.path.join('DatasetSplit','valid'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 4 out of 8
Validation / Subject : 4 / Timepoint : 0 / Loss : 0.19649668037891388 / AP : 0.9910070463617028 / F1 : 0.9891165800568683 / AUROC : 0.9892337536372454

Validation / Subject : 4 / Timepoint : 1 / Loss : 0.19447015225887299 / AP : 0.9923776922734259 / F1 : 0.9908221050576059 / AUROC : 0.9909055727554179

Validation / Subject : 4 / Timepoint : 2 / Loss : 0.19380199909210205 / AP : 0.9922000690439426 / F1 : 0.990672557682867 / AUROC : 0.9907587548638133

Validation / Subject : 4 / Timepoint : 3 / Loss : 0.19270265102386475 / AP : 0.9921507988591319 / F1 : 0.9906778015148573 / AUROC : 0.9907639022512988

Validation / Subject : 4 / Timepoint : 4 / Loss : 0.19118377566337585 / AP : 0.9941759113986484 / F1 : 0.9930589500439926 / AUROC : 0.9931067961165049

Validation / Subject : 4 / Timepoint : 5 / Loss : 0.19122068583965302 / AP : 0.9927764177529014 / F1 : 0.9914730961481917 / AUROC : 0.9915451895043732

Validation / Subject : 4 / Timepoint : 6 / Loss : 0

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))



Subject index : 10 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3943, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 35,  39,  40,  41,  78,  87,  95,  96, 109], device='cuda:0')
Epoch : 2 / Subject : 10 / Timepoint : 0 / Loss : 0.22015532851219177 / AUROC : 0.9769351055512119 / F1 : 0.976390556222489 / AP : 0.9803127881332361

Time index : 1 out of 15
tensor(0.3949, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([40, 41, 64, 74, 75, 78], device='cuda:0')
Epoch : 2 / Subject : 10 / Timepoint : 1 / Loss : 0.2173086702823639 / AUROC : 0.9774818876052477 / F1 : 0.9769631410256411 / AP : 0.9807254239560151

Time index : 2 out of 15
tensor(0.3958, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([40, 73, 75], device='cuda:0')
Epoch : 2 / Subject : 10 / Timepoint : 2 / Loss : 0.21675440669059753 / AUROC : 0.9774465581486567 / F1 : 0.9769261637239165 / AP : 0.9807068087776724

Time index : 3 out of 15
tensor(0.3974, device='cuda:0', grad_fn=<MeanBackward0>)

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')



Subject index : 11 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15


  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))


tensor(0.3973, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 13,  29,  36,  37,  38,  39,  40,  42,  44,  45,  46,  48,  49,  53,
         54,  55,  58,  67,  72,  74,  77,  79,  80,  86,  89,  90,  91,  92,
         93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 108,
        109, 110, 111, 112, 113, 114, 115], device='cuda:0')
Epoch : 2 / Subject : 11 / Timepoint : 0 / Loss : 0.21534481644630432 / AUROC : 0.9792559407869108 / F1 : 0.9788165091994033 / AP : 0.9823081054594209

Time index : 1 out of 15
tensor(0.4008, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 13,  29,  36,  38,  39,  40,  42,  44,  45,  46,  48,  49,  53,  54,
         55,  58,  72,  73,  74,  77,  80,  81,  86,  89,  90,  91,  92,  93,
         94,  95,  96,  97,  99, 100, 101, 102, 103, 104, 105, 108, 109, 110,
        111, 112, 113, 114, 115], device='cuda:0')
Epoch : 2 / Subject : 11 / Timepoint : 1 / Loss : 0.2149111032485962 / AUROC : 0.981376200744952 / F1 : 0.9810227726727927 / AP :

  valid3D=torch.load(os.path.join('DatasetSplit','valid'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 5 out of 8
Validation / Subject : 5 / Timepoint : 0 / Loss : 0.193945050239563 / AP : 0.9916253191253191 / F1 : 0.9899970865300572 / AUROC : 0.9900961538461539

Validation / Subject : 5 / Timepoint : 1 / Loss : 0.192504420876503 / AP : 0.9910328825154404 / F1 : 0.9894324853228963 / AUROC : 0.9895429899302866

Validation / Subject : 5 / Timepoint : 2 / Loss : 0.193409264087677 / AP : 0.9905787645499011 / F1 : 0.9888658981180412 / AUROC : 0.9889885012668096

Validation / Subject : 5 / Timepoint : 3 / Loss : 0.19508890807628632 / AP : 0.9902642943718951 / F1 : 0.9883953581432573 / AUROC : 0.9885284810126582

Validation / Subject : 5 / Timepoint : 4 / Loss : 0.19760514795780182 / AP : 0.989664736293738 / F1 : 0.9875435974090683 / AUROC : 0.9876968503937008

Validation / Subject : 5 / Timepoint : 5 / Loss : 0.1951228678226471 / AP : 0.990665206164416 / F1 : 0.9888360353363751 / AUROC : 0.9889592933947773

Validation / Subject : 5 / Timepoint : 6 / Loss : 0.1961227

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))



Subject index : 12 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3983, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,  13,  17,  21,  22,  23,  24,  26,  27,  30,  31,  32,  35,  36,
         40,  43,  44,  45,  46,  47,  57,  64,  65,  66,  71,  72,  78,  79,
         81,  82,  83,  84,  86,  87,  94,  96,  97, 101, 104, 114, 115],
       device='cuda:0')
Epoch : 2 / Subject : 12 / Timepoint : 0 / Loss : 0.21704375743865967 / AUROC : 0.9806613807774125 / F1 : 0.9802800236639716 / AP : 0.9836168201881992

Time index : 1 out of 15
tensor(0.3957, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,  16,  17,  18,  21,  22,  23,  24,  27,  31,  32,  35,  36,  40,
         43,  56,  57,  64,  65,  69,  71,  72,  77,  78,  79,  80,  81,  82,
         84,  87,  94,  95,  96, 114, 115], device='cuda:0')
Epoch : 2 / Subject : 12 / Timepoint : 1 / Loss : 0.2177072912454605 / AUROC : 0.9807095773555642 / F1 : 0.9803301373925076 / AP : 0.983552399066595

Tim

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')



Subject index : 13 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15


  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))


tensor(0.3990, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,   2,   3,  16,  17,  18,  28,  30,  32,  33,  34,  35,  38,  40,
         41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  54,  55,  56,
         57,  58,  59,  62,  64,  66,  67,  73,  76,  77,  79,  80,  81,  85,
         89,  92,  93,  96,  98,  99, 100, 101, 102, 109, 110],
       device='cuda:0')
Epoch : 2 / Subject : 13 / Timepoint : 0 / Loss : 0.21501867473125458 / AUROC : 0.980284989264103 / F1 : 0.9798884906411788 / AP : 0.9832768130605802

Time index : 1 out of 15
tensor(0.3967, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,   2,  16,  17,  26,  28,  30,  31,  32,  33,  35,  38,  41,  42,
         43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,  56,
         57,  58,  59,  62,  64,  66,  67,  72,  76,  80,  81,  85,  93,  96,
         98,  99, 100, 101, 109], device='cuda:0')
Epoch : 2 / Subject : 13 / Timepoint : 1 / Loss : 0.21288280189037323 / AUROC : 0.9816118935837246 / F1

  valid3D=torch.load(os.path.join('DatasetSplit','valid'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 6 out of 8
Validation / Subject : 6 / Timepoint : 0 / Loss : 0.20351853966712952 / AP : 0.9875703513075917 / F1 : 0.9846462619167085 / AUROC : 0.984878434473216

Validation / Subject : 6 / Timepoint : 1 / Loss : 0.20105111598968506 / AP : 0.988505078700473 / F1 : 0.9859154929577465 / AUROC : 0.9861111111111112

Validation / Subject : 6 / Timepoint : 2 / Loss : 0.19896510243415833 / AP : 0.9879904968213331 / F1 : 0.9855511771243811 / AUROC : 0.9857569721115538

Validation / Subject : 6 / Timepoint : 3 / Loss : 0.19582895934581757 / AP : 0.9879247921310684 / F1 : 0.985758390598546 / AUROC : 0.9859583660644148

Validation / Subject : 6 / Timepoint : 4 / Loss : 0.19573001563549042 / AP : 0.9882668495548212 / F1 : 0.9861276258422513 / AUROC : 0.9863174354964817

Validation / Subject : 6 / Timepoint : 5 / Loss : 0.19286830723285675 / AP : 0.9909915863672758 / F1 : 0.9893419380072358 / AUROC : 0.989454334365325

Validation / Subject : 6 / Timepoint : 6 / Loss : 0.19

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))



Subject index : 14 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3922, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  4,  19,  21,  28,  29,  36,  40,  41,  46,  47,  49,  51,  53,  55,
         56,  57,  58,  59,  62,  66,  67,  69,  72,  73,  74,  75,  76,  79,
         83,  95,  96,  98, 106, 110, 111, 113], device='cuda:0')
Epoch : 2 / Subject : 14 / Timepoint : 0 / Loss : 0.21574373543262482 / AUROC : 0.978273577552611 / F1 : 0.9777910566676625 / AP : 0.981231680232519

Time index : 1 out of 15
tensor(0.3939, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 17,  21,  25,  36,  40,  41,  46,  47,  53,  54,  55,  56,  57,  58,
         59,  66,  67,  68,  69,  72,  73,  74,  75,  76,  79,  80,  81,  83,
         95,  96,  97,  98, 110, 111, 113], device='cuda:0')
Epoch : 2 / Subject : 14 / Timepoint : 1 / Loss : 0.22093328833580017 / AUROC : 0.9730994152046784 / F1 : 0.9723557692307693 / AP : 0.9770264859438145

Time index : 2 out of 15
tensor(0.3

  checkpoint=torch.load('checkpoint8.pth', map_location='cuda')



Subject index : 15 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15


  train3D=torch.load(os.path.join('DatasetSplit','train'+str(i)+'.pth'))


tensor(0.4114, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([36, 40, 86], device='cuda:0')
Epoch : 2 / Subject : 15 / Timepoint : 0 / Loss : 0.21998989582061768 / AUROC : 0.9799404170804369 / F1 : 0.9795297932711796 / AP : 0.9835981055005475

Time index : 1 out of 15
tensor(0.4106, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([40], device='cuda:0')
Epoch : 2 / Subject : 15 / Timepoint : 1 / Loss : 0.22080931067466736 / AUROC : 0.9758497316636852 / F1 : 0.9752520623281393 / AP : 0.9802306362755915

Time index : 2 out of 15
tensor(0.4094, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([40], device='cuda:0')
Epoch : 2 / Subject : 15 / Timepoint : 2 / Loss : 0.22485992312431335 / AUROC : 0.9735486427580742 / F1 : 0.9728299582782131 / AP : 0.9784316359445017

Time index : 3 out of 15
tensor(0.4031, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([40], device='cuda:0')
Epoch : 2 / Subject : 15 / Timepoint : 3 / Loss : 0.22104163467884064 / AUROC : 0.9747364233141038 / F1 : 0.97408

  valid3D=torch.load(os.path.join('DatasetSplit','valid'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 7 out of 8
Validation / Subject : 7 / Timepoint : 0 / Loss : 0.196870818734169 / AP : 0.9920491237060474 / F1 : 0.9902719872940242 / AUROC : 0.9903657097915848

Validation / Subject : 7 / Timepoint : 1 / Loss : 0.19023564457893372 / AP : 0.9953721317271603 / F1 : 0.9944510503369005 / AUROC : 0.994481671265274

Validation / Subject : 7 / Timepoint : 2 / Loss : 0.18991710245609283 / AP : 0.9963322371343107 / F1 : 0.995583472372166 / AUROC : 0.9956028923197187

Validation / Subject : 7 / Timepoint : 3 / Loss : 0.19373288750648499 / AP : 0.9925073595979679 / F1 : 0.9909892068521636 / AUROC : 0.9910696761530913

Validation / Subject : 7 / Timepoint : 4 / Loss : 0.19824828207492828 / AP : 0.9900153042974286 / F1 : 0.9878461844989042 / AUROC : 0.9879921259842519

Validation / Subject : 7 / Timepoint : 5 / Loss : 0.1974443793296814 / AP : 0.9903280633412573 / F1 : 0.9882633777600954 / AUROC : 0.988399528116398

Validation / Subject : 7 / Timepoint : 6 / Loss : 0.2007

Exception: Early Stopping

This cell is the training and validation loop for the 80th Quantile Dataset experiment.\
The model will be saved in "checkpoint9.pth" file.

In [None]:
# 80th Quantile Dataset Experiment
torch.autograd.set_detect_anomaly(False)
if os.path.isfile('checkpoint9.pth'):
        print('resuming checkpoint experiment')
        checkpoint = torch.load('checkpoint9.pth', map_location='cuda')
else:
  checkpoint = {
    'epoch': 0,
    'subject': 0,
    'timeIdx':0,
    'encoder': None,
    'decoder': None,
    'optimizerE': None,
    'lossSubject': None,
    'prevLoss' : np.Inf,
    'patience': 0,
    'lossSubjectV': np.Inf,
    'lr': 1e-4
    }

# Argv setting
epc=10
sampleStrategy='uniform'
negSampleStrategy='inductive'
N=116 # Number of ROIs
numEdges=math.ceil(116*116*0.4)
nodeDim=N # One-Hot encoding as a node embedding
edgeDim=4 # Arbitrarily set
timeDim=100 # Half of the value mentioned in HOT encoder paper
channelDim=50 # The value mentioned in HOT encoder paper
latentDim=N # Positive node embedding and negative node embedding concatenated across dim=1
numFilter=4 # Following Triadic Decoder paper
B=10000
prob=0.7398
nodeFeat=torch.cat((torch.eye(N),torch.zeros((1,N))),dim=0).numpy()
edgeFeat=torch.zeros((numEdges+1,edgeDim)).numpy()
learning_rateE=1e-4
pt=2
patience=0
prevLoss=np.inf
lossSubjectV=np.inf


encoder=HOT(nodeFeat, edgeFeat, None, timeDim, channelDim, patch_size=8, num_state_vectors=32, num_layers=2, num2hop=0, dropout=0.1, max_input_sequence_length=4096, device='cuda')
encoder.apply(init_kaiming_normal)
decoder=TriadicDecoder(latentDim, numFilter, device='cuda')
decoder.apply(init_kaiming_normal)
model=nn.Sequential(encoder, decoder).to('cuda')
lossFunc=nn.BCELoss().to('cuda')
optimizerE=torch.optim.Adam(model.parameters(), lr=checkpoint['lr'])


for epoch in range(checkpoint['epoch'], epc):
  print("Epoch : "+str(epoch))
  if os.path.isfile('checkpoint9.pth'):
    checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  for i in range(checkpoint['subject'], 16):
    if os.path.isfile('checkpoint9.pth'):
      checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
    model.train()
    if checkpoint['lr']!=0: learning_rateE=checkpoint['lr']
    if checkpoint['optimizerE'] is not None: optimizerE.load_state_dict(checkpoint['optimizerE'])
    if checkpoint['prevLoss'] is not None: prevLoss=checkpoint['prevLoss']
    if checkpoint['patience'] != 0: patience=checkpoint['patience']
    if checkpoint['lossSubjectV'] is not None: lossSubjectV=checkpoint['lossSubjectV']
    if checkpoint['encoder'] is not None: model[0].load_state_dict(checkpoint['encoder'])
    if checkpoint['decoder'] is not None: model[1].load_state_dict(checkpoint['decoder'])
    print("\nSubject index : "+str(i)+" out of 16")

    train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))
    train3D=train3D.reshape(-1,3).detach().to('cpu').numpy()
    train3D=train3D.reshape(15, -1, 3)

    print("Sampling probability : "+str(prob))

    lossSubject=torch.zeros(1).to('cuda')
    lossSubject.requires_grad_()

    for t in range(checkpoint['timeIdx'], 15):
      print("Time index : "+str(t)+" out of 15")

      timeBatch=train3D[t]
      srcNodes=timeBatch[:,0]
      dstNodes=timeBatch[:,1]
      timepoints=timeBatch[:,2]

      adjMat=torch.zeros((N,N)).to('cuda')
      for row in timeBatch:
        adjMat[int(row[0]), int(row[1])]=1

      nodePosEmbedding=torch.zeros((N,N)).to('cuda')
      nodeNegEmbedding=torch.zeros((N,N)).to('cuda')
      nodePosTimes=torch.zeros(N).to('cuda')
      nodeNegTimes=torch.zeros(N).to('cuda')
      nodeEmbedding=torch.zeros((N,N)).to('cuda')

      # Sample neighbor edges among the dynamic graphs of a single subject
      train_neighbor_sampler = get_neighbor_sampler(data=timeBatch, sample_neighbor_strategy=sampleStrategy, seed=0)
      # Sample negative edges among the dynamic graphs of a single subject
      train_neg_edge_sampler = NegativeEdgeSampler(src_node_ids=srcNodes, dst_node_ids=dstNodes, interact_times=timepoints, last_observed_time=np.min(timepoints), negative_sample_strategy=negSampleStrategy, seed=1)
      model[0].set_neighbor_sampler(train_neighbor_sampler)

      subject_src_node_embeddings, subject_dst_node_embeddings = \
        torch.utils.checkpoint.checkpoint(model[0].compute_src_dst_node_temporal_embeddings, srcNodes, dstNodes, timepoints, use_reentrant=False)
      negSrcNodes, negDstNodes = train_neg_edge_sampler.sample(len(srcNodes), srcNodes, dstNodes, np.min(timepoints), np.max(timepoints))
      subject_neg_src_node_embeddings, subject_neg_dst_node_embeddings = \
        torch.utils.checkpoint.checkpoint(model[0].compute_src_dst_node_temporal_embeddings, negSrcNodes, negDstNodes, timepoints, use_reentrant=False)

      for tb, row in enumerate(zip(srcNodes, dstNodes)):
        nodePosEmbedding[row[0]]=nodePosEmbedding[row[0]]+subject_src_node_embeddings[tb]
        nodePosTimes[row[0]]+=1
        nodePosEmbedding[row[1]]=nodePosEmbedding[row[1]]+subject_dst_node_embeddings[tb]
        nodePosTimes[row[1]]+=1
      for tb, row in enumerate(zip(negSrcNodes, negDstNodes)):
        nodeNegEmbedding[row[0]]=nodeNegEmbedding[row[0]]+subject_neg_src_node_embeddings[tb]
        nodeNegTimes[row[0]]+=1
        nodeNegEmbedding[row[1]]=nodeNegEmbedding[row[1]]+subject_neg_dst_node_embeddings[tb]
        nodeNegTimes[row[1]]+=1

      for node in range(N):
        if nodePosTimes[node]>0:
          nodePosEmbedding[node]=nodePosEmbedding[node]/nodePosTimes[node]
        if nodeNegTimes[node]>0:
          nodeNegEmbedding[node]=nodeNegEmbedding[node]/nodeNegTimes[node]


      nodePosEmbedding=(nodePosEmbedding-torch.mean(nodePosEmbedding, dim=0, keepdims=True))/(torch.std(nodePosEmbedding, dim=0, keepdims=True)+1e-10)
      nodeNegEmbedding=(nodeNegEmbedding-torch.mean(nodeNegEmbedding, dim=0, keepdims=True))/(torch.std(nodeNegEmbedding, dim=0, keepdims=True)+1e-10)

      graph2D=np.concatenate((np.expand_dims(srcNodes, axis=1), np.expand_dims(dstNodes, axis=1)), axis=1)
      graph2D=torch.tensor(graph2D).to('cuda')
      graph2DN=np.concatenate((np.expand_dims(negSrcNodes, axis=1), np.expand_dims(negDstNodes, axis=1)), axis=1)
      graph2DN=torch.tensor(graph2DN).to('cuda')
      triads=triadSample(graph2D, B, prob, device='cuda')

      mSquare=torch.zeros((N,N)).to('cuda')
      eSquare=torch.zeros((N,N)).to('cuda')

      for j, triad in enumerate(triads):
        triad0=triad[0].unsqueeze(0).unsqueeze(0)
        triad1=triad[1].unsqueeze(0).unsqueeze(0)
        triad2=triad[2].unsqueeze(0).unsqueeze(0)

        if mSquare[triad[0], triad[1]]==0:
          pair1=torch.cat((triad0, triad1), dim=1).to('cuda')
          M1=calculateM(triads, pair1)
          mSquare[triad[0], triad[1]]=M1
        if mSquare[triad[0], triad[2]]==0:
          pair2=torch.cat((triad0, triad2), dim=1).to('cuda')
          M2=calculateM(triads, pair2)
          mSquare[triad[0], triad[2]]=M2
        if mSquare[triad[1], triad[2]]==0:
          pair3=torch.cat((triad1, triad2), dim=1).to('cuda')
          M3=calculateM(triads, pair3)
          mSquare[triad[1], triad[2]]=M3

      for j, triad in enumerate(triads):
        triad0=triad[0].unsqueeze(0).unsqueeze(0)
        triad1=triad[1].unsqueeze(0).unsqueeze(0)
        triad2=triad[2].unsqueeze(0).unsqueeze(0)
        pair1=torch.cat((triad0, triad1), dim=1).squeeze(0).to('cuda')
        pair2=torch.cat((triad0, triad2), dim=1).squeeze(0).to('cuda')
        pair3=torch.cat((triad1, triad2), dim=1).squeeze(0).to('cuda')

        emb1=torch.zeros((1,N)).to('cuda')
        emb2=torch.zeros((1,N)).to('cuda')
        emb3=torch.zeros((1,N)).to('cuda')

        if pairIn(pair1,graph2D[:,0:2]):
          emb1=nodePosEmbedding[triad[0]]
        if pairIn(pair1,graph2DN):
          emb1=nodeNegEmbedding[triad[0]]
        if pairIn(pair2,graph2D[:,0:2]):
          emb2=nodePosEmbedding[triad[1]]
        if pairIn(pair2,graph2DN):
          emb2=nodeNegEmbedding[triad[1]]
        if pairIn(pair3,graph2D[:,0:2]):
          emb3=nodePosEmbedding[triad[2]]
        if pairIn(pair3,graph2DN):
          emb3=nodeNegEmbedding[triad[2]]

        eTriplet=model[1].forward(emb1, emb2, emb3)
        if pairIn(pair1,graph2D[:,0:2]):
          eSquare[triad[0], triad[1]]+=eTriplet[0]
        if pairIn(pair1,graph2DN):
          eSquare[triad[0], triad[1]]-=eTriplet[0]
        if pairIn(pair2,graph2D[:,0:2]):
          eSquare[triad[0], triad[2]]+=eTriplet[1]
        if pairIn(pair2,graph2DN):
          eSquare[triad[0], triad[2]]-=eTriplet[1]
        if pairIn(pair3,graph2D[:,0:2]):
          eSquare[triad[1], triad[2]]+=eTriplet[2]
        if pairIn(pair3,graph2DN):
          eSquare[triad[1], triad[2]]-=eTriplet[2]
      eSquare.clamp_(min=0.0)

      nonzeros=((mSquare>0).nonzero())
      adjRecon=[]
      adjPart=[]
      adjFull=torch.zeros((N,N)).to('cuda')
      for nz in nonzeros:
        nzPair=torch.tensor([int(nz[0]), int(nz[1])]).to('cuda')
        em1=eSquare[int(nz[0]), int(nz[1])]/mSquare[int(nz[0]), int(nz[1])]
        if pairIn(nzPair,graph2D[:,0:2]):
          adjRecon.append(em1)
          adjPart.append(torch.tensor([1.0]))
        if pairIn(nzPair, graph2DN):
          adjRecon.append(em1)
          adjPart.append(torch.tensor([0.0]))
        if em1>0.5:
          adjFull[int(nz[0]), int(nz[1])]=1
        else:
          adjFull[int(nz[0]), int(nz[1])]=0
      adjRecon=torch.stack(adjRecon).to('cuda')
      print(torch.mean(adjRecon))
      adjPart=torch.stack(adjPart).squeeze(1).to('cuda')



      print(torch.nonzero(adjFull[40])[:,0])
      loss=lossFunc(input=adjRecon, target=adjPart)
      lossSubject=lossSubject+loss

      adjReconC=adjRecon.clone().detach().cpu().numpy()
      adjPartC=adjPart.clone().detach().cpu().numpy()


      adjReconC[adjReconC>0.5]=1
      adjReconC[adjReconC<=0.5]=0
      F1=f1_score(adjPartC, adjReconC)
      AUROC=roc_auc_score(adjPartC, adjReconC)
      AP=average_precision_score(adjPartC, adjReconC)


      print("Epoch : "+str(epoch)+" / Subject : "+str(i)+" / Timepoint : "+str(t) +" / Loss : "+str(loss.item())+" / AUROC : "+str(AUROC)+" / F1 : "+str(F1)+" / AP : "+str(AP)+"\n")

      torch.cuda.empty_cache()

    optimizerE.zero_grad()
    lossSubject.backward()
    optimizerE.step()

    # Validation Step
    if (i+1)%2==0:
      model.eval()
      with torch.no_grad():
        patience=checkpoint['patience']
        prevLoss=checkpoint['prevLoss']
        valid3D=torch.load(os.path.join('DatasetSplit','valid40'+str((i+1)//2-1)+'.pth'))
        valid3D=valid3D.reshape(-1, 3).detach().to('cpu').numpy()
        valid3D=valid3D.reshape(15, -1, 3)
        lossSubjectV=0

        print("Validation Subject index : "+str((i+1)//2-1)+" out of 8")
        for t in range(15):
          adjRecon=torch.zeros((N,N)).to('cuda')

          timeBatch=valid3D[t]
          srcNodes=timeBatch[:,0]
          dstNodes=timeBatch[:,1]
          timepoints=timeBatch[:,2]

          adjMat=torch.zeros((N,N)).to('cuda')
          for row in timeBatch:
            adjMat[int(row[0]), int(row[1])]=1

          nodePosEmbedding=torch.zeros((N,N)).to('cuda')
          nodeNegEmbedding=torch.zeros((N,N)).to('cuda')
          nodePosTimes=torch.zeros(N).to('cuda')
          nodeNegTimes=torch.zeros(N).to('cuda')
          nodeEmbedding=torch.zeros((N,N)).to('cuda')

          # Sample neighbor edges among the dynamic graphs of a single subject
          valid_neighbor_sampler = get_neighbor_sampler(data=timeBatch, sample_neighbor_strategy=sampleStrategy, seed=2)
          # Sample negative edges among the dynamic graphs of a single subject
          valid_neg_edge_sampler = NegativeEdgeSampler(src_node_ids=srcNodes, dst_node_ids=dstNodes, interact_times=timepoints, last_observed_time=np.min(timepoints), negative_sample_strategy=negSampleStrategy, seed=3)
          model[0].set_neighbor_sampler(valid_neighbor_sampler)

          subject_src_node_embeddings, subject_dst_node_embeddings = \
            model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=srcNodes, dst_node_ids=dstNodes, node_interact_times=timepoints)
          negSrcNodes, negDstNodes = valid_neg_edge_sampler.sample(len(srcNodes), srcNodes, dstNodes, np.min(timepoints), np.max(timepoints))
          subject_neg_src_node_embeddings, subject_neg_dst_node_embeddings = \
            model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=negSrcNodes,
                                                                          dst_node_ids=negDstNodes,
                                                                          node_interact_times=timepoints)

          for tb, row in enumerate(zip(srcNodes, dstNodes)):
            nodePosEmbedding[row[0]]=nodePosEmbedding[row[0]]+subject_src_node_embeddings[tb]
            nodePosTimes[row[0]]+=1
            nodePosEmbedding[row[1]]=nodePosEmbedding[row[1]]+subject_dst_node_embeddings[tb]
            nodePosTimes[row[1]]+=1
          for tb, row in enumerate(zip(negSrcNodes, negDstNodes)):
            nodeNegEmbedding[row[0]]=nodeNegEmbedding[row[0]]+subject_neg_src_node_embeddings[tb]
            nodeNegTimes[row[0]]+=1
            nodeNegEmbedding[row[1]]=nodeNegEmbedding[row[1]]+subject_neg_dst_node_embeddings[tb]
            nodeNegTimes[row[1]]+=1

          for node in range(N):
            if nodePosTimes[node]>0:
              nodePosEmbedding[node]=nodePosEmbedding[node]/nodePosTimes[node]
            if nodeNegTimes[node]>0:
              nodeNegEmbedding[node]=nodeNegEmbedding[node]/nodeNegTimes[node]

          nodePosEmbedding=(nodePosEmbedding-torch.mean(nodePosEmbedding, dim=0, keepdims=True))/(torch.std(nodePosEmbedding, dim=0, keepdims=True)+1e-10)
          nodeNegEmbedding=(nodeNegEmbedding-torch.mean(nodeNegEmbedding, dim=0, keepdims=True))/(torch.std(nodeNegEmbedding, dim=0, keepdims=True)+1e-10)

          graph2D=np.concatenate((np.expand_dims(srcNodes, axis=1), np.expand_dims(dstNodes, axis=1)), axis=1)
          graph2D=torch.tensor(graph2D).to('cuda')
          graph2DN=np.concatenate((np.expand_dims(negSrcNodes, axis=1), np.expand_dims(negDstNodes, axis=1)), axis=1)
          graph2DN=torch.tensor(graph2DN).to('cuda')

          triads=triadSample(graph2D, B, prob, device='cuda')
          mSquare=torch.zeros((N,N)).to('cuda')
          eSquare=torch.zeros((N,N)).to('cuda')

          for j, triad in enumerate(triads):
            triad0=triad[0].unsqueeze(0).unsqueeze(0)
            triad1=triad[1].unsqueeze(0).unsqueeze(0)
            triad2=triad[2].unsqueeze(0).unsqueeze(0)

            if mSquare[triad[0], triad[1]]==0:
              pair1=torch.cat((triad0, triad1), dim=1).to('cuda')
              M1=calculateM(triads, pair1)
              mSquare[triad[0], triad[1]]=M1

            if mSquare[triad[0], triad[2]]==0:
              pair1=torch.cat((triad0, triad2), dim=1).to('cuda')
              M2=calculateM(triads, pair1)
              mSquare[triad[0], triad[2]]=M2

            if mSquare[triad[1], triad[2]]==0:
              pair1=torch.cat((triad1, triad2), dim=1).to('cuda')
              M3=calculateM(triads, pair1)
              mSquare[triad[1], triad[2]]=M3

          for j, triad in enumerate(triads):
            triad0=triad[0].unsqueeze(0).unsqueeze(0)
            triad1=triad[1].unsqueeze(0).unsqueeze(0)
            triad2=triad[2].unsqueeze(0).unsqueeze(0)
            pair1=torch.cat((triad0, triad1), dim=1).squeeze(0).to('cuda')
            pair2=torch.cat((triad0, triad2), dim=1).squeeze(0).to('cuda')
            pair3=torch.cat((triad1, triad2), dim=1).squeeze(0).to('cuda')

            if pair1 in graph2D[:,0:2]:
              emb1=nodePosEmbedding[triad[0]]
            if pair1 in graph2DN:
              emb1=nodeNegEmbedding[triad[0]]
            if pair2 in graph2D[:,0:2]:
              emb2=nodePosEmbedding[triad[1]]
            if pair2 in graph2DN:
              emb2=nodeNegEmbedding[triad[1]]
            if pair3 in graph2D[:,0:2]:
              emb3=nodePosEmbedding[triad[2]]
            if pair3 in graph2DN:
              emb3=nodeNegEmbedding[triad[2]]

            eTriplet=model[1].forward(emb1, emb2, emb3)
            if pairIn(pair1,graph2D[:,0:2]):
              eSquare[triad[0], triad[1]]+=eTriplet[0]
            if pairIn(pair1,graph2DN):
              eSquare[triad[0], triad[1]]-=eTriplet[0]
            if pairIn(pair2,graph2D[:,0:2]):
              eSquare[triad[0], triad[2]]+=eTriplet[1]
            if pairIn(pair2,graph2DN):
              eSquare[triad[0], triad[2]]-=eTriplet[1]
            if pairIn(pair3,graph2D[:,0:2]):
              eSquare[triad[1], triad[2]]+=eTriplet[2]
            if pairIn(pair3,graph2DN):
              eSquare[triad[1], triad[2]]-=eTriplet[2]

          eSquare.clamp_(min=0.0)
          nonzeros=((mSquare>0).nonzero())
          adjRecon=[]
          adjPart=[]
          adjFull=torch.zeros((N,N)).to('cuda')
          for nz in nonzeros:
            nzPair=torch.tensor([int(nz[0]), int(nz[1])]).to('cuda')
            em1=eSquare[int(nz[0]), int(nz[1])]/mSquare[int(nz[0]), int(nz[1])]
            if pairIn(nzPair, graph2D[:,0:2]):
              adjRecon.append(em1)
              adjPart.append(torch.tensor([1.0]))
            if pairIn(nzPair,graph2DN):
              adjRecon.append(em1)
              adjPart.append(torch.tensor([0.0]))
            if em1>0.5:
              adjFull[int(nz[0]), int(nz[1])]=1
            else:
              adjFull[int(nz[0]), int(nz[1])]=0
          adjRecon=torch.stack(adjRecon).to('cuda')
          adjPart=torch.stack(adjPart).squeeze(1).to('cuda')


          loss=lossFunc(input=adjRecon, target=adjPart)
          lossSubjectV=lossSubjectV+loss

          adjReconC=adjRecon.clone().detach().cpu().numpy()
          adjPartC=adjPart.clone().detach().cpu().numpy()

          adjReconC[adjReconC>0.5]=1
          adjReconC[adjReconC<=0.5]=0
          F1=f1_score(adjPartC, adjReconC)
          AUROC=roc_auc_score(adjPartC, adjReconC)
          AP=average_precision_score(adjPartC, adjReconC)

          print("Validation / Subject : "+str((i+1)//2-1)+" / Timepoint : "+str(t) +" / Loss : "+str(loss.item())+" / AP : "+str(AP)+" / F1 : "+str(F1)+" / AUROC : "+str(AUROC)+'\n')


        if lossSubjectV>prevLoss:
          patience+=1
        else:
          patience=0

        if patience==1 and learning_rateE<=2.5e-5:
          torch.save({
          'epoch': epoch,
          'encoder': model[0].state_dict(),
          'decoder': model[1].state_dict(),
          'subject': i,
          'timeIdx': t,
          'optimizerE': optimizerE.state_dict(),
          'lossSubject': lossSubject,
          'patience': patience,
          'prevLoss': prevLoss,
          'lossSubjectV': lossSubjectV,
          'lr': learning_rateE
          },
          'checkpoint9.pth')
          raise Exception('Early Stopping')
        if patience==pt and learning_rateE>2.5e-5:
          learning_rateE=learning_rateE/2
          patience=0

    if lossSubjectV is not None:
      prevLoss=lossSubjectV
    torch.save({
        'epoch': epoch,
        'encoder': model[0].state_dict(),
        'decoder': model[1].state_dict(),
        'subject': i+1,
        'timeIdx': 0,
        'optimizerE': optimizerE.state_dict(),
        'lossSubject': None,
        'patience': patience,
        'prevLoss': prevLoss,
        'lossSubjectV': lossSubjectV,
        'lr': learning_rateE
        },
        'checkpoint9.pth')

    torch.cuda.empty_cache()

  torch.save({
              'epoch': epoch+1,
              'encoder': model[0].state_dict(),
              'decoder': model[1].state_dict(),
              'subject': 0,
              'timeIdx': 0,
              'optimizerE': optimizerE.state_dict(),
              'lossSubject': None,
              'patience': patience,
              'prevLoss': prevLoss,
              'lossSubjectV': lossSubjectV,
              'lr': learning_rateE
              },
              'checkpoint9.pth')


resuming checkpoint experiment


  checkpoint = torch.load('checkpoint9.pth', map_location='cuda')


Epoch : 0


  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')



Subject index : 9 out of 16


  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3664, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40, 76, 87], device='cuda:0')
Epoch : 0 / Subject : 9 / Timepoint : 0 / Loss : 0.26974180340766907 / AUROC : 0.951833073322933 / F1 : 0.9493956156525302 / AP : 0.9590845895446767

Time index : 1 out of 15
tensor(0.3630, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  2, 35, 40], device='cuda:0')
Epoch : 0 / Subject : 9 / Timepoint : 1 / Loss : 0.2738535702228546 / AUROC : 0.9512980675385516 / F1 : 0.9488047604391094 / AP : 0.9585441113076224

Time index : 2 out of 15
tensor(0.3620, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([1], device='cuda:0')
Epoch : 0 / Subject : 9 / Timepoint : 2 / Loss : 0.26985132694244385 / AUROC : 0.9547733847637416 / F1 : 0.9526310473689527 / AP : 0.961227210298833

Time index : 3 out of 15
tensor(0.3603, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([0, 1], device='cuda:0')
Epoch : 0 / Subject : 9 / Timepoint : 3 / L

  valid3D=torch.load(os.path.join('DatasetSplit','valid40'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 4 out of 8
Validation / Subject : 4 / Timepoint : 0 / Loss : 0.24744191765785217 / AP : 0.9791541753523715 / F1 : 0.975141471301536 / AUROC : 0.975744429106685

Validation / Subject : 4 / Timepoint : 1 / Loss : 0.25276345014572144 / AP : 0.9750424945171867 / F1 : 0.9702153350774804 / AUROC : 0.971076802814149

Validation / Subject : 4 / Timepoint : 2 / Loss : 0.26036912202835083 / AP : 0.9727497304213747 / F1 : 0.9673108552631579 / AUROC : 0.9683456101931116

Validation / Subject : 4 / Timepoint : 3 / Loss : 0.26111656427383423 / AP : 0.9680558485895097 / F1 : 0.9618542771021723 / AUROC : 0.9632559070885063

Validation / Subject : 4 / Timepoint : 4 / Loss : 0.25488078594207764 / AP : 0.9708808918081931 / F1 : 0.9648425557933353 / AUROC : 0.9660366213821618

Validation / Subject : 4 / Timepoint : 5 / Loss : 0.2527993321418762 / AP : 0.9701400420605364 / F1 : 0.9641352378994964 / AUROC : 0.9653769841269841

Validation / Subject : 4 / Timepoint : 6 / Loss : 0.25

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))



Subject index : 10 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3926, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  0,  20,  40,  94, 113], device='cuda:0')
Epoch : 0 / Subject : 10 / Timepoint : 0 / Loss : 0.2563875913619995 / AUROC : 0.9516096780643871 / F1 : 0.9491489808783358 / AP : 0.9606469783404497

Time index : 1 out of 15
tensor(0.3919, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  0,   1,  40,  74,  83, 111], device='cuda:0')
Epoch : 0 / Subject : 10 / Timepoint : 1 / Loss : 0.2519092559814453 / AUROC : 0.9546460176991151 / F1 : 0.9524913093858632 / AP : 0.9627533914911651

Time index : 2 out of 15
tensor(0.3874, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  0,  11,  20,  39,  40,  75,  83, 107, 108, 111, 113, 114],
       device='cuda:0')
Epoch : 0 / Subject : 10 / Timepoint : 2 / Loss : 0.2504679560661316 / AUROC : 0.9537848605577689 / F1 : 0.9515455304928989 / AP : 0.9615106650820573

Time index : 3 out of 15
tensor(0.3862, d

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')



Subject index : 11 out of 16


  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3802, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1,  4, 40, 41, 74], device='cuda:0')
Epoch : 0 / Subject : 11 / Timepoint : 0 / Loss : 0.25270092487335205 / AUROC : 0.9536908767069068 / F1 : 0.9514422079269558 / AP : 0.9608980484395425

Time index : 1 out of 15
tensor(0.3788, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40, 74], device='cuda:0')
Epoch : 0 / Subject : 11 / Timepoint : 1 / Loss : 0.25470054149627686 / AUROC : 0.9530689517736151 / F1 : 0.9507579717720858 / AP : 0.9603693370532749

Time index : 2 out of 15
tensor(0.3792, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 0 / Subject : 11 / Timepoint : 2 / Loss : 0.25568097829818726 / AUROC : 0.9574593217016272 / F1 : 0.9555692055692055 / AP : 0.964225619090598

Time index : 3 out of 15
tensor(0.3759, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0, 40], device='cuda:0')
Epoch : 0 / Subject : 1

  valid3D=torch.load(os.path.join('DatasetSplit','valid40'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 5 out of 8
Validation / Subject : 5 / Timepoint : 0 / Loss : 0.24312525987625122 / AP : 0.969952824397061 / F1 : 0.9632999696693965 / AUROC : 0.9645991808074897

Validation / Subject : 5 / Timepoint : 1 / Loss : 0.24242521822452545 / AP : 0.9697708484067052 / F1 : 0.9632270551084853 / AUROC : 0.9645313421104342

Validation / Subject : 5 / Timepoint : 2 / Loss : 0.23978766798973083 / AP : 0.969004759193091 / F1 : 0.9626888683317916 / AUROC : 0.9640309155766944

Validation / Subject : 5 / Timepoint : 3 / Loss : 0.2376752346754074 / AP : 0.9718494195524873 / F1 : 0.9661532225374949 / AUROC : 0.9672613213095471

Validation / Subject : 5 / Timepoint : 4 / Loss : 0.24031265079975128 / AP : 0.9694137367861366 / F1 : 0.9632277834525026 / AUROC : 0.9645320197044335

Validation / Subject : 5 / Timepoint : 5 / Loss : 0.2365896999835968 / AP : 0.9769456828375124 / F1 : 0.9722138629752232 / AUROC : 0.9729650595354284

Validation / Subject : 5 / Timepoint : 6 / Loss : 0.23

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))



Subject index : 12 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3741, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 1, 40, 64], device='cuda:0')
Epoch : 0 / Subject : 12 / Timepoint : 0 / Loss : 0.2599889934062958 / AUROC : 0.9442501942501942 / F1 : 0.9409586504834396 / AP : 0.9524843367781725

Time index : 1 out of 15
tensor(0.3775, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 2, 20, 40, 64, 95], device='cuda:0')
Epoch : 0 / Subject : 12 / Timepoint : 1 / Loss : 0.26153457164764404 / AUROC : 0.9436467348544453 / F1 : 0.9402813965607087 / AP : 0.9524885999853726

Time index : 2 out of 15
tensor(0.3765, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 0 / Subject : 12 / Timepoint : 2 / Loss : 0.25207212567329407 / AUROC : 0.9559308872063677 / F1 : 0.9538992688870837 / AP : 0.9624016831291998

Time index : 3 out of 15
tensor(0.3744, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0, 40], device='cuda:0')
E

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')



Subject index : 13 out of 16


  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3798, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  0,   1,  20,  40, 106], device='cuda:0')
Epoch : 0 / Subject : 13 / Timepoint : 0 / Loss : 0.24550174176692963 / AUROC : 0.9548987523010841 / F1 : 0.952768555210453 / AP : 0.9613282360814858

Time index : 1 out of 15
tensor(0.3859, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  0,   1, 106, 108], device='cuda:0')
Epoch : 0 / Subject : 13 / Timepoint : 1 / Loss : 0.24572156369686127 / AUROC : 0.955492615820352 / F1 : 0.9534194367986449 / AP : 0.9624713069975961

Time index : 2 out of 15
tensor(0.3805, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,  94,  95, 108, 113], device='cuda:0')
Epoch : 0 / Subject : 13 / Timepoint : 2 / Loss : 0.24497410655021667 / AUROC : 0.9526163988463123 / F1 : 0.9502595155709342 / AP : 0.9594525053776789

Time index : 3 out of 15
tensor(0.3812, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([], device='cuda:0', dtype=torch.

  valid3D=torch.load(os.path.join('DatasetSplit','valid40'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 6 out of 8
Validation / Subject : 6 / Timepoint : 0 / Loss : 0.2312721610069275 / AP : 0.9679873859378779 / F1 : 0.9612 / AUROC : 0.9626492106276473

Validation / Subject : 6 / Timepoint : 1 / Loss : 0.2278071492910385 / AP : 0.9705941429317478 / F1 : 0.9647058823529412 / AUROC : 0.9659090909090908

Validation / Subject : 6 / Timepoint : 2 / Loss : 0.23792214691638947 / AP : 0.9657696352186451 / F1 : 0.9578620758682814 / AUROC : 0.9595658855167115

Validation / Subject : 6 / Timepoint : 3 / Loss : 0.22627979516983032 / AP : 0.9750968071922096 / F1 : 0.9704178103080208 / AUROC : 0.9712677725118484

Validation / Subject : 6 / Timepoint : 4 / Loss : 0.22027942538261414 / AP : 0.9798738384994632 / F1 : 0.9761760549162124 / AUROC : 0.9767304279234865

Validation / Subject : 6 / Timepoint : 5 / Loss : 0.2266000658273697 / AP : 0.9741343079031521 / F1 : 0.969125569830087 / AUROC : 0.9700502512562814

Validation / Subject : 6 / Timepoint : 6 / Loss : 0.23798909783363

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))



Subject index : 14 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3791, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 0 / Subject : 14 / Timepoint : 0 / Loss : 0.24421565234661102 / AUROC : 0.9579393223010244 / F1 : 0.9560925449871466 / AP : 0.9638723503455578

Time index : 1 out of 15
tensor(0.3753, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1, 108], device='cuda:0')
Epoch : 0 / Subject : 14 / Timepoint : 1 / Loss : 0.2543909251689911 / AUROC : 0.945777257655549 / F1 : 0.9426685914621571 / AP : 0.9535873168284472

Time index : 2 out of 15
tensor(0.3737, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 2, 69], device='cuda:0')
Epoch : 0 / Subject : 14 / Timepoint : 2 / Loss : 0.25357696413993835 / AUROC : 0.9468085106382979 / F1 : 0.9438202247191011 / AP : 0.9542094808816226

Time index : 3 out of 15
tensor(0.3703, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 1,  2, 40], device='cuda:0')
Epoch : 0 / 

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')



Subject index : 15 out of 16


  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3809, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 0 / Subject : 15 / Timepoint : 0 / Loss : 0.23611126840114594 / AUROC : 0.9670086501709918 / F1 : 0.9658830871645517 / AP : 0.971563087887771

Time index : 1 out of 15
tensor(0.3856, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 1, 40], device='cuda:0')
Epoch : 0 / Subject : 15 / Timepoint : 1 / Loss : 0.23420429229736328 / AUROC : 0.9692802808660035 / F1 : 0.9683066706912165 / AP : 0.973756092458264

Time index : 2 out of 15
tensor(0.3803, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 0 / Subject : 15 / Timepoint : 2 / Loss : 0.23356150090694427 / AUROC : 0.9701086956521738 / F1 : 0.969187675070028 / AP : 0.9741137619959708

Time index : 3 out of 15
tensor(0.3866, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0, 37, 40], device='cuda:0')
Epoch : 0 / Subject : 15 / Timepoint : 3 

  valid3D=torch.load(os.path.join('DatasetSplit','valid40'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 7 out of 8
Validation / Subject : 7 / Timepoint : 0 / Loss : 0.2310342639684677 / AP : 0.9634037085074698 / F1 : 0.9559211862685354 / AUROC : 0.9577821011673151

Validation / Subject : 7 / Timepoint : 1 / Loss : 0.22682945430278778 / AP : 0.9728155812284681 / F1 : 0.967043974757087 / AUROC : 0.9680954228083786

Validation / Subject : 7 / Timepoint : 2 / Loss : 0.2227112352848053 / AP : 0.9733282858958844 / F1 : 0.9679106729705261 / AUROC : 0.9689083820662768

Validation / Subject : 7 / Timepoint : 3 / Loss : 0.22222383320331573 / AP : 0.9738931085298119 / F1 : 0.9686598348423042 / AUROC : 0.9696121937102065

Validation / Subject : 7 / Timepoint : 4 / Loss : 0.2265586405992508 / AP : 0.9716603307468143 / F1 : 0.9658 / AUROC : 0.9669309611293754

Validation / Subject : 7 / Timepoint : 5 / Loss : 0.2278447300195694 / AP : 0.9704927917000364 / F1 : 0.9642389442933413 / AUROC : 0.9654736432600661

Validation / Subject : 7 / Timepoint : 6 / Loss : 0.217237561941146

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')



Subject index : 0 out of 16


  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3833, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  0,   1,  40, 113], device='cuda:0')
Epoch : 1 / Subject : 0 / Timepoint : 0 / Loss : 0.24152778089046478 / AUROC : 0.9591256616349735 / F1 : 0.9573837506387327 / AP : 0.965163236871027

Time index : 1 out of 15
tensor(0.3751, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([0, 1], device='cuda:0')
Epoch : 1 / Subject : 0 / Timepoint : 1 / Loss : 0.245029479265213 / AUROC : 0.9567610062893082 / F1 : 0.9548069022185702 / AP : 0.9625580094581438

Time index : 2 out of 15
tensor(0.3776, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0, 40], device='cuda:0')
Epoch : 1 / Subject : 0 / Timepoint : 2 / Loss : 0.25133296847343445 / AUROC : 0.9518048780487804 / F1 : 0.9493644936449365 / AP : 0.9588299662473486

Time index : 3 out of 15
tensor(0.3780, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Subject : 0 / Timepoint : 3 / 

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')



Subject index : 1 out of 16


  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3821, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  0,   1,  40, 115], device='cuda:0')
Epoch : 1 / Subject : 1 / Timepoint : 0 / Loss : 0.2525413930416107 / AUROC : 0.9505140092723241 / F1 : 0.9479376524228608 / AP : 0.9582474568385534

Time index : 1 out of 15
tensor(0.3756, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 1, 40], device='cuda:0')
Epoch : 1 / Subject : 1 / Timepoint : 1 / Loss : 0.24634310603141785 / AUROC : 0.9528053473769496 / F1 : 0.95046768707483 / AP : 0.9591492976950756

Time index : 2 out of 15
tensor(0.3777, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([0, 1], device='cuda:0')
Epoch : 1 / Subject : 1 / Timepoint : 2 / Loss : 0.2413477599620819 / AUROC : 0.956884698487527 / F1 : 0.9549420096479524 / AP : 0.9626231695101085

Time index : 3 out of 15
tensor(0.3837, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Subject : 1 / Timepoint : 3 / Lo

  valid3D=torch.load(os.path.join('DatasetSplit','valid40'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 0 out of 8
Validation / Subject : 0 / Timepoint : 0 / Loss : 0.22757840156555176 / AP : 0.9639862647521039 / F1 : 0.9569127797531897 / AUROC : 0.9586926007619812

Validation / Subject : 0 / Timepoint : 1 / Loss : 0.23366975784301758 / AP : 0.9629030745109951 / F1 : 0.9545923632610939 / AUROC : 0.9565646594274433

Validation / Subject : 0 / Timepoint : 2 / Loss : 0.2239033281803131 / AP : 0.9738545117253832 / F1 : 0.9682934101977935 / AUROC : 0.9692678227360308

Validation / Subject : 0 / Timepoint : 3 / Loss : 0.2234039306640625 / AP : 0.9716431401599249 / F1 : 0.9657200811359027 / AUROC : 0.9668562463228084

Validation / Subject : 0 / Timepoint : 4 / Loss : 0.22589504718780518 / AP : 0.9705385246370729 / F1 : 0.9641653905053599 / AUROC : 0.965405085748078

Validation / Subject : 0 / Timepoint : 5 / Loss : 0.22363746166229248 / AP : 0.9734245576509188 / F1 : 0.9675855801272342 / AUROC : 0.9686032863849765

Validation / Subject : 0 / Timepoint : 6 / Loss : 0.2

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))



Subject index : 2 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3791, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Subject : 2 / Timepoint : 0 / Loss : 0.23966597020626068 / AUROC : 0.958639888645854 / F1 : 0.9568554241858536 / AP : 0.9642217768403427

Time index : 1 out of 15
tensor(0.3863, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  0,   1,  40, 107], device='cuda:0')
Epoch : 1 / Subject : 2 / Timepoint : 1 / Loss : 0.2419258952140808 / AUROC : 0.9580733229329172 / F1 : 0.9562385507836353 / AP : 0.9645210411858931

Time index : 2 out of 15
tensor(0.3839, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  0,   2,  20,  38,  40,  74, 100], device='cuda:0')
Epoch : 1 / Subject : 2 / Timepoint : 2 / Loss : 0.23350214958190918 / AUROC : 0.9632278863590571 / F1 : 0.9618240769793954 / AP : 0.9683423919515315

Time index : 3 out of 15
tensor(0.3899, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,  40,

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')



Subject index : 3 out of 16


  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3890, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0, 20, 26, 41], device='cuda:0')
Epoch : 1 / Subject : 3 / Timepoint : 0 / Loss : 0.2421012967824936 / AUROC : 0.9571931196247068 / F1 : 0.955278742086992 / AP : 0.9639505710362191

Time index : 1 out of 15
tensor(0.3915, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 20, 21, 40], device='cuda:0')
Epoch : 1 / Subject : 3 / Timepoint : 1 / Loss : 0.2462007850408554 / AUROC : 0.9550285376894312 / F1 : 0.9529108706852139 / AP : 0.9625496092635396

Time index : 2 out of 15
tensor(0.3918, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 20, 21, 35, 40, 41], device='cuda:0')
Epoch : 1 / Subject : 3 / Timepoint : 2 / Loss : 0.2470857948064804 / AUROC : 0.9523619693043652 / F1 : 0.9499790707408958 / AP : 0.960421911109313

Time index : 3 out of 15
tensor(0.3930, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1,  5, 14, 20, 26, 40, 87], device='cud

  valid3D=torch.load(os.path.join('DatasetSplit','valid40'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 1 out of 8
Validation / Subject : 1 / Timepoint : 0 / Loss : 0.214779794216156 / AP : 0.9765806040587132 / F1 : 0.9721412048677461 / AUROC : 0.9728962818003914

Validation / Subject : 1 / Timepoint : 1 / Loss : 0.21533282101154327 / AP : 0.9767043446518495 / F1 : 0.9720603015075376 / AUROC : 0.9728197105983574

Validation / Subject : 1 / Timepoint : 2 / Loss : 0.2202107012271881 / AP : 0.9740343503839634 / F1 : 0.968414779499404 / AUROC : 0.9693818601964183

Validation / Subject : 1 / Timepoint : 3 / Loss : 0.2279650717973709 / AP : 0.9689221937257352 / F1 : 0.96198414311852 / AUROC : 0.963376419898159

Validation / Subject : 1 / Timepoint : 4 / Loss : 0.22846610844135284 / AP : 0.9670305372993476 / F1 : 0.9596928982725528 / AUROC : 0.9612546125461254

Validation / Subject : 1 / Timepoint : 5 / Loss : 0.2255374789237976 / AP : 0.9698297163877843 / F1 : 0.9630750605326877 / AUROC : 0.9643899591360187

Validation / Subject : 1 / Timepoint : 6 / Loss : 0.2241801

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))



Subject index : 4 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3840, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  3, 40], device='cuda:0')
Epoch : 1 / Subject : 4 / Timepoint : 0 / Loss : 0.23037126660346985 / AUROC : 0.9667783911671924 / F1 : 0.9656367900479249 / AP : 0.9713031981049942

Time index : 1 out of 15
tensor(0.3844, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0, 40], device='cuda:0')
Epoch : 1 / Subject : 4 / Timepoint : 1 / Loss : 0.23830024898052216 / AUROC : 0.9629003381738612 / F1 : 0.9614709224253692 / AP : 0.9683077988548449

Time index : 2 out of 15
tensor(0.3831, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 1, 40], device='cuda:0')
Epoch : 1 / Subject : 4 / Timepoint : 2 / Loss : 0.23665902018547058 / AUROC : 0.9621019730416097 / F1 : 0.9606091370558376 / AP : 0.9674255123080161

Time index : 3 out of 15
tensor(0.3844, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Subj

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')



Subject index : 5 out of 16


  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3827, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Subject : 5 / Timepoint : 0 / Loss : 0.2456054389476776 / AUROC : 0.9534632034632035 / F1 : 0.9511918274687855 / AP : 0.9603304940891031

Time index : 1 out of 15
tensor(0.3794, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Subject : 5 / Timepoint : 1 / Loss : 0.2442377358675003 / AUROC : 0.95110847189232 / F1 : 0.948595213319459 / AP : 0.957910292149916

Time index : 2 out of 15
tensor(0.3821, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Subject : 5 / Timepoint : 2 / Loss : 0.2468566596508026 / AUROC : 0.9537342386032978 / F1 : 0.9514898810129157 / AP : 0.9606042686736781

Time index : 3 out of 15
tensor(0.3820, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Subject : 5 / Timepoint : 3 / Lo

  valid3D=torch.load(os.path.join('DatasetSplit','valid40'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 2 out of 8
Validation / Subject : 2 / Timepoint : 0 / Loss : 0.2305183708667755 / AP : 0.9622465221042817 / F1 : 0.9541683948569059 / AUROC : 0.9561768788419591

Validation / Subject : 2 / Timepoint : 1 / Loss : 0.22956441342830658 / AP : 0.9640271127364144 / F1 : 0.9561321246222778 / AUROC : 0.957975643841086

Validation / Subject : 2 / Timepoint : 2 / Loss : 0.22610192000865936 / AP : 0.9660661209604071 / F1 : 0.9589267285861713 / AUROC : 0.9605471847739888

Validation / Subject : 2 / Timepoint : 3 / Loss : 0.23317517340183258 / AP : 0.9636144680667316 / F1 : 0.9550194880438218 / AUROC : 0.9569556451612904

Validation / Subject : 2 / Timepoint : 4 / Loss : 0.2300053983926773 / AP : 0.9650127902845633 / F1 : 0.9570271953498027 / AUROC : 0.958797770700637

Validation / Subject : 2 / Timepoint : 5 / Loss : 0.21804465353488922 / AP : 0.9744507785134647 / F1 : 0.9690220692447864 / AUROC : 0.9699528672427338

Validation / Subject : 2 / Timepoint : 6 / Loss : 0.21

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))



Subject index : 6 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3821, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Subject : 6 / Timepoint : 0 / Loss : 0.24065770208835602 / AUROC : 0.9572329523058731 / F1 : 0.9553222153592753 / AP : 0.9633122748987467

Time index : 1 out of 15
tensor(0.3770, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0, 40], device='cuda:0')
Epoch : 1 / Subject : 6 / Timepoint : 1 / Loss : 0.242281973361969 / AUROC : 0.95399683419074 / F1 : 0.9517784921704864 / AP : 0.9600653369145147

Time index : 2 out of 15
tensor(0.3789, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  3, 40], device='cuda:0')
Epoch : 1 / Subject : 6 / Timepoint : 2 / Loss : 0.24768531322479248 / AUROC : 0.9511627906976744 / F1 : 0.9486552567237164 / AP : 0.9581024738256277

Time index : 3 out of 15
tensor(0.3858, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Subj

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')



Subject index : 7 out of 16


  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3914, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0, 40], device='cuda:0')
Epoch : 1 / Subject : 7 / Timepoint : 0 / Loss : 0.23355865478515625 / AUROC : 0.9645224171539961 / F1 : 0.9632174616006467 / AP : 0.9699621004950426

Time index : 1 out of 15
tensor(0.3925, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0, 40], device='cuda:0')
Epoch : 1 / Subject : 7 / Timepoint : 1 / Loss : 0.2336975485086441 / AUROC : 0.965933014354067 / F1 : 0.9647315236774321 / AP : 0.9712596327802818

Time index : 2 out of 15
tensor(0.3869, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Subject : 7 / Timepoint : 2 / Loss : 0.2384049892425537 / AUROC : 0.9620743034055728 / F1 : 0.9605792437650845 / AP : 0.9677477615214113

Time index : 3 out of 15
tensor(0.3815, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0, 40], device='cuda:0')
Epoch : 1 / Subject : 7 / Timepoint : 3 / Loss : 0.2

  valid3D=torch.load(os.path.join('DatasetSplit','valid40'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 3 out of 8
Validation / Subject : 3 / Timepoint : 0 / Loss : 0.22898834943771362 / AP : 0.9727018493694383 / F1 : 0.9651258770119686 / AUROC : 0.9663010967098704

Validation / Subject : 3 / Timepoint : 1 / Loss : 0.23171357810497284 / AP : 0.970828258869858 / F1 : 0.9627249357326478 / AUROC : 0.9640644361833952

Validation / Subject : 3 / Timepoint : 2 / Loss : 0.22974097728729248 / AP : 0.9716589473515702 / F1 : 0.9639583333333334 / AUROC : 0.9652121455861653

Validation / Subject : 3 / Timepoint : 3 / Loss : 0.21824510395526886 / AP : 0.9747717401360402 / F1 : 0.9693214140593255 / AUROC : 0.97023457520205

Validation / Subject : 3 / Timepoint : 4 / Loss : 0.2088332623243332 / AP : 0.9796877565553191 / F1 : 0.9758219524532119 / AUROC : 0.9763927301461872

Validation / Subject : 3 / Timepoint : 5 / Loss : 0.22041267156600952 / AP : 0.9721518682674879 / F1 : 0.966132630298114 / AUROC : 0.9672420557081208

Validation / Subject : 3 / Timepoint : 6 / Loss : 0.222

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))



Subject index : 8 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3864, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 1, 10, 40], device='cuda:0')
Epoch : 1 / Subject : 8 / Timepoint : 0 / Loss : 0.2407984435558319 / AUROC : 0.9563164108618654 / F1 : 0.9543209876543209 / AP : 0.9628818031632962

Time index : 1 out of 15
tensor(0.3814, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Subject : 8 / Timepoint : 1 / Loss : 0.24108214676380157 / AUROC : 0.9547119886016691 / F1 : 0.9525636925700884 / AP : 0.9610600876698037

Time index : 2 out of 15
tensor(0.3841, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40, 74], device='cuda:0')
Epoch : 1 / Subject : 8 / Timepoint : 2 / Loss : 0.2387240082025528 / AUROC : 0.9595409453413732 / F1 : 0.957834988850598 / AP : 0.9653709749155273

Time index : 3 out of 15
tensor(0.3835, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch :

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')



Subject index : 9 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15


  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))


tensor(0.3844, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Subject : 9 / Timepoint : 0 / Loss : 0.24857112765312195 / AUROC : 0.9507701306297525 / F1 : 0.9482210601866092 / AP : 0.9584160355584124

Time index : 1 out of 15
tensor(0.3842, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  2, 40], device='cuda:0')
Epoch : 1 / Subject : 9 / Timepoint : 1 / Loss : 0.2464342713356018 / AUROC : 0.9530429710285826 / F1 : 0.9507293685606447 / AP : 0.9601975261933963

Time index : 2 out of 15
tensor(0.3799, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Subject : 9 / Timepoint : 2 / Loss : 0.24589934945106506 / AUROC : 0.9541859567901234 / F1 : 0.9519862529060952 / AP : 0.9607654719301256

Time index : 3 out of 15
tensor(0.3786, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Subject : 9 / Timepoint : 3 / Loss : 0.24166585505008698 / AUROC : 0.955581305523

  valid3D=torch.load(os.path.join('DatasetSplit','valid40'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 4 out of 8
Validation / Subject : 4 / Timepoint : 0 / Loss : 0.20548921823501587 / AP : 0.9808314841060232 / F1 : 0.9773527931555108 / AUROC : 0.9778543307086613

Validation / Subject : 4 / Timepoint : 1 / Loss : 0.20890919864177704 / AP : 0.9798398014512784 / F1 : 0.97592867756315 / AUROC : 0.9764944863609983

Validation / Subject : 4 / Timepoint : 2 / Loss : 0.21073701977729797 / AP : 0.9792613027672534 / F1 : 0.9751020408163266 / AUROC : 0.9757068896853843

Validation / Subject : 4 / Timepoint : 3 / Loss : 0.20978815853595734 / AP : 0.9788739952641592 / F1 : 0.9748297916236848 / AUROC : 0.9754477762125175

Validation / Subject : 4 / Timepoint : 4 / Loss : 0.2193736582994461 / AP : 0.9730895941156615 / F1 : 0.9672213817448311 / AUROC : 0.96826171875

Validation / Subject : 4 / Timepoint : 5 / Loss : 0.21912577748298645 / AP : 0.9705249612730853 / F1 : 0.9646190134345195 / AUROC : 0.9658280507131538

Validation / Subject : 4 / Timepoint : 6 / Loss : 0.220705

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))



Subject index : 10 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3971, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  0,   1,  40,  74, 108], device='cuda:0')
Epoch : 1 / Subject : 10 / Timepoint : 0 / Loss : 0.24341334402561188 / AUROC : 0.9547 / F1 : 0.9525505394364722 / AP : 0.9626690498588899

Time index : 1 out of 15
tensor(0.3977, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  0,  20,  72,  95, 108, 111], device='cuda:0')
Epoch : 1 / Subject : 10 / Timepoint : 1 / Loss : 0.2433367371559143 / AUROC : 0.9541868932038835 / F1 : 0.9519872813990461 / AP : 0.9622831045084453

Time index : 2 out of 15
tensor(0.4007, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  0,   1,  11,  37,  39,  40,  83, 107, 114], device='cuda:0')
Epoch : 1 / Subject : 10 / Timepoint : 2 / Loss : 0.24370630085468292 / AUROC : 0.9578646141701204 / F1 : 0.9560111259915525 / AP : 0.9656065791170192

Time index : 3 out of 15
tensor(0.3926, device='cuda:0', grad_fn=<MeanBac

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')



Subject index : 11 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15


  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))


tensor(0.3865, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40, 75], device='cuda:0')
Epoch : 1 / Subject : 11 / Timepoint : 0 / Loss : 0.24342219531536102 / AUROC : 0.9548700654632017 / F1 : 0.9527370935909422 / AP : 0.9617996274939092

Time index : 1 out of 15
tensor(0.3829, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40, 75], device='cuda:0')
Epoch : 1 / Subject : 11 / Timepoint : 1 / Loss : 0.2386823296546936 / AUROC : 0.9541808702626831 / F1 : 0.951980666176316 / AP : 0.9605962816157763

Time index : 2 out of 15
tensor(0.3861, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Subject : 11 / Timepoint : 2 / Loss : 0.24274380505084991 / AUROC : 0.9556912170145726 / F1 : 0.9536369256130228 / AP : 0.9624839557189023

Time index : 3 out of 15
tensor(0.3888, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Subject : 11 / Timepoint : 3 / Loss : 0.23845155537128448 / AUROC : 0.9

  valid3D=torch.load(os.path.join('DatasetSplit','valid40'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 5 out of 8
Validation / Subject : 5 / Timepoint : 0 / Loss : 0.22277110815048218 / AP : 0.9691242086707181 / F1 : 0.9625063163213744 / AUROC : 0.963861289694136

Validation / Subject : 5 / Timepoint : 1 / Loss : 0.22173237800598145 / AP : 0.9705106285062717 / F1 : 0.9641214351425943 / AUROC : 0.9653641207815276

Validation / Subject : 5 / Timepoint : 2 / Loss : 0.22097234427928925 / AP : 0.9697348516550288 / F1 : 0.9634496919917864 / AUROC : 0.9647385103011094

Validation / Subject : 5 / Timepoint : 3 / Loss : 0.2201615869998932 / AP : 0.9720562569576662 / F1 : 0.9659809332664325 / AUROC : 0.9671001552795031

Validation / Subject : 5 / Timepoint : 4 / Loss : 0.2205701768398285 / AP : 0.9686225291972371 / F1 : 0.962432183437404 / AUROC : 0.9637924230465666

Validation / Subject : 5 / Timepoint : 5 / Loss : 0.21320879459381104 / AP : 0.977815910453763 / F1 : 0.9731031714171016 / AUROC : 0.973807662236122

Validation / Subject : 5 / Timepoint : 6 / Loss : 0.2056

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))



Subject index : 12 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3775, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([1, 2], device='cuda:0')
Epoch : 1 / Subject : 12 / Timepoint : 0 / Loss : 0.25069063901901245 / AUROC : 0.9450097847358121 / F1 : 0.9418098985297163 / AP : 0.9527637918675158

Time index : 1 out of 15
tensor(0.3804, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 1, 40], device='cuda:0')
Epoch : 1 / Subject : 12 / Timepoint : 1 / Loss : 0.25230827927589417 / AUROC : 0.9440415523324186 / F1 : 0.9407245925464549 / AP : 0.9523921484760773

Time index : 2 out of 15
tensor(0.3837, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Subject : 12 / Timepoint : 2 / Loss : 0.24064071476459503 / AUROC : 0.9571957878315133 / F1 : 0.9552816542731996 / AP : 0.9633761773144436

Time index : 3 out of 15
tensor(0.3804, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Su

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')



Subject index : 13 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15


  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))


tensor(0.3861, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 1, 20, 40], device='cuda:0')
Epoch : 1 / Subject : 13 / Timepoint : 0 / Loss : 0.23998890817165375 / AUROC : 0.9555669050051072 / F1 : 0.9535008017103154 / AP : 0.9621123638599742

Time index : 1 out of 15
tensor(0.3877, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 20, 40], device='cuda:0')
Epoch : 1 / Subject : 13 / Timepoint : 1 / Loss : 0.24134093523025513 / AUROC : 0.9556744749596122 / F1 : 0.9536185948230322 / AP : 0.9624429908875672

Time index : 2 out of 15
tensor(0.3859, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Subject : 13 / Timepoint : 2 / Loss : 0.2391081303358078 / AUROC : 0.9556008146639512 / F1 : 0.953537936913896 / AP : 0.9621183641171116

Time index : 3 out of 15
tensor(0.3887, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([], device='cuda:0', dtype=torch.int64)
Epoch : 1 / Subject : 13 / Timepoint : 3 / Loss : 0.23102474212646484 / AUROC 

  valid3D=torch.load(os.path.join('DatasetSplit','valid40'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 6 out of 8
Validation / Subject : 6 / Timepoint : 0 / Loss : 0.2230307161808014 / AP : 0.9683950193169878 / F1 : 0.9615615615615616 / AUROC : 0.9629843840370156

Validation / Subject : 6 / Timepoint : 1 / Loss : 0.21802125871181488 / AP : 0.96998329400995 / F1 : 0.9640984267849939 / AUROC : 0.9653426791277259

Validation / Subject : 6 / Timepoint : 2 / Loss : 0.22285233438014984 / AP : 0.9663985617798043 / F1 : 0.9596122778675282 / AUROC : 0.9611801242236024

Validation / Subject : 6 / Timepoint : 3 / Loss : 0.21209535002708435 / AP : 0.9747662096784654 / F1 : 0.9700598802395209 / AUROC : 0.9709302325581395

Validation / Subject : 6 / Timepoint : 4 / Loss : 0.20463816821575165 / AP : 0.9795571691297507 / F1 : 0.9760533494998485 / AUROC : 0.9766133806986382

Validation / Subject : 6 / Timepoint : 5 / Loss : 0.21444067358970642 / AP : 0.9736952172219298 / F1 : 0.9685717249248004 / AUROC : 0.9695293644408689

Validation / Subject : 6 / Timepoint : 6 / Loss : 0.2

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))



Subject index : 14 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3865, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Subject : 14 / Timepoint : 0 / Loss : 0.2389889657497406 / AUROC : 0.9576589024871693 / F1 : 0.9557868700401938 / AP : 0.9639297029913358

Time index : 1 out of 15
tensor(0.3809, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 1,  2, 40], device='cuda:0')
Epoch : 1 / Subject : 14 / Timepoint : 1 / Loss : 0.24915148317813873 / AUROC : 0.9484335473827593 / F1 : 0.9456298727944193 / AP : 0.955985890080633

Time index : 2 out of 15
tensor(0.3745, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 1,  2, 40], device='cuda:0')
Epoch : 1 / Subject : 14 / Timepoint : 2 / Loss : 0.24723762273788452 / AUROC : 0.948019801980198 / F1 : 0.9451697127937336 / AP : 0.9548170334274807

Time index : 3 out of 15
tensor(0.3731, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 1,  2, 40], device='cuda:0')
Epoch :

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')



Subject index : 15 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15


  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))


tensor(0.3815, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Subject : 15 / Timepoint : 0 / Loss : 0.22767594456672668 / AUROC : 0.9664855072463768 / F1 : 0.9653233364573571 / AP : 0.9707218363950323

Time index : 1 out of 15
tensor(0.3893, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Subject : 15 / Timepoint : 1 / Loss : 0.2274906188249588 / AUROC : 0.9695027195027195 / F1 : 0.968543378080545 / AP : 0.9738749559404751

Time index : 2 out of 15
tensor(0.3840, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 1 / Subject : 15 / Timepoint : 2 / Loss : 0.22573256492614746 / AUROC : 0.9694949494949495 / F1 : 0.9685351114815587 / AP : 0.9734569156527614

Time index : 3 out of 15
tensor(0.3924, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([0, 1], device='cuda:0')
Epoch : 1 / Subject : 15 / Timepoint : 3 / Loss : 0.23023924231529236 / AUROC : 0.969947481034818

  valid3D=torch.load(os.path.join('DatasetSplit','valid40'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 7 out of 8
Validation / Subject : 7 / Timepoint : 0 / Loss : 0.22850091755390167 / AP : 0.963958750645597 / F1 : 0.9559812512736906 / AUROC : 0.9578372047628343

Validation / Subject : 7 / Timepoint : 1 / Loss : 0.21976158022880554 / AP : 0.9731114044157956 / F1 : 0.9670571743266246 / AUROC : 0.9681077937184955

Validation / Subject : 7 / Timepoint : 2 / Loss : 0.2163609117269516 / AP : 0.973345522457081 / F1 : 0.9678394999495917 / AUROC : 0.9688415706192616

Validation / Subject : 7 / Timepoint : 3 / Loss : 0.21766890585422516 / AP : 0.9727662762376349 / F1 : 0.9670307845084409 / AUROC : 0.9680830609498173

Validation / Subject : 7 / Timepoint : 4 / Loss : 0.2196178287267685 / AP : 0.9711649041639128 / F1 : 0.9650711513583441 / AUROC : 0.96625

Validation / Subject : 7 / Timepoint : 5 / Loss : 0.22028155624866486 / AP : 0.9707122431295065 / F1 : 0.9644886363636364 / AUROC : 0.9657064471879286

Validation / Subject : 7 / Timepoint : 6 / Loss : 0.2119285017251

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')



Subject index : 0 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15


  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))


tensor(0.3890, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 36, 40], device='cuda:0')
Epoch : 2 / Subject : 0 / Timepoint : 0 / Loss : 0.23819008469581604 / AUROC : 0.9603622906116089 / F1 : 0.9587262955075551 / AP : 0.9664806239438474

Time index : 1 out of 15
tensor(0.3826, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 2 / Subject : 0 / Timepoint : 1 / Loss : 0.23938949406147003 / AUROC : 0.9579476068544416 / F1 : 0.9561015729412974 / AP : 0.9638889194523922

Time index : 2 out of 15
tensor(0.3816, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 2 / Subject : 0 / Timepoint : 2 / Loss : 0.2457381933927536 / AUROC : 0.9516536964980544 / F1 : 0.9491975876520494 / AP : 0.9586957803804843

Time index : 3 out of 15
tensor(0.3846, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 2 / Subject : 0 / Timepoint : 3 / Loss : 0.24135568737983704 / AUROC : 0.95647496

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')



Subject index : 1 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15


  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))


tensor(0.3857, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([0, 1], device='cuda:0')
Epoch : 2 / Subject : 1 / Timepoint : 0 / Loss : 0.24626104533672333 / AUROC : 0.9517269238537669 / F1 : 0.9492784380305602 / AP : 0.95919728794077

Time index : 1 out of 15
tensor(0.3761, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 2 / Subject : 1 / Timepoint : 1 / Loss : 0.24101050198078156 / AUROC : 0.9522986167615948 / F1 : 0.9499092171312613 / AP : 0.9583754699639647

Time index : 2 out of 15
tensor(0.3827, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0, 40], device='cuda:0')
Epoch : 2 / Subject : 1 / Timepoint : 2 / Loss : 0.23929783701896667 / AUROC : 0.9582512800315084 / F1 : 0.956432387998356 / AP : 0.9641323061874834

Time index : 3 out of 15
tensor(0.3837, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([1], device='cuda:0')
Epoch : 2 / Subject : 1 / Timepoint : 3 / Loss : 0.24037548899650574 / AUROC : 0.9566026645768024 / F1 : 0.9546338

  valid3D=torch.load(os.path.join('DatasetSplit','valid40'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 0 out of 8
Validation / Subject : 0 / Timepoint : 0 / Loss : 0.22553576529026031 / AP : 0.963774967430382 / F1 : 0.9563941299790356 / AUROC : 0.9582161510646846

Validation / Subject : 0 / Timepoint : 1 / Loss : 0.22822575271129608 / AP : 0.9630877022560651 / F1 : 0.9551592619317596 / AUROC : 0.9570836621941594

Validation / Subject : 0 / Timepoint : 2 / Loss : 0.21533505618572235 / AP : 0.9739973546435218 / F1 : 0.9687061183550651 / AUROC : 0.9696557090060299

Validation / Subject : 0 / Timepoint : 3 / Loss : 0.22002825140953064 / AP : 0.9713540540777574 / F1 : 0.9651469098277609 / AUROC : 0.9663207362443704

Validation / Subject : 0 / Timepoint : 4 / Loss : 0.22076182067394257 / AP : 0.9705391497890773 / F1 : 0.9641946342956238 / AUROC : 0.9654323419342131

Validation / Subject : 0 / Timepoint : 5 / Loss : 0.21796336770057678 / AP : 0.972881596754848 / F1 : 0.9670835417708854 / AUROC : 0.9681325067803177

Validation / Subject : 0 / Timepoint : 6 / Loss : 0.

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))



Subject index : 2 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3846, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([1], device='cuda:0')
Epoch : 2 / Subject : 2 / Timepoint : 0 / Loss : 0.2385949343442917 / AUROC : 0.9570862239841427 / F1 : 0.9551620586103344 / AP : 0.9632988378026399

Time index : 1 out of 15
tensor(0.3870, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  0,   1,  20,  40, 107], device='cuda:0')
Epoch : 2 / Subject : 2 / Timepoint : 1 / Loss : 0.2379431128501892 / AUROC : 0.9593535749265426 / F1 : 0.9576314446145993 / AP : 0.9654150845745259

Time index : 2 out of 15
tensor(0.3904, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  2, 20], device='cuda:0')
Epoch : 2 / Subject : 2 / Timepoint : 2 / Loss : 0.23251497745513916 / AUROC : 0.9647972564050837 / F1 : 0.9635128071092525 / AP : 0.9701052498966445

Time index : 3 out of 15
tensor(0.3926, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  1,  40, 107], device='cuda:0')
Epo

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')



Subject index : 3 out of 16
Sampling probability : 0.7398
Time index : 0 out of 15


  train3D=torch.load(os.path.join('DatasetSplit','train40'+str(i)+'.pth'))


tensor(0.3923, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([0, 1], device='cuda:0')
Epoch : 2 / Subject : 3 / Timepoint : 0 / Loss : 0.243011936545372 / AUROC : 0.9557737169517885 / F1 : 0.9537272449913556 / AP : 0.9630821490468549

Time index : 1 out of 15
tensor(0.3932, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Epoch : 2 / Subject : 3 / Timepoint : 1 / Loss : 0.24735811352729797 / AUROC : 0.9519818576217709 / F1 : 0.9495598135680994 / AP : 0.9602317683920107

Time index : 2 out of 15
tensor(0.3953, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 14, 20, 26, 40], device='cuda:0')
Epoch : 2 / Subject : 3 / Timepoint : 2 / Loss : 0.24320609867572784 / AUROC : 0.9531436530043773 / F1 : 0.9508402045715478 / AP : 0.9611568307058654

Time index : 3 out of 15
tensor(0.3938, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([15, 26, 40], device='cuda:0')
Epoch : 2 / Subject : 3 / Timepoint : 3 / Loss : 0.24458646774291992 / AUROC : 0.9533678

  valid3D=torch.load(os.path.join('DatasetSplit','valid40'+str((i+1)//2-1)+'.pth'))


Validation Subject index : 1 out of 8
Validation / Subject : 1 / Timepoint : 0 / Loss : 0.20948141813278198 / AP : 0.9773184077752715 / F1 : 0.9729946792490713 / AUROC : 0.973704789833822

Validation / Subject : 1 / Timepoint : 1 / Loss : 0.21194475889205933 / AP : 0.9768780307980238 / F1 : 0.972152407761134 / AUROC : 0.9729068857589984

Validation / Subject : 1 / Timepoint : 2 / Loss : 0.2131851315498352 / AP : 0.9749184150878355 / F1 : 0.9699157641395909 / AUROC : 0.9707943925233644

Validation / Subject : 1 / Timepoint : 3 / Loss : 0.22286032140254974 / AP : 0.9692961979331771 / F1 : 0.9625290374709625 / AUROC : 0.963882398753894

Validation / Subject : 1 / Timepoint : 4 / Loss : 0.22532296180725098 / AP : 0.9674873597866345 / F1 : 0.9601778855872246 / AUROC : 0.9617029548989113

Validation / Subject : 1 / Timepoint : 5 / Loss : 0.22091566026210785 / AP : 0.9692566950278158 / F1 : 0.9628128483128989 / AUROC : 0.964146150840172

Validation / Subject : 1 / Timepoint : 6 / Loss : 0.219

Exception: Early Stopping

The cell below installs and initializes the wandb session, which is a dashboard service to record experimental details and compare the results.

In [None]:
# Installing wandb
!pip install wandb
import wandb
!wandb login --relogin

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


The two cells below is for testing the model trained with the 40th Quantile Dataset.

In [None]:
# 40th Quantile Dataset Testing

sampleStrategy='uniform'
negSampleStrategy='inductive'
N=116 # Number of ROIs
numEdges=math.ceil(116*116*0.4)
nodeDim=N # One-Hot encoding as a node embedding
edgeDim=4 # Arbitrarily set
timeDim=100 # Half of the value mentioned in D=the HOT paper
channelDim=50 # Half of the value mentioned in the HOT paper
latentDim=N # Positive node embedding and negative node embedding concatenated across dim=1
numFilter=4 # Following Triadic Decoder paper
B=10000
prob=0.7398
nodeFeat=torch.cat((torch.eye(N),torch.zeros((1,N))),dim=0).numpy()
edgeFeat=torch.zeros((numEdges+1,edgeDim)).numpy()

wandb.init(
    project="EmotionConnectivity - 40thQuantileTest",
    config={
        "Mode" : "Test",
        "initialization" : "kaiming_normal",
        "sampleStrategy" : sampleStrategy,
        "negSampleStrategy" : negSampleStrategy,
        "B": B,
        "prob": prob,
        "Dropout" : 0.1,
    }
)

checkpoint=torch.load('checkpoint8.pth')
encoder=HOT(nodeFeat, edgeFeat, None, timeDim, channelDim, patch_size=8, num_state_vectors=32, num_layers=2, num2hop=0, dropout=0.1, max_input_sequence_length=4096, device='cuda')
encoder.load_state_dict(checkpoint['encoder'])
encoder.eval()
decoder=TriadicDecoder(latentDim, numFilter, device='cuda')
decoder.load_state_dict(checkpoint['decoder'])
decoder.eval()
model=nn.Sequential(encoder,  decoder).to('cuda')
lossFunc=nn.BCELoss()
data=datasets.fetch_atlas_aal()
connDictAmygL={}
connDictInsL={}

with torch.no_grad():
  for i in range(10):
    model.eval()
    print("\nSubject index : "+str(i)+" out of 10")
    test3D=torch.load(os.path.join('DatasetSplit','test'+str(i)+'.pth'))
    test3D=test3D.reshape(-1,3).detach().to('cpu').numpy()
    test3D=test3D.reshape(15, -1, 3)
    print("Sampling probability : "+str(prob))

    lossSubject=0
    F1Subject=0
    APSubject=0
    AUROCSubject=0
    for t in range(15):

      print("Time index : "+str(t)+" out of 15")

      timeBatch=test3D[t]
      srcNodes=timeBatch[:,0]
      dstNodes=timeBatch[:,1]
      timepoints=timeBatch[:,2]

      adjMat=torch.zeros((N,N)).to('cuda')
      for row in timeBatch:
        adjMat[int(row[0]), int(row[1])]=1

      nodePosEmbedding=torch.zeros((N,N)).to('cuda')
      nodeNegEmbedding=torch.zeros((N,N)).to('cuda')
      nodePosTimes=torch.zeros(N).to('cuda')
      nodeNegTimes=torch.zeros(N).to('cuda')
      nodeEmbedding=torch.zeros((N,N)).to('cuda')

      test_neighbor_sampler = get_neighbor_sampler(data=timeBatch, sample_neighbor_strategy=sampleStrategy, seed=4)
      test_neg_edge_sampler = NegativeEdgeSampler(src_node_ids=srcNodes, dst_node_ids=dstNodes, interact_times=timepoints, last_observed_time=np.min(timepoints), negative_sample_strategy=negSampleStrategy, seed=5)
      model[0].set_neighbor_sampler(test_neighbor_sampler)

      subject_src_node_embeddings, subject_dst_node_embeddings = \
        model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=srcNodes, dst_node_ids=dstNodes, node_interact_times=timepoints)
      negSrcNodes, negDstNodes = test_neg_edge_sampler.sample(len(srcNodes), srcNodes, dstNodes, np.min(timepoints), np.max(timepoints))
      subject_neg_src_node_embeddings, subject_neg_dst_node_embeddings = \
        model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=negSrcNodes,
                                                                      dst_node_ids=negDstNodes,
                                                                      node_interact_times=timepoints)

      for tb, row in enumerate(zip(srcNodes, dstNodes)):
        nodePosEmbedding[row[0]]=nodePosEmbedding[row[0]]+subject_src_node_embeddings[tb]
        nodePosTimes[row[0]]+=1
        nodePosEmbedding[row[1]]=nodePosEmbedding[row[1]]+subject_dst_node_embeddings[tb]
        nodePosTimes[row[1]]+=1
      for tb, row in enumerate(zip(negSrcNodes, negDstNodes)):
        nodeNegEmbedding[row[0]]=nodeNegEmbedding[row[0]]+subject_neg_src_node_embeddings[tb]
        nodeNegTimes[row[0]]+=1
        nodeNegEmbedding[row[1]]=nodeNegEmbedding[row[1]]+subject_neg_dst_node_embeddings[tb]
        nodeNegTimes[row[1]]+=1

      for node in range(N):
        if nodePosTimes[node]>0:
          nodePosEmbedding[node]=nodePosEmbedding[node]/nodePosTimes[node]
        if nodeNegTimes[node]>0:
          nodeNegEmbedding[node]=nodeNegEmbedding[node]/nodeNegTimes[node]

      nodePosEmbedding=(nodePosEmbedding-torch.mean(nodePosEmbedding, dim=0, keepdims=True))/(torch.std(nodePosEmbedding, dim=0, keepdims=True)+1e-10)
      nodeNegEmbedding=(nodeNegEmbedding-torch.mean(nodeNegEmbedding, dim=0, keepdims=True))/(torch.std(nodeNegEmbedding, dim=0, keepdims=True)+1e-10)

      graph2D=np.concatenate((np.expand_dims(srcNodes, axis=1), np.expand_dims(dstNodes, axis=1)), axis=1)
      graph2D=torch.tensor(graph2D).to('cuda')
      graph2DN=np.concatenate((np.expand_dims(negSrcNodes, axis=1), np.expand_dims(negDstNodes, axis=1)), axis=1)
      graph2DN=torch.tensor(graph2DN).to('cuda')

      triads=triadSample(graph2D, B, prob, device='cuda')
      mSquare=torch.zeros((N,N)).to('cuda')
      eSquare=torch.zeros((N,N)).to('cuda')

      for j, triad in enumerate(triads):
        triad0=triad[0].unsqueeze(0).unsqueeze(0)
        triad1=triad[1].unsqueeze(0).unsqueeze(0)
        triad2=triad[2].unsqueeze(0).unsqueeze(0)

        if mSquare[triad[0], triad[1]]==0:
          pair1=torch.cat((triad0, triad1), dim=1).to('cuda')
          M1=calculateM(triads, pair1)
          mSquare[triad[0], triad[1]]=M1

        if mSquare[triad[0], triad[2]]==0:
          pair1=torch.cat((triad0, triad2), dim=1).to('cuda')
          M2=calculateM(triads, pair1)
          mSquare[triad[0], triad[2]]=M2

        if mSquare[triad[1], triad[2]]==0:
          pair1=torch.cat((triad1, triad2), dim=1).to('cuda')
          M3=calculateM(triads, pair1)
          mSquare[triad[1], triad[2]]=M3

      for j, triad in enumerate(triads):
        triad0=triad[0].unsqueeze(0).unsqueeze(0)
        triad1=triad[1].unsqueeze(0).unsqueeze(0)
        triad2=triad[2].unsqueeze(0).unsqueeze(0)
        pair1=torch.cat((triad0, triad1), dim=1).squeeze(0).to('cuda')
        pair2=torch.cat((triad0, triad2), dim=1).squeeze(0).to('cuda')
        pair3=torch.cat((triad1, triad2), dim=1).squeeze(0).to('cuda')

        if pair1 in graph2D[:,0:2]:
          emb1=nodePosEmbedding[triad[0]]
        if pair1 in graph2DN:
          emb1=nodeNegEmbedding[triad[0]]
        if pair2 in graph2D[:,0:2]:
          emb2=nodePosEmbedding[triad[1]]
        if pair2 in graph2DN:
          emb2=nodeNegEmbedding[triad[1]]
        if pair3 in graph2D[:,0:2]:
          emb3=nodePosEmbedding[triad[2]]
        if pair3 in graph2DN:
          emb3=nodeNegEmbedding[triad[2]]

        eTriplet=model[1].forward(emb1, emb2, emb3)
        if pairIn(pair1,graph2D[:,0:2]):
          eSquare[triad[0], triad[1]]+=eTriplet[0]
        if pairIn(pair1,graph2DN):
          eSquare[triad[0], triad[1]]-=eTriplet[0]
        if pairIn(pair2,graph2D[:,0:2]):
          eSquare[triad[0], triad[2]]+=eTriplet[1]
        if pairIn(pair2,graph2DN):
          eSquare[triad[0], triad[2]]-=eTriplet[1]
        if pairIn(pair3,graph2D[:,0:2]):
          eSquare[triad[1], triad[2]]+=eTriplet[2]
        if pairIn(pair3,graph2DN):
          eSquare[triad[1], triad[2]]-=eTriplet[2]

      eSquare.clamp_(min=0.0)
      nonzeros=((mSquare>0).nonzero())
      adjRecon=[]
      adjPart=[]
      adjFull=torch.zeros((N,N)).to('cuda')
      for nz in nonzeros:
        nzPair=torch.tensor([int(nz[0]), int(nz[1])]).to('cuda')
        em1=eSquare[int(nz[0]), int(nz[1])]/mSquare[int(nz[0]), int(nz[1])]
        if pairIn(nzPair, graph2D[:,0:2]):
          adjRecon.append(em1)
          adjPart.append(torch.tensor([1.0]))
        if pairIn(nzPair,graph2DN):
          adjRecon.append(em1)
          adjPart.append(torch.tensor([0.0]))
        if em1>0.5:
          adjFull[int(nz[0]), int(nz[1])]=1
        else:
          adjFull[int(nz[0]), int(nz[1])]=0
      adjRecon=torch.stack(adjRecon).to('cuda')
      adjPart=torch.stack(adjPart).squeeze(1).to('cuda')

      loss=lossFunc(input=adjRecon, target=adjPart)
      lossSubject+=loss
      adjReconC=adjRecon.clone().detach().cpu().numpy()
      adjPartC=adjPart.clone().detach().cpu().numpy()

      adjReconC[adjReconC>0.5]=1
      adjReconC[adjReconC<=0.5]=0
      F1=f1_score(adjPartC, adjReconC)
      AUROC=roc_auc_score(adjPartC, adjReconC)
      AP=average_precision_score(adjPartC, adjReconC)
      F1Subject+=F1
      APSubject+=AP
      AUROCSubject+=AUROC

      print("Test / Subject : "+str(i)+" / Timepoint : "+str(t) +" / Loss : "+str(loss.item())+" / AUROC : "+str(AUROC)+" / F1 : "+str(F1)+" / AP : "+str(AP)+"\n")


      AmygLConn=torch.nonzero(adjFull[40])[:,0]
      for elem in AmygLConn:
        if data.labels[elem] not in connDictAmygL.keys():
          connDictAmygL[data.labels[elem]]=1
        else:
          connDictAmygL[data.labels[elem]]+=1

      InsLConn=torch.nonzero(adjFull[28])[:,0]
      for elem in InsLConn:
        if data.labels[elem] not in connDictInsL.keys():
          connDictInsL[data.labels[elem]]=1
        else:
          connDictInsL[data.labels[elem]]+=1

    wandb.log({"TLoss": lossSubject.item()/15, "T_AP": APSubject/15, "T_F1": F1Subject/15, "T_AUROC":AUROCSubject/15})
    torch.cuda.empty_cache()

torch.save(connDictAmygL, 'connDictTestAmygL.pth')
torch.save(connDictInsL, 'connDictTestInsL.pth')

VBox(children=(Label(value='0.014 MB of 0.014 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

  checkpoint=torch.load('checkpoint8.pth')
  test3D=torch.load(os.path.join('DatasetSplit','test'+str(i)+'.pth'))



Subject index : 0 out of 10
Sampling probability : 0.7398
Time index : 0 out of 15
tensor([  8,   9,  14,  15,  19,  25,  31,  32,  33,  35,  37,  40,  54,  55,
         74,  82,  96,  97, 104, 108, 111, 115], device='cuda:0')
Test / Subject : 0 / Timepoint : 0 / Loss : 0.19278858602046967 / AUROC : 0.9902054926061072 / F1 : 0.9901086113266098 / AP : 0.9916804581986247

Time index : 1 out of 15
tensor([ 15,  19,  25,  32,  33,  35,  37,  40,  41,  48,  54,  82, 106, 111],
       device='cuda:0')
Test / Subject : 0 / Timepoint : 1 / Loss : 0.19345472753047943 / AUROC : 0.9889298892988929 / F1 : 0.9888059701492538 / AP : 0.9905546561033539

Time index : 2 out of 15
tensor([37, 40, 48, 82], device='cuda:0')
Test / Subject : 0 / Timepoint : 2 / Loss : 0.19252915680408478 / AUROC : 0.9891981315687037 / F1 : 0.9890801770782095 / AP : 0.9907337089526793

Time index : 3 out of 15
tensor([ 25,  40,  82, 107], device='cuda:0')
Test / Subject : 0 / Timepoint : 3 / Loss : 0.19231298565864563 / AU

  test3D=torch.load(os.path.join('DatasetSplit','test'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor([ 19,  20,  21,  25,  26,  28,  30,  31,  35,  36,  40,  45,  47,  49,
         50,  58,  59,  60,  61,  66,  67,  70,  73,  75,  80,  83,  85,  90,
         92,  95,  96,  98,  99, 100, 102, 103, 104, 105, 109, 112, 114],
       device='cuda:0')
Test / Subject : 1 / Timepoint : 0 / Loss : 0.19138039648532867 / AUROC : 0.9900329011031546 / F1 : 0.9899325579122276 / AP : 0.9914382213296393

Time index : 1 out of 15
tensor([ 25,  31,  33,  36,  40,  42,  43,  47,  59,  60,  62,  66,  67,  75,
         83,  85,  90,  94,  95,  96, 100, 104, 109, 114], device='cuda:0')
Test / Subject : 1 / Timepoint : 1 / Loss : 0.19226908683776855 / AUROC : 0.9898629078972775 / F1 : 0.9897590948990539 / AP : 0.991335722915298

Time index : 2 out of 15
tensor([  1,  11,  28,  36,  40,  41,  47,  49,  50,  51,  53,  59,  60,  62,
         72,  75,  83,  85,  96, 114], device='cuda:0')
Test / Subject : 1 / Timepoint : 2 / Loss : 0.1934458464384079

  test3D=torch.load(os.path.join('DatasetSplit','test'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor([12, 13, 32, 40, 60, 80], device='cuda:0')
Test / Subject : 2 / Timepoint : 0 / Loss : 0.20297621190547943 / AUROC : 0.9861636472227753 / F1 : 0.9859695165034824 / AP : 0.9887149388690355

Time index : 1 out of 15
tensor([40], device='cuda:0')
Test / Subject : 2 / Timepoint : 1 / Loss : 0.20567631721496582 / AUROC : 0.9858802323252553 / F1 : 0.9856780091416963 / AP : 0.988673023124782

Time index : 2 out of 15
tensor([40], device='cuda:0')
Test / Subject : 2 / Timepoint : 2 / Loss : 0.20154549181461334 / AUROC : 0.9870672502984481 / F1 : 0.9868978028623261 / AP : 0.9894196269931337

Time index : 3 out of 15
tensor([ 1, 29, 40, 61, 81], device='cuda:0')
Test / Subject : 2 / Timepoint : 3 / Loss : 0.19876013696193695 / AUROC : 0.9869812059514487 / F1 : 0.9868094813051671 / AP : 0.9891277482768998

Time index : 4 out of 15
tensor([  1,  11,  17,  19,  28,  29,  37,  40,  43,  45,  46,  48,  49,  50,
         51,  53,  55,  56, 

  test3D=torch.load(os.path.join('DatasetSplit','test'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor([ 37,  40,  69,  76,  82, 103, 104, 105, 114], device='cuda:0')
Test / Subject : 3 / Timepoint : 0 / Loss : 0.194421648979187 / AUROC : 0.9876242447865913 / F1 : 0.98746916625555 / AP : 0.9894020613997735

Time index : 1 out of 15
tensor([ 40,  69,  76,  82, 103], device='cuda:0')
Test / Subject : 3 / Timepoint : 1 / Loss : 0.19457107782363892 / AUROC : 0.9878238846678356 / F1 : 0.9876737994280643 / AP : 0.9896002193858999

Time index : 2 out of 15
tensor([ 25,  36,  37,  40,  69,  72,  76,  90, 103, 106, 108, 110],
       device='cuda:0')
Test / Subject : 3 / Timepoint : 2 / Loss : 0.1919236183166504 / AUROC : 0.9903752673536846 / F1 : 0.9902817316187298 / AP : 0.9917934794926453

Time index : 3 out of 15
tensor([ 25,  37,  40,  46,  72, 103, 106, 108, 110], device='cuda:0')
Test / Subject : 3 / Timepoint : 3 / Loss : 0.1947091966867447 / AUROC : 0.9875386398763524 / F1 : 0.9873813948938668 / AP : 0.9893440348953724

Time i

  test3D=torch.load(os.path.join('DatasetSplit','test'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor([  2,  14,  19,  24,  25,  28,  30,  31,  33,  34,  38,  40,  44,  46,
         47,  48,  49,  50,  51,  52,  53,  57,  61,  66,  69,  71,  72,  76,
         78,  86,  91,  93,  96,  98,  99, 108, 110, 114], device='cuda:0')
Test / Subject : 4 / Timepoint : 0 / Loss : 0.18953324854373932 / AUROC : 0.9918111753371869 / F1 : 0.9917435648372996 / AP : 0.9929640365739232

Time index : 1 out of 15
tensor([  8,  10,  14,  15,  17,  19,  22,  23,  24,  25,  28,  30,  34,  35,
         40,  44,  47,  54,  56,  62,  63,  70,  71,  72,  78,  79,  82,  86,
         87,  93,  98, 101, 106, 109, 110, 113, 114], device='cuda:0')
Test / Subject : 4 / Timepoint : 1 / Loss : 0.19113793969154358 / AUROC : 0.9894277400581959 / F1 : 0.9893147730614645 / AP : 0.9908625301274075

Time index : 2 out of 15
tensor([ 15,  17,  19,  28,  30,  34,  40,  47,  54,  62,  70,  71,  72,  78,
         79,  82,  86,  93,  98, 106, 110, 113, 114], device='cuda

  test3D=torch.load(os.path.join('DatasetSplit','test'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor([  1,   3,   7,   8,  11,  13,  14,  15,  20,  23,  25,  27,  31,  34,
         38,  40,  42,  44,  45,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         61,  63,  64,  65,  66,  71,  78,  81,  84,  85,  86,  87,  88,  89,
         91,  92,  93,  94,  97, 101, 104, 108, 111], device='cuda:0')
Test / Subject : 5 / Timepoint : 0 / Loss : 0.1914660483598709 / AUROC : 0.9905038759689923 / F1 : 0.9904128350616318 / AP : 0.9918737015998096

Time index : 1 out of 15
tensor([  1,   3,   4,   8,   9,  11,  13,  14,  20,  23,  25,  26,  27,  31,
         33,  38,  40,  42,  45,  47,  48,  50,  51,  53,  54,  55,  63,  69,
         71,  78,  81,  85,  86,  87,  88,  89,  91,  93,  94,  97, 101, 104,
        111], device='cuda:0')
Test / Subject : 5 / Timepoint : 1 / Loss : 0.19132089614868164 / AUROC : 0.9902377730523875 / F1 : 0.990141532454856 / AP : 0.991623448974532

Time index : 2 out of 15
tensor([  1,   3,   4,   8,   9,  11

  test3D=torch.load(os.path.join('DatasetSplit','test'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor([  2,   3,   4,   6,   7,   8,   9,  10,  12,  14,  15,  17,  18,  22,
         23,  24,  28,  32,  34,  36,  40,  41,  42,  43,  44,  48,  53,  54,
         55,  56,  58,  60,  62,  64,  65,  66,  78,  79,  80,  81,  82,  83,
         84,  87,  90,  91,  92,  93,  94,  95,  96,  97,  99, 101, 102, 104,
        113], device='cuda:0')
Test / Subject : 6 / Timepoint : 0 / Loss : 0.19701030850410461 / AUROC : 0.9865549493374903 / F1 : 0.9863717163736915 / AP : 0.9885946383518207

Time index : 1 out of 15
tensor([  2,   4,  10,  11,  12,  14,  17,  22,  28,  36,  40,  42,  56,  58,
         60,  62,  66,  78,  82,  87,  91,  93,  94,  95,  96,  97,  99, 104,
        109, 113], device='cuda:0')
Test / Subject : 6 / Timepoint : 1 / Loss : 0.1934494525194168 / AUROC : 0.9882926829268293 / F1 : 0.9881539980256664 / AP : 0.9899603101247241

Time index : 2 out of 15
tensor([  2,   4,  12,  16,  17,  21,  22,  25,  28,  36,  40,  80,  

  test3D=torch.load(os.path.join('DatasetSplit','test'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor([  1,  16,  20,  21,  24,  25,  26,  33,  36,  38,  39,  40,  41,  42,
         43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  56,  57,
         59,  66,  67,  69,  70,  71,  73,  77,  80,  81,  82,  87,  89,  92,
         94,  95,  96,  97,  98,  99, 100, 110, 111, 112, 114],
       device='cuda:0')
Test / Subject : 7 / Timepoint : 0 / Loss : 0.19652020931243896 / AUROC : 0.987223823246878 / F1 : 0.9870584801011969 / AP : 0.989186121671061

Time index : 1 out of 15
tensor([  1,  16,  20,  21,  24,  25,  26,  34,  36,  37,  39,  40,  41,  42,
         43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  56,  57,
         59,  66,  67,  69,  70,  71,  73,  75,  87,  89,  94,  96,  97,  98,
        100, 110, 111, 114, 115], device='cuda:0')
Test / Subject : 7 / Timepoint : 1 / Loss : 0.19627973437309265 / AUROC : 0.9874304356169641 / F1 : 0.987270430473229 / AP : 0.9893584365283064

Time index : 2 out of 15

  test3D=torch.load(os.path.join('DatasetSplit','test'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor([  1,   2,   3,   4,   5,   6,   7,   8,  10,  11,  12,  13,  16,  17,
         19,  20,  21,  22,  23,  24,  25,  26,  27,  28,  29,  30,  31,  32,
         33,  37,  40,  41,  43,  45,  48,  49,  50,  51,  52,  55,  56,  58,
         59,  60,  63,  64,  67,  70,  71,  72,  73,  74,  76,  77,  78,  79,
         80,  81,  82,  83,  85,  86,  87,  88,  89,  95,  97, 113, 114],
       device='cuda:0')
Test / Subject : 8 / Timepoint : 0 / Loss : 0.1937376707792282 / AUROC : 0.989656518345043 / F1 : 0.989548412541905 / AP : 0.9912419135750735

Time index : 1 out of 15
tensor([  1,   3,   4,   5,   6,   7,   8,  11,  12,  13,  14,  16,  17,  18,
         19,  20,  21,  22,  23,  25,  26,  27,  28,  29,  30,  31,  32,  33,
         37,  40,  41,  43,  48,  49,  50,  52,  55,  57,  59,  62,  63,  67,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         85,  86,  88,  89,  95,  97, 113], device='cud

  test3D=torch.load(os.path.join('DatasetSplit','test'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor([ 31,  39,  40,  41,  49,  53,  95,  99, 105], device='cuda:0')
Test / Subject : 9 / Timepoint : 0 / Loss : 0.18685559928417206 / AUROC : 0.9955555555555555 / F1 : 0.9955357142857143 / AP : 0.9962182500030841

Time index : 1 out of 15
tensor([ 31,  36,  39,  40,  41,  43,  45,  47,  48,  49,  53,  58,  63,  67,
         94,  99, 102, 103, 105, 110, 113, 115], device='cuda:0')
Test / Subject : 9 / Timepoint : 1 / Loss : 0.1875666081905365 / AUROC : 0.9934786840568426 / F1 : 0.993435877339081 / AP : 0.9943894257953267

Time index : 2 out of 15
tensor([ 10,  36,  39,  40,  41,  43,  45,  47,  48,  49,  50,  51,  53,  58,
         62,  63,  67,  72,  75,  78,  80,  82,  83,  91,  93,  94,  99, 102,
        103, 105, 113], device='cuda:0')
Test / Subject : 9 / Timepoint : 2 / Loss : 0.1889793425798416 / AUROC : 0.9932196822936846 / F1 : 0.9931733957480008 / AP : 0.9942206006923048

Time index : 3 out of 15
tensor([  1,   2,   7, 

In [None]:
# Calculating the recorded frequencies of all brain regions

connAmyg=torch.load('connDictTestAmygL.pth')
connAmyg = dict(sorted(connAmyg.items(), key=lambda item: item[1], reverse=True))
print(connAmyg)

values=list(connAmyg.values())
values.remove(max(values))
print(torch.quantile(torch.tensor(list(values), dtype=float), 0.75))

connIns=torch.load('connDictTestInsL.pth')
connIns = dict(sorted(connIns.items(), key=lambda item: item[1], reverse=True))
print(connIns)

values=list(connIns.values())
values.remove(max(values))
print(torch.quantile(torch.tensor(list(values), dtype=float), 0.75))

  connAmyg=torch.load('connDictTestAmygL.pth')


{'Amygdala_L': 147, 'Precuneus_R': 76, 'Rolandic_Oper_R': 69, 'Cuneus_R': 66, 'Parietal_Sup_R': 65, 'Precentral_R': 64, 'Temporal_Sup_R': 64, 'Occipital_Mid_R': 63, 'Frontal_Med_Orb_R': 62, 'Amygdala_R': 62, 'Occipital_Sup_L': 61, 'Temporal_Inf_R': 61, 'Cerebelum_3_L': 61, 'Occipital_Sup_R': 60, 'Fusiform_L': 59, 'Frontal_Inf_Oper_R': 59, 'Hippocampus_L': 58, 'Temporal_Mid_R': 58, 'Calcarine_L': 57, 'Parietal_Inf_L': 57, 'Putamen_L': 56, 'Caudate_R': 56, 'Cerebelum_4_5_R': 55, 'Precuneus_L': 55, 'Supp_Motor_Area_R': 54, 'Cingulum_Mid_R': 54, 'Frontal_Inf_Tri_R': 54, 'Olfactory_L': 54, 'Insula_L': 54, 'Calcarine_R': 53, 'Occipital_Inf_R': 53, 'Thalamus_L': 53, 'Parietal_Sup_L': 52, 'Rectus_L': 52, 'ParaHippocampal_L': 52, 'Fusiform_R': 51, 'Lingual_R': 51, 'Postcentral_R': 51, 'Temporal_Inf_L': 51, 'Postcentral_L': 51, 'Vermis_6': 50, 'Cuneus_L': 50, 'Frontal_Sup_Medial_R': 50, 'Frontal_Inf_Orb_L': 49, 'Temporal_Pole_Sup_L': 49, 'Cerebelum_4_5_L': 49, 'Occipital_Inf_L': 49, 'Cerebelum_6

  connIns=torch.load('connDictTestInsL.pth')


{'Insula_L': 144, 'Temporal_Sup_L': 112, 'Rolandic_Oper_L': 106, 'Insula_R': 104, 'Supp_Motor_Area_L': 96, 'Frontal_Inf_Oper_L': 95, 'Postcentral_L': 95, 'Supp_Motor_Area_R': 94, 'Temporal_Pole_Sup_L': 94, 'Rolandic_Oper_R': 92, 'Parietal_Inf_L': 92, 'SupraMarginal_L': 92, 'Frontal_Inf_Tri_L': 91, 'Cingulum_Mid_L': 88, 'SupraMarginal_R': 87, 'Frontal_Mid_L': 86, 'Frontal_Sup_L': 85, 'Parietal_Sup_L': 82, 'Temporal_Sup_R': 81, 'Frontal_Inf_Tri_R': 81, 'Postcentral_R': 80, 'Frontal_Mid_R': 77, 'Cingulum_Ant_L': 74, 'Precentral_R': 73, 'Heschl_L': 73, 'Lingual_R': 71, 'Hippocampus_L': 70, 'Cingulum_Mid_R': 70, 'Occipital_Sup_L': 69, 'Putamen_L': 69, 'Frontal_Mid_Orb_L': 68, 'Cuneus_R': 68, 'Calcarine_L': 68, 'Frontal_Inf_Orb_L': 67, 'ParaHippocampal_L': 67, 'Occipital_Mid_L': 67, 'Putamen_R': 66, 'Frontal_Inf_Oper_R': 66, 'Parietal_Sup_R': 63, 'Parietal_Inf_R': 63, 'Temporal_Inf_R': 62, 'Fusiform_L': 61, 'Thalamus_L': 61, 'Occipital_Sup_R': 61, 'Fusiform_R': 60, 'Frontal_Med_Orb_L': 60, '

The two cells below is for testing the model trained with the 80th Quantile Dataset.

In [None]:
# 80th Quantile Dataset Testing

sampleStrategy='uniform'
negSampleStrategy='inductive'
N=116 # Number of ROIs
numEdges=math.ceil(116*116*0.4)
nodeDim=N # One-Hot encoding as a node embedding
edgeDim=4 # Arbitrarily set
timeDim=100 # Half of the value mentioned in DyGFormer - HOT
channelDim=50 # Half of the value mentioned in DyGFormer - HOT
latentDim=N # Positive node embedding and negative node embedding concatenated across dim=1
numFilter=4 # Following Triadic Decoder paper
B=10000
prob=0.7398
nodeFeat=torch.cat((torch.eye(N),torch.zeros((1,N))),dim=0).numpy()
edgeFeat=torch.zeros((numEdges+1,edgeDim)).numpy()

checkpoint=torch.load('checkpoint9.pth')
wandb.init(
    project="EmotionConnectivity - 80thQuantileTest",
    config={
        "Mode" : "Test",
        "initialization" : "kaiming_normal",
        "sampleStrategy" : sampleStrategy,
        "negSampleStrategy" : negSampleStrategy,
        "B": B,
        "prob": prob,
        "Dropout" : 0.1,
    }
)


encoder=HOT(nodeFeat, edgeFeat, None, timeDim, channelDim, patch_size=8, num_state_vectors=32, num_layers=2, num2hop=0, dropout=0.1, max_input_sequence_length=4096, device='cuda')
encoder.apply(init_kaiming_normal)
decoder=TriadicDecoder(latentDim, numFilter, device='cuda')
decoder.apply(init_kaiming_normal)
model=nn.Sequential(encoder, decoder).to('cuda')
lossFunc=nn.BCELoss().to('cuda')
optimizerE=torch.optim.Adam(model.parameters(), lr=checkpoint['lr'])
data=datasets.fetch_atlas_aal()
connDictAmygL={}
connDictInsL={}


if os.path.isfile('checkpoint9.pth'):
  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
for i in range(0, 10):
  if os.path.isfile('checkpoint9.pth'):
    checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  model.eval()
  if checkpoint['encoder'] is not None: model[0].load_state_dict(checkpoint['encoder'])
  if checkpoint['decoder'] is not None: model[1].load_state_dict(checkpoint['decoder'])
  print("\nSubject index : "+str(i)+" out of 10")

  test403D=torch.load(os.path.join('DatasetSplit','test80'+str(i)+'.pth'))
  test403D=test403D.reshape(-1,3).detach().to('cpu').numpy()
  test403D=test403D.reshape(15, -1, 3)

  print("Sampling probability : "+str(prob))

  lossSubject=torch.zeros(1).to('cuda')
  lossSubject.requires_grad_()
  F1Subject=0
  APSubject=0
  AUROCSubject=0
  for t in range(0, 15):
    print("Time index : "+str(t)+" out of 15")

    timeBatch=test403D[t]
    srcNodes=timeBatch[:,0]
    dstNodes=timeBatch[:,1]
    timepoints=timeBatch[:,2]

    adjMat=torch.zeros((N,N)).to('cuda')
    for row in timeBatch:
      adjMat[int(row[0]), int(row[1])]=1

    nodePosEmbedding=torch.zeros((N,N)).to('cuda')
    nodeNegEmbedding=torch.zeros((N,N)).to('cuda')
    nodePosTimes=torch.zeros(N).to('cuda')
    nodeNegTimes=torch.zeros(N).to('cuda')
    nodeEmbedding=torch.zeros((N,N)).to('cuda')

    # Sample neighbor edges among the dynamic graphs of a single subject
    test_neighbor_sampler = get_neighbor_sampler(data=timeBatch, sample_neighbor_strategy=sampleStrategy, seed=4)
    # Sample negative edges among the dynamic graphs of a single subject
    test_neg_edge_sampler = NegativeEdgeSampler(src_node_ids=srcNodes, dst_node_ids=dstNodes, interact_times=timepoints, last_observed_time=np.min(timepoints), negative_sample_strategy=negSampleStrategy, seed=5)
    model[0].set_neighbor_sampler(test_neighbor_sampler)

    subject_src_node_embeddings, subject_dst_node_embeddings = \
      torch.utils.checkpoint.checkpoint(model[0].compute_src_dst_node_temporal_embeddings, srcNodes, dstNodes, timepoints, use_reentrant=False)
    negSrcNodes, negDstNodes = test_neg_edge_sampler.sample(len(srcNodes), srcNodes, dstNodes, np.min(timepoints), np.max(timepoints))
    subject_neg_src_node_embeddings, subject_neg_dst_node_embeddings = \
      torch.utils.checkpoint.checkpoint(model[0].compute_src_dst_node_temporal_embeddings, negSrcNodes, negDstNodes, timepoints, use_reentrant=False)

    for tb, row in enumerate(zip(srcNodes, dstNodes)):
      nodePosEmbedding[row[0]]=nodePosEmbedding[row[0]]+subject_src_node_embeddings[tb]
      nodePosTimes[row[0]]+=1
      nodePosEmbedding[row[1]]=nodePosEmbedding[row[1]]+subject_dst_node_embeddings[tb]
      nodePosTimes[row[1]]+=1
    for tb, row in enumerate(zip(negSrcNodes, negDstNodes)):
      nodeNegEmbedding[row[0]]=nodeNegEmbedding[row[0]]+subject_neg_src_node_embeddings[tb]
      nodeNegTimes[row[0]]+=1
      nodeNegEmbedding[row[1]]=nodeNegEmbedding[row[1]]+subject_neg_dst_node_embeddings[tb]
      nodeNegTimes[row[1]]+=1

    for node in range(N):
      if nodePosTimes[node]>0:
        nodePosEmbedding[node]=nodePosEmbedding[node]/nodePosTimes[node]
      if nodeNegTimes[node]>0:
        nodeNegEmbedding[node]=nodeNegEmbedding[node]/nodeNegTimes[node]


    nodePosEmbedding=(nodePosEmbedding-torch.mean(nodePosEmbedding, dim=0, keepdims=True))/(torch.std(nodePosEmbedding, dim=0, keepdims=True)+1e-10)
    nodeNegEmbedding=(nodeNegEmbedding-torch.mean(nodeNegEmbedding, dim=0, keepdims=True))/(torch.std(nodeNegEmbedding, dim=0, keepdims=True)+1e-10)

    graph2D=np.concatenate((np.expand_dims(srcNodes, axis=1), np.expand_dims(dstNodes, axis=1)), axis=1)
    graph2D=torch.tensor(graph2D).to('cuda')
    graph2DN=np.concatenate((np.expand_dims(negSrcNodes, axis=1), np.expand_dims(negDstNodes, axis=1)), axis=1)
    graph2DN=torch.tensor(graph2DN).to('cuda')
    triads=triadSample(graph2D, B, prob, device='cuda')

    mSquare=torch.zeros((N,N)).to('cuda')
    eSquare=torch.zeros((N,N)).to('cuda')

    for j, triad in enumerate(triads):
      triad0=triad[0].unsqueeze(0).unsqueeze(0)
      triad1=triad[1].unsqueeze(0).unsqueeze(0)
      triad2=triad[2].unsqueeze(0).unsqueeze(0)

      if mSquare[triad[0], triad[1]]==0:
        pair1=torch.cat((triad0, triad1), dim=1).to('cuda')
        M1=calculateM(triads, pair1)
        mSquare[triad[0], triad[1]]=M1
      if mSquare[triad[0], triad[2]]==0:
        pair2=torch.cat((triad0, triad2), dim=1).to('cuda')
        M2=calculateM(triads, pair2)
        mSquare[triad[0], triad[2]]=M2
      if mSquare[triad[1], triad[2]]==0:
        pair3=torch.cat((triad1, triad2), dim=1).to('cuda')
        M3=calculateM(triads, pair3)
        mSquare[triad[1], triad[2]]=M3

    for j, triad in enumerate(triads):
      triad0=triad[0].unsqueeze(0).unsqueeze(0)
      triad1=triad[1].unsqueeze(0).unsqueeze(0)
      triad2=triad[2].unsqueeze(0).unsqueeze(0)
      pair1=torch.cat((triad0, triad1), dim=1).squeeze(0).to('cuda')
      pair2=torch.cat((triad0, triad2), dim=1).squeeze(0).to('cuda')
      pair3=torch.cat((triad1, triad2), dim=1).squeeze(0).to('cuda')

      emb1=torch.zeros((1,N)).to('cuda')
      emb2=torch.zeros((1,N)).to('cuda')
      emb3=torch.zeros((1,N)).to('cuda')

      if pairIn(pair1,graph2D[:,0:2]):
        emb1=nodePosEmbedding[triad[0]]
      if pairIn(pair1,graph2DN):
        emb1=nodeNegEmbedding[triad[0]]
      if pairIn(pair2,graph2D[:,0:2]):
        emb2=nodePosEmbedding[triad[1]]
      if pairIn(pair2,graph2DN):
        emb2=nodeNegEmbedding[triad[1]]
      if pairIn(pair3,graph2D[:,0:2]):
        emb3=nodePosEmbedding[triad[2]]
      if pairIn(pair3,graph2DN):
        emb3=nodeNegEmbedding[triad[2]]

      eTriplet=model[1].forward(emb1, emb2, emb3)
      if pairIn(pair1,graph2D[:,0:2]):
        eSquare[triad[0], triad[1]]+=eTriplet[0]
      if pairIn(pair1,graph2DN):
        eSquare[triad[0], triad[1]]-=eTriplet[0]
      if pairIn(pair2,graph2D[:,0:2]):
        eSquare[triad[0], triad[2]]+=eTriplet[1]
      if pairIn(pair2,graph2DN):
        eSquare[triad[0], triad[2]]-=eTriplet[1]
      if pairIn(pair3,graph2D[:,0:2]):
        eSquare[triad[1], triad[2]]+=eTriplet[2]
      if pairIn(pair3,graph2DN):
        eSquare[triad[1], triad[2]]-=eTriplet[2]
    eSquare.clamp_(min=0.0)

    nonzeros=((mSquare>0).nonzero())
    adjRecon=[]
    adjPart=[]
    adjFull=torch.zeros((N,N)).to('cuda')
    for nz in nonzeros:
      nzPair=torch.tensor([int(nz[0]), int(nz[1])]).to('cuda')
      em1=eSquare[int(nz[0]), int(nz[1])]/mSquare[int(nz[0]), int(nz[1])]
      if pairIn(nzPair,graph2D[:,0:2]):
        adjRecon.append(em1)
        adjPart.append(torch.tensor([1.0]))
      if pairIn(nzPair, graph2DN):
        adjRecon.append(em1)
        adjPart.append(torch.tensor([0.0]))
      if em1>0.5:
        adjFull[int(nz[0]), int(nz[1])]=1
      else:
        adjFull[int(nz[0]), int(nz[1])]=0
    adjRecon=torch.stack(adjRecon).to('cuda')
    adjPart=torch.stack(adjPart).squeeze(1).to('cuda')

    loss=lossFunc(input=adjRecon, target=adjPart)
    lossSubject=lossSubject+loss

    adjReconC=adjRecon.clone().detach().cpu().numpy()
    adjPartC=adjPart.clone().detach().cpu().numpy()


    adjReconC[adjReconC>0.5]=1
    adjReconC[adjReconC<=0.5]=0
    F1=f1_score(adjPartC, adjReconC)
    AUROC=roc_auc_score(adjPartC, adjReconC)
    AP=average_precision_score(adjPartC, adjReconC)
    F1Subject+=F1
    APSubject+=AP
    AUROCSubject+=AUROC

    print("Test / Subject : "+str(i)+" / Timepoint : "+str(t) +" / Loss : "+str(loss.item())+" / AUROC : "+str(AUROC)+" / F1 : "+str(F1)+" / AP : "+str(AP)+"\n")

    AmygLConn=torch.nonzero(adjFull[40])[:,0]
    for elem in AmygLConn:
      if data.labels[elem] not in connDictAmygL.keys():
        connDictAmygL[data.labels[elem]]=1
      else:
        connDictAmygL[data.labels[elem]]+=1

    InsLConn=torch.nonzero(adjFull[28])[:,0]
    for elem in InsLConn:
      if data.labels[elem] not in connDictInsL.keys():
        connDictInsL[data.labels[elem]]=1
      else:
        connDictInsL[data.labels[elem]]+=1

  wandb.log({"TLoss": lossSubject.item()/15, "T_AP": APSubject/15, "T_F1": F1Subject/15, "T_AUROC":AUROCSubject/15})
  torch.cuda.empty_cache()

torch.save(connDictAmygL, 'connDictTestAmygL40.pth')
torch.save(connDictInsL, 'connDictTestInsL40.pth')

  checkpoint=torch.load('checkpoint9.pth')


  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')



Subject index : 0 out of 10
Sampling probability : 0.7398
Time index : 0 out of 15


  test403D=torch.load(os.path.join('DatasetSplit','test40'+str(i)+'.pth'))


tensor(0.3846, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 1, 40, 95], device='cuda:0')
Test / Subject : 0 / Timepoint : 0 / Loss : 0.2437719851732254 / AUROC : 0.9509726861858911 / F1 : 0.9484450873024073 / AP : 0.9582595581295974

Time index : 1 out of 15
tensor(0.3864, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Test / Subject : 0 / Timepoint : 1 / Loss : 0.24172593653202057 / AUROC : 0.9549911747401452 / F1 : 0.9528699045076496 / AP : 0.9617882185547824

Time index : 2 out of 15
tensor(0.3859, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Test / Subject : 0 / Timepoint : 2 / Loss : 0.2384541630744934 / AUROC : 0.957004160887656 / F1 : 0.9550724637681159 / AP : 0.9632705238184449

Time index : 3 out of 15
tensor(0.3857, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Test / Subject : 0 / Timepoint : 3 / Loss : 0.23961371183395386 / AUROC : 0.9555643721483833 / F1 : 0.95349802

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  test403D=torch.load(os.path.join('DatasetSplit','test40'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3843, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0, 40], device='cuda:0')
Test / Subject : 1 / Timepoint : 0 / Loss : 0.2418879270553589 / AUROC : 0.9549076773566569 / F1 : 0.9527783431711785 / AP : 0.9614972437754659

Time index : 1 out of 15
tensor(0.3842, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Test / Subject : 1 / Timepoint : 1 / Loss : 0.24273209273815155 / AUROC : 0.952857984678845 / F1 : 0.9505256648113791 / AP : 0.9597700234117441

Time index : 2 out of 15
tensor(0.3857, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Test / Subject : 1 / Timepoint : 2 / Loss : 0.24317775666713715 / AUROC : 0.955425219941349 / F1 : 0.9533456108041743 / AP : 0.9621613434497059

Time index : 3 out of 15
tensor(0.3852, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([0, 1], device='cuda:0')
Test / Subject : 1 / Timepoint : 3 / Loss : 0.2419394850730896 / 

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  test403D=torch.load(os.path.join('DatasetSplit','test40'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3890, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Test / Subject : 2 / Timepoint : 0 / Loss : 0.2366773635149002 / AUROC : 0.9636768694304534 / F1 : 0.9623077696250879 / AP : 0.969197077401411

Time index : 1 out of 15
tensor(0.3860, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([0, 3], device='cuda:0')
Test / Subject : 2 / Timepoint : 1 / Loss : 0.2354723960161209 / AUROC : 0.9616888193901485 / F1 : 0.9601626016260163 / AP : 0.9672108043846804

Time index : 2 out of 15
tensor(0.3789, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Test / Subject : 2 / Timepoint : 2 / Loss : 0.2415512502193451 / AUROC : 0.9563058589870903 / F1 : 0.9543094496365524 / AP : 0.962217129473617

Time index : 3 out of 15
tensor(0.3822, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Test / Subject : 2 / Timepoint : 3 / Loss : 0.24377073347568512

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  test403D=torch.load(os.path.join('DatasetSplit','test40'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3850, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  2, 40], device='cuda:0')
Test / Subject : 3 / Timepoint : 0 / Loss : 0.22696414589881897 / AUROC : 0.9678421158084551 / F1 : 0.9667736259186419 / AP : 0.9721366504222704

Time index : 1 out of 15
tensor(0.3861, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([0, 1], device='cuda:0')
Test / Subject : 3 / Timepoint : 1 / Loss : 0.2303195595741272 / AUROC : 0.9677876823338736 / F1 : 0.9667155118275068 / AP : 0.9722747299302976

Time index : 2 out of 15
tensor(0.3881, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 1, 40], device='cuda:0')
Test / Subject : 3 / Timepoint : 2 / Loss : 0.22941486537456512 / AUROC : 0.9666797334378675 / F1 : 0.9655312246553123 / AP : 0.9714461701355087

Time index : 3 out of 15
tensor(0.3883, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  3, 40], device='cuda:0')
Test / Subject : 3 / Timepoint : 3 / Loss : 0.2293027639389038 

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  test403D=torch.load(os.path.join('DatasetSplit','test40'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3888, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 1, 40], device='cuda:0')
Test / Subject : 4 / Timepoint : 0 / Loss : 0.23789207637310028 / AUROC : 0.9572859387274155 / F1 : 0.9553800389783568 / AP : 0.9637855894131975

Time index : 1 out of 15
tensor(0.3853, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([1], device='cuda:0')
Test / Subject : 4 / Timepoint : 1 / Loss : 0.244545578956604 / AUROC : 0.9532544378698224 / F1 : 0.9509621353196772 / AP : 0.9603480397560011

Time index : 2 out of 15
tensor(0.3820, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Test / Subject : 4 / Timepoint : 2 / Loss : 0.2386275678873062 / AUROC : 0.9543564356435643 / F1 : 0.9521734619773835 / AP : 0.9606272857015431

Time index : 3 out of 15
tensor(0.3883, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Test / Subject : 4 / Timepoint : 3 / Loss : 0.24052561819553375 / AU

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  test403D=torch.load(os.path.join('DatasetSplit','test40'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3892, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([0, 1], device='cuda:0')
Test / Subject : 5 / Timepoint : 0 / Loss : 0.24004600942134857 / AUROC : 0.9572498029944838 / F1 : 0.9553406050627701 / AP : 0.9639019964272147

Time index : 1 out of 15
tensor(0.3896, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([40], device='cuda:0')
Test / Subject : 5 / Timepoint : 1 / Loss : 0.23892849683761597 / AUROC : 0.9578517337584695 / F1 : 0.9559970872776449 / AP : 0.9643633087685207

Time index : 2 out of 15
tensor(0.3883, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([  0,   1,  40, 107], device='cuda:0')
Test / Subject : 5 / Timepoint : 2 / Loss : 0.24096301198005676 / AUROC : 0.9550173010380623 / F1 : 0.9528985507246377 / AP : 0.9619308096609549

Time index : 3 out of 15
tensor(0.3884, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([0, 1], device='cuda:0')
Test / Subject : 5 / Timepoint : 3 / Loss : 0.24196399748325348 

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  test403D=torch.load(os.path.join('DatasetSplit','test40'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3808, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Test / Subject : 6 / Timepoint : 0 / Loss : 0.23986676335334778 / AUROC : 0.9553294573643412 / F1 : 0.9532406937823309 / AP : 0.9614802133727126

Time index : 1 out of 15
tensor(0.3839, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Test / Subject : 6 / Timepoint : 1 / Loss : 0.24016691744327545 / AUROC : 0.9576337368215541 / F1 : 0.9557594291539245 / AP : 0.9637970498479611

Time index : 2 out of 15
tensor(0.3827, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([0, 1], device='cuda:0')
Test / Subject : 6 / Timepoint : 2 / Loss : 0.23482826352119446 / AUROC : 0.9619373776908023 / F1 : 0.9604312887803885 / AP : 0.9671210032971133

Time index : 3 out of 15
tensor(0.3838, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 1, 40], device='cuda:0')
Test / Subject : 6 / Timepoint : 3 / Loss : 0.2395030111074447

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  test403D=torch.load(os.path.join('DatasetSplit','test40'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3796, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Test / Subject : 7 / Timepoint : 0 / Loss : 0.24345709383487701 / AUROC : 0.9525348414461725 / F1 : 0.9501696352841391 / AP : 0.9590926713980921

Time index : 1 out of 15
tensor(0.3820, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Test / Subject : 7 / Timepoint : 1 / Loss : 0.242991641163826 / AUROC : 0.9527308838133068 / F1 : 0.9503856577027309 / AP : 0.9594912795448883

Time index : 2 out of 15
tensor(0.3776, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0, 40], device='cuda:0')
Test / Subject : 7 / Timepoint : 2 / Loss : 0.24260054528713226 / AUROC : 0.9529182879377431 / F1 : 0.9505920783993467 / AP : 0.9591816657024455

Time index : 3 out of 15
tensor(0.3823, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 1, 40], device='cuda:0')
Test / Subject : 7 / Timepoint : 3 / Loss : 0.2417932450771331

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  test403D=torch.load(os.path.join('DatasetSplit','test40'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3846, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Test / Subject : 8 / Timepoint : 0 / Loss : 0.24377255141735077 / AUROC : 0.9518648701425503 / F1 : 0.9494307108421377 / AP : 0.959023594031175

Time index : 1 out of 15
tensor(0.3814, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([0, 1], device='cuda:0')
Test / Subject : 8 / Timepoint : 1 / Loss : 0.24033187329769135 / AUROC : 0.9544628432956381 / F1 : 0.9522902782185549 / AP : 0.9607588092208591

Time index : 2 out of 15
tensor(0.3927, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Test / Subject : 8 / Timepoint : 2 / Loss : 0.249963641166687 / AUROC : 0.9513386750048857 / F1 : 0.9488496302382908 / AP : 0.9598070964821815

Time index : 3 out of 15
tensor(0.3954, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 1, 41], device='cuda:0')
Test / Subject : 8 / Timepoint : 3 / Loss : 0.24873854219913483 /

  checkpoint=torch.load('checkpoint9.pth', map_location='cuda')
  test403D=torch.load(os.path.join('DatasetSplit','test40'+str(i)+'.pth'))


Sampling probability : 0.7398
Time index : 0 out of 15
tensor(0.3868, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 1, 38, 40], device='cuda:0')
Test / Subject : 9 / Timepoint : 0 / Loss : 0.24745668470859528 / AUROC : 0.9503567181926278 / F1 : 0.9477635283077885 / AP : 0.9581478052950344

Time index : 1 out of 15
tensor(0.3871, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Test / Subject : 9 / Timepoint : 1 / Loss : 0.24714255332946777 / AUROC : 0.9498818432453722 / F1 : 0.9472374831553851 / AP : 0.957775995585798

Time index : 2 out of 15
tensor(0.3911, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Test / Subject : 9 / Timepoint : 2 / Loss : 0.2513614594936371 / AUROC : 0.9482254290171607 / F1 : 0.9453984575835476 / AP : 0.9571088429555002

Time index : 3 out of 15
tensor(0.3884, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([ 0,  1, 40], device='cuda:0')
Test / Subject : 9 / Timepoint : 3 / Loss : 0.24874988

In [None]:
# Calculating the recorded frequencies of all brain regions

connAmyg=torch.load('connDictTestAmygL40.pth')
connAmyg = dict(sorted(connAmyg.items(), key=lambda item: item[1], reverse=True))
print(connAmyg)

values=list(connAmyg.values())
values.remove(max(values))
print(torch.quantile(torch.tensor(list(values), dtype=float), 0.75))

connIns=torch.load('connDictTestInsL40.pth')
connIns = dict(sorted(connIns.items(), key=lambda item: item[1], reverse=True))
print(connIns)

values=list(connIns.values())
values.remove(max(values))
print(torch.quantile(torch.tensor(list(values), dtype=float), 0.75))

{'Amygdala_L': 121, 'Precentral_R': 118, 'Precentral_L': 107, 'Frontal_Sup_L': 22, 'Frontal_Sup_R': 5, 'Temporal_Pole_Sup_L': 3, 'Cerebelum_10_R': 3, 'Frontal_Mid_Orb_R': 2, 'Amygdala_R': 2, 'Olfactory_L': 2, 'Cerebelum_3_R': 1, 'Hippocampus_L': 1, 'Pallidum_L': 1, 'Caudate_R': 1, 'Frontal_Sup_Orb_L': 1, 'Pallidum_R': 1, 'Temporal_Pole_Mid_L': 1, 'Vermis_6': 1, 'ParaHippocampal_L': 1, 'Thalamus_L': 1, 'Cerebelum_3_L': 1}
tensor(3., dtype=torch.float64)
{'Insula_L': 139, 'Rolandic_Oper_L': 104, 'Insula_R': 99, 'Frontal_Mid_L': 96, 'Cingulum_Mid_R': 95, 'Supp_Motor_Area_L': 94, 'Postcentral_L': 94, 'Parietal_Inf_L': 94, 'Temporal_Sup_L': 94, 'Frontal_Inf_Tri_R': 88, 'Supp_Motor_Area_R': 88, 'Frontal_Inf_Oper_L': 85, 'Frontal_Inf_Tri_L': 85, 'Cingulum_Mid_L': 85, 'Calcarine_R': 85, 'SupraMarginal_L': 82, 'Frontal_Mid_R': 82, 'SupraMarginal_R': 81, 'Rolandic_Oper_R': 79, 'Parietal_Sup_L': 79, 'Frontal_Inf_Orb_L': 77, 'Cingulum_Ant_L': 77, 'Lingual_L': 75, 'Temporal_Sup_R': 74, 'Frontal_Sup

  connAmyg=torch.load('connDictTestAmygL40.pth')
  connIns=torch.load('connDictTestInsL40.pth')
