In [None]:
test_size=0.2
in_channels=1
embed_dim=64
num_layers=6
num_heads=8 
mlp_dim=128
dropout=0.1
text_dim=768
lr=0.001
gamma=0.1
num_epochs = 100

In [None]:
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from transformers import BertTokenizer, BertModel
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import precision_score, recall_score, f1_score

# -------------------------------
# 1. Image Data Loading and Preprocessing
# -------------------------------
data_dir = '../raw'  # Change to your image dataset folder path

images = []  # Store image data, keep 2D shape (32, 32) for patch splitting
labels = []  # Store class labels

for class_name in os.listdir(data_dir):
    class_path = os.path.join(data_dir, class_name)
    if os.path.isdir(class_path):
        for filename in os.listdir(class_path):
            if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
                file_path = os.path.join(class_path, filename)
                try:
                    img = Image.open(file_path)
                    img = img.convert('L')            # Convert to grayscale
                    img = img.resize((32, 32))        # Resize to 32×32
                    img_array = np.array(img)         # Keep 2D (32, 32)
                    images.append(img_array)
                    labels.append(class_name)
                except Exception as e:
                    print(f"Error reading file {file_path}: {e}")

# Normalize to [0,1], shape: (N, 32, 32)
images = np.array(images, dtype=np.float32) / 255.0
labels = np.array(labels)
print("Total number of images read:", images.shape[0])
print("Image dimensions:", images.shape[1:])

# Label encoding
le = LabelEncoder()
labels_encoded = le.fit_transform(labels)
num_classes = len(le.classes_)
print("Number of encoded classes:", num_classes)

# Convert images to (N, C, H, W) format, here C=1
images = np.expand_dims(images, axis=1)  # Shape: (N, 1, 32, 32)

# Split into training and test sets
indices = np.arange(len(images))
train_idx, test_idx = train_test_split(
    indices, test_size=test_size, random_state=42, stratify=labels_encoded
)

x_tensor = torch.tensor(images, dtype=torch.float)   # (N, 1, 32, 32)
y_tensor = torch.tensor(labels_encoded, dtype=torch.long)

# Split data by index
x_train = x_tensor[train_idx]
y_train = y_tensor[train_idx]
x_test  = x_tensor[test_idx]
y_test  = y_tensor[test_idx]

# Create DataLoader
batch_size = 32
train_dataset = TensorDataset(x_train, y_train)
test_dataset  = TensorDataset(x_test, y_test)
train_loader  = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader   = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False)

# -------------------------------
# 2. Text Prior: Extract text embedding using BERT
# -------------------------------
# Change text_data_path to your Excel file path
text_data_path = '../Sample Data Texts.xlsx'
df_text = pd.read_excel(text_data_path)
# Use the first row of "List of Store Names" as plugin information
first_text = df_text['List of Store Names'].iloc[0]
print("First row text:", first_text)

# Use pre-trained BERT model (e.g., bert-base-chinese if text is Chinese)
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
bert_model = BertModel.from_pretrained("bert-base-chinese")
bert_model.eval()
with torch.no_grad():
    inputs = tokenizer(first_text, return_tensors="pt", truncation=True, padding=True)
    outputs = bert_model(**inputs)
    # Take the [CLS] token's hidden state, shape: (1, 768)
    text_hidden_state = outputs.last_hidden_state[:, 0, :]
    text_hidden_state = text_hidden_state.squeeze(0)  # (768,)
print("Text hidden state shape:", text_hidden_state.shape)

# -------------------------------
# 3. Define ViT Model (Plugin-based Multimodal ViT)
# -------------------------------
class PatchEmbedding(nn.Module):
    """
    Split input image into patches and apply convolution for linear projection to obtain embeddings.
    """
    def __init__(self, img_size=32, patch_size=4, in_channels=1, embed_dim=64):
        super(PatchEmbedding, self).__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        # x: (B, C, H, W)
        x = self.proj(x)              # (B, embed_dim, H/patch_size, W/patch_size)
        x = x.flatten(2)              # (B, embed_dim, num_patches)
        x = x.transpose(1, 2)         # (B, num_patches, embed_dim)
        return x

