In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.svm import SVC
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import accuracy_score, classification_report
from sklearn.multiclass import OneVsRestClassifier
from sklearn.utils import shuffle
from scipy.stats import kurtosis
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import os
import random

np.random.seed(42)
random.seed(42)
torch.manual_seed(42)

#subsetting casia b
def load_casia_subset(base_path, num_subjects=10, sequences_per_subject=9):
    images = []
    labels = []

    subject_dirs = sorted([d for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))])
    subject_dirs = subject_dirs[:num_subjects]

    for subject in subject_dirs:
        subject_path = os.path.join(base_path, subject)
        for condition in os.listdir(subject_path):
            if not condition.startswith("nm"):  #have onlly used normal walking 
                continue
            cond_path = os.path.join(subject_path, condition)
            if not os.path.isdir(cond_path):
                continue
            for angle in os.listdir(cond_path)[:sequences_per_subject]:
                angle_path = os.path.join(cond_path, angle)
                if not os.path.isdir(angle_path):
                    continue
                for frame in sorted(os.listdir(angle_path))[:20]: 
                    frame_path = os.path.join(angle_path, frame)
                    try:
                        img = Image.open(frame_path).convert("RGB").resize((224, 224))
                        images.append(np.array(img) / 255.0)
                        labels.append(subject)
                    except Exception as e:
                        print(f"Error loading {frame_path}: {e}")
    return np.array(images), np.array(labels)

#Feature Extraction using Transfer learning approach (ResNet101)
def extract_features_with_resnet(images):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    resnet = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1)
    resnet.fc = torch.nn.Identity()  # Remove the final classification layer
    resnet = resnet.to(device)
    resnet.eval()

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224)),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    features = []
    with torch.no_grad():
        for img in images:
            pil_img = Image.fromarray((img * 255).astype(np.uint8))
            input_tensor = transform(pil_img).unsqueeze(0).to(device)
            feat = resnet(input_tensor)
            features.append(feat.cpu().numpy().flatten())
    return np.array(features)

#Kurtosis-based FS
def kurtosis_selection(features, threshold=3):
    k_vals = kurtosis(features, axis=0)
    selected_features = features[:, k_vals > threshold]
    return selected_features

#Correlation-based FF
def correlation_feature_fusion(features):
    corr = np.corrcoef(features.T)
    upper_triangle = np.triu(corr, k=1)
    fused = np.mean(upper_triangle[upper_triangle != 0])
    fused_features = features * fused
    return fused_features

#OaA SVM Classification
def train_svm(features, labels):
    # 🔁 Shuffle data
    features, labels = shuffle(features, labels)

    # ⚖️ Standardize
    scaler = StandardScaler()
    features = scaler.fit_transform(features)

    # Split
    X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.3)

    # Train
    model = OneVsRestClassifier(SVC(kernel='linear', probability=True))
    model.fit(X_train, y_train)

    # Evaluate
    y_pred = model.predict(X_test)
    test_acc = accuracy_score(y_test, y_pred) * 100
    val_scores = cross_val_score(model, features, labels, cv=5)
    val_acc = np.mean(val_scores) * 100

    print("Classification Report:")
    print(classification_report(y_test, y_pred))
    print(f"✅ Test Accuracy: {test_acc:.2f}%")
    print(f"✅ Validation Accuracy (5-Fold): {val_acc:.2f}%")

#main
if __name__ == "__main__":
    print("Loading subset of CASIA-B dataset...")
    casia_path = "/kaggle/input/casiab"
    data, labels = load_casia_subset(casia_path)

    if len(data) == 0:
        raise RuntimeError("No images loaded. Check dataset path and folder structure.")

    labels = LabelEncoder().fit_transform(labels)

    print("Extracting deep features using ResNet101...")
    deep_features = extract_features_with_resnet(data)

    print("Selecting features using kurtosis...")
    selected = kurtosis_selection(deep_features)

    print("Fusing features using correlation...")
    fused = correlation_feature_fusion(selected)

    print("Training One-against-All SVM...")
    train_svm(fused, labels)


Loading subset of CASIA-B dataset...
Extracting deep features using ResNet101...


Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth
100%|██████████| 171M/171M [00:01<00:00, 168MB/s]


Selecting features using kurtosis...
Fusing features using correlation...
Training One-against-All SVM...
Classification Report:
              precision    recall  f1-score   support

           0       0.77      0.73      0.75       337
           1       0.80      0.78      0.79       313
           2       0.69      0.69      0.69       318
           3       0.72      0.69      0.71       331
           4       0.82      0.92      0.87       309
           5       0.65      0.60      0.63       341
           6       0.85      0.83      0.84       320
           7       0.81      0.86      0.83       318
           8       0.62      0.68      0.65       304
           9       0.64      0.60      0.62       333

    accuracy                           0.74      3224
   macro avg       0.74      0.74      0.74      3224
weighted avg       0.74      0.74      0.74      3224

✅ Test Accuracy: 73.70%
✅ Validation Accuracy (5-Fold): 75.63%
