In [None]:
import os
import json
import yaml
import torch
import difflib
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

In [2]:
import importlib
import help_functions
importlib.reload(help_functions)
from help_functions import *

In [None]:
from peft import PeftModel
from qwen_vl_utils import process_vision_info

max_pixels = 262144     # Inference 与 Train 保持一致
torch_dtype = torch.float16

# from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
# base_model_path = "../../OS-Atlas-main/models/OS-Atlas-Base-7B"
# lora_weights_path = "../../OS-Atlas-main/saves/android_world/OS-Atlas-InnovAll-Iter2"  # 要利用迭代后的模型 

# # 加载基础模型
# model = Qwen2VLForConditionalGeneration.from_pretrained(
#     base_model_path,
#     torch_dtype=torch_dtype,
#     device_map="auto"
# )

from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
base_model_path = "../../GUI-R1/models/GUI-R1-7B"
lora_weights_path = "../../GUI-R1/saves/android_world/GUI-R1-InnovAll-Iter2"  # 要利用迭代后的模型 

# 加载基础模型
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    base_model_path,
    torch_dtype=torch_dtype,
    device_map="auto"
)

# 加载 LoRA 权重并应用到基础模型
model = PeftModel.from_pretrained(
    model,
    lora_weights_path,
    torch_dtype=torch_dtype,
    device_map="auto"
)

processor = AutoProcessor.from_pretrained(
    base_model_path,
    max_pixels=max_pixels
)



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [4]:
with open('prompt.yaml', 'r') as file:
    prompts = yaml.safe_load(file)
prompt_key_action = prompts['prompt_key_action']        # 提取关键节点

def get_messages(prompt, text):
    messages = [
        {
            "role": "system",
            "content": [
                {
                    "type": "text",
                    "text": prompt,
                }
            ],
        },
        {
            "role": "user",
            "content": [],
        }
    ]
    
    messages[1]["content"].append({
        "type": "text",
        "text": text,
    })
    
    return messages

def get_response(messages):
    # Prepare the input for the model
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(model.device)

    # Generate output
    generated_ids = model.generate(**inputs, max_new_tokens=512)

    # Post-process the output
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=False, clean_up_tokenization_spaces=False
    )
    return output_text[0]

def process_action_str(action_str):
    action_str = action_str.replace("<|im_end|>", "")
    lines = action_str.strip().split('\n')
    result = [line.split('. ', 1)[1].strip() for line in lines]
    return result

In [5]:
def align_key_nodes(key_nodes, action_description):
    if not key_nodes or not action_description:
        return []
    aligned = []
    for node in key_nodes:
        # 这里一定会有 matches[0]
        best_match = difflib.get_close_matches(node, action_description, n=1, cutoff=0)[0]
        aligned.append(best_match)
    return aligned

def get_key_node(file_path):
    trajectory = get_trajectory(file_path)
    objective = trajectory['objective']
    template = trajectory['task_template']
    action_description = trajectory['action_description']

    # 提取 key actions
    print("Action List: ")
    for desc in action_description:
        print(desc)
    action_sequence = ""
    for i, action in enumerate(action_description):
        action_sequence += f"Action ({i+1}): " + action + "\n"
    
    print("提取 key_actions:")
    text_format = "\nObjective: {objective}.\nSuccessful Action Sequence: {action_sequence}"
    text = text_format.format(
        objective = objective,
        action_sequence = action_sequence
    )
    messages = get_messages(prompt_key_action, text)
    response = get_response(messages)
    try:
        key_nodes = process_action_str(response)
    except Exception:
        key_nodes = action_description
        print("出错")

    # 将 key_nodes 中每个动作替换为 action_description 中相似度最高的原始描述，防止前后不一致
    key_nodes = align_key_nodes(key_nodes, action_description)
    if "Set the task's status" not in key_nodes[-1] and "Set the task's status" in action_description[-1]:
        key_nodes.append(action_description[-1])
    
    # 执行去重操作
    key_nodes_new = []
    for key_node in key_nodes:
        if key_node not in key_nodes_new:
            key_nodes_new.append(key_node)
    
    for key_node in key_nodes_new:
        print(key_node)

    return key_nodes_new, objective, template

In [None]:
def get_key_nodes_for_expert_demo():
    # success_tasks_path = 'json_files/OS-Atlas/InnovAll/Iter3/success_tasks_new.json'
    # output_dir = "json_files/OS-Atlas/InnovAll/Iter3"

    success_tasks_path = 'json_files/GUI-R1/InnovAll/Iter3/success_tasks_new.json'
    output_dir = "json_files/GUI-R1/InnovAll/Iter3"
    
    with open(success_tasks_path, 'r', encoding='utf-8') as json_file:
        success_tasks = json.load(json_file)

    results = {}
    for filename, file_path in success_tasks.items():
        # if filename != "SystemBluetoothTurnOffVerify":
        #     continue
        results[filename] = {}
        print("=========================")
        print(filename)
        key_nodes, objective, template = get_key_node(file_path)
        
        results[filename] = {
            "objective": objective,
            "template": template,
            "key_nodes": key_nodes
        }

    output_file = os.path.join(output_dir, 'tasks_and_key_nodes.json')
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=4, ensure_ascii=False)

In [7]:
# get_key_nodes_for_expert_demo()

SystemBluetoothTurnOn
Action List: 
Scroll up on the screen.
Click on a UI element 'Off' on the screen.
Set the task's status as 'complete'.
提取 key_actions:
Scroll up on the screen.
Click on a UI element 'Off' on the screen.
Set the task's status as 'complete'.
SystemBrightnessMinVerify
Action List: 
Scroll down on the screen.
Click on a UI element 'Calendar' on the screen.
Set the task's status as 'infeasible'.
提取 key_actions:
Scroll down on the screen.
Click on a UI element 'Calendar' on the screen.
Set the task's status as 'infeasible'.
