In [1]:
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import TensorDataset, DataLoader

In [10]:
trans = transforms.ToTensor()

train_data = MNIST(root='./data', train=True, download=True, transform=trans)
test_data = MNIST(root='./data', train=False, download=True, transform=trans)

train_loader = DataLoader(train_data, shuffle=True, batch_size=128)
test_loader = DataLoader(test_data, batch_size=128)

In [None]:
nn.LSTM

In [None]:
input_size = 28
seq_len = 28
hidden_size = 128
output_size = 10

Wx = torch.normal(0, 1, (input_size, hidden_size * 4), requires_grad=True)
Wh = torch.normal(0, 1, (hidden_size, hidden_size * 4), requires_grad=True)	
b = torch.normal(0, 1, (1, hidden_size * 4), requires_grad=True)


def MyLSTM(X, Wx, Wh, b, hidden_size=128):
	batch_size, seq_len, input_size = X.shape
	h = torch.zeros(batch_size, hidden_size)
	c = torch.zeros(batch_size, hidden_size)
	for s in range(seq_len):
		combined = torch.matmul(X[:, s, :], Wx) + torch.matmul(h, Wh) + b
		f, i, g, o = torch.chunk(combined, 4, dim=-1)
		c_next = c * torch.sigmoid(f) + torch.tanh(g) * torch.sigmoid(i)
		h_next = torch.sigmoid(o) * torch.tanh(c_next)
		c = c_next
		h = h_next
	
	return h
	
# 示例

X = torch.normal(0, 1, (10, 28, 28))
MyLSTM(X, Wx, Wh, b).shape

torch.Size([10, 128])

In [62]:
epochs = 10
lr = 1e-2


import torch.nn.functional as F

W_out = torch.randn(hidden_size, output_size, requires_grad=True)
b_out = torch.zeros(1, output_size, requires_grad=True)


for epoch in range(epochs):
	total_loss = 0
	correct = 0
	for X, y in train_loader:
		# print(X.shape)
		h = MyLSTM(X.squeeze(), Wx, Wh, b)
		logits = torch.matmul(h, W_out) + b_out
		loss = F.cross_entropy(logits, y)
			
		loss.backward()

		# 梯度下降
		with torch.no_grad():
			for param in [Wx, Wh, b, W_out, b_out]:
				param -= lr * param.grad
				param.grad.zero_()

		total_loss += loss.item() * len(X)
		correct += (logits.argmax(dim=1) == y).sum().item()

	acc = correct / len(train_data)
	print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_data) :.4f}, Train Acc: {acc:.4f}")

Epoch 1, Loss: 2.9985, Train Acc: 0.2182
Epoch 2, Loss: 1.9577, Train Acc: 0.2971
Epoch 3, Loss: 1.9469, Train Acc: 0.2992
Epoch 4, Loss: 1.8435, Train Acc: 0.3454
Epoch 5, Loss: 1.7690, Train Acc: 0.3740
Epoch 6, Loss: 1.6905, Train Acc: 0.4091
Epoch 7, Loss: 1.5922, Train Acc: 0.4459
Epoch 8, Loss: 1.4930, Train Acc: 0.4814
Epoch 9, Loss: 1.4151, Train Acc: 0.5077
Epoch 10, Loss: 1.3520, Train Acc: 0.5291


In [80]:
input_size = 28
seq_len = 28
hidden_size = 128
output_size = 10


epochs = 10
lr = 1e-2

class RNNClassifier(nn.Module):
	def __init__(self):
		super().__init__()
		self.rnn = nn.LSTM(input_size, hidden_size, batch_first=True)
		self.fc = nn.Linear(hidden_size, output_size)

	def forward(self, x):
		# x: (batch, seq_len, input_size)
		outputs, _ = self.rnn(x)  # outputs: (batch, seq_len, hidden_size)
		last_output = outputs[:, -1, :]  # (batch, hidden_size)
		return self.fc(last_output)     # (batch, output_size)

model = RNNClassifier()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = torch.nn.CrossEntropyLoss()
for epoch in range(epochs):
	total_loss = 0
	correct = 0
	for X, y in train_loader:
		X = X.squeeze()
		logits = model(X)
		# print(logits.shape)
		loss = loss_fn(logits, y)

		loss.backward()
		optimizer.step()
		optimizer.zero_grad()

		total_loss += loss.item() * len(X)
		correct += (logits.argmax(dim=1) == y).sum().item()

	acc = correct / len(train_data)
	print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_data) :.4f}, Train Acc: {acc:.4f}")

Epoch 1, Loss: 0.4085, Train Acc: 0.8670
Epoch 2, Loss: 0.1038, Train Acc: 0.9704
Epoch 3, Loss: 0.0771, Train Acc: 0.9776
Epoch 4, Loss: 0.0700, Train Acc: 0.9797
Epoch 5, Loss: 0.0607, Train Acc: 0.9822
Epoch 6, Loss: 0.0473, Train Acc: 0.9859
Epoch 7, Loss: 0.0508, Train Acc: 0.9849
Epoch 8, Loss: 0.0403, Train Acc: 0.9883
Epoch 9, Loss: 0.0393, Train Acc: 0.9880
Epoch 10, Loss: 0.0394, Train Acc: 0.9879
