## 下载数据集

In [None]:
import subprocess
import os

result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value

from huggingface_hub import notebook_login

notebook_login()

In [None]:
from datasets import load_dataset


ds_reason = load_dataset("Ronndy/medical_o1_sft_Chinese",cache_dir='./data/reason')
ds_no_reason = load_dataset("BAAI/IndustryInstruction_Health-Medicine",cache_dir = './data/no_reason')


## 处理数据

### 1. 推理数据集处理

In [None]:
import pandas as pd
from datasets import load_dataset
import random


# 随机选择4500条数据
random.seed(42)  # 固定随机种子
selected_indices_reason = random.sample(range(len(ds_reason['train'])), 4500)
selected_samples_reason = ds_reason['train'].select(selected_indices_reason)

# 准备提取问题、COT和回答
reason_data = []
for sample in selected_samples_reason:
    messages = sample['messages']
    
    # 初始化变量
    question_reason = ""
    cot_reason = ""
    answer_reason = ""
    
    # 提取用户问题(最后一个user消息)
    for msg in reversed(messages):  # 倒序查找确保获得最后一个user问题
        if msg['role'] == 'user':
            question_reason = msg['content']
            break
    
    # 提取assistant的COT和回答
    for msg in messages:
        if msg['role'] == 'assistant':
            content = msg['content']
            # 提取COT部分
            cot_start = content.find('<think>') + len('<think>')
            cot_end = content.find('</think>')
            cot_reason = content[cot_start:cot_end].strip() if cot_start != -1 and cot_end != -1 else ''
            
            # 提取回答部分
            resp_start = content.find('<response>') + len('<response>')
            resp_end = content.find('</response>')
            answer_reason = content[resp_start:resp_end].strip() if resp_start != -1 and resp_end != -1 else ''
            break
    
    reason_data.append({
        'question_reason': question_reason,
        'cot_reason': cot_reason,
        'answer_reason': answer_reason
    })

# 创建DataFrame
df_reason = pd.DataFrame(reason_data)

# 检查结果
print(f"推理数据集样本数: {len(df_reason)}")
print(df_reason.head(3))

# 保存到CSV（UTF-8-BOM编码避免中文乱码）
df_reason.to_csv('medical_reason_data.csv', index=False, encoding='utf-8-sig')

### 2. 非推理数据集处理

In [None]:
from datasets import load_dataset, Dataset
from unsloth.chat_templates import standardize_sharegpt
import json
import random



# 2. 随机选择1500条非推理数据
random.seed(42)
selected_indices_no_reason = random.sample(range(len(ds_no_reason['train'])), 1500)
selected_samples_no_reason = ds_no_reason['train'].select(selected_indices_no_reason)

# 3. 准备数据并转换为Dataset格式
data_for_standardization_no_reason = []
for sample_no_reason in selected_samples_no_reason:
    data_for_standardization_no_reason.append({
        "conversations": sample_no_reason["conversations"]
    })

# 转换为Hugging Face Dataset
dataset_no_reason = Dataset.from_list(data_for_standardization_no_reason)

# 4. 标准化处理
standardized_dataset_no_reason = standardize_sharegpt(dataset_no_reason)

# 5. 转换为列表并保存JSON
standardized_list_no_reason = standardized_dataset_no_reason.to_list()
output_path_no_reason = "standardized_no_reason_data.json"
with open(output_path_no_reason, 'w', encoding='utf-8') as f_no_reason:
    json.dump(standardized_list_no_reason, f_no_reason, ensure_ascii=False, indent=2)

print(f"非推理数据处理完成，已保存到 {output_path_no_reason}")
print(f"总样本数: {len(standardized_list_no_reason)}")
print("示例第一条非推理数据:")
print(json.dumps(standardized_list_no_reason[0], ensure_ascii=False, indent=2))

## 合并数据

In [None]:
import json
import pandas as pd
import random
from collections import defaultdict

# 1. 加载推理数据
df_reason = pd.read_csv('medical_reason_data.csv')
reason_data = [{
    'question': row['question_reason'],
    'cot': row['cot_reason'],
    'answer': row['answer_reason'],
    'type': 'reason',
    'is_multi_turn': False
} for _, row in df_reason.iterrows()]

# 2. 加载非推理数据
with open('standardized_no_reason_data.json', 'r', encoding='utf-8') as f:
    no_reason_data = []
    for dialog in json.load(f):
        # 兼容两种可能的格式
        if 'conversations' in dialog:
            convs = dialog['conversations']
            # 检查消息格式
            if len(convs) > 0 and isinstance(convs[0], dict):
                if 'from' in convs[0]:  # 标准格式
                    no_reason_data.append({
                        'conversations': convs,
                        'type': 'no_reason',
                        'is_multi_turn': True
                    })
                elif 'role' in convs[0]:  # 可能的替代格式
                    no_reason_data.append({
                        'conversations': [{'from': msg['role'], 'value': msg['content']} for msg in convs],
                        'type': 'no_reason',
                        'is_multi_turn': True
                    })

# 3. 合并并打乱原始数据
combined = reason_data + no_reason_data
random.shuffle(combined)

# 4. 处理多轮对话拆分
final_data = []
dialog_id = 0

