In [None]:
# custom_inference.py
import torch
import numpy as np
from generate_summary import generate_summary
from layers.summarizer import PGL_SUM
import h5py
import os
from tqdm import tqdm

def custom_inference(model, data_path, output_dir):
    model.eval()
    os.makedirs(output_dir, exist_ok=True)
    
    with h5py.File(data_path, "r") as hdf:
        video_keys = list(hdf.keys())
        for video in tqdm(video_keys, desc="Processing videos"):
            frame_features = torch.Tensor(np.array(hdf[f"{video}/features"])).view(-1, 1024)
            positions = np.array(hdf[f"{video}/picks"])
            n_frames = int(hdf[f"{video}/n_frames"][()])
            
            with torch.no_grad():
                scores, _ = model(frame_features)
                scores = scores.squeeze(0).cpu().numpy().tolist()
                summary = generate_summary([None], [scores], [n_frames], [positions])[0]

                # Save predicted summary
                np.save(os.path.join(output_dir, f"{video}_summary.npy"), summary)

if __name__ == "__main__":
    # Config
    DATASET_NAME = "CustomDataset"
    DATASET_PATH = f"../PGL-SUM/data/{DATASET_NAME}/custom_dataset_pool5.h5"
    MODEL_PATH = "../PGL-SUM/inference/pretrained_models/table4_models/SumMe/split0/model.pth"  # Use any pre-trained model
    OUTPUT_DIR = f"../PGL-SUM/Summaries/{DATASET_NAME}/"

    # Model setup
    model = PGL_SUM(input_size=1024, output_size=1024, num_segments=4, heads=8,
                    fusion="add", pos_enc="absolute")
    model.load_state_dict(torch.load(MODEL_PATH))
    
    # Inference
    custom_inference(model, DATASET_PATH, OUTPUT_DIR)
