In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]=""
os.environ["OMP_NUM_THREADS"] = "1"

In [2]:
import random
import numpy as np
import torch
SEED = 42
os.environ['PYTHONHASHSEED']=str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x21b4d207450>

In [3]:
import torch.nn as nn
import torch.nn.functional as F
from tqdm.notebook import tqdm
from torch_geometric.datasets import Planetoid
from torch_geometric.data import NeighborSampler
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, DeepGraphInfomax
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

In [4]:
dataset = Planetoid(root='./tmp/Cora', name='Cora')
data = dataset[0]
print(data)

Data(edge_index=[2, 10556], test_mask=[2708], train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])


In [5]:
train_loader = NeighborSampler(
    data.edge_index, node_idx=None,
    sizes=[10, 5], batch_size=1024, shuffle=True,
    num_workers=0
)
subgraph_loader = NeighborSampler(
    data.edge_index, node_idx=None, 
    sizes=[-1], batch_size=1024, shuffle=False,
    num_workers=0
)

In [6]:
class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.num_layers = 2
        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels, cached=True, normalize=True))
        self.convs.append(GCNConv(hidden_channels, hidden_channels, cached=True, normalize=True))

    def forward(self, x, edge_index):
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
        return x

In [7]:
class GAT(nn.Module):
    def __init__(self, in_channels, hidden_channels, heads=1):
        super().__init__()
        self.num_layers = 2
        self.convs = nn.ModuleList()
        self.convs.append(GATConv(in_channels, hidden_channels, heads=heads))
        self.convs.append(GATConv(heads * hidden_channels, hidden_channels, heads=heads, concat=False))

    def forward(self, x, edge_index):
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
        return x

In [8]:
class SAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.num_layers = 2
        self.convs = nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels, normalize=True))
        self.convs.append(SAGEConv(hidden_channels, hidden_channels, normalize=True))

    def forward(self, x, adjs):
        # `train_loader` computes the k-hop neighborhood of a batch of nodes,
        # and returns, for each layer, a bipartite graph object, holding the
        # bipartite edges `edge_index`, the index `e_id` of the original edges,
        # and the size/shape `size` of the bipartite graph.
        # Target nodes are also included in the source nodes so that one can
        # easily apply skip-connections or add self-loops.
        for i, (edge_index, _, size) in enumerate(adjs):
            x_target = x[:size[1]]  # Target nodes are always placed first.
            x = self.convs[i]((x, x_target), edge_index)
            x = F.relu(x)
        return x

    def inference(self, x_all):
        pbar = tqdm(total=x_all.size(0) * self.num_layers)
        pbar.set_description('Evaluating')
        # Compute representations of nodes layer by layer, using *all*
        # available edges. This leads to faster computation in contrast to
        # immediately computing the final representations of each batch.
        for i in range(self.num_layers):
            xs = []
            for batch_size, n_id, adj in subgraph_loader:
                edge_index, _, size = adj.to(device)
                x = x_all[n_id].to(device)
                x_target = x[:size[1]]
                x = self.convs[i]((x, x_target), edge_index)
                x = F.relu(x)
                xs.append(x.cpu())
                pbar.update(batch_size)
            x_all = torch.cat(xs, dim=0)
        pbar.close()
        return x_all

In [9]:
def summary(z, *args, **kwargs):
    return torch.sigmoid(z.mean(dim=0))

def corruption(x, *args):
    return (x[torch.randperm(x.size(0))], *args)

In [10]:
def train(epoch):
    model.train()
    if type(model.encoder).__name__ in ('GCN', 'GAT'):
        optimizer.zero_grad()
        pos_z, neg_z, summary = model(data.x, data.edge_index)
        loss = model.loss(pos_z, neg_z, summary)
        loss.backward()
        optimizer.step()
        return loss.item()
    elif type(model.encoder).__name__ == 'SAGE':
        pbar = tqdm(total=data.x.shape[0])
        pbar.set_description('Epoch {:03d}'.format(epoch))
        total_loss = 0
        for batch_size, n_id, adjs in train_loader:
            # `adjs` holds a list of `(edge_index, e_id, size)` tuples.
            adjs = [adj.to(device) for adj in adjs]
            optimizer.zero_grad()
            pos_z, neg_z, summary = model(data.x[n_id], adjs)
            loss = model.loss(pos_z, neg_z, summary)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            pbar.update(batch_size)
        pbar.close()
        loss = total_loss / len(train_loader)
        return loss

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)
model = DeepGraphInfomax(
    hidden_channels=128, encoder=SAGE(data.num_features, 128),
    summary=summary, corruption=corruption).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [12]:
