In [48]:
# Optional: eliminating warnings
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

from arguments import arg_parse
from cortex_DIM.nn_modules.mi_networks import MIFCNet, MI1x1ConvNet
from evaluate_embedding import evaluate_embedding
from gin import Encoder
from losses import local_global_loss_
from model import FF, PriorDiscriminator
from torch import optim
from torch.autograd import Variable
from torch_geometric.data import DataLoader
from torch_geometric.datasets import TUDataset
import json
import json
import numpy as np
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class arg_parse:
    def __init__(self) -> None:
        self.DS = 'KKI'
        self.local = False
        self.glob = False
        self.prior = False
        self.lr = 0.001
        self.num_gc_layers = 2
        self.hidden_dim = 32

args = arg_parse()

In [3]:
class InfoGraph(nn.Module):
  def __init__(self, hidden_dim, num_gc_layers, alpha=0.5, beta=1., gamma=.1):
    super(InfoGraph, self).__init__()

    self.alpha = alpha
    self.beta = beta
    self.gamma = gamma
    self.prior = args.prior

    self.embedding_dim = mi_units = hidden_dim * num_gc_layers
    self.encoder = Encoder(dataset_num_features, hidden_dim, num_gc_layers)

    self.local_d = FF(self.embedding_dim)
    self.global_d = FF(self.embedding_dim)
    # self.local_d = MI1x1ConvNet(self.embedding_dim, mi_units)
    # self.global_d = MIFCNet(self.embedding_dim, mi_units)

    if self.prior:
        self.prior_d = PriorDiscriminator(self.embedding_dim)

    self.init_emb()

  def init_emb(self):
    initrange = -1.5 / self.embedding_dim
    for m in self.modules():
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

  def forward(self, x, edge_index, batch, num_graphs):
    # batch_size = data.num_graphs
    if x is None:
        x = torch.ones(batch.shape[0]).to(device)

    y, M = self.encoder(x, edge_index, batch)

    g_enc = self.global_d(y)
    l_enc = self.local_d(M)

    mode='fd'
    measure='JSD'
    local_global_loss = local_global_loss_(l_enc, g_enc, edge_index, batch, measure)

    if self.prior:
        prior = torch.rand_like(y)
        term_a = torch.log(self.prior_d(prior)).mean()
        term_b = torch.log(1.0 - self.prior_d(y)).mean()
        PRIOR = - (term_a + term_b) * self.gamma
    else:
        PRIOR = 0

    return local_global_loss + PRIOR


In [14]:
from torch_geometric.data import InMemoryDataset

class KKIData(InMemoryDataset):
    def __init__(self, root='../data/KKI/KKI/', transform= None, pre_transform=None, pre_filter = None):
        super().__init__(root, transform, pre_transform, pre_filter)
        out = torch.load(self.processed_paths[0])
        self.data, self.slices, self.sizes = out

    @property
    def raw_file_names(self):
        return ['KKI_A.txt']

    @property
    def processed_file_names(self):
        return ['data.pt']

dataset = KKIData()

In [57]:
import torch
from torch_geometric.data import Data, InMemoryDataset
import pandas as pd


class MyData(InMemoryDataset):
    def __init__(self, root='../mydata/', transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        x = torch.tensor(
            [[0], [0], [0], [1], [1], [1], [1]], dtype=torch.float32
        )
        edge_index = torch.tensor([
            [0, 1, 0, 3, 4, 5],
            [1, 2, 2, 4, 5, 6]
        ], dtype=torch.int64)
        # y = torch.tensor([1,1], dtype=torch.int64)
        self.data = Data(x=x, edge_index=edge_index)
        self.slices = {
            'x': x, 'edge_index': edge_index
        }
mydata = MyData()

In [76]:
myloader = DataLoader(mydata, batch_size=2)
for data in mydata:
    print(data)

ValueError: only one element tensors can be converted to Python scalars

In [75]:
dataloader =  DataLoader(dataset, batch_size=128)
for data in dataloader:
    print(data)

DataBatch(edge_index=[2, 8038], x=[2238, 190], y=[83], batch=[2238], ptr=[84])


In [26]:
# accuracies = {'logreg': [], 'svc': [], 'linearsvc': [], 'randomforest': []}
epochs = 2
log_interval = 1
batch_size = 128
lr = args.lr
DS = args.DS
path = os.path.join(os.getcwd(), '..', 'data', DS)
# kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=None)

# datapath = path+'/KKI/processed/data.pt'
# dataset = torch.load(datapath)[0]
# dataset = TUDataset(path, name=DS).shuffle()
dataset_num_features = max(dataset.num_features, 1)
dataloader = DataLoader(dataset, batch_size=batch_size)

In [21]:
print('================')
print('lr: {}'.format(lr))
print('num_features: {}'.format(dataset_num_features))
print('hidden_dim: {}'.format(args.hidden_dim))
print('num_gc_layers: {}'.format(args.num_gc_layers))
print('================')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = InfoGraph(args.hidden_dim, args.num_gc_layers).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

model.eval()
emb = model.encoder.get_embeddings(dataloader)
print('===== Before training =====')
# res = evaluate_embedding(emb)
# accuracies['logreg'].append(res[0])
# accuracies['svc'].append(res[1])
# accuracies['linearsvc'].append(res[2])
# accuracies['randomforest'].append(res[3])
# print(accuracies)
print(f"embedding length:\n{emb}")

lr: 0.001
num_features: 190
hidden_dim: 32
num_gc_layers: 2
===== Before training =====
embedding length:
[[3.17232490e+00 9.59328473e-01 1.80678892e+00 ... 3.30266058e-01
  5.39329338e+00 1.51378889e+01]
 [2.03503036e+00 7.26320297e-02 2.29423451e+00 ... 2.40086228e-01
  1.54712951e+00 5.59823322e+00]
 [1.56246972e+00 0.00000000e+00 4.64691455e-03 ... 0.00000000e+00
  2.67243886e+00 5.49802876e+00]
 ...
 [3.01488191e-01 7.46013736e-03 1.05110705e-02 ... 0.00000000e+00
  2.12603331e+00 2.15811729e-01]
 [1.23019493e+00 3.08957040e-01 6.29850149e-01 ... 0.00000000e+00
  7.20523894e-01 6.69880137e-02]
 [3.18962075e-02 5.05512990e-02 4.75434586e-02 ... 0.00000000e+00
  2.47058123e-01 0.00000000e+00]]


In [22]:
# unsupervised train
for epoch in range(1, epochs+1):
    loss_all = 0
    model.train()
    for data in dataloader:
        data = data.to(device)
        optimizer.zero_grad()
        loss = model(data.x, data.edge_index, data.batch, data.num_graphs)
        loss_all += loss.item() * data.num_graphs
        loss.backward()
        optimizer.step()
    print('===== Epoch {}, Loss {} ====='.format(
        epoch, loss_all / len(dataloader)))

    if epoch % log_interval == 0:
        model.eval()
        emb = model.encoder.get_embeddings(dataloader)
        print(f"embedding shape:\n{len(emb)}x{len(emb[0])}")

        # res = evaluate_embedding(emb, y)
        # accuracies['logreg'].append(res[0])
        # accuracies['svc'].append(res[1])
        # accuracies['linearsvc'].append(res[2])
        # accuracies['randomforest'].append(res[3])
        # print(accuracies)



===== Epoch 1, Loss 6082.568550109863 =====
embedding shape:
83x64
===== Epoch 2, Loss 4675.409149169922 =====
embedding shape:
83x64
