In [1]:

import os
import sys
import re
import random
import time
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import SequentialSampler

from torch_geometric.loader import DataLoader
from torch_geometric.data import Data, Dataset, Batch
from torch_geometric.nn import GCNConv, GINConv, TopKPooling
from torch_geometric.nn import (
    global_mean_pool as gap,
    global_max_pool as gmp,
)
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, mean_squared_error as MSE

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
GPU_NUM = 3
torch.cuda.set_device(GPU_NUM)

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(
        f"90server: {torch.cuda.get_device_name()} - cuda({torch.cuda.current_device()}) v{torch.version.cuda} is available"
    )
    print(f"Torch version: {torch.__version__}")
    print(f"Count of using GPUs:", torch.cuda.device_count())
    print()
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
else:
    print("Can not use GPU device!")
    device = torch.device("cpu")

90server: TITAN Xp - cuda(3) v10.2 is available
Torch version: 1.9.1+cu102
Count of using GPUs: 8



In [3]:
PATH        = "/nasdata4/pei4/feature_selection_gnn/experiments"
suffix      = 'Y'
saving_path = PATH + "/experiment_" + suffix

In [4]:
# data_path = '/nasdata4/pei4/feature_selection_gnn/sample_data/'

# # 데이터 파일 읽어오기
# subjects_list = open(data_path + 'subjects_list.txt', 'r')
# subjects_list = subjects_list.read()
# subjects_list = subjects_list.split('\n')

# subjects_diag = open(data_path + 'subjects_diagnosis.txt', 'r')
# subjects_diag = subjects_diag.read()
# subjects_diag = subjects_diag.split('\n')

# snp_Additive    = pd.read_csv(data_path + 'Data01_SNPs_Additive.csv')
# snp_GENEgroups  = pd.read_csv(data_path + 'Data02_SNPs_GENEgroups.csv')
# subcor_volumes  = pd.read_csv(data_path + 'Data04_Subcortical_volumes_ICVcorrected.csv')
# normal_subcor_volumes  = pd.read_csv(data_path + 'Data04_Subcortical_volumes_ICVcorrected_Normalized.csv')

# # 데이터 합치기
# smaple_data_features = []
# for sbj, label in zip(subjects_list, subjects_diag):
#     edge_feature5       = np.loadtxt(f'{data_path}edge_index5/{sbj}_edge_index.txt', dtype=int)
#     edge_feature10      = np.loadtxt(f'{data_path}edge_index10/{sbj}_edge_index.txt', dtype=int)
#     edge_feature15      = np.loadtxt(f'{data_path}edge_index15/{sbj}_edge_index.txt', dtype=int)
#     edge_feature20      = np.loadtxt(f'{data_path}edge_index20/{sbj}_edge_index.txt', dtype=int)
#     node_feature_pcc    = np.loadtxt(f'{data_path}x_partialCC/pcc_ROISignals_{sbj}.txt', delimiter=",")
#     node_feature_bold   = np.loadtxt(f'{data_path}x_roi_signal/ROISignals_{sbj}.txt')

#     graph_label = int(label[-1])

#     temp = snp_Additive.loc[snp_Additive['IID'] == sbj]
#     snp_data = temp.get(temp.columns[6:]).to_numpy()    # 3001개의 SNP값만 쏙 읽어오기

#     temp = subcor_volumes.loc[subcor_volumes['IID'] == sbj]
#     subcor_data = temp.get(temp.columns[1:]).to_numpy(dtype=float)  # 8개의 볼륨값만 쏙 읽어오기

#     temp = normal_subcor_volumes.loc[normal_subcor_volumes['IID'] == sbj]
#     normal_subcor_data = temp.get(temp.columns[1:]).to_numpy(dtype=float)  # 8개의 볼륨값만 쏙 읽어오기

