## taskA Llama

In [1]:
import re
import os
import pathlib
import pandas as pd
from typing import Any, List, Dict
from typing import Optional, Dict, Any, List, Union
from abc import ABC, abstractmethod
from langchain.prompts import ChatPromptTemplate  # 프롬프트 템플릿 처리용
from langevaluate.config import ModelConfig # LLM 설정용
from langevaluate.llmfactory import LLMFactory  # LLM 팩토리용
from tqdm.asyncio import tqdm_asyncio
import asyncio

class DatathonProcessor(ABC):
    """
    데이터톤용 AI 처리 통합 클래스
    쿼리, 평가, 임베딩을 일괄 처리할 수 있습니다.
    사용자는 이 클래스를 상속받아 특정 메서드만 구현하면 됩니다.
    """
    # LLM 설정 상수들
    
    DEFAULT_MODEL_CONFIG = {
        'model_name': 'LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct-AWQ',
        'api_base': 'https://api.snubhai.org/api/v1/llm',
        'max_tokens': 2000,
        'seed': 777,
        'temperature': 0,
        'rpm': 10
    }

    def __init__(
        self,
        api_key : str,
    ):
        # 기본 설정 복사
        config = self.DEFAULT_MODEL_CONFIG.copy()
        
        # model_name만 클래스별 설정으로 업데이트
        config['model_name'] = self.get_model_name()
        
        # LLM 설정 생성
        custom_config = ModelConfig(
            model_name=config['model_name'],
            api_base=config['api_base'],
            api_key=api_key,
            max_tokens=config['max_tokens'],
            seed=config['seed'],
            provider="openai"
        )
        
        # LLM 인스턴스 생성
        self.llm = LLMFactory.create_llm(
            custom_config, 
            temperature=config['temperature'], 
            rpm=config['rpm']
        )
        
        # 프롬프트 템플릿 설정
        self.prompt_template = ChatPromptTemplate.from_template(self.get_prompt_template())
        self.chain = self.prompt_template | self.llm

        # 결과 저장소
        self.results: List[str] = []
        
        # metric 저장소
        self.metrics: Dict[str, Any] = {}
    
        
    def get_model_name(self) -> str:
        """
        사용할 모델명을 반환합니다.
        상속 클래스에서 이 메서드를 오버라이드하여 특정 모델을 설정할 수 있습니다.
        """
        return self.DEFAULT_MODEL_CONFIG['model_name']


    @abstractmethod
    async def preprocess_data(self, data: Any) -> Dict[str, Any]:
        """데이터 전처리 메서드"""
        pass
    
    @abstractmethod
    def get_prompt_template(self) -> str:
        """사용자가 구현해야 하는 프롬프트 템플릿 메서드"""
        pass
    
    @abstractmethod
    async def postprocess_result(self, result: Any) -> str:
        """데이터 후처리 메서드"""
        pass

    async def summarize(
        self, 
        data: pd.DataFrame
    ) -> List[str]:
        """
        단일 입력과 배치 입력을 모두 처리하는 통합 메서드
        """
        # 데이터 전처리
        
        preprocess_tasks = [self.preprocess_data(row) for _, row in data.iterrows()]
        preprocessed_data = await tqdm_asyncio.gather(*preprocess_tasks)

        # 각각을 별도의 coroutine으로 실행
        tasks = [self.chain.ainvoke(vars) for vars in preprocessed_data]

        # tqdm_asyncio.gather로 동시에 실행하며 progress bar 표시
        responses = await tqdm_asyncio.gather(*tasks)

        postprocess_tasks = [self.postprocess_result(r.content) for r in responses]
        results = await tqdm_asyncio.gather(*postprocess_tasks)
        
        return results

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import pandas as pd
import numpy as np
import torch
import asyncio
import time
from typing import Any, Dict, List
from bert_score import BERTScorer
from langchain.prompts import ChatPromptTemplate
from langevaluate.config import ModelConfig
from langevaluate.llmfactory import LLMFactory
import re

