In [None]:
# 지식 증류 함수
def distillation_loss(student_output, teacher_output, true_labels, alpha):
    loss_fn = nn.MSELoss()
    distillation_loss = loss_fn(student_output, teacher_output)
    student_loss = loss_fn(student_output, true_labels)
    return alpha * distillation_loss + (1 - alpha) * student_loss

# 학생 모델 학습
class DistilledLSTMModel:
    def __init__(self, teacher_model, input_size, hidden_size, num_layers, output_size, learning_rate, gradient_threshold, epoch, alpha):
        self.teacher_model = teacher_model
        self.model = _LSTMModel(input_size, hidden_size, num_layers, output_size)
        self.criterion = nn.MSELoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=30, gamma=0.1)
        self.epoch = epoch
        self.gradient_threshold = gradient_threshold
        self.alpha = alpha

    def train(self, X_train, y_train):
        print('Distilled LSTM Training')
        X_train_tensor = torch.tensor(X_train, dtype=torch.float32).unsqueeze(-1)
        y_train_tensor = torch.tensor(y_train, dtype=torch.float32).unsqueeze(-1)
        
        self.teacher_model.eval()
        self.model.train()
        for epoch in tqdm(range(self.epoch)):
            student_outputs = self.model(X_train_tensor)
            with torch.no_grad():
                teacher_outputs = self.teacher_model(X_train_tensor)
            
            loss = distillation_loss(student_outputs, teacher_outputs, y_train_tensor, self.alpha)
            
            self.optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_threshold)
            
            self.optimizer.step()
            
            if (epoch+1) % 10 == 0:
                print(f'Epoch [{epoch+1}/{self.epoch}], Loss: {loss.item():.4f}')
            
            # self.scheduler.step()
        
        print('Distilled LSTM Training Done')

    def predict(self, X):
        X_tensor = torch.tensor(X, dtype=torch.float32).unsqueeze(-1)
        self.model.eval()
        with torch.no_grad():
            y_pred = self.model(X_tensor).detach().numpy()
        return y_pred

    def save_model(self, path):
        torch.save(self.model.state_dict(), f'{path}.pth')
        print("학생 모델 상태 저장 완료")