epochs = 20
for epoch in range(1, epochs + 1):
    loss = train(epoch)
    print('Epoch: {:03d}, Loss: {:.4f}'.format(epoch, loss))

HBox(children=(FloatProgress(value=0.0, max=2708.0), HTML(value='')))


Epoch: 001, Loss: 1.3930


HBox(children=(FloatProgress(value=0.0, max=2708.0), HTML(value='')))


Epoch: 002, Loss: 1.3454


HBox(children=(FloatProgress(value=0.0, max=2708.0), HTML(value='')))


Epoch: 003, Loss: 1.2865


HBox(children=(FloatProgress(value=0.0, max=2708.0), HTML(value='')))


Epoch: 004, Loss: 1.1720


HBox(children=(FloatProgress(value=0.0, max=2708.0), HTML(value='')))


Epoch: 005, Loss: 0.9690


HBox(children=(FloatProgress(value=0.0, max=2708.0), HTML(value='')))


Epoch: 006, Loss: 0.8035


HBox(children=(FloatProgress(value=0.0, max=2708.0), HTML(value='')))


Epoch: 007, Loss: 0.6108


HBox(children=(FloatProgress(value=0.0, max=2708.0), HTML(value='')))


Epoch: 008, Loss: 0.4260


HBox(children=(FloatProgress(value=0.0, max=2708.0), HTML(value='')))


Epoch: 009, Loss: 0.2995


HBox(children=(FloatProgress(value=0.0, max=2708.0), HTML(value='')))


Epoch: 010, Loss: 0.2296


HBox(children=(FloatProgress(value=0.0, max=2708.0), HTML(value='')))


Epoch: 011, Loss: 0.1491


HBox(children=(FloatProgress(value=0.0, max=2708.0), HTML(value='')))


Epoch: 012, Loss: 0.1276


HBox(children=(FloatProgress(value=0.0, max=2708.0), HTML(value='')))


Epoch: 013, Loss: 0.0945


HBox(children=(FloatProgress(value=0.0, max=2708.0), HTML(value='')))


Epoch: 014, Loss: 0.0781


HBox(children=(FloatProgress(value=0.0, max=2708.0), HTML(value='')))


Epoch: 015, Loss: 0.0768


HBox(children=(FloatProgress(value=0.0, max=2708.0), HTML(value='')))


Epoch: 016, Loss: 0.0655


HBox(children=(FloatProgress(value=0.0, max=2708.0), HTML(value='')))


Epoch: 017, Loss: 0.0542


HBox(children=(FloatProgress(value=0.0, max=2708.0), HTML(value='')))


Epoch: 018, Loss: 0.0507


HBox(children=(FloatProgress(value=0.0, max=2708.0), HTML(value='')))


Epoch: 019, Loss: 0.0446


HBox(children=(FloatProgress(value=0.0, max=2708.0), HTML(value='')))


Epoch: 020, Loss: 0.0345


In [13]:
model_dir = os.path.join(os.getcwd(), 'models')
os.makedirs(model_dir, exist_ok=True)
model_name = os.path.join(model_dir, 'demo.pt')
torch.save(model, model_name)
model = torch.load(model_name)

In [14]:
model.eval()
if type(model.encoder).__name__ in ('GCN', 'GAT'):
    z = model.encoder(data.x, data.edge_index)
elif type(model.encoder).__name__ == 'SAGE':
    z = model.encoder.inference(data.x)

HBox(children=(FloatProgress(value=0.0, max=5416.0), HTML(value='')))




In [15]:
X = z.detach().cpu().numpy()
y = data.y.detach().cpu().numpy()
X_train, X_test, y_train, y_test = X[data.train_mask], X[data.test_mask], y[data.train_mask], y[data.test_mask]
clf = LogisticRegression()
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
print(accuracy_score(y_test, y_pred))

0.693