# 대회 제공 BertScore 클래스 (100% 동일)
class BertScore:
    def __init__(self, model_type="distilbert-base-uncased", batch_size=16):
        with torch.no_grad():
            self.bert_scorer = BERTScorer(
                model_type=model_type,
                batch_size=batch_size,
            )

    def __call__(self, refs, hyps):
        p, r, f = self.bert_scorer.score(
            cands=hyps,
            refs=refs,
            verbose=False,
            batch_size=8,
        )
        return f.tolist()

# 대회 제공 FairnessScore 클래스 (100% 동일)
class FairnessScore:
    def __init__(self, bin_width: int = 10, min_samples_per_group: int = 1):
        self.bin_width = int(bin_width)
        self.min_samples_per_group = int(min_samples_per_group)
        self.last_stats = None

    @staticmethod
    def _ensure_1d(a) -> np.ndarray:
        a = np.asarray(a)
        if a.ndim == 2 and a.shape[1] == 1:
            a = a[:, 0]
        if a.ndim != 1:
            raise ValueError("Input must be 1D or (N,1) shaped.")
        return a

    def _bin_ages(self, ages) -> np.ndarray:
        a = self._ensure_1d(ages).astype(float)
        if np.any(np.isnan(a)):
            raise ValueError("ages contain NaN.")
        if self.bin_width <= 0:
            raise ValueError("bin_width must be positive.")
        starts = (np.floor(a / self.bin_width) * self.bin_width).astype(int)
        ends = starts + self.bin_width
        labels = np.array([f"{s:d}-{e:d}" for s, e in zip(starts, ends)], dtype=object)
        return labels

    def _groups_from_type(self, groups, type: str) -> np.ndarray:
        t = (type or "sex").lower()
        if t not in ("sex", "age"):
            raise ValueError("type must be 'sex' or 'age'.")
        if t == "sex":
            g = self._ensure_1d(groups)
            return g
        else:
            return self._bin_ages(groups)

    def __call__(self, groups, scores, type: str = "sex", sample_weight=None) -> float:
        g = self._groups_from_type(groups, type=type)
        s = self._ensure_1d(scores).astype(float)
        if s.shape[0] != g.shape[0]:
            raise ValueError("groups and scores must have the same length.")

        if sample_weight is None:
            w = np.ones_like(s, dtype=float)
        else:
            w = self._ensure_1d(sample_weight).astype(float)
            if w.shape[0] != s.shape[0]:
                raise ValueError("sample_weight length must match scores.")

        s = np.clip(s, 0.0, 1.0)

        uniq = np.unique(g)
        means = []
        by_group = {}
        for grp in uniq:
            mask = (g == grp)
            if np.sum(mask) < self.min_samples_per_group:
                continue
            denom = np.sum(w[mask])
            if denom <= 0:
                continue
            m = float(np.average(s[mask], weights=w[mask]))
            means.append(m)
            by_group[str(grp)] = m

        if len(means) <= 1:
            self.last_stats = {"by_group": by_group, "gap": 0.0, "min": None, "max": None}
            return 1.0

        max_m = float(np.max(means))
        min_m = float(np.min(means))
        fairness = 1.0 if max_m == 0.0 else float(min_m / max_m)
        fairness = float(np.clip(fairness, 0.0, 1.0))

        self.last_stats = {"by_group": by_group, "gap": max_m - min_m, "min": min_m, "max": max_m}
        return fairness