#     smaple_data_features.append([sbj,                   # [0]
#                                  node_feature_pcc,      # [1]
#                                  node_feature_bold,     # [2]
#                                  edge_feature5,         # [3]
#                                  edge_feature10,        # [4]
#                                  edge_feature15,        # [5]
#                                  edge_feature20,        # [6]
#                                  graph_label,           # [7]
#                                  snp_data,              # [8]
#                                  subcor_data,           # [9]
#                                  normal_subcor_data     # [10]
#                                  ])

# # 데이터 저장
# smaple_data_features = np.asarray(smaple_data_features, dtype=object)
# np.save(data_path + 'smaple_data_features.npy', smaple_data_features)

# Define functions and classes

In [60]:
class my_dataset(Dataset):
    def __init__(self, data_path, node_feature_type="pcor", edge_sparse_type=5, snp_data_type="real", qt_data_type='norm', num_fold=5):
        super(my_dataset, self).__init__()
        self.node_feature_type  = node_feature_type
        self.edge_sparse_type   = edge_sparse_type
        self.snp_data_type      = snp_data_type
        self.qt_data_type       = qt_data_type

        snp_Additive    = pd.read_csv(data_path + 'Data01_SNPs_Additive.csv')   # snp additive 정보 읽어오기
        snp_GENEgroups  = pd.read_csv(data_path + 'Data02_SNPs_GENEgroups.csv') # snp gene group 정보 읽어오기
        add_snp_list    = snp_Additive.columns[6:].to_list()                    # csv파일에서 FID, IID, PAT, MAT, SEX, PHENOTYPE 제외한 SNP정보만 추출
        gene_group      = dict(zip(list(snp_GENEgroups.SNPID), list(snp_GENEgroups.GENEgroup))) # {SNP명 : group 번호} 생성
        # gene_group_list = list(snp_GENEgroups.SNPID)    # group 번호 정보를 담은 list data 생성

        self.snp_group = []
        for snp_name in add_snp_list:
            self.snp_group.append(gene_group[snp_name[:-2]])    # SNP명을 key로 받는 dictionary를 이용해서 add_snp_list의 SNP명 순서로 group 번호 할당
        self.snp_group_data = torch.tensor(self.snp_group, dtype=torch.long)

        self.my_Data = []
        self.num_sbj    = 157
        self.num_QT     = 8
        self.num_snp    = 3001

        smaple_data_features = np.load(data_path + 'smaple_data_features.npy', allow_pickle=True)
        for subject_data in smaple_data_features:
            if node_feature_type == "pcor":
                node_feature = subject_data[1]
            else:   # "bold"
                node_feature = subject_data[2]

            if edge_sparse_type == 5:
                edge_idx = subject_data[3]
            elif edge_sparse_type == 10:
                edge_idx = subject_data[4]
            elif edge_sparse_type == 15:
                edge_idx = subject_data[5]
            else:
                edge_idx = subject_data[6]

            graph_label = subject_data[7]

            if self.snp_data_type == "real":
                snp_data = subject_data[8]
            else:
                snp_data = torch.randint(0, 3, size=(157, 3001)).to(device=device)
                torch.save(snp_data, data_path + "/fake_snp_data.pt")

            if self.qt_data_type == "raw":
                t1_measure  = subject_data[9]
            else:
                t1_measure  = subject_data[10]

            # numpy to tensor
            node_feature    = torch.Tensor(node_feature).transpose(0, 1)
            graph_label     = torch.tensor(graph_label, dtype=torch.long)
            edge_flag       = generate_edge_flag(116, edge_idx)
            edge_idx        = torch.Tensor(edge_idx).transpose(0, 1)
            # edge_flag       = generate_edge_flag(116, edge_idx)
            edge_flag       = torch.Tensor(edge_flag).to(device=device)
            edge_attr       = torch.randint(0, 3, size=(1, 672)).to(device=device)
            edge_attr       = edge_attr.squeeze(0)
            # print("edge_idx: ", edge_idx.shape)
            # print("edge_attr: ", edge_attr.shape)
            fc_data         = Data(x=node_feature,
                                   edge_index=edge_idx.long(),
                                   y=graph_label,
                                   edge_attr=edge_attr,
                                   edge_flag=edge_flag.long())

            snp_data    = torch.Tensor(snp_data).to(device=device)
            t1_measure  = torch.Tensor(t1_measure).to(device=device)
            snp_data    = torch.squeeze(snp_data)
            t1_measure  = torch.squeeze(t1_measure)

            self.my_Data.append([fc_data, snp_data, t1_measure])

        rand_state = 42
        self.num_fold = num_fold
        skf = StratifiedKFold(n_splits=self.num_fold, shuffle=True, random_state=rand_state)

        y_idx = []
        for i in range(self.num_sbj):
            y_idx.append(self.my_Data[i][0].y.item())

        self.train_dataset  = []
        self.test_dataset   = []
        self.y_train = []
        for train_idx, test_idx in skf.split(range(self.num_sbj), y_idx):
            for i in train_idx:
                self.train_dataset.append([self.my_Data[i][0], self.my_Data[i][1], self.my_Data[i][2]])
                self.y_train.append(self.my_Data[i][0].y.item())
            for j in test_idx:
                self.test_dataset.append([self.my_Data[j][0], self.my_Data[j][1], self.my_Data[j][2]])
            break

    # 인덱스에 해당되는 데이터를 tensor 형태로 반환
    def __getitem__(self, idx):
        fc_map = self.my_Data[idx][0]
        minor_allele_cnt = self.my_Data[idx][1]
        t1_meausre = self.my_Data[idx][2]

        return fc_map, minor_allele_cnt, t1_meausre

    # 데이터 총 개수 반환
    def __len__(self):
        return len(self.num_sbj)

