In [1]:
text_dim=768
k = 10
test_size=0.2
lr=0.005
hidden_channels=64

In [None]:
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.neighbors import NearestNeighbors

from torch_geometric.data import Data
from torch_geometric.nn import GATConv
from transformers import BertTokenizer, BertModel
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# ---------------------------
# 1. Image loading, preprocessing, and graph construction
# ---------------------------
data_dir = '../raw'  # Change to your image-dataset folder path

X = []  # store image data (each image flattened to a 1024-dimensional vector)
y = []  # store class labels

for class_name in os.listdir(data_dir):
    class_path = os.path.join(data_dir, class_name)
    if os.path.isdir(class_path):
        for filename in os.listdir(class_path):
            if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
                file_path = os.path.join(class_path, filename)
                try:
                    img = Image.open(file_path)
                    img = img.convert('L')          # convert to grayscale
                    img = img.resize((32, 32))      # resize to 32×32
                    img_array = np.array(img).flatten()  # flatten to 1024-dim vector
                    X.append(img_array)
                    y.append(class_name)
                except Exception as e:
                    print(f"Error reading file {file_path}: {e}")

X = np.array(X, dtype='float32') / 255.0  # normalize to [0, 1]
y = np.array(y)
print("Total images loaded:", X.shape[0])
print("Features per sample:", X.shape[1])
print("Original class labels:", np.unique(y))

# Label encoding
le = LabelEncoder()
y_encoded = le.fit_transform(y)
num_classes = len(le.classes_)
print("Number of encoded classes:", num_classes)

# Build a KNN graph

nbrs = NearestNeighbors(n_neighbors=k, algorithm='auto').fit(X)
_, indices = nbrs.kneighbors(X)

edge_index = []
num_nodes = X.shape[0]
for i in range(num_nodes):
    for j in indices[i]:
        if i != j:                         # exclude self-loops
            edge_index.append([i, j])
            edge_index.append([j, i])      # bidirectional edges
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

# Construct a PyG Data object
x_tensor = torch.tensor(X, dtype=torch.float)
y_tensor = torch.tensor(y_encoded, dtype=torch.long)
data = Data(x=x_tensor, edge_index=edge_index, y=y_tensor)

# Stratified train/test node split
indices = np.arange(num_nodes)
train_idx, test_idx = train_test_split(indices, test_size=test_size, random_state=42, stratify=y_encoded)
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
test_mask  = torch.zeros(num_nodes, dtype=torch.bool)
train_mask[train_idx] = True
test_mask[test_idx]   = True
data.train_mask = train_mask
data.test_mask  = test_mask

print("Training nodes:", int(train_mask.sum()))
print("Test nodes    :", int(test_mask.sum()))

# ---------------------------
# 2. Text prior: load Excel and obtain the [CLS] vector via BERT
# ---------------------------
text_data_path = '../Sample Data Texts.xlsx'  # Change to your Excel file path
df_text = pd.read_excel(text_data_path)
# Take the first row of “List of Store Names” as plugin information (example only)
first_text = df_text['List of Store Names'].iloc[0]
print("First row text:", first_text)

# Use a pre-trained BERT model (bert-base-chinese if the text is Chinese)
tokenizer  = BertTokenizer.from_pretrained("bert-base-chinese")
bert_model = BertModel.from_pretrained("bert-base-chinese")
bert_model.eval()
with torch.no_grad():
    inputs  = tokenizer(first_text, return_tensors="pt", truncation=True, padding=True)
    outputs = bert_model(**inputs)
    # Extract the [CLS] token hidden state, shape (1, 768)
    text_hidden_state = outputs.last_hidden_state[:, 0, :].squeeze(0)  # (768,)
print("Text hidden state shape:", text_hidden_state.shape)

# ---------------------------
# 3. Define a plugin-based multimodal GAT model
# ---------------------------
class GATPlugin(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_classes, text_dim,
                 heads=8, dropout=0.6):
        super().__init__()
        self.gat1 = GATConv(in_channels, hidden_channels,
                            heads=heads, dropout=dropout)
        # First layer output dimension = hidden_channels * heads
        self.gat2 = GATConv(hidden_channels * heads, hidden_channels,
                            heads=1, concat=False, dropout=dropout)
        # Fusion dimension: hidden_channels + text_dim
        self.classifier = nn.Linear(hidden_channels + text_dim, num_classes)
    
    def forward(self, data, text_vector):
        x, edge_index = data.x, data.edge_index
        x = self.gat1(x, edge_index)
        x = F.elu(x)
        x = self.gat2(x, edge_index)            # (num_nodes, hidden_channels)
        num_nodes = x.size(0)
        text_expanded = text_vector.unsqueeze(0).expand(num_nodes, -1)
        fused = torch.cat([x, text_expanded], dim=1)
        logits = self.classifier(fused)
        return logits

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GATPlugin(in_channels=X.shape[1],
                  hidden_channels=hidden_channels,
                  num_classes=num_classes,
                  text_dim=text_dim).to(device)
data = data.to(device)
text_hidden_state = text_hidden_state.to(device)

optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()

# ---------------------------
# 4. Train the model
# ---------------------------
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out  = model(data, text_hidden_state)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    if epoch % 20 == 0:
        print(f"Epoch {epoch:03d}, Loss: {loss.item():.4f}")

# ---------------------------
# 5. Evaluate the model
# ---------------------------
model.eval()
with torch.no_grad():
    out  = model(data, text_hidden_state)
    pred = out.argmax(dim=1)

    # Only consider test nodes
    y_true = data.y[data.test_mask].cpu().numpy()
    y_pred = pred[data.test_mask].cpu().numpy()

    acc  = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, average='macro')
    rec  = recall_score(y_true, y_pred, average='macro')
    f1   = f1_score(y_true, y_pred, average='macro')

    print(f"Test Accuracy : {acc:.4f}")
    print(f"Precision     : {prec:.4f}")
    print(f"Recall        : {rec:.4f}")
    print(f"F1-score      : {f1:.4f}")