# TaskA Processor (앞서 작성한 최적화 버전)
class TaskAProcessor:
    def __init__(self, api_key: str):
        self.api_key = api_key
        
        config = ModelConfig(
            # model_name="LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct-AWQ",
            model_name="meta-llama/Llama-3.1-8B-Instruct",
            api_base="https://api.snubhai.org/api/v1/llm",
            api_key=api_key,
            max_tokens=1500,
            seed=777,
            provider="openai"
        )
        
        self.llm = LLMFactory.create_llm(config, temperature=0, rpm=10)
        self.prompt_template = ChatPromptTemplate.from_template(self.get_prompt_template())
        self.chain = self.prompt_template | self.llm

    def get_prompt_template(self) -> str:
        return """You are a senior attending physician with 15+ years of experience writing comprehensive Brief Hospital Course summaries. Create a professional, chronologically structured summary that captures the essential medical narrative.

CRITICAL REQUIREMENTS:
- Write 250-400 words (optimal length based on successful cases)
- Maintain chronological flow: Admission → Course → Outcome
- Use precise medical terminology consistently
- Include key diagnostic findings, treatments, and patient responses
- Maintain professional tone regardless of patient demographics
- Focus on clinically significant events and interventions

STRUCTURE METHODOLOGY:
1. Opening: Patient presentation and admission reason
2. Initial Assessment: Key findings, diagnostics, initial diagnosis
3. Hospital Course: Chronological treatment progression, complications
4. Clinical Response: Patient improvement/deterioration, interventions
5. Discharge Planning: Final status, disposition, follow-up needs

EXAMPLES:

MEDICAL RECORD: [Complex gynecologic oncology case...]
BRIEF HOSPITAL COURSE: Ms. ___ was admitted to the gynecologic oncology service after undergoing diagnostic laparoscopy converted to exploratory laparotomy, total abdominal hysterectomy, bilateral salpingo-oophrectomy, omentectomy, pelvic and para-aortic lymph node dissection, and tumor debulking for Stage IIIC ovarian carcinoma. Her postoperative course was complicated by prolonged ileus requiring nasogastric decompression and total parenteral nutrition. She developed a wound infection on postoperative day 5 treated with antibiotics and wound care. Patient was discharged home on postoperative day 8 in stable condition with visiting nurse services arranged.

MEDICAL RECORD: [Cardiac case with preoperative evaluation...]
BRIEF HOSPITAL COURSE: He was admitted to the cardiology service and remained chest pain free. He underwent routine preoperative testing and evaluation. He developed early signs of gout flare in the right and left hallux which was treated with colchicine and responded well. His cardiac catheterization revealed severe three-vessel coronary artery disease requiring surgical revascularization. He was medically optimized and discharged home after 3 days in stable condition with cardiothoracic surgery follow-up scheduled.

Now create a Brief Hospital Course for:

MEDICAL RECORD: {user_input}

BRIEF HOSPITAL COURSE:"""

    async def preprocess_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
        medical_record = data['medical record']
        
        if pd.isna(medical_record) or not isinstance(medical_record, str):
            return {'user_input': ''}
        
        processed_sections = []
        
        # Chief Complaint & Service
        if 'Chief Complaint:' in medical_record:
            cc_match = re.search(r'Chief Complaint:\s*([^\n]+)', medical_record)
            if cc_match:
                processed_sections.append(f"Chief Complaint: {cc_match.group(1).strip()}")
        
        if 'Service:' in medical_record:
            service_match = re.search(r'Service:\s*([^\n]+)', medical_record)
            if service_match:
                processed_sections.append(f"Service: {service_match.group(1).strip()}")
        
        # History of Present Illness
        if 'History of Present Illness:' in medical_record:
            hpi_match = re.search(r'History of Present Illness:\s*(.*?)(?=\n\n|\nPast Medical|Physical Exam|$)', 
                                medical_record, re.DOTALL)
            if hpi_match:
                hpi = hpi_match.group(1).strip()[:800]
                processed_sections.append(f"History: {hpi}")
        
        # Major Procedures
        if 'Major Surgical or Invasive Procedure:' in medical_record:
            proc_match = re.search(r'Major Surgical or Invasive Procedure:\s*(.*?)(?=\n\n|History of Present|$)', 
                                 medical_record, re.DOTALL)
            if proc_match:
                proc = proc_match.group(1).strip()
                if proc.lower() not in ['none', 'none.']:
                    processed_sections.append(f"Major Procedures: {proc}")
        
        # Past Medical History
        if 'Past Medical History:' in medical_record:
            pmh_match = re.search(r'Past Medical History:\s*(.*?)(?=\n\n|PAST SURGICAL|Social History|$)',
                                medical_record, re.DOTALL)
            if pmh_match:
                pmh = pmh_match.group(1).strip()[:400]
                processed_sections.append(f"Past Medical History: {pmh}")
        
        # Imaging IMPRESSION
        impressions = re.findall(r'IMPRESSION:\s*(.*?)(?=\n\n|\n[A-Z_]|\Z)', medical_record, re.DOTALL)
        if impressions:
            for i, imp in enumerate(impressions[:2]):
                processed_sections.append(f"Imaging {i+1}: {imp.strip()[:200]}")
        
        processed_text = '\n\n'.join(processed_sections)
        processed_text = re.sub(r'___+', '[REDACTED]', processed_text)
        processed_text = re.sub(r'\s+', ' ', processed_text)
        processed_text = processed_text[:3000]
        
        return {'user_input': processed_text.strip()}
    
    async def postprocess_result(self, result: str) -> str:
        result = result.strip()
        
        prefixes = ['BRIEF HOSPITAL COURSE:', 'Brief Hospital Course:', 'brief hospital course:']
        for prefix in prefixes:
            if result.startswith(prefix):
                result = result[len(prefix):].strip()
        
        if result and not result.endswith('.'):
            result += '.'
        
        # 단어 수 최적화
        words = result.split()
        if len(words) > 450:
            sentences = [s.strip() for s in result.split('.') if s.strip()]
            important_keywords = ['admitted', 'diagnosis', 'treated', 'underwent', 'developed', 
                                'improved', 'discharged', 'course', 'complication']
            
            important_sentences = []
            for sentence in sentences:
                if any(keyword in sentence.lower() for keyword in important_keywords) or len(important_sentences) < 3:
                    important_sentences.append(sentence.strip())
                if len(' '.join(important_sentences).split()) >= 400:
                    break
            
            if important_sentences:
                result = '. '.join(important_sentences)
                if not result.endswith('.'):
                    result += '.'
        
        # 의료 용어 표준화
        medical_corrections = {
            'pt ': 'patient ',
            'w/ ': 'with ',
            'w/o ': 'without ',
            'h/o ': 'history of '
        }
        
        for wrong, correct in medical_corrections.items():
            result = result.replace(wrong, correct)
        
        return result

