In [None]:
import os
import json
import torch
import numpy as np
from tqdm import tqdm
from google.colab import drive

if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')
else:
    print("Google Drive already mounted")

In [None]:
# --- PATH CONFIGURATION ---
GT_ANNOTATIONS_PATH = '/content/drive/MyDrive/AML_Project/annotations-main/annotation_json/complete_step_annotations.json'
ERROR_ANNOTATIONS_PATH = '/content/drive/MyDrive/AML_Project/annotations-main/annotation_json/error_annotations.json'

VIDEO_FEATURES_DIR = '/content/drive/MyDrive/AML_Project/3_EgoVLP/features'
TEXT_FEATURES_DIR = '/content/drive/MyDrive/AML_Project/Extension/step_3_task_graph/text_features_egovlp'
TASK_GRAPHS_DIR = '/content/drive/MyDrive/AML_Project/annotations-main/task_graphs'

# Output
OUTPUT_DIR = '/content/drive/MyDrive/AML_Project/Extension/step_4_gnn/gnn_ready_data_groundtruth'
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
# --- METADATA LOADING ---
print("Loading Ground Truth Annotations...")
with open(GT_ANNOTATIONS_PATH, 'r') as f:
    gt_data = json.load(f)

with open(ERROR_ANNOTATIONS_PATH, 'r') as f:
    error_list = json.load(f)
    error_map = {item['recording_id']: (1 if item['is_error'] else 0) for item in error_list}

In [None]:
# --- HELPER FUNCTIONS ---
def get_video_features(recording_id, start_time, end_time):
    """
    Loads the video .npy file and extracts the average features in the time segment.
    Naming convention: {recording_id}_360p_224.mp4_1s_1s.npy
    """
    filename = f"{recording_id}_360p_224.mp4_1s_1s.npy"
    path_npy = os.path.join(VIDEO_FEATURES_DIR, filename)
    path_npz = path_npy.replace('.npy', '.npz') # Fallback

    features = None

    # Loading attempt
    if os.path.exists(path_npy):
        try: features = np.load(path_npy)
        except: pass
    elif os.path.exists(path_npz):
        try:
            data = np.load(path_npz)
            features = data['arr_0'] if 'arr_0' in data else data[data.files[0]]
        except: pass

    if features is None:
        return torch.zeros(256) # Fallback if file not found

    # Temporal slicing (1 feature per second)
    fps = 1
    total_frames = features.shape[0]
    start_idx = max(0, int(np.floor(start_time * fps)))
    end_idx = min(total_frames, int(np.ceil(end_time * fps)))

    if start_idx < end_idx:
        segment = features[start_idx:end_idx]
        return torch.tensor(segment).float().mean(dim=0)
    else:
        # If segment is a single point or out of range, we take the closest frame
        safe_idx = min(start_idx, total_frames - 1)
        return torch.tensor(features[safe_idx]).float()

def load_text_features(recipe_name):
    """
    Loads text features from the .pt file corresponding to the recipe.
    Expected filename: blenderbananapancakes.pt (all lowercase, no spaces)
    """
    safe_name = recipe_name.lower().replace(" ", "").replace("-", "") + ".pt"
    path = os.path.join(TEXT_FEATURES_DIR, safe_name)

    if os.path.exists(path):
        try:
            # The .pt file saved in the previous substep is a dictionary with key 'text_features'
            data = torch.load(path)
            if isinstance(data, dict) and 'text_features' in data:
                return data['text_features'].float()
            elif isinstance(data, torch.Tensor):
                return data.float()
        except Exception as e:
            print(f"Text loading error for {safe_name}: {e}")

    return None # Returns None if it fails

def get_canonical_graph(activity_name):
    """Retrieves the graph structure (nodes and edges) and sorted IDs."""
    safe_name = activity_name.lower().replace(" ", "").replace("-", "") + ".json"
    json_path = os.path.join(TASK_GRAPHS_DIR, safe_name)

    # File search fallback
    if not os.path.exists(json_path) and os.path.exists(TASK_GRAPHS_DIR):
        for f in os.listdir(TASK_GRAPHS_DIR):
            if f.lower().replace(" ", "").replace("-", "") == safe_name:
                json_path = os.path.join(TASK_GRAPHS_DIR, f)
                break

    if os.path.exists(json_path):
        with open(json_path, 'r') as f:
            data = json.load(f)

        steps = data.get('steps', {})
        sorted_ids = sorted([int(k) for k in steps.keys()]) # Important: numerical sorting
        id_to_idx = {str(sid): i for i, sid in enumerate(sorted_ids)}

        edges = []
        for key in ['edges', 'adjacency', 'successors']:
            if key in data:
                struct = data[key]
                if isinstance(struct, list): edges = struct
                elif isinstance(struct, dict):
                    for u, neighbors in struct.items():
                        for v in neighbors: edges.append([u, v])
                break

        final_edges = []
        for u, v in edges:
            if str(u) in id_to_idx and str(v) in id_to_idx:
                final_edges.append([id_to_idx[str(u)], id_to_idx[str(v)]])

        return id_to_idx, final_edges
    return {}, []

In [None]:
# --- 4. GENERATION LOOP ---
print(f"Starting Ground Truth dataset creation for {len(gt_data)} videos...")
count = 0

for recording_id, data in tqdm(gt_data.items()):
    activity_name = data['activity_name']

    # 1. Graph Structure
    step_id_to_idx, edges = get_canonical_graph(activity_name)
    if not step_id_to_idx: continue

    num_nodes = len(step_id_to_idx)

    # 2. Loading Textual Features (from .pt file)
    # Features in the .pt file are already sorted by increasing step ID (like step_id_to_idx)
    text_tensor = load_text_features(activity_name)

    if text_tensor is not None and text_tensor.shape[0] == num_nodes:
        x_text = text_tensor
    else:
        # Fallback if we don't have the text or dimensions do not match
        # (This should not happen if substep 3 was correct)
        x_text = torch.zeros((num_nodes, 256))

    # 3. Loading Video Features (from .npy files + GT time)
    x_video = torch.zeros((num_nodes, 256))

    for step in data['steps']:
        step_id = str(step['step_id'])
        if step_id in step_id_to_idx:
            idx = step_id_to_idx[step_id]
            start, end = step['start_time'], step['end_time']

            # If the step was not performed (-1.0) we skip (it stays at zero)
            if start < 0: continue

            vid_feat = get_video_features(recording_id, start, end)

            if x_video[idx].abs().sum() == 0:
                x_video[idx] = vid_feat
            else:
                x_video[idx] = (x_video[idx] + vid_feat) / 2

    # 4. Saving
    is_error = error_map.get(recording_id, 0)
    edge_index = torch.tensor(edges).t().long() if edges else torch.empty((2, 0)).long()

    payload = {
        "vid_id": recording_id,
        "recipe": activity_name,
        "x_text": x_text.cpu(),
        "x_video": x_video.cpu(),
        "edge_index": edge_index,
        "y": torch.tensor(is_error, dtype=torch.float)
    }

    torch.save(payload, os.path.join(OUTPUT_DIR, f"gnn_ready_gt_{recording_id}.pt"))
    count += 1

print(f"\nCompleted! {count} files saved in: {OUTPUT_DIR}")