In [None]:

import os
import json
from pathlib import Path

import torch
from PIL import Image
from tqdm import tqdm
from transformers import CLIPProcessor, CLIPModel
# ...existing code...
from torch.utils.data import Dataset, DataLoader



In [None]:
dir_cache_embeddings = "./cache_embeddings"


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_dir = "path/to/clip/model"  # specify the path to your CLIP model
clip_model = CLIPModel.from_pretrained(clip_dir).to(device)
clip_processor = CLIPProcessor.from_pretrained(clip_dir)

In [None]:
def get_cache_data_dir(data_dir: str, dataset: str) -> Path:
    """
    Map original data_dir to corresponding directory under cache_embeddings
    data_dir example:
      /root/.../osworld/ui_tars_15_7b/chrome/xxxx
    Returns:
      ./cache_embeddings/osworld/ui_tars_15_7b/chrome/xxxx
    """
    data_dir = Path(data_dir)
    parts = list(data_dir.parts)
    if dataset not in parts:
        raise ValueError(f"Dataset '{dataset}' not found in data_dir '{data_dir}'")
    idx = parts.index(dataset)
    rel_from_dataset = Path(*parts[idx:])             # osworld/...
    cache_root = Path(dir_cache_embeddings)
    return cache_root / rel_from_dataset 

In [None]:


class ImageDataset(Dataset):
    def __init__(self, image_paths, processor):
        self.image_paths = image_paths
        self.processor = processor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert("RGB")
            # Preprocess image to get pixel_values
            # Note: processing single image here, returns tensor
            inputs = self.processor(images=image, return_tensors="pt")
            pixel_values = inputs['pixel_values'].squeeze(0) # [3, 224, 224]
            return pixel_values, idx, True
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            # Return a placeholder, marked as invalid
            return torch.zeros(3, 224, 224), idx, False

def compute_clip_embeddings_batch(image_paths: list, save_paths: list, batch_size: int = 32):
    """
    Use DataLoader multiprocessing to accelerate batch computation of CLIP embeddings
    """
    if len(image_paths) != len(save_paths):
        raise ValueError("image_paths and save_paths must have the same length")
    
    # 1. Filter out already existing ones
    to_process_indices = []
    for i, save_path in enumerate(save_paths):
        if not save_path.exists():
            to_process_indices.append(i)
    
    if not to_process_indices:
        return

    # Extract paths that need processing
    current_image_paths = [image_paths[i] for i in to_process_indices]
    current_save_paths = [save_paths[i] for i in to_process_indices]

    # 2. Build Dataset and DataLoader
    # num_workers recommended to be set to CPU core count, e.g., 8 or 16
    num_workers = 32
    dataset = ImageDataset(current_image_paths, clip_processor)
    
    dataloader = DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=num_workers, 
        pin_memory=True,
        prefetch_factor=2
    )
    
    # 3. Batch inference
    # Using inference_mode is slightly faster than no_grad
    with torch.inference_mode():
        for batch_pixels, batch_indices, batch_valid in tqdm(dataloader, desc="Computing embeddings", leave=False):
            
            # Filter out images that failed to load
            valid_mask = batch_valid.bool()
            if not valid_mask.any():
                continue
            
            # Move to GPU
            pixel_values = batch_pixels[valid_mask].to(device, non_blocking=True)
            
            # Compute features
            image_features = clip_model.get_image_features(pixel_values=pixel_values)
            
            # L2 normalization
            image_features = torch.nn.functional.normalize(image_features, p=2, dim=-1)
            
            # Move back to CPU for saving
            image_features = image_features.cpu()
            
            # Save
            # Note: batch_indices are indices in the original dataset
            current_batch_indices = batch_indices[valid_mask]
            
            for i, feat in enumerate(image_features):
                global_idx = current_batch_indices[i].item()
                save_path = current_save_paths[global_idx]
                
                save_path.parent.mkdir(parents=True, exist_ok=True)
                torch.save(feat.unsqueeze(0), save_path)

# ...existing code...

