In [1]:
import pandas as pd
import numpy as np
import tqdm
import torch
import math

# Pre processing + Define Model

In [12]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import EdgeConv, global_mean_pool

class EdgeNet(nn.Module):
    def __init__(self, input_dim=4, big_dim=32, hidden_dim=2, aggr='mean'):
        super(EdgeNet, self).__init__()
        encoder_nn = nn.Sequential(nn.Linear(2*(input_dim), big_dim),
                               nn.ReLU(),
                               nn.Linear(big_dim, big_dim),
                               nn.ReLU()
        )
        
        self.mu_layer = nn.Linear(big_dim, 1)
        self.var_layer = nn.Linear(big_dim, 1)
        
        decoder_nn = nn.Sequential(nn.Linear(2*(hidden_dim), big_dim),
                               nn.ReLU(),
                               nn.Linear(big_dim, big_dim),
                               nn.ReLU(),
                               nn.Linear(big_dim, input_dim)
        )
        
        self.batchnorm = nn.BatchNorm1d(input_dim)

        self.encoder = EdgeConv(nn=encoder_nn,aggr=aggr)
        self.decoder = EdgeConv(nn=decoder_nn,aggr=aggr)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, data):
        data.x = self.batchnorm(data.x)
        data.x = self.encoder(data.x,data.edge_index)
        mu = self.mu_layer(data.x)
        log_var = self.var_layer(data.x)
        z = self.reparameterize(mu, log_var)
        print(z)
        data.x = self.decoder(z,data.edge_index)
        return data.x, mu, log_var

In [3]:
from torch_geometric.data import Data, DataLoader, DataListLoader, Batch
from torch.utils.data import random_split
import os.path as osp
import matplotlib.pyplot as plt
from graph_data import GraphDataset

gdata = GraphDataset(root='/anomalyvol/data/gnn_node_global_merge/')

input_dim = 4
big_dim = 32
hidden_dim = 2
fulllen = len(gdata)
tv_frac = 0.10
tv_num = math.ceil(fulllen*tv_frac)
batch_size = 1
n_epochs = 100
lr = 0.001
patience = 10
device = 'cuda:0'
model_fname = 'GVAE_sparseloss'

In [13]:
model = EdgeNet(input_dim=input_dim, big_dim=big_dim, hidden_dim=hidden_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = lr)

In [4]:
torch.manual_seed(0)
train_dataset, valid_dataset, test_dataset = random_split(gdata, [fulllen-2*tv_num,tv_num,tv_num])

train_loader = DataLoader(train_dataset, batch_size=batch_size, pin_memory=True, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, pin_memory=True, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, pin_memory=True, shuffle=False)

train_samples = len(train_dataset)
valid_samples = len(valid_dataset)
test_samples = len(test_dataset)

print(train_samples)
print(valid_samples)
print(test_samples)

8000
1000
1000


In [5]:
def sparseloss3d(x,y):
    nparts = x.shape[0]
    dist = torch.pow(torch.cdist(x,y),2)
    in_dist_out = torch.min(dist,dim=0)
    out_dist_in = torch.min(dist,dim=1)
    loss = torch.sum(in_dist_out.values + out_dist_in.values) / nparts
    return loss

# Reconstruction + KL divergence losses summed over all elements and batch
def vae_loss(x, y, mu, logvar):
    BCE = sparseloss3d(x,y)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

In [6]:
@torch.no_grad()
def single_test(model, loader, total, batch_size):
    model.eval()

    sum_loss = 0.
    t = tqdm.tqdm(enumerate(loader),total=total/batch_size)
    for i,data_list in t:
        for data in data_list:
            if data.x.shape[0] <= 1:
                continue
            data = data.to(device)
            y = data.x
            batch_output, mu, log_var = model(data)
            batch_loss_item = vae_loss(batch_output, y, mu, log_var).item()
            sum_loss += batch_loss_item
            t.set_description("loss = %.5f" % (batch_loss_item))
            t.refresh() # to show immediately the update

    return sum_loss/(i+1)

def sgd_train(model, optimizer, loader, total, batch_size):
    model.train()

    sum_loss = 0.
    t = tqdm.tqdm(enumerate(loader),total=total/batch_size)
    for i,data_list in t:
        for data in data_list:
            if data.x.shape[0] <= 1:
                continue
            data = data.to(device)
            y = data.x
            optimizer.zero_grad()
            batch_output, mu, log_var = model(data)
            batch_loss = vae_loss(batch_output, y, mu, log_var)
            batch_loss.backward()
            batch_loss_item = batch_loss.item()
            t.set_description("loss = %.5f" % batch_loss_item)
            t.refresh() # to show immediately the update
            sum_loss += batch_loss_item
            optimizer.step()
    
    return sum_loss/(i+1)

In [7]:
modpath = osp.join('/anomalyvol/models/gnn/',model_fname+'.best.pth')

# Train

