## Imports

In [7]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from types import SimpleNamespace
from sklearn.preprocessing import MinMaxScaler
import os

In [3]:
df = pd.read_csv('/content/train.csv')
df

Unnamed: 0,시점,품목명,품종명,거래단위,등급,평년 평균가격(원),평균가격(원)
0,201801상순,건고추,화건,30 kg,상품,381666.666667,590000.0
1,201801중순,건고추,화건,30 kg,상품,380809.666667,590000.0
2,201801하순,건고추,화건,30 kg,상품,380000.000000,590000.0
3,201802상순,건고추,화건,30 kg,상품,380000.000000,590000.0
4,201802중순,건고추,화건,30 kg,상품,376666.666667,590000.0
...,...,...,...,...,...,...,...
29371,202111중순,대파,대파(일반),10키로묶음,상,0.000000,0.0
29372,202111하순,대파,대파(일반),10키로묶음,상,0.000000,0.0
29373,202112상순,대파,대파(일반),10키로묶음,상,0.000000,0.0
29374,202112중순,대파,대파(일반),10키로묶음,상,0.000000,0.0


In [8]:
# Hyperparameter Setting
config = {
    "learning_rate": 2e-5,
    "epoch": 30,
    "batch_size": 64,
    "hidden_size": 64,
    "num_layers": 2,
    "output_size": 3
}

CFG = SimpleNamespace(**config)

품목_리스트 = ['건고추', '사과', '감자', '배', '깐마늘(국산)', '무', '상추', '배추', '양파', '대파']

In [15]:
# 데이터 로딩 및 기본 필터링
def load_and_filter_data(raw_file, 산지공판장_file, 전국도매_file, 품목명, conditions):
    raw_data = pd.read_csv(raw_file)
    산지공판장 = pd.read_csv(산지공판장_file)
    전국도매 = pd.read_csv(전국도매_file)

    raw_품목 = raw_data[raw_data['품목명'] == 품목명]
    target_mask = conditions[품목명]['target'](raw_품목)
    filtered_data = raw_품목[target_mask]

    return raw_data, 산지공판장, 전국도매, filtered_data, target_mask


In [16]:
# 파생변수 생성
def generate_derived_features(filtered_data, other_data):
    unique_combinations = other_data[['품종명', '거래단위', '등급']].drop_duplicates()
    for _, row in unique_combinations.iterrows():
        품종명, 거래단위, 등급 = row['품종명'], row['거래단위'], row['등급']
        mask = (other_data['품종명'] == 품종명) & (other_data['거래단위'] == 거래단위) & (other_data['등급'] == 등급)
        temp_df = other_data[mask]
        for col in ['평년 평균가격(원)', '평균가격(원)']:
            new_col_name = f'{품종명}_{거래단위}_{등급}_{col}'
            filtered_data = filtered_data.merge(temp_df[['시점', col]], on='시점', how='left', suffixes=('', f'_{new_col_name}'))
            filtered_data.rename(columns={f'{col}_{new_col_name}': new_col_name}, inplace=True)
    return filtered_data


In [17]:
# 공판장 데이터 처리
def process_공판장_data(filtered_data, 산지공판장, 품목명, conditions):
    if conditions[품목명]['공판장']:
        filtered_공판장 = 산지공판장
        for key, value in conditions[품목명]['공판장'].items():
            filtered_공판장 = filtered_공판장[filtered_공판장[key].isin(value)]

        filtered_공판장 = filtered_공판장.add_prefix('공판장_').rename(columns={'공판장_시점': '시점'})
        filtered_data = filtered_data.merge(filtered_공판장, on='시점', how='left')
    return filtered_data


In [18]:
# 도매 데이터 처리
def process_도매_data(filtered_data, 전국도매, 품목명, conditions):
    if conditions[품목명]['도매']:
        filtered_도매 = 전국도매
        for key, value in conditions[품목명]['도매'].items():
            filtered_도매 = filtered_도매[filtered_도매[key].isin(value)]

        filtered_도매 = filtered_도매.add_prefix('도매_').rename(columns={'도매_시점': '시점'})
        filtered_data = filtered_data.merge(filtered_도매, on='시점', how='left')
    return filtered_data


In [19]:
# 데이터 정규화
def normalize_data(filtered_data, scaler=None):
    numeric_columns = filtered_data.select_dtypes(include=[np.number]).columns
    filtered_data = filtered_data[['시점'] + list(numeric_columns)]
    filtered_data[numeric_columns] = filtered_data[numeric_columns].fillna(0)

    if scaler is None:
        scaler = MinMaxScaler()
        filtered_data[numeric_columns] = scaler.fit_transform(filtered_data[numeric_columns])
    else:
        filtered_data[numeric_columns] = scaler.transform(filtered_data[numeric_columns])
    return filtered_data, scaler


In [20]:
def process_data(raw_file, 산지공판장_file, 전국도매_file, 품목명, scaler=None):
    # Step 1: 데이터 로딩 및 필터링
    raw_data, 산지공판장, 전국도매, filtered_data, target_mask = load_and_filter_data(
        raw_file, 산지공판장_file, 전국도매_file, 품목명, conditions
    )

    # Step 2: 파생변수 생성
    other_data = raw_data[~target_mask]
    filtered_data = generate_derived_features(filtered_data, other_data)

    # Step 3: 공판장 데이터 처리
    filtered_data = process_공판장_data(filtered_data, 산지공판장, 품목명, conditions)

    # Step 4: 도매 데이터 처리
    filtered_data = process_도매_data(filtered_data, 전국도매, 품목명, conditions)

    # Step 5: 데이터 정규화
    filtered_data, scaler = normalize_data(filtered_data, scaler)

    return filtered_data, scaler