for item in combined:
    if not item['is_multi_turn']:
        final_data.append({
            'question': item['question'],
            'cot': item['cot'],
            'answer': item['answer'],
            'type': item['type'],
            'dialog_id': None
        })
    else:
        conversations = item['conversations']
        human_msgs = []
        gpt_msgs = []
        
        for msg in conversations:
            # 兼容不同字段名
            speaker = msg.get('from') or msg.get('role')  # 尝试两种可能的键
            content = msg.get('value') or msg.get('content')  # 尝试两种可能的键
            
            if speaker and content:
                if speaker.lower() in ['human', 'user']:
                    human_msgs.append(content.replace('问：', '').strip())
                elif speaker.lower() in ['gpt', 'assistant']:
                    gpt_msgs.append(content.replace('答：', '').strip())
        
        min_len = min(len(human_msgs), len(gpt_msgs))
        for i in range(min_len):
            final_data.append({
                'question': human_msgs[i],
                'cot': None,
                'answer': gpt_msgs[i],
                'type': item['type'],
                'dialog_id': dialog_id
            })
        dialog_id += 1

# 5. 保存最终JSON文件
output_path = 'combined_medical_data.json'
with open(output_path, 'w', encoding='utf-8') as f:
    json.dump(final_data, f, ensure_ascii=False, indent=2)

print(f"数据处理完成，已保存到 {output_path}")
print(f"总数据量: {len(final_data)}")
print(f"其中推理数据: {len(reason_data)}")
print(f"非推理对话组数: {dialog_id}")
print("\n示例数据:")
print(json.dumps(final_data[:3], ensure_ascii=False, indent=2))

## 测试集处理

In [None]:
import json
import random
from datasets import load_dataset

# 设置随机种子保证可复现
random.seed(42)

def process_reason_test(ds_reason_test, sample_size=450):
    """处理推理数据测试集"""
    # 随机抽样指定数量
    total_samples = len(ds_reason_test)
    selected_indices = random.sample(range(total_samples), min(sample_size, total_samples))
    selected_samples = [ds_reason_test[i] for i in selected_indices]
    
    reason_test_data = []
    for sample in selected_samples:
        messages = sample['messages']
        
        # 提取最后一个用户问题
        question = next((msg['content'] for msg in reversed(messages) if msg['role'] == 'user'), "")
        
        # 提取assistant的COT和回答
        assistant_msg = next((msg['content'] for msg in messages if msg['role'] == 'assistant'), "")
        cot_start = assistant_msg.find('<think>') + len('<think>')
        cot_end = assistant_msg.find('</think>')
        cot = assistant_msg[cot_start:cot_end].strip() if cot_start != -1 and cot_end != -1 else ''
        
        resp_start = assistant_msg.find('<response>') + len('<response>')
        resp_end = assistant_msg.find('</response>')
        answer = assistant_msg[resp_start:resp_end].strip() if resp_start != -1 and resp_end != -1 else ''
        
        reason_test_data.append({
            'question': question,
            'cot': cot,
            'answer': answer,
            'type': 'reason',
            'dialog_id': None
        })
    return reason_test_data

def process_no_reason_test(ds_no_reason_test, sample_size=50):
    """处理非推理数据测试集"""
    # 随机抽样指定数量
    total_samples = len(ds_no_reason_test)
    selected_indices = random.sample(range(total_samples), min(sample_size, total_samples))
    selected_samples = [ds_no_reason_test[i] for i in selected_indices]
    
    no_reason_test_data = []
    dialog_id = 0
    
    for sample in selected_samples:
        conversations = sample['conversations']
        human_msgs = []
        gpt_msgs = []
        
        for msg in conversations:
            # 兼容不同字段名
            speaker = msg.get('from') or msg.get('role')
            content = msg.get('value') or msg.get('content')
            
            if speaker and content:
                if speaker.lower() in ['human', 'user']:
                    human_msgs.append(content.replace('问：', '').strip())
                elif speaker.lower() in ['gpt', 'assistant']:
                    gpt_msgs.append(content.replace('答：', '').strip())
        
        # 确保问题和回答配对
        min_len = min(len(human_msgs), len(gpt_msgs))
        for i in range(min_len):
            no_reason_test_data.append({
                'question': human_msgs[i],
                'cot': None,
                'answer': gpt_msgs[i],
                'type': 'no_reason',
                'dialog_id': f"test_{dialog_id}"
            })
        dialog_id += 1
    
    return no_reason_test_data

# 处理测试集（推理数据450条，非推理数据50条）
reason_test_processed = process_reason_test(ds_reason['test'], sample_size=450)
no_reason_test_processed = process_no_reason_test(ds_no_reason['test'], sample_size=50)

# 合并并打乱测试集数据
combined_test = reason_test_processed + no_reason_test_processed
random.shuffle(combined_test)

# 保存测试集结果
test_output_path = 'combined_medical_test.json'
with open(test_output_path, 'w', encoding='utf-8') as f:
    json.dump(combined_test, f, ensure_ascii=False, indent=2)

print(f"测试集处理完成，已保存到 {test_output_path}")
print(f"测试集总量: {len(combined_test)}")
print(f"其中推理数据: {len(reason_test_processed)} (抽样450条)")
print(f"非推理数据: {len(no_reason_test_processed)} (从50组对话拆分得到)")
print("\n测试集示例数据:")
print(json.dumps(combined_test[0], ensure_ascii=False, indent=2))

## 微调

deepspeed --include 'localhost:0,1,2' train.py