In [1]:
import torch
import torch.nn as nn


## 1. GRU Cell - Version đơn giản nhất


In [2]:
def gru_cell_forward(input_current, hidden_state_previous,
                     weight_ih_r, weight_hh_r, bias_r,
                     weight_ih_z, weight_hh_z, bias_z,
                     weight_ih_h, weight_hh_h, bias_h):
    """
    Forward pass của GRU cell - Version đơn giản nhất
    
    Args:
        input_current: (batch_size, input_size)
        hidden_state_previous: (batch_size, hidden_size)
        weight_ih_r, weight_hh_r, bias_r: Parameters cho Reset Gate
        weight_ih_z, weight_hh_z, bias_z: Parameters cho Update Gate
        weight_ih_h, weight_hh_h, bias_h: Parameters cho Candidate Hidden State
    
    Returns:
        hidden_state_new: (batch_size, hidden_size)
    """
    
    # 1. RESET GATE: Quyết định reset thông tin nào từ hidden state trước đó
    # reset_gate = sigmoid(W_r * x + U_r * h_prev + b_r)
    reset_gate = torch.sigmoid(
        torch.mm(input_current, weight_ih_r.t()) + 
        torch.mm(hidden_state_previous, weight_hh_r.t()) + 
        bias_r
    )
    
    # 2. UPDATE GATE: Quyết định update thông tin mới vào hidden state
    # update_gate = sigmoid(W_z * x + U_z * h_prev + b_z)
    update_gate = torch.sigmoid(
        torch.mm(input_current, weight_ih_z.t()) + 
        torch.mm(hidden_state_previous, weight_hh_z.t()) + 
        bias_z
    )
    
    # 3. CANDIDATE HIDDEN STATE: Tạo giá trị candidate cho hidden state mới
    # candidate_hidden = tanh(W_h * x + U_h * (reset_gate * h_prev) + b_h)
    candidate_hidden = torch.tanh(
        torch.mm(input_current, weight_ih_h.t()) + 
        torch.mm(reset_gate * hidden_state_previous, weight_hh_h.t()) + 
        bias_h
    )
    
    # 4. CẬP NHẬT HIDDEN STATE
    # hidden_state_new = (1 - update_gate) * h_prev + update_gate * candidate_hidden
    hidden_state_new = (1 - update_gate) * hidden_state_previous + update_gate * candidate_hidden
    
    return hidden_state_new


## 2. Khởi tạo Parameters (Đơn giản - không dùng Xavier)


In [3]:
def init_gru_parameters(input_size, hidden_size):
    """
    Khởi tạo parameters cho GRU - Đơn giản nhất
    """
    # Reset Gate parameters
    weight_ih_r = nn.Parameter(torch.randn(hidden_size, input_size))
    weight_hh_r = nn.Parameter(torch.randn(hidden_size, hidden_size))
    bias_r = nn.Parameter(torch.zeros(hidden_size))
    
    # Update Gate parameters
    weight_ih_z = nn.Parameter(torch.randn(hidden_size, input_size))
    weight_hh_z = nn.Parameter(torch.randn(hidden_size, hidden_size))
    bias_z = nn.Parameter(torch.zeros(hidden_size))
    
    # Candidate Hidden State parameters
    weight_ih_h = nn.Parameter(torch.randn(hidden_size, input_size))
    weight_hh_h = nn.Parameter(torch.randn(hidden_size, hidden_size))
    bias_h = nn.Parameter(torch.zeros(hidden_size))
    
    return {
        'weight_ih_r': weight_ih_r, 'weight_hh_r': weight_hh_r, 'bias_r': bias_r,
        'weight_ih_z': weight_ih_z, 'weight_hh_z': weight_hh_z, 'bias_z': bias_z,
        'weight_ih_h': weight_ih_h, 'weight_hh_h': weight_hh_h, 'bias_h': bias_h
    }


## 3. Test GRU Cell


In [4]:
input_size = 4
hidden_size = 3
batch_size = 2

# Khởi tạo parameters
params = init_gru_parameters(input_size, hidden_size)

# Tạo input và initial state
input_current = torch.randn(batch_size, input_size)
h_prev = torch.zeros(batch_size, hidden_size)

print("Input shape:", input_current.shape)
print("Initial hidden state shape:", h_prev.shape)

# Forward pass
h_new = gru_cell_forward(
    input_current, h_prev,
    params['weight_ih_r'], params['weight_hh_r'], params['bias_r'],
    params['weight_ih_z'], params['weight_hh_z'], params['bias_z'],
    params['weight_ih_h'], params['weight_hh_h'], params['bias_h']
)

print("\nAfter forward pass:")
print("New hidden state shape:", h_new.shape)
print("\nNew hidden state:")
print(h_new)


Input shape: torch.Size([2, 4])
Initial hidden state shape: torch.Size([2, 3])

After forward pass:
New hidden state shape: torch.Size([2, 3])

