In [14]:
import torch
import torch.nn as nn


In [15]:
def lstm_cell_forward(input_current, hidden_state_previous, cell_state_previous, 
                      weight_ih_f, weight_hh_f, bias_f,
                      weight_ih_i, weight_hh_i, bias_i,
                      weight_ih_c, weight_hh_c, bias_c,
                      weight_ih_o, weight_hh_o, bias_o):
    """
    Forward pass của LSTM cell - Version đơn giản nhất
    
    Args:
        input_current: (batch_size, input_size)
        hidden_state_previous: (batch_size, hidden_size)
        cell_state_previous: (batch_size, hidden_size)
        weight_ih_f, weight_hh_f, bias_f: Parameters cho Forget Gate
        weight_ih_i, weight_hh_i, bias_i: Parameters cho Input Gate
        weight_ih_c, weight_hh_c, bias_c: Parameters cho Candidate Gate
        weight_ih_o, weight_hh_o, bias_o: Parameters cho Output Gate
    
    Returns:
        hidden_state_new: (batch_size, hidden_size)
        cell_state_new: (batch_size, hidden_size)
    """
    
    # 1. FORGET GATE: Quyết định quên thông tin nào từ cell state cũ
    # forget_gate = sigmoid(W_f * x + U_f * h_prev + b_f)
    forget_gate = torch.sigmoid(
        torch.mm(input_current, weight_ih_f.t()) + 
        torch.mm(hidden_state_previous, weight_hh_f.t()) + 
        bias_f
    )
    
    # 2. INPUT GATE: Quyết định lưu thông tin mới nào
    # input_gate = sigmoid(W_i * x + U_i * h_prev + b_i)
    input_gate = torch.sigmoid(
        torch.mm(input_current, weight_ih_i.t()) + 
        torch.mm(hidden_state_previous, weight_hh_i.t()) + 
        bias_i
    )
    
    # 3. CANDIDATE GATE: Tạo giá trị candidate cho cell state mới
    # candidate_gate = tanh(W_c * x + U_c * h_prev + b_c)
    candidate_gate = torch.tanh(
        torch.mm(input_current, weight_ih_c.t()) + 
        torch.mm(hidden_state_previous, weight_hh_c.t()) + 
        bias_c
    )
    
    # 4. OUTPUT GATE: Quyết định output phần nào của cell state
    # output_gate = sigmoid(W_o * x + U_o * h_prev + b_o)
    output_gate = torch.sigmoid(
        torch.mm(input_current, weight_ih_o.t()) + 
        torch.mm(hidden_state_previous, weight_hh_o.t()) + 
        bias_o
    )
    
    # 5. CẬP NHẬT CELL STATE
    # cell_state_new = forget_gate * cell_state_prev + input_gate * candidate_gate
    cell_state_new = forget_gate * cell_state_previous + input_gate * candidate_gate
    
    # 6. CẬP NHẬT HIDDEN STATE
    # hidden_state_new = output_gate * tanh(cell_state_new)
    hidden_state_new = output_gate * torch.tanh(cell_state_new)
    
    return hidden_state_new, cell_state_new


In [16]:
def init_lstm_parameters(input_size, hidden_size):
    """
    Khởi tạo parameters cho LSTM - Đơn giản nhất
    """
    # Forget Gate parameters
    weight_ih_f = nn.Parameter(torch.randn(hidden_size, input_size))
    weight_hh_f = nn.Parameter(torch.randn(hidden_size, hidden_size))
    bias_f = nn.Parameter(torch.ones(hidden_size))  # Bias = 1 để giúp gradient flow
    
    # Input Gate parameters
    weight_ih_i = nn.Parameter(torch.randn(hidden_size, input_size))
    weight_hh_i = nn.Parameter(torch.randn(hidden_size, hidden_size))
    bias_i = nn.Parameter(torch.zeros(hidden_size))
    
    # Candidate Gate parameters
    weight_ih_c = nn.Parameter(torch.randn(hidden_size, input_size))
    weight_hh_c = nn.Parameter(torch.randn(hidden_size, hidden_size))
    bias_c = nn.Parameter(torch.zeros(hidden_size))
    
    # Output Gate parameters
    weight_ih_o = nn.Parameter(torch.randn(hidden_size, input_size))
    weight_hh_o = nn.Parameter(torch.randn(hidden_size, hidden_size))
    bias_o = nn.Parameter(torch.zeros(hidden_size))
    
    return {
        'weight_ih_f': weight_ih_f, 'weight_hh_f': weight_hh_f, 'bias_f': bias_f,
        'weight_ih_i': weight_ih_i, 'weight_hh_i': weight_hh_i, 'bias_i': bias_i,
        'weight_ih_c': weight_ih_c, 'weight_hh_c': weight_hh_c, 'bias_c': bias_c,
        'weight_ih_o': weight_ih_o, 'weight_hh_o': weight_hh_o, 'bias_o': bias_o
    }

