Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
149 lines (118 sloc) 4.76 KB
import argparse
import os.path as osp
import torch
from torch_geometric.datasets import PPI
from torch_geometric.data import DataLoader
from torch_geometric.nn import GATConv
from sklearn.metrics import f1_score
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='GeniePathLazy')
args = parser.parse_args()
assert args.model in ['GeniePath', 'GeniePathLazy']
path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'PPI')
train_dataset = PPI(path, split='train')
val_dataset = PPI(path, split='val')
test_dataset = PPI(path, split='test')
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)
dim = 256
lstm_hidden = 256
layer_num = 4
class Breadth(torch.nn.Module):
def __init__(self, in_dim, out_dim):
super(Breadth, self).__init__()
self.gatconv = GATConv(in_dim, out_dim, heads=1)
def forward(self, x, edge_index):
x = torch.tanh(self.gatconv(x, edge_index))
return x
class Depth(torch.nn.Module):
def __init__(self, in_dim, hidden):
super(Depth, self).__init__()
self.lstm = torch.nn.LSTM(in_dim, hidden, 1, bias=False)
def forward(self, x, h, c):
x, (h, c) = self.lstm(x, (h, c))
return x, (h, c)
class GeniePathLayer(torch.nn.Module):
def __init__(self, in_dim):
super(GeniePathLayer, self).__init__()
self.breadth_func = Breadth(in_dim, dim)
self.depth_func = Depth(dim, lstm_hidden)
def forward(self, x, edge_index, h, c):
x = self.breadth_func(x, edge_index)
x = x[None, :]
x, (h, c) = self.depth_func(x, h, c)
x = x[0]
return x, (h, c)
class GeniePath(torch.nn.Module):
def __init__(self, in_dim, out_dim):
super(GeniePath, self).__init__()
self.lin1 = torch.nn.Linear(in_dim, dim)
self.gplayers = torch.nn.ModuleList(
[GeniePathLayer(dim) for i in range(layer_num)])
self.lin2 = torch.nn.Linear(dim, out_dim)
def forward(self, x, edge_index):
x = self.lin1(x)
h = torch.zeros(1, x.shape[0], lstm_hidden, device=x.device)
c = torch.zeros(1, x.shape[0], lstm_hidden, device=x.device)
for i, l in enumerate(self.gplayers):
x, (h, c) = self.gplayers[i](x, edge_index, h, c)
x = self.lin2(x)
return x
class GeniePathLazy(torch.nn.Module):
def __init__(self, in_dim, out_dim):
super(GeniePathLazy, self).__init__()
self.lin1 = torch.nn.Linear(in_dim, dim)
self.breadths = torch.nn.ModuleList(
[Breadth(dim, dim) for i in range(layer_num)])
self.depths = torch.nn.ModuleList(
[Depth(dim * 2, lstm_hidden) for i in range(layer_num)])
self.lin2 = torch.nn.Linear(dim, out_dim)
def forward(self, x, edge_index):
x = self.lin1(x)
h = torch.zeros(1, x.shape[0], lstm_hidden, device=x.device)
c = torch.zeros(1, x.shape[0], lstm_hidden, device=x.device)
h_tmps = []
for i, l in enumerate(self.breadths):
h_tmps.append(self.breadths[i](x, edge_index))
x = x[None, :]
for i, l in enumerate(self.depths):
in_cat = torch.cat((h_tmps[i][None, :], x), -1)
x, (h, c) = self.depths[i](in_cat, h, c)
x = self.lin2(x[0])
return x
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
kwargs = {'GeniePath': GeniePath, 'GeniePathLazy': GeniePathLazy}
model = kwargs[args.model](train_dataset.num_features,
train_dataset.num_classes).to(device)
loss_op = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
def train():
model.train()
total_loss = 0
for data in train_loader:
num_graphs = data.num_graphs
data.batch = None
data = data.to(device)
optimizer.zero_grad()
loss = loss_op(model(data.x, data.edge_index), data.y)
total_loss += loss.item() * num_graphs
loss.backward()
optimizer.step()
return total_loss / len(train_loader.dataset)
def test(loader):
model.eval()
ys, preds = [], []
for data in loader:
ys.append(data.y)
with torch.no_grad():
out = model(data.x.to(device), data.edge_index.to(device))
preds.append((out > 0).float().cpu())
y, pred = torch.cat(ys, dim=0).numpy(), torch.cat(preds, dim=0).numpy()
return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0
for epoch in range(1, 101):
loss = train()
val_f1 = test(val_loader)
test_f1 = test(test_loader)
print('Epoch: {:02d}, Loss: {:.4f}, Val: {:.4f}, Test: {:.4f}'.format(
epoch, loss, val_f1, test_f1))
You can’t perform that action at this time.