In [None]:
def preprocess_traj_jsonl(input_jsonl: str, batch_size: int = 32):
    """
    Preprocess jsonl generated by transform_osworld_trajectories:
      - Generate CLIP embeddings (.pt) for each observation image
      - Replace observation field with corresponding .pt file path
      - Output to dir_cache_embeddings/dataset/model
    
    Args:
        input_jsonl: Input jsonl file path
        batch_size: Batch size for CLIP encoding
    """
    input_path = Path(input_jsonl)

    # Read all trajectories
    trajectories = []
    with open(input_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            trajectories.append(json.loads(line))

    if not trajectories:
        print("No trajectories read")
        return

    dataset = trajectories[0]["dataset"]
    model_name = trajectories[0]["model"]

    # Output jsonl placed under dir_cache_embeddings/dataset/model
    out_dir = Path(dir_cache_embeddings) / dataset / model_name
    out_dir.mkdir(parents=True, exist_ok=True)
    out_path = out_dir / input_path.name

    print(f"dataset: {dataset}, model: {model_name}")
    print(f"Output trajectory file: {out_path}")

    # First pass: collect all images that need processing
    all_image_paths = []
    all_save_paths = []
    image_to_embed_map = {}  # Original image path -> embedding file path
    
    print("Collecting all image paths...")
    for traj in tqdm(trajectories, desc="Collecting images"):
        data_dir = traj["data_dir"]
        cache_data_dir = get_cache_data_dir(data_dir, dataset)

        for step in traj.get("trajectory", []):
            obs = step.get("observation", "")
            if not obs or obs == "empty":
                continue

            obs_path = Path(obs)
            if not obs_path.is_file():
                continue

            # Calculate embedding file path
            try:
                rel_obs = obs_path.relative_to(Path(data_dir))
            except ValueError:
                rel_obs = obs_path.name

            embed_dir = cache_data_dir / rel_obs.parent
            embed_path = embed_dir / (Path(rel_obs).stem + ".pt")

            # Record mapping relationship
            image_to_embed_map[str(obs_path)] = str(embed_path.resolve())
            
            # Add to batch processing list
            if not embed_path.exists():
                all_image_paths.append(obs_path)
                all_save_paths.append(embed_path)

    # Batch compute embeddings
    print(f"\nTotal images to process: {len(all_image_paths)}")
    if all_image_paths:
        compute_clip_embeddings_batch(all_image_paths, all_save_paths, batch_size=batch_size)

    # Second pass: update trajectory data and save
    print("\nUpdating trajectory data...")
    with open(out_path, "w", encoding="utf-8") as f_out:
        for traj in tqdm(trajectories, desc="Writing trajectories"):
            for step in traj.get("trajectory", []):
                obs = step.get("observation", "")
                if obs and obs != "empty" and obs in image_to_embed_map:
                    step["observation"] = image_to_embed_map[obs]

            f_out.write(json.dumps(traj, ensure_ascii=False) + "\n")

    print(f"\nPreprocessing completed!")
    print(f"Processed {len(trajectories)} trajectories")
    print(f"Generated {len(all_image_paths)} embedding files")
    print(f"Output file: {out_path}")


In [None]:
# ...existing code...

def preprocess_all_jsonl_in_dataset(dataset_dir: str = "osworld", batch_size: int = 128):
    """
    Find all *_transformed_trajectories.jsonl files under dataset directory and batch process them
    
    Args:
        dataset_dir: Dataset directory path
        batch_size: Batch size for CLIP encoding
    """
    dataset_path = Path(dataset_dir)
    
    if not dataset_path.exists():
        print(f"Error: Directory {dataset_dir} does not exist")
        return
    
    # Find all jsonl files
    jsonl_files = list(dataset_path.glob("*_transformed_trajectories.jsonl"))
    
    if not jsonl_files:
        print(f"No *_transformed_trajectories.jsonl files found in {dataset_dir} directory")
        return
    
    print(f"Found {len(jsonl_files)} jsonl files:")
    for f in jsonl_files:
        print(f"  - {f}")
    print()
    
    # Process one by one
    for idx, jsonl_file in enumerate(jsonl_files, 1):
        print(f"\n{'='*80}")
        print(f"Processing file {idx}/{len(jsonl_files)}: {jsonl_file.name}")
        print(f"{'='*80}")
        
        try:
            preprocess_traj_jsonl(str(jsonl_file), batch_size=batch_size)
        except Exception as e:
            print(f"\nError: Error processing {jsonl_file}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    print(f"\n{'='*80}")
    print(f"All completed! Processed {len(jsonl_files)} files in total")
    print(f"{'='*80}")




In [None]:
# Automatically process all jsonl files in osworld directory
preprocess_all_jsonl_in_dataset(dataset_dir="androidworld", batch_size=1024)