In [14]:
stale_epochs = 0
best_valid_loss = 99999
for epoch in range(0, n_epochs):
    loss = sgd_train(model, optimizer, train_loader, train_samples, batch_size)
    valid_loss = single_test(model, valid_loader, valid_samples, batch_size)
    print('Epoch: {:02d}, Training Loss:   {:.4f}'.format(epoch, loss))
    print('               Validation Loss: {:.4f}'.format(valid_loss))

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        modpath = osp.join('/anomalyvol/models/gnn/',model_fname+'.best.pth')
        print('New best model saved to:',modpath)
        torch.save(model.state_dict(),modpath)
        stale_epochs = 0
    else:
        print('Stale epoch')
        stale_epochs += 1
    if stale_epochs >= patience:
        print('Early stopping after %i stale epochs'%patience)
        break

  0%|          | 0/8000.0 [00:00<?, ?it/s]

tensor([[ 0.1734],
        [-0.5700],
        [ 0.9133],
        [-0.7500],
        [ 1.3687],
        [ 0.8588],
        [ 0.1150],
        [-0.9509],
        [ 0.8017],
        [ 0.8712],
        [-1.0281],
        [-0.3750],
        [ 0.4965],
        [ 0.4117],
        [-0.6998],
        [-0.1468],
        [-0.3096],
        [ 1.7931],
        [-0.8607],
        [ 1.6399],
        [ 1.0368],
        [-0.5362],
        [ 0.1936],
        [-0.1420],
        [-0.2449],
        [-0.6599],
        [ 1.0450],
        [-0.5126],
        [ 0.8001],
        [ 0.5338],
        [ 1.9521],
        [ 0.4658],
        [-0.2646],
        [-0.5669],
        [ 3.0222],
        [-0.0817],
        [-0.4093],
        [-0.1265],
        [ 1.9565],
        [ 1.8605],
        [-0.5329],
        [-0.1412],
        [ 0.0474],
        [-0.3073],
        [ 2.2743],
        [ 0.1159],
        [ 0.0034],
        [ 0.4239],
        [-1.5765],
        [-0.8847],
        [ 2.0283],
        [ 0.2383],
        [-1.




RuntimeError: size mismatch, m1: [6320 x 2], m2: [4 x 32] at /pytorch/aten/src/THC/generic/THCTensorMathBlas.cu:283

# Visualize

In [None]:
model.load_state_dict(torch.load(modpath))
input_x = []
output_x = []

t = tqdm.tqdm(enumerate(test_loader),total=test_samples/batch_size)
for i, data in t:
    data[0].to(device)
    input_x.append(data[0].x.cpu().numpy())
    output_x.append(model(data[0]).cpu().detach().numpy())
    del data
    torch.cuda.empty_cache()

In [None]:
def in_out_diff_append(diff, output, inputs, i, ft_idx): # helper for appending 3 lists
    diff.append(((output_x[i][:,ft_idx]-input_x[i][:,ft_idx])/input_x[i][:,ft_idx]).flatten())
    output.append(output_x[i][:,ft_idx].flatten())
    inputs.append(input_x[i][:,ft_idx].flatten())

def in_out_diff_concat(diff, output, inputs):
    diff = np.concatenate(diff)
    output = np.concatenate(output)
    inputs = np.concatenate(inputs)
    return [diff, output, inputs]

def make_hists(diff, output, inputs, bin1):
    plt.figure()
    plt.hist(inputs, bins=bin1,alpha=0.5)
    plt.hist(output, bins=bin1,alpha=0.5)
    plt.show()

    plt.figure()
    plt.hist(diff, bins=np.linspace(-5, 5, 101))
    plt.show()

In [None]:
diff_px = []
output_px = []
input_px = []
diff_py = []
output_py = []
input_py = []
diff_pz = []
output_pz = []
input_pz = []
diff_e = []
output_e = []
input_e = []

# get output in readable format
for i in range(len(input_x)):
    # px
    in_out_diff_append(diff_px, output_px, input_px, i, 0)
    in_out_diff_append(diff_py, output_py, input_py, i, 1)
    in_out_diff_append(diff_pz, output_pz, input_pz, i, 2)
    in_out_diff_append(diff_e, output_e, input_e, i, 3)

# remove extra brackets
diff_px, output_px, input_px = in_out_diff_concat(diff_px, output_px, input_px)
diff_py, output_py, input_py = in_out_diff_concat(diff_py, output_py, input_py)
diff_pz, output_pz, input_pz = in_out_diff_concat(diff_pz, output_pz, input_pz)
diff_e, output_e, input_e = in_out_diff_concat(diff_e, output_e, input_e)

print("px")
bins = np.linspace(-40, 40, 101)
make_hists(diff_px, output_px, input_px, bins)

print("py")
make_hists(diff_py, output_py, input_py, bins)

print("pz")
make_hists(diff_pz, output_pz, input_pz, bins)

print("e")
bins = np.linspace(-5, 40, 101)
make_hists(diff_e, output_e, input_e, bins)