In [2]:
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

dataset = Planetoid(root='data/Planetoid', name='Cora',
					transform=NormalizeFeatures())

print()
print(f'Dataset: {dataset}')
print('=======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
data = dataset[0]

print()
print(data)
print('=======================')
print(f'Number of Nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree :{data.num_edges / data.num_nodes:2f}')



Dataset: Cora()
Number of graphs: 1
Number of features: 1433
Number of classes: 7

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
Number of Nodes: 2708
Number of edges: 10556
Average node degree :3.898080


In [3]:
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

def visualize(h, color):
	z = TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())
	plt.figure(figsize=(10,10))
	plt.xticks([])
	plt.yticks([])
	plt.scatter(z[:,0], z[:,1], s=70, c=color, cmap='Set2')
	plt.show()

#### 传统神经网络求解

In [None]:
import torch
from torch.nn import Linear
import torch.nn.functional as F

class MLP(torch.nn.Module):
	def __init__(self, hidden_channels):
		super().__init__()
		torch.manual_seed(12345)
		# 指定 2 个全连接层
		self.lin1 = Linear(dataset.num_features, hidden_channels)
		self.lin2 = Linear(hidden_channels, dataset.num_classes)

	def forward(self, x):
		x = self.lin1(x)
		x = x.relu()
		x = F.dropout(x, p=0.5, training=self.training)
		x = self.lin2(x)
		return x

model = MLP(hidden_channels=16)
print(model)

MLP(
  (lin1): Linear(in_features=1433, out_features=16, bias=True)
  (lin2): Linear(in_features=16, out_features=7, bias=True)
)


In [None]:
model = MLP(hidden_channels=16)
criterian = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, 
							 weight_decay=5e-4)
