In [None]:
!export PATH=/scratch/x2989a04/ollama/bin:$PATH
!export LD_LIBRARY_PATH=/scratch/x2989a04/ollama/lib/ollama:$LD_LIBRARY_PATH
import ollama
ollama.list()

In [None]:
#text = ollama.generate(model='llama3.3', prompt='Generate example data(column, row) in CSV format')
text = ollama.chat(model='llama3.1-medical-v2', messages=[{'role': 'user', 'content': 'Generate sample data (columns, rows) in CSV format. No explanation required.'}])
text

In [None]:
import pandas as pd
import ollama
import os
from io import StringIO

def analyze_data(df, binary_columns=['REAL_STONE']):
    """
    Analyzes the dataset and returns summary statistics in the format:
    - For binary variables: count and percentage for each group
    - For continuous variables: mean and interquartile range (Q1-Q3)
    """
    summary = {}
    
    # Split by REAL_STONE (CBD stone presence)
    groups = {
        'CBD stone': df[df['REAL_STONE'] == 1],
        'No CBD stone': df[df['REAL_STONE'] == 0]
    }
    
    summary['Total Samples'] = {key: len(value) for key, value in groups.items()}
    
    for col in df.columns:
        if col == 'REAL_STONE':
            continue
        
        column_stats = {}
        
        if col in binary_columns:
            for group_name, group_df in groups.items():
                counts = group_df[col].value_counts()
                total = len(group_df)
                column_stats[group_name] = {
                    f'{int(val)}': f'{count} ({count/total*100:.2f}%)' for val, count in counts.items()
                }
        else:
            for group_name, group_df in groups.items():
                q1 = group_df[col].quantile(0.25)
                median = group_df[col].median()
                q3 = group_df[col].quantile(0.75)
                mean = group_df[col].mean()
                column_stats[group_name] = f'{mean:.1f} ({q1:.1f}-{q3:.1f})'
                
        summary[col] = column_stats
    
    return summary

def generate_prompt_from_data(df, times, retry=False):
 
    class_description = analyze_data(df, binary_columns=['PANCREATITIS', 'DUCT_DILIATATION_8MM','DUCT_DILIATATION_10MM','SEX','REAL_STONE'])

    prompt = f"""
        Context:
        - You are an advanced AI model specializing in generating high-quality synthetic medical data.
        - Your primary task is to generate synthetic patient records to support a predictive model for common bile duct stones (CBDS).
        - Ensure the generated data maintains realistic distributions, correlations, and medically relevant patterns based on clinical knowledge.
        - The dataset must reflect real-world medical trends while covering diverse patient profiles.

        Medical Data Guidelines:
        1. **Balance**: The dataset must maintain a 50:50 ratio for the target variable **REAL_STONE** (1 = presence of stone, 0 = absence of stone).
        2. **Feature Correlations**:
           - Patients with **REAL_STONE = 1** tend to be older and have higher values in liver function tests (**AST, ALT, ALP**), **total bilirubin**, and **bile duct diameter** (>10 mm).
           - Patients with **REAL_STONE = 0** tend to exhibit opposite trends, with lower liver function values and bile duct diameters typically ≤10 mm.
        3. **Data Constraints**:
           - Maintain realistic medical values based on real-world observations.
           - Avoid generating extreme outliers or biologically implausible values (e.g., negative lab results, bilirubin >50 mg/dL, ALP 1000 IU/L unless justified).
           - Ensure no duplicate records while maximizing variability.

        Output Format:
        - Strictly output a **CSV-formatted table** with appropriate column headers and data rows.
        - Do NOT include explanations, headers, or descriptions—only raw CSV data.

        Example Dataset Overview:
        {class_description}

        Sample Data Structure (Reference):
        {df.head(1)}

        {"WARNING: Your previous response was not in CSV format. Ensure this output is STRICTLY CSV format!" if retry else ""}
    """

    return prompt

def generate_synthetic_csv(df, model='llama3.3', times=1):
    retry_count = 0
    max_retries = 3  # 최대 3번 재시도

    while retry_count < max_retries:
        prompt = generate_prompt_from_data(df, times, retry=(retry_count > 0))
        response = ollama.chat(model=model, messages=[{'role': 'user', 'content': prompt}])

        synthetic_data = response.message.content.strip()  # 불필요한 줄 제거

        try:
            # CSV 포맷 검증
            synthetic_df = pd.read_csv(StringIO(synthetic_data))
            return synthetic_df  # CSV가 정상적으로 변환되면 반환
        except Exception as e:
            print(f"Error parsing synthetic data (attempt {retry_count + 1}): {e}")
            print("Generated Data Preview:\n", synthetic_data[:500])
            retry_count += 1

    raise ValueError("Failed to generate valid CSV after multiple attempts.")

