# 03. Model Definition

## 개요
LSTM + MLP 기반 CTR 예측 모델과 관련 함수들을 정의하는 노트북입니다.

## 주요 구성
- TabularSeqModel 클래스 정의
- 모델 저장/로드 함수들
- 모델 평가 함수

**주의**: 이 노트북을 실행하기 전에 **01_setup_and_config.ipynb**와 **02_data_processing.ipynb**를 먼저 실행해주세요!


## Model Architecture


In [None]:
class TabularSeqModel(nn.Module):
    def __init__(self, d_features, lstm_hidden=32, hidden_units=[1024, 512, 256, 128], dropout=0.2):
        super().__init__()
        # 모든 비-시퀀스 피처에 BN
        self.bn_x = nn.BatchNorm1d(d_features)
        # seq: 숫자 시퀀스 → LSTM
        self.lstm = nn.LSTM(input_size=1, hidden_size=lstm_hidden, batch_first=True)

        # 최종 MLP
        input_dim = d_features + lstm_hidden
        layers = []
        for h in hidden_units:
            layers += [nn.Linear(input_dim, h), nn.ReLU(), nn.Dropout(dropout)]
            input_dim = h
        layers += [nn.Linear(input_dim, 1)]
        self.mlp = nn.Sequential(*layers)

    def forward(self, x_feats, x_seq, seq_lengths):
        # 비-시퀀스 피처
        x = self.bn_x(x_feats)

        # 시퀀스 → LSTM (pack)
        x_seq = x_seq.unsqueeze(-1)  # (B, L, 1)
        packed = nn.utils.rnn.pack_padded_sequence(
            x_seq, seq_lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        _, (h_n, _) = self.lstm(packed)
        h = h_n[-1]                  # (B, lstm_hidden)

        z = torch.cat([x, h], dim=1)
        return self.mlp(z).squeeze(1)  # logits


## Model Save & Load Functions


In [None]:
def save_model(model, model_path, model_config=None, optimizer=None, scheduler=None, feature_cols=None, training_history=None):
    """
    확장된 모델 저장 함수 - Enhanced Gradient Descent를 위한 완전한 체크포인트
    """
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'model_config': model_config,
        'timestamp': datetime.now().isoformat()
    }
    
    # Optimizer state 저장 (연속 학습을 위해)
    if optimizer is not None:
        checkpoint['optimizer_state_dict'] = optimizer.state_dict()
    
    # Scheduler state 저장
    if scheduler is not None:
        checkpoint['scheduler_state_dict'] = scheduler.state_dict()
        
    # Feature columns 저장 (추론 시 사용)
    if feature_cols is not None:
        checkpoint['feature_cols'] = feature_cols
        
    # Training history 저장
    if training_history is not None:
        checkpoint['training_history'] = training_history
    
    torch.save(checkpoint, model_path)
    print(f"Enhanced checkpoint saved to: {model_path}")

def load_model_for_inference(model_path, d_features=None, device='cpu'):
    """
    추론용 모델 로드 함수 - feature_cols 자동 추출 지원
    """
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model file not found: {model_path}")
        
    checkpoint = torch.load(model_path, map_location=device)
    
    # Feature columns가 체크포인트에 저장되어 있으면 사용
    if 'feature_cols' in checkpoint:
        if d_features is None:
            d_features = len(checkpoint['feature_cols'])
        print(f"Features from checkpoint: {d_features}")
    
    # 모델 설정 로드
    if 'model_config' in checkpoint and checkpoint['model_config']:
        config = checkpoint['model_config']
        model = TabularSeqModel(
            d_features=d_features,
            lstm_hidden=config.get('lstm_hidden', 64),
            hidden_units=config.get('hidden_units', [256, 128]),
            dropout=config.get('dropout', 0.2)
        )
        print(f"Model config: LSTM={config.get('lstm_hidden', 64)}, Hidden={config.get('hidden_units', [256, 128])}")
    else:
        model = TabularSeqModel(
            d_features=d_features,
            lstm_hidden=64,
            hidden_units=[256, 128],
            dropout=0.2
        )
    
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    
    print(f"Model loaded from: {model_path}")
    if 'timestamp' in checkpoint:
        print(f"Model timestamp: {checkpoint['timestamp']}")
    
    return model, checkpoint.get('feature_cols', None)

def evaluate_model(model, data_loader, criterion, device):
    """모델 평가 함수"""
    model.eval()
    total_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        for xs, seqs, lens, ys in data_loader:
            xs, seqs, lens, ys = xs.to(device), seqs.to(device), lens.to(device), ys.to(device)
            
            outputs = model(xs, seqs, lens)
            loss = criterion(outputs, ys)
            
            total_loss += loss.item()
            num_batches += 1
    
    avg_loss = total_loss / num_batches
    return avg_loss

print("Model definition functions loaded successfully!")