New hidden state:
tensor([[-0.6203, -0.0158,  0.0007],
        [-0.6222, -0.0956,  0.0054]], grad_fn=<AddBackward0>)


In [5]:
def gru_forward(input_sequence, params, batch_first=True):
    """
    Forward pass qua GRU layer cho sequence
    
    Args:
        input_sequence: (batch_size, seq_len, input_size) nếu batch_first=True
                       (seq_len, batch_size, input_size) nếu batch_first=False
        params: Dictionary chứa tất cả parameters
        batch_first: True nếu input là (batch, seq, features)
    
    Returns:
        output: (batch_size, seq_len, hidden_size) nếu batch_first=True
                (seq_len, batch_size, hidden_size) nếu batch_first=False
        final_hidden_state: (batch_size, hidden_size)
    """
    if batch_first:
        batch_size, seq_len, input_size = input_sequence.shape
    else:
        seq_len, batch_size, input_size = input_sequence.shape
    
    hidden_size = params['weight_ih_r'].shape[0]
    
    # Khởi tạo hidden state
    h = torch.zeros(batch_size, hidden_size, device=input_sequence.device, dtype=input_sequence.dtype)
    
    outputs = []
    
    # Xử lý từng time step
    for t in range(seq_len):
        # Lấy input tại time step t
        if batch_first:
            x_t = input_sequence[:, t, :]  # (batch_size, input_size)
        else:
            x_t = input_sequence[t]  # (batch_size, input_size)
        
        # Forward qua GRU cell
        h = gru_cell_forward(
            x_t, h,
            params['weight_ih_r'], params['weight_hh_r'], params['bias_r'],
            params['weight_ih_z'], params['weight_hh_z'], params['bias_z'],
            params['weight_ih_h'], params['weight_hh_h'], params['bias_h']
        )
        
        # Lưu output
        outputs.append(h)
    
    # Stack outputs
    if batch_first:
        output = torch.stack(outputs, dim=1)  # (batch_size, seq_len, hidden_size)
    else:
        output = torch.stack(outputs, dim=0)  # (seq_len, batch_size, hidden_size)
    
    return output, h


## 5. Test GRU Layer


In [6]:
input_size = 4
hidden_size = 3
batch_size = 2
seq_len = 5

# Khởi tạo parameters
params = init_gru_parameters(input_size, hidden_size)

# Tạo input sequence
input_sequence = torch.randn(batch_size, seq_len, input_size)

print("Input shape:", input_sequence.shape)

# Forward pass
output, final_h = gru_forward(input_sequence, params, batch_first=True)

print("\nOutput shape:", output.shape)
print("Final hidden state shape:", final_h.shape)
print("\nOutput:")
print(output)


Input shape: torch.Size([2, 5, 4])

Output shape: torch.Size([2, 5, 3])
Final hidden state shape: torch.Size([2, 3])

Output:
tensor([[[-0.1132, -0.0307,  0.8701],
         [-0.6443, -0.2697,  0.5332],
         [-0.7137, -0.2707,  0.6518],
         [-0.7024, -0.2293,  0.7325],
         [-0.3107, -0.2297,  0.9162]],

        [[ 0.1312, -0.1023,  0.2817],
         [ 0.8074, -0.1091,  0.0809],
         [ 0.4824,  0.5362,  0.3975],
         [ 0.8798,  0.8491,  0.0958],
         [ 0.9736,  0.2081, -0.1529]]], grad_fn=<StackBackward0>)


## 6. So sánh với PyTorch GRU


In [9]:
input_size = 4
hidden_size = 3
batch_size = 2
seq_len = 5

# Our implementation
our_params = init_gru_parameters(input_size, hidden_size)

# PyTorch's implementation
pytorch_gru = nn.GRU(input_size, hidden_size, num_layers=1, batch_first=True)

# Same input
input_sequence = torch.randn(batch_size, seq_len, input_size)

# Forward pass - Our GRU
our_output, our_h = gru_forward(input_sequence, our_params, batch_first=True)

# Forward pass - PyTorch GRU
pytorch_output, pytorch_h = pytorch_gru(input_sequence)

print(pytorch_output, pytorch_h)


tensor([[[-0.0485,  0.2012,  0.1549],
         [ 0.1199,  0.2688,  0.3696],
         [ 0.0655,  0.3771,  0.3831],
         [-0.1321,  0.3802,  0.6871],
         [ 0.0205,  0.3498,  0.7068]],

        [[-0.2097,  0.0193, -0.4212],
         [-0.0569,  0.0637, -0.5055],
         [ 0.0231,  0.1297,  0.0946],
         [-0.0409,  0.1225,  0.0962],
         [-0.3406,  0.1711, -0.1552]]], grad_fn=<TransposeBackward1>) tensor([[[ 0.0205,  0.3498,  0.7068],
         [-0.3406,  0.1711, -0.1552]]], grad_fn=<StackBackward0>)
