In [None]:
pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m17.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [None]:
import os
import cv2
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, random_split
from torchvision import transforms
from sklearn.preprocessing import LabelEncoder
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as GeoDataLoader

# 1. Load Pose Keypoints CSV
pose_csv = "/content/drive/MyDrive/HRNet-Human-Pose-Estimation/tools/casia_pose_all_subjects.csv"  # your path
df = pd.read_csv(pose_csv)

# 2. Extract Image Names and Keypoints
image_names = df['image'].tolist()
keypoints = df.iloc[:, 1:].values.reshape(-1, 17, 2)

# 3. Build Graphs
COCO_EDGES = [
    (0, 1), (0, 2), (1, 3), (2, 4),
    (5, 6), (5, 7), (7, 9), (6, 8), (8, 10),
    (5, 11), (6, 12), (11, 12), (11, 13), (13, 15),
    (12, 14), (14, 16)
]
edge_index = torch.tensor(COCO_EDGES, dtype=torch.long).t().contiguous()
pose_graphs = [Data(x=torch.tensor(kpt, dtype=torch.float), edge_index=edge_index) for kpt in keypoints]

# 4. Label Encoding
subject_ids = [img.split('/')[0] for img in image_names]
le = LabelEncoder()
encoded_labels = le.fit_transform(subject_ids)
label_map = {sid: lbl for sid, lbl in zip(subject_ids, encoded_labels)}

# 5. Fusion Dataset
class FusionDataset(Dataset):
    def __init__(self, root_dir, image_names, pose_graphs, label_map, transform=None):
        self.root_dir = root_dir
        self.image_names = image_names
        self.pose_graphs = pose_graphs
        self.label_map = label_map
        self.transform = transform or transforms.ToTensor()

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
      rel_path = self.image_names[idx]
      full_path = os.path.join(self.root_dir, rel_path)

      image = cv2.imread(full_path, cv2.IMREAD_GRAYSCALE)
      if image is None:
          raise FileNotFoundError(f"Image not found or cannot be read: {full_path}")

      image = cv2.resize(image, (128, 128))
      image = self.transform(image)
      graph = self.pose_graphs[idx]

      subject_id = rel_path.split('/')[0]
      label = torch.tensor(self.label_map[subject_id], dtype=torch.long)
      return image, graph, label


# 6. Create Dataset and Split
dataset = FusionDataset(
    root_dir="/content/drive/MyDrive/Dataset",  # update this path
    image_names=image_names,
    pose_graphs=pose_graphs,
    label_map=label_map
)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_set, val_set = random_split(dataset, [train_size, val_size])

# 7. Create Dataloaders
train_loader = GeoDataLoader(train_set, batch_size=16, shuffle=True)
val_loader = GeoDataLoader(val_set, batch_size=16)


In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torchvision import models
import torch.optim as optim

# CNN Backbone (e.g., ResNet18 without final FC)
class CNNEncoder(nn.Module):
    def __init__(self, output_dim=128):
        super(CNNEncoder, self).__init__()
        resnet = models.resnet18(pretrained=True)
        resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)  # For grayscale
        self.features = nn.Sequential(*list(resnet.children())[:-1])  # Remove final FC
        self.fc = nn.Linear(512, output_dim)

    def forward(self, x):
        x = self.features(x)  # (B, 512, 1, 1)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# GCN Encoder
class GCNEncoder(nn.Module):
    def __init__(self, output_dim=128):
        super(GCNEncoder, self).__init__()
        self.conv1 = GCNConv(2, 64)
        self.conv2 = GCNConv(64, 128)
        self.fc = nn.Linear(128, output_dim)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)
        return self.fc(x)

# Late Fusion Classifier
class LateFusionModel(nn.Module):
    def __init__(self, cnn_out=128, gcn_out=128, num_classes=124):
        super(LateFusionModel, self).__init__()
        self.cnn = CNNEncoder(cnn_out)
        self.gcn = GCNEncoder(gcn_out)
        self.classifier = nn.Linear(cnn_out + gcn_out, num_classes)

    def forward(self, image, graph_data):
        img_feat = self.cnn(image)
        gcn_feat = self.gcn(graph_data)
        combined = torch.cat([img_feat, gcn_feat], dim=1)
        return self.classifier(combined)


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LateFusionModel(num_classes=len(set(encoded_labels))).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

def train(model, loader):
    model.train()
    total_loss = 0
    for images, graphs, labels in loader:
        images, labels = images.to(device), labels.to(device)
        graphs = graphs.to(device)
        optimizer.zero_grad()
        outputs = model(images, graphs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for images, graphs, labels in loader:
            images, labels = images.to(device), labels.to(device)
            graphs = graphs.to(device)
            outputs = model(images, graphs)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total

# Training Loop
for epoch in range(1, 21):
    train_loss = train(model, train_loader)
    val_acc = evaluate(model, val_loader)
    print(f"Epoch {epoch}: Train Loss={train_loss:.4f}, Val Acc={val_acc:.4f}")




Epoch 1: Train Loss=0.6880, Val Acc=0.1351
Epoch 2: Train Loss=0.1244, Val Acc=0.9009
Epoch 3: Train Loss=0.0718, Val Acc=0.9640
Epoch 4: Train Loss=0.0303, Val Acc=0.9820
Epoch 5: Train Loss=0.0635, Val Acc=0.9910
Epoch 6: Train Loss=0.0294, Val Acc=1.0000
Epoch 7: Train Loss=0.0244, Val Acc=1.0000
Epoch 8: Train Loss=0.0461, Val Acc=0.9009
Epoch 9: Train Loss=0.0199, Val Acc=0.9910
Epoch 10: Train Loss=0.0090, Val Acc=0.9910
Epoch 11: Train Loss=0.0090, Val Acc=0.9730
Epoch 12: Train Loss=0.0061, Val Acc=1.0000
Epoch 13: Train Loss=0.0040, Val Acc=0.9730
Epoch 14: Train Loss=0.0227, Val Acc=0.9730
Epoch 15: Train Loss=0.0172, Val Acc=0.9640
Epoch 16: Train Loss=0.0100, Val Acc=0.9550
Epoch 17: Train Loss=0.0258, Val Acc=1.0000
Epoch 18: Train Loss=0.0214, Val Acc=0.9820
Epoch 19: Train Loss=0.0520, Val Acc=0.9550
Epoch 20: Train Loss=0.0485, Val Acc=0.9459