def lstm_forward(input_sequence, params, batch_first=True):
    """
    Forward pass qua LSTM layer cho sequence
    
    LƯU Ý: Cùng một bộ weights trong params được dùng cho TẤT CẢ các time steps
    Đây là weight sharing - đặc điểm của Recurrent Networks
    
    Args:
        input_sequence: (batch_size, seq_len, input_size) nếu batch_first=True
                       (seq_len, batch_size, input_size) nếu batch_first=False
        params: Dictionary chứa tất cả parameters (được share qua các time steps)
        batch_first: True nếu input là (batch, seq, features)
    
    Returns:
        output: (batch_size, seq_len, hidden_size) nếu batch_first=True
                (seq_len, batch_size, hidden_size) nếu batch_first=False
        final_hidden_state: (batch_size, hidden_size)
        final_cell_state: (batch_size, hidden_size)
    """
    if batch_first:
        batch_size, seq_len, input_size = input_sequence.shape
    else:
        seq_len, batch_size, input_size = input_sequence.shape
    
    hidden_size = params['weight_ih_f'].shape[0]
    
    # Khởi tạo hidden và cell states
    h = torch.zeros(batch_size, hidden_size, device=input_sequence.device, dtype=input_sequence.dtype)
    c = torch.zeros(batch_size, hidden_size, device=input_sequence.device, dtype=input_sequence.dtype)
    
    outputs = []
    
    # Xử lý từng time step - DÙNG CÙNG BỘ WEIGHTS
    for t in range(seq_len):
        # Lấy input tại time step t
        if batch_first:
            x_t = input_sequence[:, t, :]  # (batch_size, input_size)
        else:
            x_t = input_sequence[t]  # (batch_size, input_size)
        
        # Forward qua LSTM cell - DÙNG CÙNG BỘ WEIGHTS cho mọi time step
        h, c = lstm_cell_forward(
            x_t, h, c,
            params['weight_ih_f'], params['weight_hh_f'], params['bias_f'],
            params['weight_ih_i'], params['weight_hh_i'], params['bias_i'],
            params['weight_ih_c'], params['weight_hh_c'], params['bias_c'],
            params['weight_ih_o'], params['weight_hh_o'], params['bias_o']
        )
        
        # Lưu output
        outputs.append(h)
    
    # Stack outputs
    if batch_first:
        output = torch.stack(outputs, dim=1)  # (batch_size, seq_len, hidden_size)
    else:
        output = torch.stack(outputs, dim=0)  # (seq_len, batch_size, hidden_size)
    
    return output, h, c


## 3. Demo Weight Sharing qua các Time Steps


In [17]:
# Demo: Cùng một bộ weights được dùng cho TẤT CẢ các time steps
input_size = 2
hidden_size = 3
batch_size = 1
seq_len = 5

# Khởi tạo parameters (CHỈ TẠO 1 LẦN)
params = init_lstm_parameters(input_size, hidden_size)

# Tạo input sequence
input_sequence = torch.randn(batch_size, seq_len, input_size)

# Khởi tạo states
h = torch.zeros(batch_size, hidden_size)
c = torch.zeros(batch_size, hidden_size)


print(f"\nInput sequence shape: {input_sequence.shape}")
print(f"Number of time steps: {seq_len}")
print(f"\nWeights được tạo 1 lần và dùng lại cho {seq_len} time steps:")

# Lưu id của weights để chứng minh cùng một object
weight_id = id(params['weight_ih_f'])
print(f"\nID của weight_ih_f: {weight_id}")

# Xử lý từng time step
for t in range(seq_len):
    x_t = input_sequence[:, t, :]
    
    # Forward qua LSTM cell - DÙNG CÙNG BỘ WEIGHTS
    h, c = lstm_cell_forward(
        x_t, h, c,
        params['weight_ih_f'], params['weight_hh_f'], params['bias_f'],
        params['weight_ih_i'], params['weight_hh_i'], params['bias_i'],
        params['weight_ih_c'], params['weight_hh_c'], params['bias_c'],
        params['weight_ih_o'], params['weight_hh_o'], params['bias_o']
    )
    
    # Kiểm tra xem có cùng object không
    current_weight_id = id(params['weight_ih_f'])
    print(f"Time step {t}: Input shape {x_t.shape}, Hidden shape {h.shape}, Weight ID: {current_weight_id} (same: {current_weight_id == weight_id})")





Input sequence shape: torch.Size([1, 5, 2])
Number of time steps: 5

Weights được tạo 1 lần và dùng lại cho 5 time steps:

ID của weight_ih_f: 1604014177920
Time step 0: Input shape torch.Size([1, 2]), Hidden shape torch.Size([1, 3]), Weight ID: 1604014177920 (same: True)
Time step 1: Input shape torch.Size([1, 2]), Hidden shape torch.Size([1, 3]), Weight ID: 1604014177920 (same: True)
Time step 2: Input shape torch.Size([1, 2]), Hidden shape torch.Size([1, 3]), Weight ID: 1604014177920 (same: True)
Time step 3: Input shape torch.Size([1, 2]), Hidden shape torch.Size([1, 3]), Weight ID: 1604014177920 (same: True)
Time step 4: Input shape torch.Size([1, 2]), Hidden shape torch.Size([1, 3]), Weight ID: 1604014177920 (same: True)


## 4. Test LSTM Cell


In [18]:
input_size = 4
hidden_size = 3
batch_size = 2

# Khởi tạo parameters
params = init_lstm_parameters(input_size, hidden_size)

# Tạo input và initial states
input_current = torch.randn(batch_size, input_size)
h_prev = torch.zeros(batch_size, hidden_size)
c_prev = torch.zeros(batch_size, hidden_size)

print("Input shape:", input_current.shape)
print("Initial hidden state shape:", h_prev.shape)
print("Initial cell state shape:", c_prev.shape)

# Forward pass
h_new, c_new = lstm_cell_forward(
    input_current, h_prev, c_prev,
    params['weight_ih_f'], params['weight_hh_f'], params['bias_f'],
    params['weight_ih_i'], params['weight_hh_i'], params['bias_i'],
    params['weight_ih_c'], params['weight_hh_c'], params['bias_c'],
    params['weight_ih_o'], params['weight_hh_o'], params['bias_o']
)

print("\nAfter forward pass:")
print("New hidden state shape:", h_new.shape)
print("New cell state shape:", c_new.shape)
print("\nNew hidden state:")
print(h_new)
print("\nNew cell state:")
print(c_new)


Input shape: torch.Size([2, 4])
Initial hidden state shape: torch.Size([2, 3])
Initial cell state shape: torch.Size([2, 3])

After forward pass:
New hidden state shape: torch.Size([2, 3])
New cell state shape: torch.Size([2, 3])

New hidden state:
tensor([[ 4.4023e-02,  1.1454e-01, -2.6374e-03],
        [-2.6462e-02,  2.8649e-01, -2.7646e-04]], grad_fn=<MulBackward0>)

New cell state:
tensor([[ 0.1464,  0.1384, -0.0578],
        [-0.3858,  0.3016, -0.1774]], grad_fn=<AddBackward0>)


In [19]:
input_size = 4
hidden_size = 3
batch_size = 2
seq_len = 5

# Khởi tạo parameters
params = init_lstm_parameters(input_size, hidden_size)

# Tạo input sequence
input_sequence = torch.randn(batch_size, seq_len, input_size)

print("Input shape:", input_sequence.shape)

# Forward pass
output, final_h, final_c = lstm_forward(input_sequence, params, batch_first=True)

print("\nOutput shape:", output.shape)
print("Final hidden state shape:", final_h.shape)
print("Final cell state shape:", final_c.shape)
print("\nOutput:")
print(output)


Input shape: torch.Size([2, 5, 4])

Output shape: torch.Size([2, 5, 3])
Final hidden state shape: torch.Size([2, 3])
Final cell state shape: torch.Size([2, 3])

Output:
tensor([[[-0.0326, -0.0079, -0.4383],
         [-0.1450,  0.6577,  0.3170],
         [-0.0528,  0.0357,  0.0065],
         [-0.0656,  0.3367,  0.3931],
         [-0.2881,  0.0119,  0.0511]],

        [[ 0.1387, -0.0525, -0.0023],
         [-0.0406, -0.0527,  0.3837],
         [-0.0940,  0.5924,  0.2164],
         [ 0.1135,  0.0959, -0.0280],
         [ 0.4703,  0.0048, -0.0711]]], grad_fn=<StackBackward0>)


## 7. So sánh với PyTorch LSTM


In [20]:
input_size = 4
hidden_size = 3
batch_size = 2
seq_len = 5

# Our implementation
our_params = init_lstm_parameters(input_size, hidden_size)

# PyTorch's implementation
pytorch_lstm = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True)

# Same input
input_sequence = torch.randn(batch_size, seq_len, input_size)

# Forward pass - Our LSTM
our_output, our_h, our_c = lstm_forward(input_sequence, our_params, batch_first=True)

# Forward pass - PyTorch LSTM
pytorch_output, (pytorch_h, pytorch_c) = pytorch_lstm(input_sequence)

print("Our LSTM output shape:", our_output.shape)
print("PyTorch LSTM output shape:", pytorch_output.shape)
print("\nOutput shapes match:", our_output.shape == pytorch_output.shape)
print("Hidden state shapes match:", our_h.shape == pytorch_h.squeeze(0).shape)
print("Cell state shapes match:", our_c.shape == pytorch_c.squeeze(0).shape)


Our LSTM output shape: torch.Size([2, 5, 3])
PyTorch LSTM output shape: torch.Size([2, 5, 3])

Output shapes match: True
Hidden state shapes match: True
Cell state shapes match: True