In [69]:
from torch import Tensor
from torch_geometric.nn import MessagePassing
from torch_sparse import SparseTensor
from torch_geometric.typing import Adj, Size

import numpy as np
import torch
from torch.nn import Parameter, Linear
from torch.nn import functional as F
from torch_scatter import scatter_add
from torch_sparse import SparseTensor, set_diag
from torch_geometric.nn import global_add_pool, global_mean_pool
from torch_sparse import SparseTensor, matmul, fill_diag, sum as sparsesum, mul
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
from torch_geometric.utils.num_nodes import maybe_num_nodes
import argparse
import sys

from typing import Tuple, Optional
from torch import Tensor
from sklearn import metrics
import torch.nn.functional as F
from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedKFold
from torch_geometric.data import DataLoader
import nni
import os
import random
from typing import List

def maybe_num_nodes(edge_index, num_nodes=None):
    if num_nodes is not None:
        return num_nodes
    elif isinstance(edge_index, Tensor):
        return int(edge_index.max()) + 1
    else:
        return max(edge_index.size(0), edge_index.size(1))

def _remove_self_loops(edge_index, edge_attr: torch.Tensor, edge_flags: torch.Tensor):
    mask = edge_index[0] != edge_index[1]
    edge_index = edge_index[:, mask]
    return edge_index, edge_attr[mask], edge_flags[mask]

def _add_self_loops(edge_index, edge_weight: Optional[torch.Tensor] = None,
                    edge_flags: Optional[torch.Tensor] = None,
                    fill_value: float = 1., num_nodes: Optional[int] = None):
    N = maybe_num_nodes(edge_index, num_nodes)

    loop_index = torch.arange(0, N, dtype=torch.long, device=edge_index.device)
    loop_index = loop_index.unsqueeze(0).repeat(2, 1)

    if edge_weight is not None:
        assert edge_weight.numel() == edge_index.size(1)
        loop_weight = edge_weight.new_full((N,), fill_value)
        edge_weight = torch.cat([edge_weight, loop_weight], dim=0)
    if edge_flags is not None:
        assert edge_flags.numel() == edge_index.size(1)
        loop_weight = edge_flags.new_full((N,), fill_value)
        edge_flags = torch.cat([edge_flags, loop_weight], dim=0)

    edge_index = torch.cat([edge_index, loop_index], dim=1)

    return edge_index, edge_weight, edge_flags