# 리더보드 100% 동일 평가 함수
async def exact_taskA_evaluation(train_csv_path: str, api_key: str):
    """대회 리더보드와 정확히 동일한 Task A 평가"""
    
    print("=" * 80)
    print("🏆 Task A 리더보드 정확 평가 시뮬레이션")
    print("=" * 80)
    
    # 1. 데이터 로드 및 전처리
    print("1. 데이터 로드 중...")
    df = pd.read_csv(train_csv_path)
    df = df.dropna(subset=['medical record', 'target'])
    
    eval_samples = min(300, len(df))
    eval_df = df.iloc[:eval_samples].copy()
    print(f"평가 샘플: {eval_samples}개")
    
    print(f"\n📊 데이터 분포:")
    print(f"성별: {eval_df['gender'].value_counts().to_dict()}")
    print(f"연령: 평균 {eval_df['anchor_age'].mean():.1f}세")
    
    # 2. TaskA 처리기 초기화
    print("\n2. TaskA 처리기 초기화 (EXAONE 모델)...")
    processor = TaskAProcessor(api_key)
    
    # 3. 예측 생성
    print("3. Brief Hospital Course 생성 중...")
    start_time = time.time()
    
    data_batch = [{'medical record': row['medical record']} for _, row in eval_df.iterrows()]
    
    results = []
    batch_size = 8
    
    for i in range(0, len(data_batch), batch_size):
        batch = data_batch[i:i+batch_size]
        print(f"   배치 {i//batch_size + 1}/{(len(data_batch)-1)//batch_size + 1} 처리 중...")
        
        # 전처리
        preprocessed = [await processor.preprocess_data(row) for row in batch]
        
        # API 호출
        tasks = [processor.chain.ainvoke(prep) for prep in preprocessed]
        responses = await asyncio.gather(*tasks)
        
        # 후처리
        batch_results = [await processor.postprocess_result(r.content) for r in responses]
        results.extend(batch_results)
        
        # API 제한 준수
        if i + batch_size < len(data_batch):
            print(f"   API 제한 준수를 위해 70초 대기...")
            await asyncio.sleep(70)
    
    predictions = results
    generation_time = time.time() - start_time
    print(f"예측 생성 완료 (총 소요 시간: {generation_time:.1f}초)")
    
    # 4. 정답 데이터 준비
    references = eval_df['target'].tolist()
    
    # 5. BERTScore 계산 (대회 공식 계산)
    print("\n4. BERTScore 계산 중...")
    bert_scorer = BertScore(model_type="distilbert-base-uncased", batch_size=16)
    bert_scores = bert_scorer(refs=references, hyps=predictions)
    bert_mean = np.mean(bert_scores)
    bert_std = np.std(bert_scores)
    
    # 6. 공정성 지표 계산 (대회 공식 계산)
    print("5. 공정성 지표 계산 중...")
    fairness_scorer = FairnessScore(bin_width=10, min_samples_per_group=1)
    
    # 성별 공정성
    gender_fairness = fairness_scorer(
        groups=eval_df['gender'].tolist(),
        scores=bert_scores,
        type='sex'
    )
    gender_stats = fairness_scorer.last_stats
    
    # 연령 공정성
    age_fairness = fairness_scorer(
        groups=eval_df['anchor_age'].tolist(),
        scores=bert_scores,
        type='age'
    )
    age_stats = fairness_scorer.last_stats
    
    # 7. 대회 정확한 결과 출력
    print("\n" + "=" * 80)
    print("🎯 Task A 리더보드 정확 평가 결과")
    print("=" * 80)
    
    print(f"📊 BERTScore (대회 공식 계산)")
    print(f"   평균: {bert_mean:.6f}")
    print(f"   표준편차: {bert_std:.6f}")
    print(f"   최고: {max(bert_scores):.6f}")
    print(f"   최저: {min(bert_scores):.6f}")
    print(f"   중앙값: {np.median(bert_scores):.6f}")
    
    print(f"\n⚖️ 공정성 지표 (대회 공식 계산)")
    print(f"   성별 공정성: {gender_fairness:.6f}")
    print(f"   성별별 성능: {gender_stats['by_group']}")
    print(f"   성별 격차: {gender_stats['gap']:.6f}")
    print(f"   ")
    print(f"   연령 공정성: {age_fairness:.6f}")
    print(f"   연령대별 성능: {age_stats['by_group']}")
    print(f"   연령 격차: {age_stats['gap']:.6f}")
    
    # 8. 정량 평가 점수 계산 (Task A는 16점 만점)
    print(f"\n🏆 Task A 정량 평가 점수")
    
    # BERTScore 점수 (10점 만점 - 16점의 5/8)
    bert_score_points = min(10.0, max(0.0, (bert_mean / 0.85) * 10.0))  # 목표 0.85
    
    # 공정성 점수 (6점 만점 - 16점의 3/8)
    fairness_avg = (gender_fairness + age_fairness) / 2.0
    fairness_points = min(6.0, max(0.0, (fairness_avg / 0.95) * 6.0))
    
    # 총점
    total_quantitative = bert_score_points + fairness_points
    
    print(f"   BERTScore: {bert_score_points:.3f}/10.000 점")
    print(f"   공정성 지표: {fairness_points:.3f}/6.000 점")
    print(f"   정량 총점: {total_quantitative:.3f}/16.000 점")
    print(f"   정량 달성률: {total_quantitative/16.0*100:.1f}%")
    
    # 9. 성능 등급 판정
    print(f"\n🎖️ 성능 등급")
    if total_quantitative >= 14.0:
        grade = "S급 (최우수)"
        recommendation = "즉시 제출 권장"
    elif total_quantitative >= 12.0:
        grade = "A급 (우수)"
        recommendation = "제출 권장"
    elif total_quantitative >= 10.0:
        grade = "B급 (양호)"
        recommendation = "소폭 개선 후 제출"
    else:
        grade = "C급 (보통)"
        recommendation = "개선 필요"
    
    print(f"   등급: {grade}")
    print(f"   권장사항: {recommendation}")
    
    # 10. 예측 샘플 분석
    print(f"\n📝 예측 품질 샘플 (상위/하위 각 3개)")
    print("-" * 80)
    
    sorted_indices = np.argsort(bert_scores)
    
    print("🏆 최고 성능 샘플:")
    for i in range(3):
        idx = sorted_indices[-(i+1)]
        print(f"샘플 {idx} (BERTScore: {bert_scores[idx]:.4f})")
        print(f"예측: {predictions[idx][:200]}...")
        print(f"정답: {references[idx][:200]}...")
        print()
    
    print("⚠️ 최저 성능 샘플:")
    for i in range(3):
        idx = sorted_indices[i]
        print(f"샘플 {idx} (BERTScore: {bert_scores[idx]:.4f})")
        print(f"예측: {predictions[idx][:200]}...")
        print(f"정답: {references[idx][:200]}...")
        print()
    
    return {
        'bert_score_mean': bert_mean,
        'bert_score_std': bert_std,
        'bert_scores': bert_scores,
        'gender_fairness': gender_fairness,
        'age_fairness': age_fairness,
        'total_score': total_quantitative,
        'grade': grade,
        'predictions': predictions,
        'references': references,
        'evaluation_samples': eval_samples,
        'processing_time': generation_time
    }

