# 04. Enhanced Split Training

## 개요
10개로 분할된 데이터를 Enhanced Gradient Descent 전략으로 순차 학습하는 노트북입니다.

## 주요 특징
- **Advanced Gradient Descent**: 전역 optimizer & scheduler 상태 유지
- **Catastrophic Forgetting 방지**: gradient clipping + 성능 감소 감지
- **메모리 효율성**: 각 split 처리 후 자동 메모리 관리
- **체크포인트 시스템**: 중간 저장 및 완전한 상태 복원

**주의**: 이 노트북을 실행하기 전에 **01**, **02**, **03** 노트북을 먼저 실행해주세요!

**시간 소요**: 약 30분~2시간 (GPU 성능에 따라)


## Enhanced Split Training Function


In [None]:
def train_on_split(model, train_df, feature_cols, seq_col, target_col, optimizer, scheduler, 
                   split_idx, batch_size=4096, epochs=3, device="cuda", prev_val_loss=None):
    """
    Enhanced Gradient Descent를 적용한 단일 split 학습 함수
    
    Args:
        model: 학습할 모델
        train_df: 학습 데이터 
        feature_cols: 피처 컬럼들
        seq_col: 시퀀스 컬럼
        target_col: 타겟 컬럼
        optimizer: 전역 optimizer (상태 유지)
        scheduler: 전역 scheduler (상태 유지)
        split_idx: 현재 split 번호
        prev_val_loss: 이전 split의 validation loss (catastrophic forgetting 감지용)
    """
    print(f"Training on Split {split_idx} with {len(train_df)} samples")
    print(f"Current learning rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Train/Validation split
    tr_df, va_df = train_test_split(train_df, test_size=0.2, random_state=42, shuffle=True)
    print(f"Train: {len(tr_df)}, Validation: {len(va_df)}")

    # Dataset & DataLoader
    train_dataset = ClickDataset(tr_df, feature_cols, seq_col, target_col, has_target=True)
    val_dataset = ClickDataset(va_df, feature_cols, seq_col, target_col, has_target=True)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_train)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn_train)

    # Loss function
    criterion = nn.BCEWithLogitsLoss()
    
    # Early stopping parameters
    patience = 2
    patience_counter = 0
    best_val_loss = float('inf')
    
    # Training history for this split
    split_history = {'train_loss': [], 'val_loss': [], 'lr': []}

    # Training Loop with Enhanced Gradient Descent
    for epoch in range(1, epochs + 1):
        # Training phase
        model.train()
        train_loss = 0.0
        train_batches = 0
        
        train_progress = tqdm(train_loader, desc=f"Split {split_idx} Epoch {epoch}")
        for xs, seqs, seq_lens, ys in train_progress:
            xs, seqs, seq_lens, ys = xs.to(device), seqs.to(device), seq_lens.to(device), ys.to(device)
            
            optimizer.zero_grad()
            logits = model(xs, seqs, seq_lens)
            loss = criterion(logits, ys)
            loss.backward()
            
            # Gradient Clipping (catastrophic forgetting 방지)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG['GRADIENT_CLIP_NORM'])
            
            optimizer.step()
            
            train_loss += loss.item()
            train_batches += 1
            
            # Progress bar 업데이트
            train_progress.set_postfix({
                'loss': f"{loss.item():.4f}",
                'lr': f"{optimizer.param_groups[0]['lr']:.6f}"
            })

        avg_train_loss = train_loss / train_batches

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_batches = 0
        
        with torch.no_grad():
            for xs, seqs, seq_lens, ys in tqdm(val_loader, desc=f"Validation {epoch}"):
                xs, seqs, seq_lens, ys = xs.to(device), seqs.to(device), seq_lens.to(device), ys.to(device)
                
                logits = model(xs, seqs, seq_lens)
                loss = criterion(logits, ys)
                
                val_loss += loss.item()
                val_batches += 1

        avg_val_loss = val_loss / val_batches
        current_lr = optimizer.param_groups[0]['lr']
        
        print(f"[Split {split_idx} Epoch {epoch}] Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f}")
        print(f"LR: {current_lr:.6f} | Memory: {get_memory_usage()}")
        
        # History 저장
        split_history['train_loss'].append(avg_train_loss)
        split_history['val_loss'].append(avg_val_loss)
        split_history['lr'].append(current_lr)
        
        # Early stopping & adaptive LR
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            
        if patience_counter >= patience:
            # LR 감소
            for param_group in optimizer.param_groups:
                param_group['lr'] *= CFG['LR_REDUCTION_FACTOR']
            print(f"Learning rate reduced to: {optimizer.param_groups[0]['lr']:.6f}")
            patience_counter = 0
        
        # Scheduler step
        scheduler.step()

    # Catastrophic forgetting 감지
    final_val_loss = avg_val_loss
    catastrophic_forgetting = False
    
    if prev_val_loss is not None and final_val_loss > prev_val_loss * CFG['CATASTROPHIC_THRESHOLD']:
        print(f"WARNING: Possible catastrophic forgetting detected!")
        print(f"Previous val loss: {prev_val_loss:.4f} -> Current: {final_val_loss:.4f}")
        
        # LR 감소로 대응
        for param_group in optimizer.param_groups:
            param_group['lr'] *= CFG['LR_REDUCTION_FACTOR']
        print(f"Learning rate reduced to: {optimizer.param_groups[0]['lr']:.6f}")
        catastrophic_forgetting = True

    return model, final_val_loss, split_history, catastrophic_forgetting


