## Import Library

In [52]:
import os
import torch
import random 
import numpy as np
import pandas as pd
import torch.nn as nn
from tqdm.notebook import tqdm
from types import SimpleNamespace
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, StandardScaler
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda', index=0)

### Seed Setting

In [54]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
seed_everything(42) 

In [55]:

품목_리스트 = ['건고추', '사과', '감자', '배', '깐마늘(국산)', '무', '상추', '배추', '양파', '대파']

item_columns = {
    "감자": ["평년 평균가격(원)", "평균가격(원)", "공판장_총반입량(kg)", "공판장_총거래금액(원)", "공판장_평균가(원/kg)","감자 수미_20키로상자_특_평균가격(원)","감자 대지_20키로상자_중_평균가격(원)"],
    "건고추": ["평년 평균가격(원)", "평균가격(원)","양건_30 kg_중품_평균가격(원)"],
    "깐마늘(국산)": ["공판장_평균가(원/kg)", "평균가격(원)", "공판장_경매 건수", "도매_평균가(원/kg)"],
    "대파": ["평년 평균가격(원)", "평균가격(원)", "공판장_총반입량(kg)", "공판장_경매 건수","쪽파_10키로상자_중_평균가격(원)","대파(일반)_1키로단_중_평균가격(원)"],
    "무": ["무_20키로상자_특_평균가격(원)", "평균가격(원)", "공판장_평균가(원/kg)", "공판장_총반입량(kg)","열무_4키로상자_상_평균가격(원)","열무_1.5키로단_상_평균가격(원)"],
    "배추": ["평년 평균가격(원)", "평균가격(원)", "공판장_총반입량(kg)", "공판장_평균가(원/kg)","쌈배추_1키로상자_상_평균가격(원)","알배기배추_8키로상자_특_평균가격(원)","봄동배추_15키로상자_중_평균가격(원)"],
    "사과": ["평년 평균가격(원)", "평균가격(원)", "공판장_총반입량(kg)", "도매_평균가(원/kg)","공판장_총거래금액(원)"],
    "상추": ["평년 평균가격(원)", "평균가격(원)", "도매_평균가(원/kg)", "도매_총반입량(kg)","청_100 g_중품_평균가격(원)","적_100 g_상품_평균가격(원)"],
    "양파": ["평년 평균가격(원)", "평균가격(원)", "공판장_총반입량(kg)","도매_경매 건수", "자주양파_1키로_중_평균가격(원)","양파 수입_1키로_상_평균가격(원)","공판장_중간가(원/kg)"],
    "배": ["평년 평균가격(원)", "평균가격(원)", "공판장_총반입량(kg)", "도매_경매 건수","신고_10 개_중품_평균가격(원)"]
}

### Functions 

In [56]:


def process_data(raw_file, 산지공판장_file, 전국도매_file, 품목명, scaler=None):
    raw_data = pd.read_csv(raw_file)
    산지공판장 = pd.read_csv(산지공판장_file)
    전국도매 = pd.read_csv(전국도매_file)

    conditions = {
        '감자': {
            'target': lambda df: (df['품종명'] == '감자 수미') & (df['거래단위'] == '20키로상자') & (df['등급'] == '상'),
            '공판장': {'공판장명': ['*전국농협공판장'], '품목명': ['감자'], '품종명': ['수미'], '등급명': ['상']},
            '도매': {'시장명': ['*전국도매시장'], '품목명': ['감자'], '품종명': ['수미']}
        },
        '건고추': {
            'target': lambda df: (df['품종명'] == '화건') & (df['거래단위'] == '30 kg') & (df['등급'] == '상품'),
            '공판장': None, 
            '도매': None  
        },
        '깐마늘(국산)': {
            'target': lambda df: (df['거래단위'] == '20 kg') & (df['등급'] == '상품'),
            '공판장': {'공판장명': ['*전국농협공판장'], '품목명': ['마늘'], '품종명': ['깐마늘'], '등급명': ['상']},
            '도매': {'시장명': ['*전국도매시장'], '품목명': ['마늘'], '품종명': ['깐마늘']}
        },
        '대파': {
            'target': lambda df: (df['품종명'] == '대파(일반)') & (df['거래단위'] == '1키로단') & (df['등급'] == '상'),
            '공판장': {'공판장명': ['*전국농협공판장'], '품목명': ['대파'], '품종명': ['대파(일반)'], '등급명': ['상']},
            '도매': {'시장명': ['*전국도매시장'], '품목명': ['대파'], '품종명': ['대파(일반)']}
        },
        '무': {
            'target': lambda df: (df['거래단위'] == '20키로상자') & (df['등급'] == '상'),
            '공판장': {'공판장명': ['*전국농협공판장'], '품목명': ['무'], '품종명': ['기타무'], '등급명': ['상']},
            '도매': {'시장명': ['*전국도매시장'], '품목명': ['무'], '품종명': ['무']}
        },
        '배추': {
            'target': lambda df: (df['거래단위'] == '10키로망대') & (df['등급'] == '상'),
            '공판장': {'공판장명': ['*전국농협공판장'], '품목명': ['배추'], '품종명': ['쌈배추'], '등급명': ['상']},
            '도매': {'시장명': ['*전국도매시장'], '품목명': ['배추'], '품종명': ['배추']}
        },
        '사과': {
            'target': lambda df: (df['품종명'].isin(['홍로', '후지'])) & (df['거래단위'] == '10 개') & (df['등급'] == '상품'),
            '공판장': {'공판장명': ['*전국농협공판장'], '품목명': ['사과'], '품종명': ['후지'], '등급명': ['상']},
            '도매': {'시장명': ['*전국도매시장'], '품목명': ['사과'], '품종명': ['후지']}
        },
        '상추': {
            'target': lambda df: (df['품종명'] == '청') & (df['거래단위'] == '100 g') & (df['등급'] == '상품'),
            '공판장': {'공판장명': ['*전국농협공판장'], '품목명': ['상추'], '품종명': ['청상추'], '등급명': ['상']},
            '도매': {'시장명': ['*전국도매시장'], '품목명': ['상추'], '품종명': ['청상추']}
        },
        '양파': {
            'target': lambda df: (df['품종명'] == '양파') & (df['거래단위'] == '1키로') & (df['등급'] == '상'),
            '공판장': {'공판장명': ['*전국농협공판장'], '품목명': ['양파'], '품종명': ['기타양파'], '등급명': ['상']},
            '도매': {'시장명': ['*전국도매시장'], '품목명': ['양파'], '품종명': ['양파(일반)']}
        },
        '배': {
            'target': lambda df: (df['품종명'] == '신고') & (df['거래단위'] == '10 개') & (df['등급'] == '상품'),
            '공판장': {'공판장명': ['*전국농협공판장'], '품목명': ['배'], '품종명': ['신고'], '등급명': ['상']},
            '도매': {'시장명': ['*전국도매시장'], '품목명': ['배'], '품종명': ['신고']}
        }
    }

    raw_품목 = raw_data[raw_data['품목명'] == 품목명]
    target_mask = conditions[품목명]['target'](raw_품목)
    filtered_data = raw_품목[target_mask]

    other_data = raw_품목[~target_mask]
    unique_combinations = other_data[['품종명', '거래단위', '등급']].drop_duplicates()
    for _, row in unique_combinations.iterrows():
        품종명, 거래단위, 등급 = row['품종명'], row['거래단위'], row['등급']
        mask = (other_data['품종명'] == 품종명) & (other_data['거래단위'] == 거래단위) & (other_data['등급'] == 등급)
        temp_df = other_data[mask]
        for col in ['평년 평균가격(원)', '평균가격(원)']:
            new_col_name = f'{품종명}_{거래단위}_{등급}_{col}'
            filtered_data = filtered_data.merge(temp_df[['시점', col]], on='시점', how='left', suffixes=('', f'_{new_col_name}'))
            filtered_data.rename(columns={f'{col}_{new_col_name}': new_col_name}, inplace=True)

    if conditions[품목명]['공판장']:
        filtered_공판장 = 산지공판장
        for key, value in conditions[품목명]['공판장'].items():
            filtered_공판장 = filtered_공판장[filtered_공판장[key].isin(value)]
        
        filtered_공판장 = filtered_공판장.add_prefix('공판장_').rename(columns={'공판장_시점': '시점'})
        filtered_data = filtered_data.merge(filtered_공판장, on='시점', how='left')

    if conditions[품목명]['도매']:
        filtered_도매 = 전국도매
        for key, value in conditions[품목명]['도매'].items():
            filtered_도매 = filtered_도매[filtered_도매[key].isin(value)]
        
        filtered_도매 = filtered_도매.add_prefix('도매_').rename(columns={'도매_시점': '시점'})
        filtered_data = filtered_data.merge(filtered_도매, on='시점', how='left')

    numeric_columns = filtered_data.select_dtypes(include=[np.number]).columns
    filtered_data = filtered_data[['시점'] + list(numeric_columns)]
    filtered_data[numeric_columns] = filtered_data[numeric_columns].fillna(0)

    return filtered_data

def normalize_data(data):
    scaler = StandardScaler()
    normalized_data = scaler.fit_transform(data)
    return normalized_data, scaler

def inverse_normalize(data, scaler):
    return scaler.inverse_transform(data)


## Model and scheduler