def generate_edge_flag(num_nodes, edge_index):
    
    edge_flag = np.full((num_nodes**2, ), False)
    # print(edge_index.shape)
    # print(edge_index)
    for i in range(edge_index.shape[1]):
        source = edge_index[0][i]
        # print("==>> source: ", source)
        target = edge_index[1][i]
        # print("==>> target: ", target)
        new_index = source * num_nodes + target
        # print("==>> new_index: ", new_index.shape)
        edge_flag[int(new_index)] = True
    # print(edge_flag.shape)
    return edge_flag

class ModifiedMessagePassing(MessagePassing):
    def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor:
        raise NotImplementedError

    def propagate(self, edge_index: Adj, pruned_edge_mask, size: Size = None, **kwargs):
        size = self.__check_input__(edge_index, size)

        # Run "fused" message and aggregation (if applicable).
        if (isinstance(edge_index, SparseTensor) and self.fuse
                and not self.__explain__):
            coll_dict = self.__collect__(self.__fused_user_args__, edge_index,
                                         size, kwargs)

            msg_aggr_kwargs = self.inspector.distribute(
                'message_and_aggregate', coll_dict)
            out = self.message_and_aggregate(edge_index, **msg_aggr_kwargs)

            update_kwargs = self.inspector.distribute('update', coll_dict)
            return self.update(out, **update_kwargs)

        # Otherwise, run both functions in separation.
        elif isinstance(edge_index, Tensor) or not self.fuse:
            coll_dict = self.__collect__(self.__user_args__, edge_index, size,
                                         kwargs)
            msg_kwargs = self.inspector.distribute('message', coll_dict)
            out = self.message(**msg_kwargs)

            # For `GNNExplainer`, we require a separate message and aggregate
            # procedure since this allows us to inject the `edge_mask` into the
            # message passing computation scheme.
            if self.__explain__:
                # edge_mask = self.__edge_mask__.sigmoid()
                assert out.size(self.node_dim) == pruned_edge_mask.size(0)
                out = out * pruned_edge_mask.view([-1] + [1] * (out.dim() - 1))

            aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
            out = self.aggregate(out, **aggr_kwargs)

            update_kwargs = self.inspector.distribute('update', coll_dict)
            return self.update(out, **update_kwargs)


