In [1]:
text_dim=768
k = 10
test_size=0.2
lr=0.005
hidden_channels=64

In [2]:
import os
from PIL import Image
import numpy as np
import torch
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.neighbors import NearestNeighbors
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer, BertModel
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# -------------------------------
# 1. Load & preprocess images and build the graph
# -------------------------------
data_dir = '../raw'  # Change to your image-dataset folder path

X = []  # store image data (each image flattened to a 1024-dimensional vector)
y = []  # store class labels

# Traverse the dataset folder; each sub-folder represents a class
for class_name in os.listdir(data_dir):
    class_path = os.path.join(data_dir, class_name)
    if os.path.isdir(class_path):
        for filename in os.listdir(class_path):
            if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
                file_path = os.path.join(class_path, filename)
                try:
                    img = Image.open(file_path)
                    img = img.convert('L')          # convert to grayscale
                    img = img.resize((32, 32))      # resize to 32×32
                    img_array = np.array(img).flatten()  # flatten to 1024-dim vector
                    X.append(img_array)
                    y.append(class_name)
                except Exception as e:
                    print(f"Error reading file {file_path}: {e}")

X = np.array(X, dtype='float32') / 255.0  # normalize
y = np.array(y)
print("Total images read:", X.shape[0])
print("Features per sample:", X.shape[1])
print("Original class labels:", np.unique(y))

# Label encoding
le = LabelEncoder()
y_encoded = le.fit_transform(y)
num_classes = len(le.classes_)
print("Number of encoded classes:", num_classes)

# Build a KNN graph

nbrs = NearestNeighbors(n_neighbors=k, algorithm='auto').fit(X)
_, indices = nbrs.kneighbors(X)
edge_index = []
num_nodes = X.shape[0]
for i in range(num_nodes):
    for j in indices[i]:
        if i != j:  # exclude self-loops
            edge_index.append([i, j])
            edge_index.append([j, i])  # bidirectional edges
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

# Construct the PyG Data object
x_tensor = torch.tensor(X, dtype=torch.float)
y_tensor = torch.tensor(y_encoded, dtype=torch.long)
data = Data(x=x_tensor, edge_index=edge_index, y=y_tensor)

# Split nodes into train/test sets
indices = np.arange(num_nodes)
train_idx, test_idx = train_test_split(indices, test_size=test_size, random_state=42, stratify=y_encoded)
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
test_mask  = torch.zeros(num_nodes, dtype=torch.bool)
train_mask[train_idx] = True
test_mask[test_idx]   = True
data.train_mask = train_mask
data.test_mask  = test_mask

print("Number of training nodes:", int(train_mask.sum()))
print("Number of test nodes:", int(test_mask.sum()))

# -------------------------------
# 2. Text prior: read Excel and obtain the [CLS] vector via BERT
# -------------------------------
text_data_path = '../Sample Data Texts.xlsx'  # Change to your Excel file path
df_text = pd.read_excel(text_data_path)
# Use the first row in “List of Store Names” as plugin information (example only)
first_text = df_text['List of Store Names'].iloc[0]
print("First row text:", first_text)

# Use a pre-trained BERT model (bert-base-chinese if the text is Chinese)
tokenizer   = BertTokenizer.from_pretrained("bert-base-chinese")
bert_model  = BertModel.from_pretrained("bert-base-chinese")
bert_model.eval()
with torch.no_grad():
    inputs  = tokenizer(first_text, return_tensors="pt", truncation=True, padding=True)
    outputs = bert_model(**inputs)
    # Get the [CLS] token hidden state, shape (1, 768)
    text_hidden_state = outputs.last_hidden_state[:, 0, :].squeeze(0)  # (768,)
print("Text hidden state shape:", text_hidden_state.shape)

# -------------------------------
# 3. Define a GCN model with text plugin
# -------------------------------
class GCNPlugin(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, text_dim):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        # After fusion the dimension is hidden_channels + text_dim
        self.classifier = nn.Linear(hidden_channels + text_dim, out_channels)
    
    def forward(self, data, text_vector):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        # x shape: (num_nodes, hidden_channels)
        num_nodes = x.size(0)
        # Expand text_vector (text_dim,) to (num_nodes, text_dim)
        text_expanded = text_vector.unsqueeze(0).expand(num_nodes, -1)
        # Fuse: concatenate image embeddings and text plugin
        fused = torch.cat([x, text_expanded], dim=1)
        logits = self.classifier(fused)
        return logits

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCNPlugin(in_channels=X.shape[1],
                  hidden_channels=hidden_channels,
                  out_channels=num_classes,
                  text_dim=text_dim).to(device)
data = data.to(device)
text_hidden_state = text_hidden_state.to(device)

optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()

# -------------------------------
# 4. Train the model
# -------------------------------
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out  = model(data, text_hidden_state)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    if epoch % 20 == 0:
        print(f"Epoch {epoch:03d}, Loss: {loss.item():.4f}")

# -------------------------------
# 5. Evaluate the model
# -------------------------------
model.eval()
with torch.no_grad():
    out  = model(data, text_hidden_state)
    pred = out.argmax(dim=1)

    # Only consider test nodes
    y_true = data.y[data.test_mask].cpu().numpy()
    y_pred = pred[data.test_mask].cpu().numpy()

    # Compute metrics
    acc  = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, average='macro')
    rec  = recall_score(y_true, y_pred, average='macro')
    f1   = f1_score(y_true, y_pred, average='macro')

    # Print results
    print(f"Test Accuracy : {acc:.4f}")
    print(f"Precision     : {prec:.4f}")
    print(f"Recall        : {rec:.4f}")
    print(f"F1-score      : {f1:.4f}")


Total images read: 3344
Features per sample: 1024
Original class labels: ['Eh-1-1' 'Eh-1-2' 'Eh-1-3' 'Eh-1-4' 'N-1-1' 'N-1-2' 'N-1-3' 'N-1-4'
 'N-1-5']
Number of encoded classes: 9
Number of training nodes: 2675
Number of test nodes: 669
First row text: WM HOUSE, MC HOUSE 展览馆, 交通银行, 良品铺子
Text hidden state shape: torch.Size([768])
Epoch 000, Loss: 2.1915
Epoch 020, Loss: 1.9007
Epoch 040, Loss: 1.7992
Epoch 060, Loss: 1.7512
Epoch 080, Loss: 1.6892
Epoch 100, Loss: 1.6553
Epoch 120, Loss: 1.5843
Epoch 140, Loss: 1.5247
Epoch 160, Loss: 1.5405
Epoch 180, Loss: 1.4312
Test Accuracy : 0.4604
Precision     : 0.5317
Recall        : 0.3447
F1-score      : 0.3550


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


GAT