In [1]:
import os
from PIL import Image
import numpy as np
import torch
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.neighbors import NearestNeighbors

from torch_geometric.data import HeteroData
from torch_geometric.nn import HGTConv
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from transformers import BertTokenizer, BertModel
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# -------------------------------
# 1. Load & preprocess images + build KNN graph
# -------------------------------
data_dir = '../all'  # Change to your image dataset path
X, y = [], []

for class_name in os.listdir(data_dir):
    class_path = os.path.join(data_dir, class_name)
    if os.path.isdir(class_path):
        for fname in os.listdir(class_path):
            if fname.lower().endswith(('.jpg', '.png', '.jpeg', '.bmp')):
                try:
                    img = Image.open(os.path.join(class_path, fname)).convert('L')
                    img = img.resize((32, 32))
                    arr = np.array(img).flatten()
                    X.append(arr)
                    y.append(class_name)
                except Exception as e:
                    print(f"Error loading {fname}: {e}")

X = np.array(X, dtype=np.float32) / 255.0
y = np.array(y)
le = LabelEncoder()
y_encoded = le.fit_transform(y)
num_nodes, feat_dim = X.shape
num_classes = len(le.classes_)
print(f"Number of nodes: {num_nodes}, Feature dimension: {feat_dim}, Number of classes: {num_classes}")

# Build KNN graph
k = 10
nbrs = NearestNeighbors(n_neighbors=k).fit(X)
_, idx = nbrs.kneighbors(X)
edges = []
for i in range(num_nodes):
    for j in idx[i]:
        if i != j:
            edges.append([i, j])
            edges.append([j, i])
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()

# Split train/test
indices = np.arange(num_nodes)
train_idx, test_idx = train_test_split(indices, test_size=0.2, random_state=42, stratify=y_encoded)
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
test_mask = torch.zeros(num_nodes, dtype=torch.bool)
train_mask[train_idx] = True
test_mask[test_idx] = True

# -------------------------------
# 2. Text prior: extract [CLS] vector with BERT
# -------------------------------
text_path = 'Final Dataset-Texts.xlsx'  # Change path as needed
df_text = pd.read_excel(text_path)
first_text = df_text['List of Store Names'].iloc[0]

tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
bert_model = BertModel.from_pretrained("bert-base-chinese").eval()
with torch.no_grad():
    tok = tokenizer(first_text, return_tensors="pt", truncation=True, padding=True)
    out = bert_model(**tok)
    text_vec = out.last_hidden_state[:, 0, :].squeeze(0)  # (768,)

# -------------------------------
# 3. Build HeteroData object
# -------------------------------
data = HeteroData()
# Image nodes
data['image'].x = torch.tensor(X)
data['image'].y = torch.tensor(y_encoded)
data['image'].train_mask = train_mask
data['image'].test_mask = test_mask
data['image', 'to', 'image'].edge_index = edge_index

# Text node (single node)
data['text'].x = text_vec.unsqueeze(0)  # (1, 768)

# Connect text ↔ image
src = torch.zeros(num_nodes, dtype=torch.long)          # text node index = 0
dst = torch.arange(num_nodes, dtype=torch.long)         # image nodes 0..num_nodes-1
data['text', 'to', 'image'].edge_index = torch.vstack([src, dst])
data['image', 'to', 'text'].edge_index = torch.vstack([dst, src])

# -------------------------------
# 4. Define HeteroGNN model
# -------------------------------
class HeteroGNN(nn.Module):
    def __init__(self, img_dim, txt_dim, hidden, out_dim):
        super().__init__()
        # Project different node types to a common hidden size
        self.lin_img = nn.Linear(img_dim, hidden)
        self.lin_txt = nn.Linear(txt_dim, hidden)
        # Meta-information of the heterogeneous graph
        self.metadata = data.metadata()
        # Two HGTConv layers
        self.conv1 = HGTConv(hidden, hidden, self.metadata, heads=2)
        self.conv2 = HGTConv(hidden, hidden, self.metadata, heads=2)
        # Final classifier for image nodes
        self.cls = nn.Linear(hidden, out_dim)

    def forward(self, data):
        x_dict = {
            'image': F.relu(self.lin_img(data['image'].x)),
            'text':  F.relu(self.lin_txt(data['text'].x))
        }
        # Message passing layer 1
        x_dict = self.conv1(x_dict, data.edge_index_dict)
        x_dict = {k: F.relu(v) for k, v in x_dict.items()}
        # Message passing layer 2
        x_dict = self.conv2(x_dict, data.edge_index_dict)
        # Predict only for image nodes
        out = self.cls(x_dict['image'])
        return out

# -------------------------------
# 5. Training and evaluation
# -------------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HeteroGNN(img_dim=feat_dim, txt_dim=text_vec.size(0),
                  hidden=64, out_dim=num_classes).to(device)
data = data.to(device)

optimizer = optim.Adam(model.parameters(), lr=5e-3, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()

# Train
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    logits = model(data)
    loss = criterion(logits[data['image'].train_mask],
                     data['image'].y[data['image'].train_mask])
    loss.backward()
    optimizer.step()
    if epoch % 20 == 0:
        print(f"[HeteroGNN] Epoch {epoch:03d}, Loss: {loss.item():.4f}")

# Test
model.eval()
with torch.no_grad():
    logits = model(data)
    pred = logits.argmax(dim=1).cpu().numpy()
    y_true = data['image'].y[data['image'].test_mask].cpu().numpy()
    y_pred = pred[data['image'].test_mask.cpu().numpy()]

    print("HeteroGNN Test Accuracy:", accuracy_score(y_true, y_pred))
    print("Precision:", precision_score(y_true, y_pred, average='macro'))
    print("Recall   :", recall_score(y_true, y_pred, average='macro'))
    print("F1-score :", f1_score(y_true, y_pred, average='macro'))


节点数: 78575, 特征维度: 1024, 类别数: 207
[HeteroGNN] Epoch 000, Loss: 5.3375
[HeteroGNN] Epoch 020, Loss: 5.0637
[HeteroGNN] Epoch 040, Loss: 4.9287
[HeteroGNN] Epoch 060, Loss: 4.7973
[HeteroGNN] Epoch 080, Loss: 4.7675
[HeteroGNN] Epoch 100, Loss: 4.6830
[HeteroGNN] Epoch 120, Loss: 4.6063
[HeteroGNN] Epoch 140, Loss: 4.5645
[HeteroGNN] Epoch 160, Loss: 4.5447
[HeteroGNN] Epoch 180, Loss: 4.4687
HeteroGNN Test Accuracy: 0.07311485841552656
Precision: 0.016192157111485465
Recall   : 0.027383932590461093
F1-score : 0.015102539694945778


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
