In [1]:
import spektral
from spektral.transforms import GCNFilter
from skimage.transform import resize
import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.sparse import coo_matrix
from IPython.display import display, clear_output
import wandb

In [2]:
numdat = 10000
adddata = 40000
    
with open('se_s.bin','rb') as se:
    st_ene = np.fromfile(se,dtype='float32',count=-1).reshape(32,32*50000).reshape(-1,32,32*50000,1).transpose([0,2,1,3]).reshape(-1,32,32,1)

with open('vms_s.bin','rb') as vms:
    von_st = np.fromfile(vms,dtype='float32',count=-1).reshape(32,32*50000).reshape(-1,32,32*50000,1).transpose([0,2,1,3]).reshape(-1,32,32,1)

with open('/root/Non-iterative/ver.6_to128/data/vol_10k_128test.bin','rb') as vol:
    v = np.fromfile(vol,dtype='float32',count=-1)
    v = v.reshape(numdat,1,1,1)

with open('/root/Non-iterative/ver.6_to128/data/vol_40k_128.bin','rb') as vol2:
    v2 = np.fromfile(vol2,dtype='float32',count=-1)
    v2 = v2.reshape(adddata,1,1,1)

with open('/root/Non-iterative/ver.6_to128/data/opt_50k_32.bin','rb') as opt32:
    optimal32 = np.fromfile(opt32,dtype='float32',count=-1).reshape(32,32*(numdat+adddata)).reshape(-1,32,32*(numdat+adddata),1).transpose([0,2,1,3]).reshape(-1,32,32,1)

v = np.concatenate((v,v2),axis=0)
v = np.ones_like(st_ene)*v

In [6]:
st_ene_flat = np.reshape(st_ene,(st_ene.shape[0],32*32,1))
von_st_flat = np.reshape(von_st,(von_st.shape[0],32*32,1))
v_flat = np.reshape(v,(v.shape[0],32*32,1))

In [7]:
print(st_ene_flat.shape,von_st_flat.shape,v_flat.shape,optimal32.shape)

(50000, 1024, 1) (50000, 1024, 1) (50000, 1024, 1) (50000, 32, 32, 1)


In [8]:
node_features = np.concatenate((st_ene_flat,von_st_flat,v_flat),axis=2)
node_label = np.reshape(optimal32,(optimal32.shape[0],32*32,1))

In [9]:
with open('adj_32_32.bin','rb') as adj:
    A = np.fromfile(adj,dtype='float32',count=-1).reshape(1024,1024)
    A = coo_matrix(A)

In [11]:
class TOP_GDATA(spektral.data.Dataset):
    def __init__(self, nf, nl, adj, **kwargs):
        self.nf = nf
        self.nl = nl
        self.adj = adj
        super().__init__(**kwargs)
    
    def read(self):
        output = []
        for i in range(self.nf.shape[0]):
            X = self.nf[i,:,:]
            A = self.adj
            Y = self.nl[i,:,:]
            output.append(spektral.data.graph.Graph(x=X, a=A, y=Y))

        return output

In [12]:
dataset = TOP_GDATA(node_features,node_label,A)
dataset.apply(GCNFilter())

In [15]:
train_data = dataset[0:40000]
test_data = dataset[40000:]

In [18]:
loader_tr = spektral.data.BatchLoader(train_data,batch_size=64)
loader_va = spektral.data.BatchLoader(test_data,batch_size=64)