class ViTPlugin(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=1, embed_dim=64,
                 num_layers=6, num_heads=8, mlp_dim=128, num_classes=10,
                 dropout=0.1, text_dim=768):
        super(ViTPlugin, self).__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches
        
        # Classification token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        # Positional encoding (including classification token)
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=dropout)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads,
            dim_feedforward=mlp_dim, dropout=dropout
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.norm = nn.LayerNorm(embed_dim)
        # Combined dimension after fusion: embed_dim + text_dim
        self.head = nn.Linear(embed_dim + text_dim, num_classes)
        
        self._init_weights()
    
    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.xavier_uniform_(self.head.weight)
        nn.init.zeros_(self.head.bias)
    
    def forward(self, x, text_vector):
        # x: (B, C, H, W)
        B = x.size(0)
        x = self.patch_embed(x)                    # (B, num_patches, embed_dim)
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, embed_dim)
        x = torch.cat((cls_tokens, x), dim=1)          # (B, num_patches+1, embed_dim)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        # Transformer expects input shape: (sequence_length, batch_size, embed_dim)
        x = x.transpose(0, 1)                      # (num_patches+1, B, embed_dim)
        x = self.transformer(x)                    # (num_patches+1, B, embed_dim)
        x = x.transpose(0, 1)                      # (B, num_patches+1, embed_dim)
        x = self.norm(x)
        cls_out = x[:, 0]                          # Classification token output, (B, embed_dim)
        # Expand fixed text vector (text_dim,) to (B, text_dim)
        text_expanded = text_vector.unsqueeze(0).expand(B, -1)
        # Fuse: concatenate classification token output and text embedding
        fused = torch.cat([cls_out, text_expanded], dim=1)  # (B, embed_dim + text_dim)
        logits = self.head(fused)
        return logits

# -------------------------------
# 4. Training and Evaluation
# -------------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
text_hidden_state = text_hidden_state.to(device)

model = ViTPlugin(
    img_size=32, patch_size=4, in_channels=in_channels, embed_dim=embed_dim,
    num_layers=num_layers, num_heads=num_heads, mlp_dim=mlp_dim, num_classes=num_classes,
    dropout=dropout, text_dim=text_dim
).to(device)

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=gamma)
criterion = nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    for batch_images, batch_labels in train_loader:
        batch_images = batch_images.to(device)
        batch_labels = batch_labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(batch_images, text_hidden_state)
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item() * batch_images.size(0)
    
    scheduler.step()
    avg_loss = epoch_loss / len(train_loader.dataset)
    
    # Evaluate every 20 epochs
    if epoch % 20 == 0:
        model.eval()
        correct = total = 0
        with torch.no_grad():
            for imgs, lbls in test_loader:
                imgs, lbls = imgs.to(device), lbls.to(device)
                preds = model(imgs, text_hidden_state).argmax(dim=1)
                correct += (preds == lbls).sum().item()
                total += lbls.size(0)
        acc = correct / total
        print(f"Epoch {epoch:03d}, Loss: {avg_loss:.4f}, Test Accuracy: {acc:.4f}")

# Final test accuracy
model.eval()
correct = total = 0
with torch.no_grad():
    for imgs, lbls in test_loader:
        imgs, lbls = imgs.to(device), lbls.to(device)
        preds = model(imgs, text_hidden_state).argmax(dim=1)
        correct += (preds == lbls).sum().item()
        total += lbls.size(0)
acc = correct / total
print(f"Final Test Accuracy: {acc:.4f}")

# -----------------------------
# Additional: compute Precision / Recall / F1
# -----------------------------
y_true, y_pred = [], []
model.eval()
with torch.no_grad():
    for imgs, lbls in test_loader:
        imgs, lbls = imgs.to(device), lbls.to(device)
        preds = model(imgs, text_hidden_state).argmax(dim=1)
        y_true.extend(lbls.cpu().tolist())
        y_pred.extend(preds.cpu().tolist())

prec = precision_score(y_true, y_pred, average='macro', zero_division=0)
rec  = recall_score(y_true, y_pred, average='macro', zero_division=0)
f1   = f1_score(y_true, y_pred, average='macro', zero_division=0)
print(f"Final Test Set — Precision: {prec:.4f} | Recall: {rec:.4f} | F1-score: {f1:.4f}")
