## Requirements

In [65]:
!pip install torch==2.0.0+cu118 --index-url https://download.pytorch.org/whl/cu118 -q
!pip install pyg-lib==0.3.1 torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cu118.html -q
!pip install 'torch-geometric==2.4.0' -q

## Code

https://github.com/ki-ljl/PyG-GCN/tree/main -- Semi-Supervised Classification with Graph Convolutional Networks, ICLR 2017

In [1]:
classes = [
    'RESIDENTIAL',
    'BUSINESS',
    'RECREATION',
    'SPECIAL',
    'INDUSTRIAL',
    'AGRICULTURE',
    'TRANSPORT',
]

In [4]:
DEVICE = 'cuda'

In [7]:
import geopandas as gpd
import networkx as nx
import torch
from torch_geometric.data import Data
from torch_geometric.utils.convert import from_networkx
from sklearn.model_selection import train_test_split

blocks_gdf = gpd.read_parquet('blocks.parquet')
adj_graph = nx.read_graphml('adj_graph.graphml')

def get_masks(y, test_size=0.1):

    labels = y.numpy()

    # Разбиваем данные с учетом распределения классов
    train_indices, test_indices = train_test_split(
        range(len(labels)),
        test_size=test_size,
        stratify=labels,  # Это обеспечит сохранение пропорций классов
    )

    # Создаем маски для обучающих и тестовых данных
    train_mask = torch.zeros(len(labels), dtype=torch.bool)
    test_mask = torch.zeros(len(labels), dtype=torch.bool)

    train_mask[train_indices] = True
    test_mask[test_indices] = True

    return train_mask, test_mask


def load_data(blocks_gdf, adj_graph):

    blocks_gdf = blocks_gdf.copy()
    blocks_gdf['area'] = blocks_gdf['area'].div(blocks_gdf['area'].sum())
    blocks_gdf['length'] = blocks_gdf.length.div(blocks_gdf.length.sum())
    x = torch.Tensor([[row['area'], row['aspect_ratio'], row['length'], row['area']/row['length'], row['length']*row['area']] for _,row in blocks_gdf.iterrows()])
    y = torch.Tensor([-1 if lu is None else classes.index(lu) for lu in blocks_gdf['land_use']]).long()
    edge_index = from_networkx(adj_graph).edge_index
    edge_attr = from_networkx(adj_graph)['border_length']
    data = Data(x = x, y=y, edge_index=edge_index, edge_attr=edge_attr)
    
    train_mask, test_mask = get_masks(data.y)

    # Добавляем маски в объект data
    data.train_mask = train_mask
    data.test_mask = test_mask
    
    return data, data.x.shape[1], len(classes) # + 1

data, num_node_features, num_classes = load_data(blocks_gdf, adj_graph)
data = data.to(DEVICE)

In [17]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from tqdm import tqdm

class GCN(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_node_features, 32)
        self.norm = torch.nn.BatchNorm1d(32)
        self.conv2 = GCNConv(32, num_classes)

    def forward(self, data):
        x = data.x
        edge_index = data.edge_index
        edge_attr = data.edge_attr

        # Применяем первый слой свертки, учитывая веса ребер
        x = self.conv1(x, edge_index, edge_weight=edge_attr)
        x = self.norm(x)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index, edge_weight=edge_attr)

        return x

def train(model, data, device):
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) # lp=0.01, weight_decay=1e-4 #3e-4

    # class_counts = torch.bincount(data.y[data.y != -1])
    # class_weights = 1. / class_counts.float()

    loss_function = torch.nn.CrossEntropyLoss(ignore_index=-1).to(device)
    model.train()
    
    pbar = tqdm(range(10_000))
    for _ in pbar:
        out = model(data)
        optimizer.zero_grad()
        loss = loss_function(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        pbar.set_description(f'loss : {round(loss.item(),2)}')
        optimizer.step()

        # print(f'Epoch {epoch:03d} loss {loss.item():.4f}')

device = torch.device(DEVICE)
model = GCN(num_node_features, num_classes).to(device)
train(model, data, device)

loss : 1.0: 100%|██████████| 10000/10000 [01:17<00:00, 129.13it/s]


In [18]:
def test(model, data):
    model.eval()
    _, pred = model(data).max(dim=1)

    mask = (data.test_mask) & (data.y != -1)

    correct = int(pred[mask].eq(data.y[mask]).sum().item())
    acc = correct / int(mask.sum())
    print('GCN Accuracy: {:.4f}'.format(acc))
    
test(model, data)

GCN Accuracy: 0.6747


In [19]:
import torch.nn.functional as F

def get_class_probabilities(model, data):
    # Прогон модели
    model.eval()  # Убедитесь, что модель в режиме оценки
    with torch.no_grad():
        output = model(data)  # Выход модели (логиты)

    # Применяем softmax к выходу модели, чтобы получить вероятности для всех классов
    probabilities = F.softmax(output, dim=1)  # softmax по оси классов (второй размер)

    return [{classes[i]:float(p) for i,p in enumerate(probability)} for probability in probabilities]

get_class_probabilities(model, data)

[{'RESIDENTIAL': 0.2374805361032486,
  'BUSINESS': 0.02441723272204399,
  'RECREATION': 0.24476124346256256,
  'SPECIAL': 0.09068704396486282,
  'INDUSTRIAL': 0.1409492790699005,
  'AGRICULTURE': 0.12245942652225494,
  'TRANSPORT': 0.13924524188041687},
 {'RESIDENTIAL': 0.4305322766304016,
  'BUSINESS': 0.02579706534743309,
  'RECREATION': 0.12397494167089462,
  'SPECIAL': 0.004708435852080584,
  'INDUSTRIAL': 0.032198481261730194,
  'AGRICULTURE': 0.0072744740173220634,
  'TRANSPORT': 0.37551426887512207},
 {'RESIDENTIAL': 0.48472464084625244,
  'BUSINESS': 0.03160572424530983,
  'RECREATION': 0.2784360945224762,
  'SPECIAL': 0.02241729386150837,
  'INDUSTRIAL': 0.0843534842133522,
  'AGRICULTURE': 0.036244746297597885,
  'TRANSPORT': 0.06221791356801987},
 {'RESIDENTIAL': 0.4394928514957428,
  'BUSINESS': 0.03265290707349777,
  'RECREATION': 0.20182818174362183,
  'SPECIAL': 0.018870240077376366,
  'INDUSTRIAL': 0.08752204477787018,
  'AGRICULTURE': 0.02697157859802246,
  'TRANSPORT'

In [20]:
blocks_gdf['probabilities'] = get_class_probabilities(model, data)

In [21]:
blocks_gdf.to_parquet('blocks_probabilities.parquet')