# 模型下載

In [1]:
import time
from os import times

from transformers import AutoProcessor, Gemma3ForConditionalGeneration
import torch
import os

  from .autonotebook import tqdm as notebook_tqdm


KeyboardInterrupt: 

In [2]:
torch.cuda.get_device_name()

'NVIDIA GeForce RTX 4090'

In [3]:
model_id = "google/gemma-3-4b-it"

model = Gemma3ForConditionalGeneration.from_pretrained(
    model_id, device_map="auto"
).eval()

processor = AutoProcessor.from_pretrained(model_id)

Fetching 2 files: 100%|██████████| 2/2 [03:30<00:00, 105.50s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.94it/s]
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [4]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [5]:
torch.cuda.empty_cache()

In [5]:
model.cuda()

Gemma3ForConditionalGeneration(
  (vision_tower): SiglipVisionModel(
    (vision_model): SiglipVisionTransformer(
      (embeddings): SiglipVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
        (position_embedding): Embedding(4096, 1152)
      )
      (encoder): SiglipEncoder(
        (layers): ModuleList(
          (0-26): 27 x SiglipEncoderLayer(
            (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (self_attn): SiglipAttention(
              (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
            )
            (layer_norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (mlp): SiglipMLP(
            

# 準備資料

In [6]:
import pandas as pd
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

In [6]:
file_locate = "/tmp/pycharm_project_477/" #遠端環境

images = os.listdir(file_locate + "/dataset/擊球數據整理/images")
Inputs = pd.read_csv(file_locate + "/dataset/擊球數據整理/question_Input.csv")

init_prompt = open(file_locate + "/dataset/init_prompt.txt").read()
#rule = pd.read_excel(file_locate + "/dataset/回饋規則.xlsx")

In [10]:
str(rule)

'                球路類型   結果                                                 原因  \\\n0            Pull左飛球   失誤                                   上桿(P2~3)時，角度過於陡峭   \n1            Pull左飛球   失誤                                             桿頭頂點過高   \n2            Pull左飛球   失誤                           下桿角度過於陡峭，左手腕過度外展，肩關節伸展抬起   \n3            Pull左飛球   失誤                                     桿面關閉，擊球點位於球的外側   \n4     Pull Hook左拉左曲球   失誤                                   上桿(P2~3)時，角度過於陡峭   \n5     Pull Hook左拉左曲球   失誤                                             桿頭頂點過高   \n6     Pull Hook左拉左曲球   失誤                          下桿角度過於陡峭，手腕過度彎曲，過度由內而外的路徑   \n7     Pull Hook左拉左曲球   失誤                         桿面關閉，擊球點位於球的外側，手腕繼續彎曲未保持向前   \n8   Pull Slice 左拉右曲球   失誤  通常是因為上桿時P2過於內側，手臂和身體過於靠近卡住之後，反而在下桿時由外側下桿、或是軸心偏...   \n9   Pull Slice 左拉右曲球   失誤                                           擺動路徑過於內向   \n10  Pull Slice 左拉右曲球   失誤                                           下桿時由外側下桿   \n11  Pull Slice 左拉右曲球   失誤 

In [8]:
import base64

def encode_base64(image):
    with open(image, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")

In [12]:
class GolfDataset(Dataset):
    def __init__(self, Input):
        self.num = Input["num"]
        self.images = []
        self.questions = Input["Input"]
        self.ground_truth = Input["GroundTruth"]
    def __len__(self):
        return len(self.questions)
    def __getitem__(self, idx):
        num = self.num.iloc[idx]
        question = self.questions.iloc[idx]
        image = encode_base64(file_locate + "/dataset/擊球數據整理/images/"+"combined_" +str(self.num.iloc[idx]) + ".jpg")
        ground_truth = self.ground_truth.iloc[idx]
        return num,image,question,ground_truth

In [13]:
golf_dataset = GolfDataset(Inputs)
golf_dataloader = DataLoader(golf_dataset, shuffle=False)

In [23]:
result_df = pd.DataFrame()


rule = str(rule)
for num ,images , questions, ground_truth in golf_dataloader:
    images = str(images)
    messages = [
        {
            "role": "system",
            "content": [{"type": "text", "text": init_prompt+rule}],
        },
        {
          "role": "user",
          "content": [
                {"type": "image", "base64": images},
                {"type": "text", "text": questions},
            ],
        },
    ]
    inputs = processor.apply_chat_template(
        messages, add_generation_prompt=True, tokenize=True,
        return_dict=True, return_tensors="pt", padding="longest",pad_to_multiple_of=8
    ).to(model.device, dtype=torch.bfloat16)
    
    input_len = inputs["input_ids"].shape[-1]
    
    with torch.inference_mode():
        generation = model.generate(
            **inputs,
            max_new_tokens=1024,
            do_sample=True,
            temperature=0.1
        )
        generation = generation[0][input_len:]
    
    decoded = processor.decode(generation, skip_special_tokens=True)
    result_df = pd.concat([result_df, pd.DataFrame({"num":num, "answer":decoded})])
    
    result_df["answer"] = result_df["answer"].str.replace(" ", "")  
    print(decoded)
    

好的，我將根據您提供的資訊，進行詳細的分析和建議。

**綜合分析結果：**

{{
        "球路": "Pull Hook 左拉左曲球",
         "原因": "整體上桿角度過於陡峭，導致上桿時手腕過度外展，再加上P4頂點轉換時，擊球點位移過左側，使得桿面在擊球時關閉，產生了明顯的左拉左曲球。此外，擊球時左手腕仍持續外展，未能充分恢復原位，進一步加劇了球路的偏差。",
         "建議": "在我們全面的揮杆分析中，您的揮杆動作與教練大致相同，請繼續保持上杆時的協調性。然而您的揮杆有幾個地方可以加強。在您的上桿階段，請嘗試減小上桿角度，避免過於陡峭，讓手腕在整個揮杆過程中保持相對穩定。在P4頂點轉換時，盡量將擊球點維持在球的中心位置，避免過度偏左。擊球時，盡可能在揮杆結束時，讓左手腕恢復成未彎曲的狀態，避免過度外展，這將有助於改善球路的穩定性，減少左拉左曲球的出現。"
}}

**具體動作細節建議 (針對上桿階段)：**

「好的，讓我們從上桿開始。我注意到您上桿時的角度有點陡峭，這導致手腕在揮桿的過程中有些過度外展，因此我們需要調整一下。下次上桿時，請嘗試放慢速度，讓上桿動作更為緩慢、更為協調。想像一下，將球杆像一條柔軟的河流，而不是像一根硬梆梆的棍子。在您上桿的過程中，請意識到手腕應該保持相對穩定，避免過度的外展。試著在P2時，盡量讓球杆與您的雙腳平行，保持球杆垂直度。當您到達P3時，手腕也應該有所收斂，但不要過度用力，以免造成不必要的壓力。重要的是找到一個能讓您感到舒適、放鬆的揮桿節奏。」

**備註：**

*   我將您的擊球數據參數和姿勢差異信息納入考慮，並結合回饋規則中的建議，進一步分析了學員的揮杆問題。
*   我根據學員的具體情況，調整了我的建議，力求更具體、更可行。
*   我盡量使用了口語化的表達方式，讓您能更好地理解和接受我的建議。

希望這些分析和建議對您有所幫助。如果您有任何問題或需要進一步的指導，請隨時提出。
{{
        "球路": "Fade 小右曲球",
         "原因": "學員在擊球時，手腕角度保持得很好，桿面觸球方正，且揮桿軌跡由外向內所致，這是一種良好的球路，代表著揮桿的控制力相對較強。",
         "建議": "非常棒！維持這種手腕角度和球路，繼續放鬆揮桿，避免過

In [15]:
result_df

Unnamed: 0,num,answer
0,186382,好的，請您提供學員的擊球數據、姿勢差異資訊以及圖片，我將按照您的要求進行分析並提供詳細的口語...
0,186387,好的，我將根據您提供的資訊進行分析，並提供詳細的口語化動作建議。\n\n**綜合分析結果：*...
0,186410,好的，我將根據您提供的資訊進行分析，並提供詳細的口語化動作建議。\n\n**綜合分析結果：*...
0,186416,好的，我將根據您提供的資訊進行分析，並提供詳細的口語化動作建議。\n\n**綜合分析結果：*...
0,198514,好的，我將根據您提供的資訊進行分析，並提供詳細的口語化動作建議。\n\n**綜合分析結果：*...
0,199353,好的，我將根據您提供的資訊進行分析，並提供詳細的口語化動作建議。\n\n**綜合分析結果：*...


In [16]:
import time

timeStamp = time.strftime('%Y_%m_%d_%H%M',time.localtime(time.time()))
result_df.to_csv(file_locate+"/experiment_result/"+timeStamp+"_test_Gemma3-4b_output_result.csv", index=False, encoding="utf-8")

In [9]:
from datasets import load_dataset
from PIL import Image

# System message for the assistant
system_message = "You are an expert product description writer for Amazon."

# User prompt that combines the user query and the schema
user_prompt = """Create a Short Product description based on the provided <PRODUCT> and <CATEGORY> and image.
Only return description. The description should be SEO optimized and for a better mobile search experience.

<PRODUCT>
{product}
</PRODUCT>

<CATEGORY>
{category}
</CATEGORY>
"""

# Convert dataset to OAI messages
def format_data(sample):
    return {
        "messages": [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_message}],
            },
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": user_prompt.format(
                            product=sample["Product Name"],
                            category=sample["Category"],
                        ),
                    },
                    {
                        "type": "image",
                        "image": sample["image"],
                    },
                ],
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": sample["description"]}],
            },
        ],
    }

def process_vision_info(messages: list[dict]) -> list[Image.Image]:
    image_inputs = []
    # Iterate through each conversation
    for msg in messages:
        # Get content (ensure it's a list)
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]

        # Check each content element for images
        for element in content:
            if isinstance(element, dict) and (
                "image" in element or element.get("type") == "image"
            ):
                # Get the image and convert to RGB
                if "image" in element:
                    image = element["image"]
                else:
                    image = element
                image_inputs.append(image.convert("RGB"))
    return image_inputs

In [22]:
import requests
from PIL import Image

image =encode_base64("/tmp/pycharm_project_979/dataset/test_question/186406_front_video_swing_plane_313_plane.jpg")

system_message = """
Mark out the right arm's bounding box on this picture ，expressed as a Precise percentage like 0.XXX，only output this format:[x1,y1,x2,y2,x3,y3,x4,y4]
"""

user_prompt = """
Push Slice 右拉右曲球
球速:69.555、發射角度:19.135、發射方向:9.572、飛行距離:206、ClubAngleFace:12.239、ClubAnglePath:-0.254

front:{A: [0.0, -0.0, 0.0, -0.0, -0.0, -0.0], F: [-0.011, 0.005, 0.022, 0.029, -0.03, -0.018], I: [-0.024, -0.053, -0.024, -0.072, -0.05, 0.226], T: [-0.029, -0.037, 0.004, -0.114, -0.07, 0.253]}
side:{A: [0.245, -0.075, -0.008, -0.078, -0.058, -0.02], F: [0.008, -0.0, 0.013, -0.01, -0.013, 0.0], I: [0.002, -0.006, -0.001, 0.01, -0.001, -0.002], T: [-0.035, -0.029, 0.218, -0.08, -0.052, -0.021]}


"""

messages = [
    {"role": "system", "content": [{"type": "text", "text": system_message}]},
    {"role": "user", "content": [
        #{"type": "text", "text": user_prompt},
        {"type": "image","image": image},
    ]},
]
    
    
inputs = processor.apply_chat_template(
    messages, add_generation_prompt=True, tokenize=True,
    return_dict=True, return_tensors="pt", padding="longest",pad_to_multiple_of=8
).to(model.device, dtype=torch.bfloat16)

input_len = inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(
        **inputs,
        max_new_tokens=2048,
        do_sample=True,
        temperature=0.1
    )
    generation = generation[0][input_len:]

decoded = processor.decode(generation, skip_special_tokens=True) 
print(decoded)

[27.1, 47.2, 57.8, 78.9, 63.2, 28.7, 78.3, 84.4]