In [70]:
def gcn_norm(edge_index, edge_flag, edge_attr=None, num_nodes=None, improved=False,
             do_add_self_loops=True, dtype=None):
    fill_value = 2. if improved else 1.

    if isinstance(edge_index, SparseTensor):
        adj_t = edge_index
        if not adj_t.has_value():
            adj_t = adj_t.fill_value(1., dtype=dtype)
        if do_add_self_loops:
            adj_t = fill_diag(adj_t, fill_value)
        deg = sparsesum(adj_t, dim=1)
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
        adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1))
        adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1))
        return adj_t
    else:
        num_nodes = maybe_num_nodes(edge_index, num_nodes)
        if edge_attr is None:
            edge_attr = torch.ones((edge_index.size(1),), dtype=dtype,
                                   device=edge_index.device)
        if do_add_self_loops:
            if isinstance(edge_index, Tensor):
                if isinstance(edge_flag, Tensor):
                    edge_index, edge_attr, edge_flag = _remove_self_loops(edge_index, edge_attr, edge_flag)
                    edge_index, edge_attr, edge_flag = _add_self_loops(edge_index, edge_attr, edge_flag,
                                                                       num_nodes=num_nodes)
                else:
                    edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
                    edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=num_nodes)
            elif isinstance(edge_index, SparseTensor):
                edge_index = set_diag(edge_index)

        row, col = edge_index[0], edge_index[1]
        deg = scatter_add(edge_attr, col, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
        return edge_index, deg_inv_sqrt[row] * edge_attr * deg_inv_sqrt[col], edge_flag


class MPConv(ModifiedMessagePassing):
    def __init__(self, in_channels, out_channels, improved: bool = False, cached: bool = False,
                 add_self_loops: bool = True, normalize: bool = True, bias: bool = True):
        super(MPConv, self).__init__(aggr='add')

        self.improved = improved
        self.cached = cached
        self.add_self_loops = add_self_loops
        self.normalize = normalize
        self._cached_edge_index = None
        self._cached_adj_t = None
        self.__explain__ = False

        self.weight = Parameter(torch.Tensor(in_channels, out_channels))

        self.lin = torch.nn.Linear(out_channels*2 + 1, out_channels)

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        zeros(self.bias)
        self._cached_edge_index = None
        self._cached_adj_t = None

    def forward(self, x, edge_index, edge_attr, edge_flag):
        if self.normalize:
            if isinstance(edge_index, Tensor):
                cache = self._cached_edge_index
                if cache is None:
                    edge_index, edge_weight, edge_flag = gcn_norm(  # yapf: disable
                        edge_index, edge_flag, edge_attr, x.size(self.node_dim),
                        self.improved, self.add_self_loops)
                    if self.cached:
                        self._cached_edge_index = (edge_index, edge_weight, edge_flag)
                else:
                    edge_index, edge_weight, edge_flag = cache[0], cache[1], cache[2]

            elif isinstance(edge_index, SparseTensor):
                cache = self._cached_adj_t
                if cache is None:
                    edge_index = gcn_norm(  # yapf: disable
                        edge_index, edge_attr, x.size(self.node_dim),
                        self.improved, self.add_self_loops)
                    if self.cached:
                        self._cached_adj_t = edge_index
                else:
                    edge_index = cache

        x = x @ self.weight

        # propagate_type: (x: Tensor, edge_weight: OptTensor)
        out = self.propagate(edge_index, edge_flag, x=x, edge_attr=edge_weight)

        if self.bias is not None:
            out += self.bias

        return out

    def message(self, x_i, x_j, edge_attr):
        msg = torch.cat([x_i, x_j, edge_attr.view(-1, 1)], dim=1)
        return self.lin(msg)


class IBGConv(torch.nn.Module):
    def __init__(self, input_dim, num_classes):
        super(IBGConv, self).__init__()
        self.activation = torch.nn.ReLU()
        self.convs = torch.nn.ModuleList()

        hidden_dim = 16
        num_layers = 2
        self.pooling = 'sum'

        for i in range(num_layers):
            if i == 0:
                conv = MPConv(input_dim, hidden_dim)
            elif i != num_layers - 1:
                conv = MPConv(hidden_dim, hidden_dim)
            else:
                conv = MPConv(hidden_dim, num_classes)
            self.convs.append(conv)

    def forward(self, x, edge_index, edge_attr, edge_flag, batch):
        z = x
        edge_attr[edge_attr < 0] = - edge_attr[edge_attr < 0]
        for i, conv in enumerate(self.convs):
            z = conv(z, edge_index, edge_attr, edge_flag)
            if i != len(self.convs) - 1:
                z = F.relu(z)  # [N * M, F]
                z = F.dropout(z, training=self.training)
            if self.pooling == 'sum':
                g = global_add_pool(z, batch)  # [N, F]
            elif self.pooling == 'mean':
                g = global_mean_pool(z, batch)  # [N, F]
            else:
                raise NotImplementedError('Pooling method not implemented')

        return F.log_softmax(g, dim=-1)


class MLP(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, activation, n_classes=0):
        super(MLP, self).__init__()
        self.net = []
        self.net.append(torch.nn.Linear(input_dim, hidden_dim))
        self.net.append(activation())
        for _ in range(num_layers - 1):
            self.net.append(torch.nn.Linear(hidden_dim, hidden_dim))
            self.net.append(activation())
        self.net = torch.nn.Sequential(*self.net)
        self.shortcut = torch.nn.Linear(input_dim, hidden_dim)

        if n_classes != 0:
            self.classifier = torch.nn.Linear(hidden_dim, n_classes)

    def forward(self, x):
        out = self.net(x) + self.shortcut(x)
        if hasattr(self, 'classifier'):
            return out, self.classifier(out)
        return out


class IBGNN(torch.nn.Module):
    def __init__(self, gnn, mlp, discriminator=lambda x, y: x @ y.t(), pooling='concat'):
        super(IBGNN, self).__init__()
        self.gnn = gnn
        self.mlp = mlp
        self.pooling = pooling
        self.discriminator = discriminator

    def forward(self, data):
        x, edge_index, edge_attr, batch, edge_flag = data.x, data.edge_index, data.edge_attr, data.batch, data.edge_flag
        g = self.gnn(x, edge_index, edge_attr, edge_flag, batch)
        if self.pooling == 'concat':
            _, g = self.mlp(g)
            log_logits = F.log_softmax(g, dim=-1)
            return log_logits
        return g


def build_model(device, num_features):
    model = IBGNN(IBGConv(num_features, num_classes=2),
                  MLP(16, 16, 1, torch.nn.ReLU, n_classes=2),
                  pooling='sum').to(device)

    return model


In [71]:
node_feature_type   = 'bold'
edge_sparse_type    = 5
snp_data_type       = 'real'
QT_type             = 'norm'
data_path           = '/nasdata4/pei4/feature_selection_gnn/sample_data/'

dataset = my_dataset(
    data_path, node_feature_type=node_feature_type, edge_sparse_type=edge_sparse_type, snp_data_type=snp_data_type
)
num_features    = dataset.train_dataset[0][0].x.shape[1]
# define hyper parameter
learning_rate   = 0.0001
epochs          = 1
batch_size      = 1

train_loader    = DataLoader(dataset.train_dataset, batch_size=batch_size, drop_last=True)
net             = build_model(device, num_features)
criterion       = nn.MSELoss()
optimizer       = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=0.001)