## Main Training Loop Initialization


In [None]:
# Split 파일 목록 가져오기
print("="*80)
print("STARTING ENHANCED SPLIT DATA TRAINING WITH ADVANCED GRADIENT DESCENT")
print("="*80)

split_files = sorted(glob.glob(os.path.join(CFG['SPLIT_DATA_PATH'], "part_*.parquet")))
print(f"Found {len(split_files)} split files:")
for f in split_files:
    print(f"  - {os.path.basename(f)}")

if len(split_files) == 0:
    print("No split files found!")
    raise FileNotFoundError("No split data files found in the specified directory")

# 첫 번째 파일로 feature 정보 확인
print("\n" + "="*50)
print("Analyzing first split for feature info...")
first_df = load_and_downsample_data(split_files[0], CFG['DOWNSAMPLE_RATIO'])
feature_cols = get_feature_columns(first_df)
seq_col = "seq"
target_col = "clicked"

print(f"Number of features: {len(feature_cols)}")
print(f"Sequence column: {seq_col}")
print(f"Target column: {target_col}")

# 메모리 정리
del first_df
clear_memory()

# 모델 초기화
print("\n" + "="*50)
print("Initializing model & optimization strategy...")
model_config = {
    'lstm_hidden': 64,
    'hidden_units': [256, 128],
    'dropout': 0.2
}

model = TabularSeqModel(
    d_features=len(feature_cols),
    lstm_hidden=model_config['lstm_hidden'],
    hidden_units=model_config['hidden_units'],
    dropout=model_config['dropout']
).to(device)

print(f"Model initialized with {sum(p.numel() for p in model.parameters())} parameters")


In [None]:
# Enhanced Gradient Descent 전략 설정
print("Setting up Enhanced Gradient Descent strategy...")

# 전역 Optimizer (모든 split에서 상태 유지)
optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=CFG['LEARNING_RATE'], 
    weight_decay=CFG['WEIGHT_DECAY']
)

# 전역 Scheduler (모든 split에서 상태 유지)  
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, 
    T_0=10,  # 첫 번째 restart까지의 epoch 수
    T_mult=2,  # restart 간격 증가 배수
    eta_min=1e-6  # 최소 학습률
)

print(f"Optimizer: Adam with LR={CFG['LEARNING_RATE']:.0e}, weight_decay={CFG['WEIGHT_DECAY']}")
print(f"Scheduler: CosineAnnealingWarmRestarts")
print(f"Initial memory: {get_memory_usage()}")

# 학습 이력 저장을 위한 변수들
global_training_history = {
    'splits': [],
    'train_losses': [], 
    'val_losses': [],
    'learning_rates': [],
    'catastrophic_forgetting_events': 0
}

prev_val_loss = None
start_time = datetime.now()
print(f"Training started at: {start_time}")


## Main Training Loop Execution


