In [92]:
import numpy as np
import pandas as pd
import networkx as nx
import torch
import torch.nn as nn
from tqdm import tqdm, trange
import matplotlib.pyplot as plt

In [93]:
from data_loader import load_data, load_npz, load_random
from LPA import LPA
from utils import *

In [94]:
dataset = 'R8'

In [95]:
adj, adj_n, features, y_train, y_val, y_test, train_mask, val_mask, test_mask, train_size, test_size = load_corpus(dataset)

In [96]:
train_size = int(y_train.sum())
val_size = int(y_val.sum())

vocab_size = adj.shape[0] - train_size - val_size - test_size

In [97]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(torch.cuda.get_device_name(0))
else:
    device = torch.device('cpu')
    print("CPU")

CPU


In [98]:
adj_final = torch.load(dataset + '_adj_final.pt', map_location=device)

In [99]:
epochs = 10

In [100]:
def get_mask(idx, length):
    mask = np.zeros(length)
    mask[idx] = 1
    return np.array(mask, dtype=np.float64)

In [101]:
rows = np.concatenate([np.arange(train_size + val_size), np.arange(-test_size, 0)])
rows.shape

(7674,)

In [102]:
cols = rows.copy()
cols.shape

(7674,)

In [109]:
adj_subset = adj_final[rows][:, cols]
# adj_subset = adj_final

In [116]:
model = LPA(adj_subset)

In [117]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

In [118]:
all_train = torch.tensor(y_train[rows], dtype=torch.float64).to(device)
val_input = torch.tensor(y_train[rows], dtype=torch.float64).to(device)
test_input = torch.tensor(y_train[rows] + y_val[rows], dtype=torch.float64).to(device)

In [119]:
train_labels = torch.argmax(torch.tensor(y_train[rows]), dim=1)
val_labels = torch.argmax(torch.tensor(y_val[rows]), dim=1)
test_labels = torch.argmax(torch.tensor(y_test[rows]), dim=1)

In [114]:
for epoch in range(epochs):
    model.train()
    train_indices = np.random.choice(list(range(train_size)), size=int(train_size * 0.8), replace=False)
    train_mask = get_mask(train_indices, adj_subset.shape[0])
    train_input = all_train * train_mask[:, None]
    print(f'\nEpoch {epoch}: ')
    outputs = model(train_input)
    loss = criterion(outputs[:train_size], train_labels[:train_size])
    # train_acc = np.sum(torch.argmax(outputs, dim=1) == train_target)

    preds = torch.argmax(outputs, dim=1)
    train_acc = torch.eq(preds[:train_size], torch.tensor(train_labels[:train_size])).sum() / train_size
    print(f'Training Loss: {loss}\tTraining Accuracy: {train_acc}')

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    model.eval()
    preds = model(val_input)
    loss = criterion(preds[train_size: train_size+val_size], val_labels[train_size: train_size+val_size])
    preds = torch.argmax(preds, dim=1)
    val_acc = torch.eq(preds[train_size: train_size+val_size], val_labels[train_size: train_size+val_size]).sum() / val_size
    print(f'Validation Loss: {loss}\tValidation Accuracy: {val_acc}')


Epoch 0: 
Training Loss: 1.4079054857807642	Training Accuracy: 0.866113007068634


  train_acc = torch.eq(preds[:train_size], torch.tensor(train_labels[:train_size])).sum() / train_size


Validation Loss: 1.4035708800234679	Validation Accuracy: 0.8704379796981812

Epoch 1: 
Training Loss: 1.4042409778909155	Training Accuracy: 0.8697589635848999
Validation Loss: 1.4035708800234679	Validation Accuracy: 0.8704379796981812

Epoch 2: 
Training Loss: 1.4056677383945069	Training Accuracy: 0.8683410882949829
Validation Loss: 1.4035708800234679	Validation Accuracy: 0.8704379796981812

Epoch 3: 
Training Loss: 1.3987809644956826	Training Accuracy: 0.8752278685569763
Validation Loss: 1.4035708800234679	Validation Accuracy: 0.8704379796981812

Epoch 4: 
Training Loss: 1.4001988301512827	Training Accuracy: 0.8738099932670593
Validation Loss: 1.4035708800234679	Validation Accuracy: 0.8704379796981812

Epoch 5: 
Training Loss: 1.3995905983157033	Training Accuracy: 0.8744176626205444
Validation Loss: 1.4035708800234679	Validation Accuracy: 0.8704379796981812

Epoch 6: 
Training Loss: 1.407697249038104	Training Accuracy: 0.8663156032562256
Validation Loss: 1.4035708800234679	Validation 

In [120]:
test_preds = model(test_input)
test_preds = torch.argmax(test_preds, dim=1)
test_acc = torch.eq(test_preds[-1 * test_size:], test_labels[-1 * test_size:]).sum() / test_size
print("Test Accuracy: ", test_acc)

Test Accuracy:  tensor(0.8693)