In [72]:
temp = dataset.train_dataset[0][0].edge_flag
print(temp.shape)
# for i in temp:
#     print(i)

torch.Size([13456])


In [73]:
for epoch in range(epochs):
    ###################
    # train the model #
    ###################
    net.train()
    for i, data in enumerate(train_loader, 0):
        print(type(data))
        print(len(data))
        train_output = net(data[0])
        train_loss = criterion(train_output, data)
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
    # print message
    epoch_len = len(str(epochs))
    print_msg = (f"[{epoch:>{epoch_len}}/{epochs:>{epoch_len}}] {train_loss}")
    print(print_msg)

print("Finished Training")
print()

<class 'list'>
3


IndexError: The shape of the mask [672] at index 0 does not match the shape of the indexed tensor [13456] at index 0

In [None]:
edge_flag = torch.rand(size=(16, 1, 5))
edge_mask = torch.rand(size=(1, 5))
edge_mask = edge_mask.squeeze(0)
print(edge_flag)
print(edge_mask)
print()
catted_edge_mask = torch.cat(len(edge_flag) * [edge_mask]) # -> edge flag 에 동일하게 마스크 씌워주기 위해서 edge
print("==>> catted_edge_mask: ", catted_edge_mask.shape)
print(catted_edge_mask)
if (len(edge_flag) != 1):
    edge_flag = [i[0] for i in edge_flag]
    # print("==>> edge_flag: ", edge_flag)
    print("==>> iter edge_flag: ", len(edge_flag))
    print(edge_flag)
    edge_flag = np.concatenate(edge_flag)
    print("==>> concat edge_flag: ", edge_flag.shape)
    print(edge_flag)
pruned_edge_mask = catted_edge_mask[edge_flag]
print("==>> pruned_edge_mask: ", pruned_edge_mask.shape)
print(pruned_edge_mask)
print(catted_edge_mask)