def main():
    # columns = ['FIRST_HR', 'FIRST_BT', 'AGE', 'DUCT_DILIATATION_10MM', 'Hb', 'PLT', 'WBC', 'ALP', 'ALT', 'AST', 'BILIRUBIN', 'REAL_STONE']
    columns = ['PANCREATITIS', 'DUCT_DILIATATION_8MM', 'REAL_STONE',
       'DUCT_DILIATATION_10MM', 'SEX', 'AGE', 'FIRST_SBP', 'FIRST_DBP',
       'FIRST_HR', 'FIRST_RR', 'FIRST_BT', 'Hb', 'WBC', 'PLT', 'BILIRUBIN',
       'AST', 'ALT', 'ALP', 'GGT', 'BUN', 'CREATININE']
    df = pd.read_csv('DUMC_project/DUMC_total/dumc_1223_case3_duct_correct.csv')
    df = df[columns]

    
    target_length = len(df) * 5
    print(f"Target length: {target_length}")

    # 기존 synthetic_data.csv 파일이 존재하면 불러오기
    if os.path.exists('synthetic_data.csv'):
        all_synthetic_data = pd.read_csv('synthetic_data.csv', header=0)
    else:
        all_synthetic_data = df.copy()
        all_synthetic_data.to_csv('synthetic_data.csv', index=False)

    current_length = len(all_synthetic_data)
    print(f"Initial synthetic data length: {current_length}")

    while current_length < target_length:
        try:
            synthetic_df = generate_synthetic_csv(df, model='llama3.1-medical-v2', times=1)
            all_synthetic_data = pd.concat([all_synthetic_data, synthetic_df], ignore_index=True).drop_duplicates()
            current_length = len(all_synthetic_data)
            all_synthetic_data.to_csv('synthetic_data.csv', index=False)
            print(f"Current synthetic data length: {current_length}")
        except Exception as e:
            print(f"Failed to generate valid CSV: {e}. Retrying...")
            continue

    # 목표 길이에 정확히 맞추기
    if current_length > target_length:
        all_synthetic_data = all_synthetic_data.sample(n=target_length, random_state=42)
        all_synthetic_data.to_csv('synthetic_data.csv', index=False)
        print(f"Final synthetic data length: {len(all_synthetic_data)}")

    print(f"Synthetic data saved to 'synthetic_data.csv'")

if __name__ == "__main__":
    main()



Target length: 3665
Initial synthetic data length: 931
Current synthetic data length: 955
Current synthetic data length: 978
Current synthetic data length: 1020
Current synthetic data length: 1051
Current synthetic data length: 1084
Current synthetic data length: 1111
Current synthetic data length: 1132
Current synthetic data length: 1192
Current synthetic data length: 1237
Current synthetic data length: 1257
Current synthetic data length: 1278
Current synthetic data length: 1293
Current synthetic data length: 1314
Current synthetic data length: 1340
Current synthetic data length: 1358
Current synthetic data length: 1374
Current synthetic data length: 1389
Current synthetic data length: 1410
Current synthetic data length: 1432
Current synthetic data length: 1453
Current synthetic data length: 1476
Current synthetic data length: 1512
Current synthetic data length: 1533
Current synthetic data length: 1582
Current synthetic data length: 1596
Current synthetic data length: 1616
Current syn

In [None]:
import pandas as pd
import numpy as np

def remove_outliers_iqr(df, threshold=1.5):
    """
    IQR(Interquartile Range) 방법을 사용하여 이상치를 제거하는 함수
    - threshold: 이상치 판별 기준 (default: 1.5)
    """
    numeric_cols = df.select_dtypes(include=[np.number]).columns  # 숫자형 컬럼만 처리
    initial_size = len(df)

    for col in numeric_cols:
        Q1 = df[col].quantile(0.25)
        Q3 = df[col].quantile(0.75)
        IQR = Q3 - Q1
        lower_bound = Q1 - threshold * IQR
        upper_bound = Q3 + threshold * IQR

        # 이상치 제거
        filtered_df = df[(df[col] >= lower_bound) & (df[col] <= upper_bound)]

        if len(df) == 0:  # 데이터가 비어 있으면 중단
            print(f"Warning: '{col}' 제거 후 데이터가 비어 있음. 원본 데이터 유지.")
            return df  # 원본 데이터 반환

        # 이상치 비율 계산 (ZeroDivisionError 방지)
        if len(df) > 0:
            outlier_ratio = (1 - len(filtered_df) / len(df)) * 100
            if outlier_ratio < 10:
                df = filtered_df
            else:
                print(f"Skipping outlier removal for '{col}' (outlier ratio: {outlier_ratio:.2f}%)")

    removed_count = initial_size - len(df)
    print(f"Total outliers removed: {removed_count}")
    return df

def preprocess_synthetic_data(input_file='synthetic_data.csv', output_file='synthetic_data_preprocessed.csv'):
    # CSV 파일 읽기
    df = pd.read_csv(input_file)

    print(f"Original data size: {len(df)}")

    # 1. Null 값 제거
    df = df.dropna()
    print(f"After removing Null values: {len(df)}")

    # 2. 이상치 제거 (ZeroDivisionError 방지 포함)
    df = remove_outliers_iqr(df)

    # 전처리된 데이터 저장
    df.to_csv(output_file, index=False)
    print(f"Preprocessed data saved to '{output_file}' with {len(df)} rows.")

if __name__ == "__main__":
    preprocess_synthetic_data()
