**Practice**
the functions of run_gpt3.py

In [None]:
import os
import json
import argparse

In [None]:
def load_data(args):
    """
    Load and process data from different sources and formats.
    
    Args:
        args: Command line arguments containing data paths and processing options
    
    Returns:
        problems: Dictionary of all problems (test + train)
        test_pids: List of test problem IDs
        cand_pids: List of candidate problem IDs for retrieval
    
    The function handles different data formats:
    1. MedQA format 
    2. MATH problems
    3. Various NLP tasks (squad, tweet_eval, etc.)
    4. Generic JSON format
    """
    if 'medqa' in args.data_root_test:
        # Load MedQA format data
        problems_test = [json.loads(line) for line in open(args.data_root_test, 'r')]
        problems_train = [json.loads(line) for line in open(args.data_root_train, 'r')]
        problems = problems_test + problems_train
        test_pids = list(i for i in range(len(problems_test)))
        train_pids = list(i for i in range(len(problems_test), len(problems_test) + len(problems_train)))
        
        # Add validation data if provided
        if args.data_root_vali is not None:
            problems_vali = [json.loads(line) for line in open(args.data_root_vali, 'r')]
            problems += problems_vali
            
    else:
        # Load generic JSON format
        problems_test = json.load(open(args.data_root_test))
        problems_train = json.load(open(args.data_root_train))
        problems = {**problems_test, **problems_train}
        test_pids = list(problems_test.keys())
        train_pids = list(problems_train.keys())
        
    # Sample test problems if specified
    if args.test_number < len(test_pids) and args.test_number > 0:
        test_pids = random.sample(test_pids, args.test_number)
        
    # Load test IDs from checkpoint if provided
    if args.test_pids_ckpt:
        test_pids = torch.load(args.test_pids_ckpt)
    print(f"number of test problems: {len(test_pids)}\n")

    # Process candidate examples
    print(f"original cand set number {len(train_pids)}")
    if args.cand_ckpt:
        # Load candidate IDs from checkpoint
        cand_pids = torch.load(args.cand_ckpt)
        if 'MATH' in args.data_root_test:
            cand_pids = [i + len(problems_test) for i in cand_pids]
    else:
        # Sample candidate examples
        if args.cand_number < len(train_pids):
            cand_pids = random.sample(train_pids, args.cand_number)
        else:
            cand_pids = train_pids
            
    # Remove test examples from candidates
    cand_pids = [i for i in cand_pids if i not in test_pids]

    return problems, test_pids, cand_pids