In [19]:
class top_model(tf.keras.models.Model):
    def __init__(self, n_hidden):
        super().__init__()
        self.graph_conv1 = spektral.layers.GCNConv(channels=n_hidden)
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.act1 = tf.keras.layers.Activation('relu')
        self.graph_conv2 = spektral.layers.GCNConv(channels=n_hidden)
        self.bn2 = tf.keras.layers.BatchNormalization()
        self.act2 = tf.keras.layers.Activation('relu')
        self.graph_conv3 = spektral.layers.GCNConv(channels=n_hidden)
        self.bn3 = tf.keras.layers.BatchNormalization()
        self.act3 = tf.keras.layers.Activation('relu')
        self.graph_conv4 = spektral.layers.GCNConv(channels=n_hidden)
        self.bn4 = tf.keras.layers.BatchNormalization()
        self.act4 = tf.keras.layers.Activation('relu')
        self.graph_conv5 = spektral.layers.GCNConv(channels=n_hidden)
        self.bn5 = tf.keras.layers.BatchNormalization()
        self.act5 = tf.keras.layers.Activation('relu')
        self.graph_conv6 = spektral.layers.GCNConv(channels=n_hidden)
        self.bn6 = tf.keras.layers.BatchNormalization()
        self.act6 = tf.keras.layers.Activation('relu')
        self.graph_conv6 = spektral.layers.GCNConv(channels=1)
        self.bn6 = tf.keras.layers.BatchNormalization()
        self.act6 = tf.keras.layers.Activation('sigmoid')

    def call(self, inputs):
        self.c1 = self.graph_conv1(inputs)
        self.b1 = self.bn1(self.c1)
        self.a1 = self.act1(self.b1)
        self.c2 = self.graph_conv2([self.a1,inputs[1]])
        self.b2 = self.bn2(self.c2)
        self.a2 = self.act2(self.b2)
        self.c3 = self.graph_conv3([self.a2,inputs[1]])
        self.b3 = self.bn3(self.c3)
        self.a3 = self.act3(self.b3)
        self.c4 = self.graph_conv4([self.a3,inputs[1]])
        self.b4 = self.bn4(self.c4)
        self.a4 = self.act4(self.b4)
        self.c5 = self.graph_conv5([self.a4,inputs[1]])
        self.b5 = self.bn5(self.c5)
        self.a5 = self.act5(self.b5)
        self.c6 = self.graph_conv6([self.a5,inputs[1]])
        self.b6 = self.bn6(self.c6)
        out = self.act6(self.b6)

        return out

In [20]:
model = top_model(64)
optimizer = tf.keras.optimizers.Adam()
loss = tf.keras.losses.MeanAbsoluteError()

In [21]:
def train_on_batch(inputs, target):
    with tf.GradientTape() as tape:
        predictions = model(inputs, training=True)
        loss_mae = loss(target, predictions)
        
        gradients = tape.gradient(loss_mae, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss_mae

def evaluate(loader):
    step = 0
    results = []
    for batch in loader:
        step += 1
        inputs, target = batch
        predictions = model(inputs, training=False)
        loss_eval = loss(target, predictions)
        results.append(loss_eval)
        print(step)
        clear_output(wait=True)
        if step == loader.steps_per_epoch:
            results_mean = np.array(np.mean(results,axis=0))
            return results_mean

In [22]:
results_tr = []
results_te = []
step=0
e=0

experiment_name = "GNN_first"
wandb.init(project="GNN_bcs2opt",group=experiment_name,config={})
config = wandb.config
    
for batch in loader_tr:
    step += 1
    inputs, target = batch
    loss_tr = train_on_batch(inputs, target)
    print('Epoch : %4d   Batch number : %4d   Training Loss : %.9f' % (e+1,step,loss_tr))
    clear_output(wait=True)
    
    if step == loader_tr.steps_per_epoch:
        e+=1
        step = 0
        loss_te = evaluate(loader_va)
        results_tr.append(loss_tr)
        results_te.append(loss_te)
        wandb.log({"training_loss":loss_tr, "test_loss":loss_te})

In [None]:
def predict(loader):
    step = 0
    pred = []
    for batch in loader:
        step += 1
        inputs, target = batch
        predictions = model(inputs, training=False)
        pred.append(predictions)
        print(step)
        clear_output(wait=True)
        if step == loader.steps_per_epoch:
            return pred

In [None]:
pred_result = predict(loader_va)

In [None]:
plt.imshow(np.reshape(pred_result[0][4],(32,32)))