In [57]:
class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()

        self.seq_len = int(configs.seq_len)  
        self.pred_len = int(configs.pred_len) 
        self.enc_in = configs.enc_in  
        self.d_model = configs.d_model
        self.dropout = configs.dropout
        self.rnn_type = configs.rnn_type 
        self.dec_way = configs.dec_way  
        self.seg_len = int(configs.seg_len ) 
        self.channel_id = configs.channel_id 
        self.revin = configs.revin  

        assert self.rnn_type in ['rnn', 'gru', 'lstm']
        assert self.dec_way in ['rmf', 'pmf']

        self.seg_num_x = self.seq_len//self.seg_len

        self.valueEmbedding = nn.Sequential(
            nn.Linear(self.seg_len, self.d_model),
            nn.ReLU()
        )

        if self.rnn_type == "rnn":            self.rnn = nn.RNN(input_size=self.d_model, hidden_size=self.d_model, num_layers=1, bias=True,                              batch_first=True, bidirectional=False)
        elif self.rnn_type == "gru":            self.rnn = nn.GRU(input_size=self.d_model, hidden_size=self.d_model, num_layers=1, bias=True,                              batch_first=True, bidirectional=False)
        elif self.rnn_type == "lstm":            self.rnn = nn.LSTM(input_size=self.d_model, hidden_size=self.d_model, num_layers=1, bias=True,                              batch_first=True, bidirectional=False)

        if self.dec_way == "rmf":
            self.seg_num_y = self.pred_len // self.seg_len
            self.predict = nn.Sequential(   nn.Dropout(self.dropout),     nn.Linear(self.d_model, self.seg_len)            )
            
        elif self.dec_way == "pmf":
            self.seg_num_y = self.pred_len // self.seg_len

            if self.channel_id:
                self.pos_emb = nn.Parameter(torch.randn(self.seg_num_y, self.d_model // 2))
                self.channel_emb = nn.Parameter(torch.randn(self.enc_in, self.d_model // 2))
            else:
                self.pos_emb = nn.Parameter(torch.randn(self.seg_num_y, self.d_model))

            self.predict = nn.Sequential(
                nn.Dropout(self.dropout),
                nn.Linear(self.d_model, self.seg_len)
            )
        if self.revin:
            self.revinLayer = RevIN(self.enc_in, affine=False, subtract_last=False)


    def forward(self, x):

        batch_size = x.size(0)
        if self.revin:
            x = self.revinLayer(x, 'norm').permute(0, 2, 1)
        else:
            seq_last = x[:, -1:, :].detach()
            x = (x - seq_last).permute(0, 2, 1) 

        newx = x.reshape(-1, self.seg_num_x, self.seg_len)
        x = self.valueEmbedding(newx)
        if self.rnn_type == "lstm":_, (hn, cn) = self.rnn(x)
        else: _, hn = self.rnn(x) 
        

        if self.dec_way == "rmf":
            y = []
            for i in range(self.seg_num_y):
                yy = self.predict(hn)    
                yy = yy.permute(1,0,2)   
                y.append(yy)
                yy = self.valueEmbedding(yy)
                if self.rnn_type == "lstm":
                    _, (hn, cn) = self.rnn(yy, (hn, cn))
                else:
                    _, hn = self.rnn(yy, hn)
            y = torch.stack(y, dim=1).squeeze(2).reshape(-1, self.enc_in, self.pred_len) 
        
        elif self.dec_way == "pmf":
            if self.channel_id:
                pos_emb = torch.cat([
                    self.pos_emb.unsqueeze(0).repeat(self.enc_in, 1, 1),
                    self.channel_emb.unsqueeze(1).repeat(1, self.seg_num_y, 1)
                ], dim=-1).view(-1, 1, self.d_model).repeat(batch_size,1,1)
            else:
                pos_emb = self.pos_emb.repeat(batch_size * self.enc_in, 1).unsqueeze(1)
            if self.rnn_type == "lstm":
                _, (hy, cy) = self.rnn(pos_emb,
                                       (hn.repeat(1, 1, self.seg_num_y).view(1, -1, self.d_model),
                                        cn.repeat(1, 1, self.seg_num_y).view(1, -1, self.d_model)))
            else:
                _, hy = self.rnn(pos_emb, hn.repeat(1, 1, self.seg_num_y).view(1, -1, self.d_model))
           
            y = self.predict(hy).view(-1, self.enc_in, self.pred_len)

        if self.revin:
            y = self.revinLayer(y.permute(0, 2, 1), 'denorm')
        else:
            y = y.permute(0, 2, 1) + seq_last

        return y
    
    
class RevIN(nn.Module):
    def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False):
        """
        :param num_features: the number of features or channels
        :param eps: a value added for numerical stability
        :param affine: if True, RevIN has learnable affine parameters
        """
        super(RevIN, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.affine = affine
        self.subtract_last = subtract_last
        if self.affine:
            self._init_params()

    def forward(self, x, mode:str):
        if mode == 'norm':
            self._get_statistics(x)
            x = self._normalize(x)
        elif mode == 'denorm':
            x = self._denormalize(x)
        else: raise NotImplementedError
        return x

    def _init_params(self):
        self.affine_weight = nn.Parameter(torch.ones(self.num_features))
        self.affine_bias = nn.Parameter(torch.zeros(self.num_features))

    def _get_statistics(self, x):
        dim2reduce = tuple(range(1, x.ndim-1))
        if self.subtract_last:
            self.last = x[:,-1,:].unsqueeze(1)
        else:
            self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
        self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()

    def _normalize(self, x):
        if self.subtract_last:
            x = x - self.last
        else:
            x = x - self.mean
        x = x / self.stdev
        if self.affine:
            x = x * self.affine_weight
            x = x + self.affine_bias
        return x

    def _denormalize(self, x):
        if self.affine:
            x = x - self.affine_bias
            x = x / (self.affine_weight + self.eps*self.eps)
        x = x * self.stdev
        if self.subtract_last:
            x = x + self.last
        else:
            x = x + self.mean
        return x
def get_optimizer(model, config):
    if config.optimizer.lower() == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    elif config.optimizer.lower() == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    elif config.optimizer.lower() == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    elif config.optimizer.lower() == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    else:
        raise ValueError(f"Unsupported optimizer type: {config.optimizer}")
    return optimizer

def get_scheduler(optimizer, config):
    if config.scheduler.lower() == 'step_lr':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.step_size, gamma=config.gamma)
    elif config.scheduler.lower() == 'reduce_on_plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=config.patience, factor=config.gamma)
    elif config.scheduler.lower() == 'none':
        scheduler = None
    else:
        raise ValueError(f"Unsupported scheduler type: {config.scheduler}")
    return scheduler


In [58]:
class Data(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y

    def __len__(self):
        return len(self.Y)
    
    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]
    
def reshape_data(df):
    time_series_data = []
    for idx, row in df.iterrows():
        sales_data = row.values.astype(float)
        time_series_data.append(sales_data)
    return np.array(time_series_data)

def time_slide_df(data, window_size, forecast_size):
    data_list = []
    dap_list = []
    for idx in range(0, len(data) - window_size - forecast_size + 1):
        x = data[idx:idx + window_size].reshape(window_size, -1)  
        y = data[idx + window_size:idx + window_size + forecast_size]
        data_list.append(x)
        dap_list.append(y)
    return np.array(data_list, dtype='float32'), np.array(dap_list, dtype='float32')

def create_dataloader(data, window_size, forecast_size, batch_size):
    X, Y = time_slide_df(data, window_size, forecast_size)
    ds = Data(X, Y)
    return DataLoader(ds, batch_size=batch_size, shuffle=True)


class NMAELoss(nn.Module):
    def __init__(self):
        super(NMAELoss, self).__init__()
    
    def forward(self, y_pred, y_true):
        mae = torch.mean(torch.abs(y_pred - y_true))
        mean_true = torch.mean(torch.abs(y_true))
        nmae = mae / mean_true
        
        return nmae


## Configs for Each Models

In [59]:
class Config_Multi_3:
    def __init__(self):
        self.seed = 33697 
        self.learning_rate = 0.0095 
        self.epoch = 136 
        self.batch_size = 64 
        self.optimizer = 'rmsprop'  
        self.weight_decay = 1e-09  
        self.scheduler = 'step_lr'  
        self.step_size = 20      
        self.gamma = 0.5         
        self.patience = 5        
        
        self.seq_len = 3
        self.pred_len = 3
        self.enc_in = 3
        self.d_model = 64 
        self.dropout = 0
        self.rnn_type = 'gru'
        self.dec_way = 'pmf'
        self.seg_len = 3  
        self.channel_id = True 
        self.revin = False 
        
        
class Config_Multi_9:
    def __init__(self):
        self.seed = 42 
        self.learning_rate = 0.001 
        self.epoch = 77 
        self.batch_size = 64 
        self.optimizer = 'adam' 
        self.weight_decay = 1e-12 
        self.scheduler = 'none'  
        self.step_size = 20      
        self.gamma = 0.5         
        self.patience = 5       
        self.seq_len = 9
        self.pred_len = 3
        self.enc_in = 3
        self.d_model = 64 
        self.dropout = 0.2 
        self.rnn_type = 'gru'
        self.dec_way = 'pmf'
        self.seg_len = 3  
        self.channel_id = True 
        self.revin = False 
config_multi_9 = Config_Multi_9() 
config_multi_3 = Config_Multi_3() 

class Config_Uni_3:
    def __init__(self):
        self.seed = 43389
        self.learning_rate = 0.0095
        self.epoch = 143
        self.batch_size = 64
        self.optimizer = 'rmsprop'  
        self.weight_decay = 1e-11
        self.scheduler = 'step_lr' 
        self.step_size = 10      
        self.gamma = 0.5         
        self.patience = 5        
        self.seq_len = 3
        self.pred_len = 3
        self.enc_in = 1
        self.d_model = 128
        self.dropout = 0.1
        self.rnn_type = 'gru'
        self.dec_way = 'pmf'
        self.seg_len = 3  
        self.channel_id = False
        self.revin = False
        
class Config_Uni_9:
    def __init__(self):
        self.seed = 14732
        self.learning_rate = 0.0085
        self.epoch = 78
        self.batch_size = 64
        self.optimizer = 'adamw'  
        self.weight_decay = 1e-10
        self.scheduler = 'reduce_on_plateau' 
        self.step_size = 10      
        self.gamma = 0.5         
        self.patience = 5   
        self.seq_len = 9
        self.pred_len = 3
        self.enc_in = 1
        self.d_model = 64
        self.dropout = 0.15
        self.rnn_type = 'gru'
        self.dec_way = 'pmf'
        self.seg_len = 3  
        self.channel_id = True
        self.revin = False
config_uni_3 = Config_Uni_3() 
config_uni_9 = Config_Uni_9() 

## Train Predict Function

In [60]:
model_dir = "../trainedmodels"  
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

In [61]:
def train_predict_multi(config, 품목리스트, seed, month, windowsize=9):
    seed_everything(seed) 
    predicts = {}
    for item in 품목리스트:
        config.enc_in = len(item_columns[item])
        train_data = process_data(
            "train/train.csv",
            "train/meta/TRAIN_산지공판장_2018-2021.csv",
            "train/meta/TRAIN_전국도매_2018-2021.csv",
            item        )
        train_data= train_data[item_columns[item]]
        
        
        price_df = train_data.iloc[:, :].T.reset_index().iloc[:, 1:] 
        price_df = price_df.iloc[:, :((month-1) * 3 + 110)]  
        timedata = reshape_data(price_df)
        
        normalized_timedata, scaler = normalize_data(timedata.T)  
        train_dl = create_dataloader(normalized_timedata, config.seq_len, config.pred_len ,config.batch_size)
        
        model = Model(config)
        model.to(device)
        optimizer = get_optimizer(model, config)
        scheduler = get_scheduler(optimizer, config)
        criterion = NMAELoss()

        
        for ep in range(1, config.epoch + 1):
            model.train()
            batch_losses = []
            for idx, (data_batch, target) in enumerate(train_dl):
                data_batch, target = data_batch.to(device), target.to(device)
                optimizer.zero_grad()
                output = model(data_batch)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
                batch_losses.append(loss.item())
            
            
            if ep % 5 == 0:
                avg_loss = np.mean(batch_losses)
            

        model_save_path = f"{model_dir}/segrnn_multi_{item}_{month}month_{windowsize}window_model.pth"
        torch.save(model.state_dict(), model_save_path)
        print(f"Model for {item} saved at {model_save_path}")
        
        
        product_predict = []
        for i in range(25):
            
            test_file = f"test/TEST_{i:02d}.csv"
            산지공판장_file = f"test/meta/TEST_산지공판장_{i:02d}.csv"
            전국도매_file = f"test/meta/TEST_전국도매_{i:02d}.csv"
            test_data = process_data(test_file, 산지공판장_file, 전국도매_file, item)
            
            test_data= test_data[item_columns[item]]
            test_df = test_data.iloc[:, :].T.reset_index().iloc[:, 1:] 
            if windowsize == 3 :
                test_df = test_df.iloc[: , -3 :]
            timedf = reshape_data(test_df).T  
            
            
            
            normalized_testdata = scaler.transform(timedf)
            test_tensor = torch.tensor(normalized_testdata, dtype=torch.float32).unsqueeze(0).to(device)  
            
            model.eval()  
            with torch.no_grad():
                prediction = model(test_tensor)
            
            prediction = prediction.squeeze(0).cpu().numpy()  
            inverse_pred = inverse_normalize(prediction, scaler)
            product_predict.append(inverse_pred.T[1]) 
            
        flatlist = np.concatenate(product_predict).tolist()
        predicts[item] = flatlist 
        
    return predicts, 0, 0


In [62]:
def train_predict_uni(config, 품목_리스트 , seed, month, windowsize=9):
    seed_everything(seed)
    predicts = {}
    total_item_loss = {}

    for item in 품목_리스트:
        
        train_data = process_data(
            "train/train.csv",
            "train/meta/TRAIN_산지공판장_2018-2021.csv",
            "train/meta/TRAIN_전국도매_2018-2021.csv",
            item
        )
        
        price_df = train_data.iloc[:, [2]].T.reset_index().iloc[:, 1:]
        price_df = price_df.iloc[:, :((month-1)*3 + 110)] 
        timedata = reshape_data(price_df)
        
        normalized_timedata, scaler = normalize_data(timedata.T) 
        normalized_timedata = normalized_timedata.flatten()  
        
        train_dl = create_dataloader(normalized_timedata, config.seq_len, config.pred_len, config.batch_size)
        
        item_losses = []
        best_loss = float('inf')
        no_improvement_count = 0
        
        model = Model(config)
        model.to(device)
        optimizer = get_optimizer(model, config)
        scheduler = get_scheduler(optimizer, config)
        criterion = NMAELoss()
        
        for epoch in range(1, config.epoch + 1):
            loss_list = []
            model.train()
            for batch_idx, (data_batch, target) in enumerate(train_dl):
                data_batch, target = data_batch.to(device), target.to(device)
                optimizer.zero_grad()
                output = model(data_batch)
                loss = criterion(output, target.unsqueeze(-1))
                loss.backward()
                optimizer.step()
                loss_list.append(loss.item())
            
            if scheduler is not None:
                if config.scheduler.lower() == 'reduce_on_plateau':
                    scheduler.step(np.mean(loss_list))
                else:
                    scheduler.step()
            
            avg_loss = np.mean(loss_list)
            item_losses.append(avg_loss)
            if (epoch % 5) == 0:
                if avg_loss + 0.001 < best_loss:
                    best_loss = avg_loss
                    no_improvement_count = 0
                else:
                    no_improvement_count += 1
                    if no_improvement_count >= 4:
                        break
        
        total_item_loss[item] = item_losses[-1]
        
        model_save_path = f"{model_dir}/segrnn_uni_{item}_{month}month_{windowsize}window_model.pth"
        torch.save(model.state_dict(), model_save_path)
        print(f"Model for {item} saved at {model_save_path}")
        
        product_predict = []
        for i in range(25):
            test_file = f"test/TEST_{i:02d}.csv"
            산지공판장_file = f"test/meta/TEST_산지공판장_{i:02d}.csv"
            전국도매_file = f"test/meta/TEST_전국도매_{i:02d}.csv"
            test_data = process_data(test_file, 산지공판장_file, 전국도매_file, item)
            
            test_price_df = test_data.iloc[:, [2]].T.reset_index().iloc[:, 1:]
            
            if windowsize == 3:
                test_price_df = test_price_df.iloc[:, -3:]
            
            timedf = reshape_data(test_price_df).T  
            
            normalized_testdata = scaler.transform(timedf)
            
            test_tensor = torch.tensor(normalized_testdata.flatten(), dtype=torch.float32).unsqueeze(0).unsqueeze(-1).to(device)
            
            model.eval()
            with torch.no_grad():
                prediction = model(test_tensor)
                prediction = prediction.cpu().numpy().squeeze()
                
                inverse_pred = inverse_normalize(prediction.reshape(-1, 1), scaler)
                product_predict.append(inverse_pred.flatten())
        
        product_predict = np.concatenate(product_predict)
        predicts[item] = product_predict
    
    lossavg = sum(total_item_loss.values()) / len(total_item_loss)
    return predicts, total_item_loss, lossavg


## Train and Predict 

In [63]:
segrnn_multi_9 = {} 
segrnn_multi_3 = {} 
segrnn_uni_9 = {}  
segrnn_uni_3 = {}  

### Multivariate

In [64]:

for seed in [333]:
    segrnn_multi_9[f'predicts_{seed}'] = {}
    for i in range(1, 12+ 1):
        segrnn_multi_9[f'predicts_{seed}'][f'predicts{i}'], total_item_loss, lossavg = train_predict_multi(config_multi_9, 품목_리스트, seed, i, 9)


for seed in [333]:
    segrnn_multi_3[f'predicts_{seed}'] = {}
    for i in range(1, 12+1):
        segrnn_multi_3[f'predicts_{seed}'][f'predicts{i}'], total_item_loss, lossavg = train_predict_multi(config_multi_3, 품목_리스트, seed, i, 3)


Model for 건고추 saved at ../trainedmodels/segrnn_multi_건고추_1month_9window_model.pth
Model for 사과 saved at ../trainedmodels/segrnn_multi_사과_1month_9window_model.pth
Model for 감자 saved at ../trainedmodels/segrnn_multi_감자_1month_9window_model.pth
Model for 배 saved at ../trainedmodels/segrnn_multi_배_1month_9window_model.pth
Model for 깐마늘(국산) saved at ../trainedmodels/segrnn_multi_깐마늘(국산)_1month_9window_model.pth
Model for 무 saved at ../trainedmodels/segrnn_multi_무_1month_9window_model.pth
Model for 상추 saved at ../trainedmodels/segrnn_multi_상추_1month_9window_model.pth
Model for 배추 saved at ../trainedmodels/segrnn_multi_배추_1month_9window_model.pth
Model for 양파 saved at ../trainedmodels/segrnn_multi_양파_1month_9window_model.pth
Model for 대파 saved at ../trainedmodels/segrnn_multi_대파_1month_9window_model.pth
Model for 건고추 saved at ../trainedmodels/segrnn_multi_건고추_2month_9window_model.pth
Model for 사과 saved at ../trainedmodels/segrnn_multi_사과_2month_9window_model.pth
Model for 감자 saved at ../train

### Univariate

In [65]:

for seed in [333]:
    segrnn_uni_9[f'predicts_{seed}'] = {}
    for i in range(1, 12+ 1):
        segrnn_uni_9[f'predicts_{seed}'][f'predicts{i}'], total_item_loss, lossavg = train_predict_uni(config_uni_9, 품목_리스트, seed, i, 9)



for seed in [333]:
    segrnn_uni_3[f'predicts_{seed}'] = {}
    for i in range(1, 12+1):
        segrnn_uni_3[f'predicts_{seed}'][f'predicts{i}'], total_item_loss, lossavg = train_predict_uni(config_uni_3, 품목_리스트, seed, i, 3)

Model for 건고추 saved at ../trainedmodels/segrnn_uni_건고추_1month_9window_model.pth
Model for 사과 saved at ../trainedmodels/segrnn_uni_사과_1month_9window_model.pth
Model for 감자 saved at ../trainedmodels/segrnn_uni_감자_1month_9window_model.pth
Model for 배 saved at ../trainedmodels/segrnn_uni_배_1month_9window_model.pth
Model for 깐마늘(국산) saved at ../trainedmodels/segrnn_uni_깐마늘(국산)_1month_9window_model.pth
Model for 무 saved at ../trainedmodels/segrnn_uni_무_1month_9window_model.pth
Model for 상추 saved at ../trainedmodels/segrnn_uni_상추_1month_9window_model.pth
Model for 배추 saved at ../trainedmodels/segrnn_uni_배추_1month_9window_model.pth
Model for 양파 saved at ../trainedmodels/segrnn_uni_양파_1month_9window_model.pth
Model for 대파 saved at ../trainedmodels/segrnn_uni_대파_1month_9window_model.pth
Model for 건고추 saved at ../trainedmodels/segrnn_uni_건고추_2month_9window_model.pth
Model for 사과 saved at ../trainedmodels/segrnn_uni_사과_2month_9window_model.pth
Model for 감자 saved at ../trainedmodels/segrnn_uni_감자_2

## Make Submission Files

### Multi_9window

In [66]:
sub_multi_9window= pd.read_csv('sample_submission.csv')

for item in sub_multi_9window.columns[1:]: 
    
    seed_predictions = []
    for seed in [333 ]:
        item_predictions = [segrnn_multi_9[f'predicts_{seed}'][f'predicts{i}'][item] for i in range(1, 12+1)]
        seed_predictions.append(np.mean(item_predictions, axis=0))  
    
    sub_multi_9window[item] = np.mean(seed_predictions, axis=0)
    

### Multi_3window

In [67]:
sub_multi_3window = pd.read_csv('sample_submission.csv')

for item in sub_multi_3window.columns[1:]: 
    
    seed_predictions = []
    for seed in [333]:
        item_predictions = [segrnn_multi_3[f'predicts_{seed}'][f'predicts{i}'][item] for i in range(1, 12+1)]
        seed_predictions.append(np.mean(item_predictions, axis=0))  
    
    sub_multi_3window[item] = np.mean(seed_predictions, axis=0)

sub_multi_3window


Unnamed: 0,시점,감자,건고추,깐마늘(국산),대파,무,배추,사과,상추,양파,배
0,TEST_00+1순,38234.060221,672739.984375,167936.826823,2132.888987,32576.179036,15335.041504,24103.981608,1020.629954,1523.339956,30089.199870
1,TEST_00+2순,39377.243164,672241.963542,168049.558594,2130.780518,35415.491862,15178.457601,24954.854655,823.255473,1542.729777,29555.851400
2,TEST_00+3순,39413.738281,670625.250000,167640.207031,2218.020253,32319.699544,14893.199300,25090.194499,694.345360,1554.030823,29308.413411
3,TEST_01+1순,40759.501953,646260.802083,167740.993490,1905.072194,11838.499837,5363.100464,22747.792969,957.762863,1502.694194,25142.069499
4,TEST_01+2순,41836.432292,645156.406250,168138.097656,1886.645325,10332.701986,5088.733561,23202.363607,942.059855,1505.733948,25428.511068
...,...,...,...,...,...,...,...,...,...,...,...
70,TEST_23+2순,44257.912109,634804.031250,168166.618490,1700.885193,10291.407796,4627.322795,22466.046549,894.653473,1453.736420,26092.292480
71,TEST_23+3순,44351.074219,633579.161458,167985.524740,1660.936056,10403.865397,4563.045349,22685.198893,910.869700,1460.429108,26290.305827
72,TEST_24+1순,46057.626628,541395.614583,174624.537760,1330.296173,14141.721842,8471.954468,28284.114258,1026.496887,529.736893,38501.723958
73,TEST_24+2순,42817.522135,541794.911458,175012.234375,1430.255971,13489.063477,8279.806722,28491.601562,1040.429850,509.635361,38725.397461


### Uni_9window

In [68]:
sub_uni_9window = pd.read_csv('sample_submission.csv')

for item in sub_uni_9window.columns[1:]: 
    
    seed_predictions = []
    for seed in [333]:
        item_predictions = [segrnn_uni_9[f'predicts_{seed}'][f'predicts{i}'][item] for i in range(1, 12+1)]
        seed_predictions.append(np.mean(item_predictions, axis=0))  
    
    sub_uni_9window[item] = np.mean(seed_predictions, axis=0)

sub_uni_9window


Unnamed: 0,시점,감자,건고추,깐마늘(국산),대파,무,배추,사과,상추,양파,배
0,TEST_00+1순,36584.941406,671514.1875,167749.906250,2017.400269,31353.345703,12179.695312,23241.583984,1062.735962,1441.678833,30177.880859
1,TEST_00+2순,35603.199219,669096.5000,167344.765625,2021.295410,32356.396484,10113.484375,23515.244141,1001.383301,1423.428711,29640.179688
2,TEST_00+3순,34865.648438,667351.0000,167314.531250,2106.803955,29355.380859,10573.295898,23663.392578,955.112793,1413.642456,28428.281250
3,TEST_01+1순,40617.136719,648558.9375,167429.734375,1765.878296,14675.241211,6081.662598,22000.158203,999.000427,1491.565308,24126.771484
4,TEST_01+2순,40832.332031,646982.9375,167344.828125,1757.073853,12402.210938,5635.400391,21697.746094,1018.908691,1496.969849,23244.898438
...,...,...,...,...,...,...,...,...,...,...,...
70,TEST_23+2순,42630.707031,635239.3125,167388.015625,1626.513550,12950.463867,4303.248535,22764.537109,933.257751,1426.719238,24898.845703
71,TEST_23+3순,41570.863281,634172.8750,167415.296875,1602.891235,15015.020508,4040.610596,22592.650391,950.769714,1423.328735,24865.572266
72,TEST_24+1순,48040.875000,543026.0000,174238.515625,1367.351074,13227.304688,8410.295898,28027.912109,1007.871826,467.466156,38387.859375
73,TEST_24+2순,42447.320312,542992.3125,174113.031250,1443.260376,13937.381836,8394.088867,27870.205078,982.339539,408.532593,38357.242188


### Uni_3window

In [69]:
sub_uni_3window = pd.read_csv('sample_submission.csv')

for item in sub_uni_3window.columns[1:]:  
    
    seed_predictions = []
    for seed in [333]:
        item_predictions = [segrnn_uni_3[f'predicts_{seed}'][f'predicts{i}'][item] for i in range(1, 12+1)]
        seed_predictions.append(np.mean(item_predictions, axis=0))  
    
    sub_uni_3window[item] = np.mean(seed_predictions, axis=0)

sub_uni_3window

Unnamed: 0,시점,감자,건고추,깐마늘(국산),대파,무,배추,사과,상추,양파,배
0,TEST_00+1순,37694.699219,673569.8750,167975.656250,2092.864990,33906.292969,15518.104492,23670.166016,1099.635864,1479.251343,30803.943359
1,TEST_00+2순,37613.207031,675492.6875,167786.109375,2095.135986,34314.886719,15320.172852,23548.654297,1014.438660,1471.903442,30860.876953
2,TEST_00+3순,37493.417969,669458.6875,167810.953125,2086.375977,33716.574219,15237.307617,23854.115234,947.491638,1471.202637,30796.000000
3,TEST_01+1순,40474.746094,649434.8750,167396.484375,1865.834106,12723.036133,5345.666504,22749.025391,1017.340637,1479.438354,25257.982422
4,TEST_01+2순,40548.281250,650071.9375,167386.265625,1865.888062,11777.989258,5272.335449,22610.865234,1007.695496,1476.425171,25268.369141
...,...,...,...,...,...,...,...,...,...,...,...
70,TEST_23+2순,42787.734375,639045.3750,167400.000000,1636.453613,9522.000977,4781.775391,22248.669922,881.064941,1427.832397,25666.224609
71,TEST_23+3순,42651.355469,637761.7500,167400.000000,1615.648560,9367.079102,4770.774414,22283.568359,871.595276,1424.452271,25652.914062
72,TEST_24+1순,48798.187500,543000.0000,174253.250000,1310.242310,13507.503906,8476.597656,28255.460938,1006.969788,509.602692,38418.972656
73,TEST_24+2순,46170.042969,543000.0000,174234.953125,1340.548950,13627.847656,8430.807617,28230.718750,1005.379456,499.116119,38403.843750


In [70]:
sub_multi_9window.to_csv('FIN_segrnn_multi_9_1013_seed333.csv',index=False)
sub_uni_9window.to_csv('FIN_segrnn_uni_9_1013_seed333.csv',index=False)