In [None]:
# 각 split에 대해 순차적으로 학습
for i, split_file in enumerate(split_files, 1):
    print(f"\n{'='*80}")
    print(f"Processing Split {i}/{len(split_files)}: {os.path.basename(split_file)}")
    print(f"{'='*80}")
    
    try:
        # 데이터 로드 및 전처리
        split_df = load_and_downsample_data(split_file, CFG['DOWNSAMPLE_RATIO'])
        
        # Enhanced Gradient Descent로 학습 수행
        model, current_val_loss, split_history, cf_detected = train_on_split(
            model=model,
            train_df=split_df,
            feature_cols=feature_cols,
            seq_col=seq_col,
            target_col=target_col,
            optimizer=optimizer,
            scheduler=scheduler,
            split_idx=i,
            batch_size=CFG['BATCH_SIZE'],
            epochs=CFG['EPOCHS_PER_SPLIT'],
            device=device,
            prev_val_loss=prev_val_loss
        )
        
        # 학습 이력 업데이트
        global_training_history['splits'].append(i)
        global_training_history['train_losses'].extend(split_history['train_loss'])
        global_training_history['val_losses'].extend(split_history['val_loss'])
        global_training_history['learning_rates'].extend(split_history['lr'])
        
        if cf_detected:
            global_training_history['catastrophic_forgetting_events'] += 1
        
        # 중간 체크포인트 저장 (N번째 split마다)
        if i % CFG['SAVE_CHECKPOINT_EVERY'] == 0:
            checkpoint_path = os.path.join(CFG['MODELS_PATH'], f"{CFG['MODEL_NAME']}_enhanced_checkpoint_split_{i:02d}.pth")
            save_model(
                model=model, 
                model_path=checkpoint_path, 
                model_config=model_config,
                optimizer=optimizer,
                scheduler=scheduler,
                feature_cols=feature_cols,
                training_history=global_training_history
            )
            print(f"Checkpoint saved: {os.path.basename(checkpoint_path)}")
        
        # 이전 validation loss 업데이트
        prev_val_loss = current_val_loss
        
        print(f"Split {i}/{len(split_files)} completed successfully!")
        print(f"Final validation loss: {current_val_loss:.4f}")
        
    except Exception as e:
        print(f"ERROR processing split {i}: {str(e)}")
        print("Continuing with next split...")
        continue
        
    finally:
        # 메모리 정리
        if 'split_df' in locals():
            del split_df
        clear_memory()
        print(f"Memory after cleanup: {get_memory_usage()}")

print(f"\n{'='*80}")
print("ALL SPLITS COMPLETED!")
print(f"{'='*80}")


In [None]:
# 최종 모델 저장
print("Saving final enhanced model...")
final_model_path = os.path.join(CFG['MODELS_PATH'], f"{CFG['MODEL_NAME']}_enhanced_final.pth")

save_model(
    model=model,
    model_path=final_model_path,
    model_config=model_config,
    optimizer=optimizer,
    scheduler=scheduler,
    feature_cols=feature_cols,
    training_history=global_training_history
)

# 학습 완료 요약
end_time = datetime.now()
total_time = end_time - start_time

print(f"\n{'='*80}")
print("ENHANCED SPLIT TRAINING COMPLETED SUCCESSFULLY!")
print(f"{'='*80}")
print(f"Total training time: {total_time}")
print(f"Number of splits processed: {len(split_files)}")
print(f"Final model saved to: {final_model_path}")
print(f"Catastrophic forgetting events: {global_training_history['catastrophic_forgetting_events']}")
print(f"Final learning rate: {optimizer.param_groups[0]['lr']:.6f}")
print(f"Final validation loss: {prev_val_loss:.4f}")
print(f"Final memory usage: {get_memory_usage()}")

# 학습 완료 상태를 변수로 저장 (05번 노트북에서 사용)
training_completed = True
trained_model = model
trained_feature_cols = feature_cols

print(f"\nTraining variables ready for inference:")
print(f"- training_completed: {training_completed}")
print(f"- trained_model: Available")
print(f"- trained_feature_cols: {len(trained_feature_cols)} features")
print(f"\nYou can now run 05_inference.ipynb for predictions!")