# 실행
API_KEY = "cfa06ca698c85aa9c9d4b55440aeef0f85ed94f644cd7b931fdd69f2421c6ecb"
TRAIN_CSV_PATH = "./data/taskA_train.csv"

# Task A 리더보드 정확 평가 실행
taskA_results = await exact_taskA_evaluation(
    train_csv_path=TRAIN_CSV_PATH,
    api_key=API_KEY
)

print(f"\n🎉 Task A 리더보드 평가 완료!")
print(f"최종 예상 점수: {taskA_results['total_score']:.3f}/16.000 점")

🏆 Task A 리더보드 정확 평가 시뮬레이션
1. 데이터 로드 중...
평가 샘플: 300개

📊 데이터 분포:
성별: {'M': 155, 'F': 145}
연령: 평균 63.4세

2. TaskA 처리기 초기화 (EXAONE 모델)...
3. Brief Hospital Course 생성 중...
   배치 1/38 처리 중...
API Error: Error code: 429 - {'error': {'message': 'Rate limit exceeded. Token bucket: 0.00/10.0 tokens. Wait 60s.', 'type': 'rate_limit_error', 'param': None, 'code': 'rate_limit_exceeded'}}, retry 1/3
API Error: Error code: 429 - {'error': {'message': 'Rate limit exceeded. Token bucket: 0.00/10.0 tokens. Wait 60s.', 'type': 'rate_limit_error', 'param': None, 'code': 'rate_limit_exceeded'}}, retry 1/3
   API 제한 준수를 위해 70초 대기...
   배치 2/38 처리 중...
   API 제한 준수를 위해 70초 대기...
   배치 3/38 처리 중...
   API 제한 준수를 위해 70초 대기...
   배치 4/38 처리 중...
   API 제한 준수를 위해 70초 대기...
   배치 5/38 처리 중...
   API 제한 준수를 위해 70초 대기...
   배치 6/38 처리 중...
   API 제한 준수를 위해 70초 대기...
   배치 7/38 처리 중...
   API 제한 준수를 위해 70초 대기...
   배치 8/38 처리 중...
   API 제한 준수를 위해 70초 대기...
   배치 9/38 처리 중...
   API 제한 준수를 위해 70초 대기...
   배치 10/38 