tensor([[[0.8823, 0.9150, 0.3829, 0.9593, 0.3904]],

        [[0.6009, 0.2566, 0.7936, 0.9408, 0.1332]],

        [[0.9346, 0.5936, 0.8694, 0.5677, 0.7411]],

        [[0.4294, 0.8854, 0.5739, 0.2666, 0.6274]],

        [[0.2696, 0.4414, 0.2969, 0.8317, 0.1053]],

        [[0.2695, 0.3588, 0.1994, 0.5472, 0.0062]],

        [[0.9516, 0.0753, 0.8860, 0.5832, 0.3376]],

        [[0.8090, 0.5779, 0.9040, 0.5547, 0.3423]],

        [[0.6343, 0.3644, 0.7104, 0.9464, 0.7890]],

        [[0.2814, 0.7886, 0.5895, 0.7539, 0.1952]],

        [[0.0050, 0.3068, 0.1165, 0.9103, 0.6440]],

        [[0.7071, 0.6581, 0.4913, 0.8913, 0.1447]],

        [[0.5315, 0.1587, 0.6542, 0.3278, 0.6532]],

        [[0.3958, 0.9147, 0.2036, 0.2018, 0.2018]],

        [[0.9497, 0.6666, 0.9811, 0.0874, 0.0041]],

        [[0.1088, 0.1637, 0.7025, 0.6790, 0.9155]]])
tensor([0.2418, 0.1591, 0.7653, 0.2979, 0.8035])

==>> catted_edge_mask:  torch.Size([80])
tensor([0.2418, 0.1591, 0.7653, 0.2979, 0.8035, 0.2418, 0.159

In [20]:
M1 = np.asarray([[0.7, 0.2, 0.1],
                [0.2, 0.7, 0.1],
                [0.1, 0.2, 0.7]])

M2 = np.asarray([[0.8, 0.1, 0.1],
                [0.1, 0.8, 0.1],
                [0.1, 0.1, 0.8]])

M3 = np.asarray([[1, 0, 0],
                [0, 1, 0],
                [0, 0, 1]])

def loss_test(M):
    EPS = 1E-10
    term_1 = -1 * M * np.log10(M + EPS)
    term_2 = (1-M) * np.log10(1-M + EPS)
    print("==>> term_1: \n", term_1)
    print()
    print("==>> term_2: \n", term_2)
    print()
    print("loss : \n", term_1 + term_2)
    print()

loss_test(M1)
loss_test(M2)
loss_test(M3)


==>> term_1: 
 [[0.10843137 0.139794   0.1       ]
 [0.139794   0.10843137 0.1       ]
 [0.1        0.139794   0.10843137]]

==>> term_2: 
 [[-0.15686362 -0.07752801 -0.04118174]
 [-0.07752801 -0.15686362 -0.04118174]
 [-0.04118174 -0.07752801 -0.15686362]]

loss : 
 [[-0.04843225  0.06226599  0.05881826]
 [ 0.06226599 -0.04843225  0.05881826]
 [ 0.05881826  0.06226599 -0.04843225]]

==>> term_1: 
 [[0.07752801 0.1        0.1       ]
 [0.1        0.07752801 0.1       ]
 [0.1        0.1        0.07752801]]

==>> term_2: 
 [[-0.139794   -0.04118174 -0.04118174]
 [-0.04118174 -0.139794   -0.04118174]
 [-0.04118174 -0.04118174 -0.139794  ]]

loss : 
 [[-0.06226599  0.05881826  0.05881826]
 [ 0.05881826 -0.06226599  0.05881826]
 [ 0.05881826  0.05881826 -0.06226599]]

==>> term_1: 
 [[-4.34294518e-11 -0.00000000e+00 -0.00000000e+00]
 [-0.00000000e+00 -4.34294518e-11 -0.00000000e+00]
 [-0.00000000e+00 -0.00000000e+00 -4.34294518e-11]]

==>> term_2: 
 [[-0.00000000e+00  4.34294518e-11  4.3429

In [None]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

M1 = np.asarray([[0.7, 0.2, 0.1],
                [0.2, 0.7, 0.1],
                [0.1, 0.2, 0.7]])

for i in range(0, 100):
    M1 = sigmoid(M1)

print(M1)

In [75]:
torch.ones(100).shape

torch.Size([100])