In [None]:
# ============================================================
# Cell 1 — Imports & Environment Check
# ============================================================
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import random, math, time, os, sys
print("Torch version:", torch.__version__)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)
torch.manual_seed(0)
random.seed(0)


In [None]:
# ============================================================
# Cell 2 — Copy Task Dataset
# ============================================================
class CopyTaskDataset(Dataset):
    """
    產生經典 Copy Task：
    [random symbols] + [delimiter 8] + [padding 8]  -->  model outputs the random symbols after delimiter
    Symbol set: 0~7  (8 = delimiter/pad)
    """
    def __init__(self, seq_len=10, num_samples=20000):
        self.seq_len = seq_len
        self.num_samples = num_samples
        self.vocab = 9  # 0-8 (8 is delimiter/pad)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # 隨機產生長度 seq_len 的 0-7 符號
        symbols = torch.randint(0, 8, (self.seq_len,), dtype=torch.long)
        x = torch.full((2*self.seq_len + 1,), 8, dtype=torch.long)  # 8 = pad / delimiter
        x[:self.seq_len] = symbols
        x[self.seq_len] = 8  # delimiter 已是 8

        # 目標：前 seq_len 步輸出 8，之後輸出 symbols
        y = torch.full_like(x, 8)
        y[self.seq_len+1:] = symbols
        return x, y

# 建立資料集 / DataLoader
SEQ_LEN = 10
BATCH = 128
train_ds = CopyTaskDataset(seq_len=SEQ_LEN, num_samples=20000)
val_ds   = CopyTaskDataset(seq_len=SEQ_LEN, num_samples=3000)
train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH)


In [None]:
# ============================================================
# Cell 3 — Config & Utility
# ============================================================
cfg = {
    "input_dim": 9,          # vocab size
    "hidden_dim": 128,
    "output_dim": 9,
    "epochs": 20,
    "lr": 1e-3,
    "proj_every": 1,         # 每幾個 batch 做一次 QR 投影；設 0 表示不投影
}

def one_hot(x, num_classes):
    return F.one_hot(x, num_classes=num_classes).float()

def accuracy(pred, target):
    pred_labels = pred.argmax(dim=-1)
    return (pred_labels == target).float().mean().item()


In [None]:
# ============================================================
# Cell 4 — Model Definitions
# ============================================================
class VanillaRNN(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.rnn = nn.RNN(cfg["input_dim"], cfg["hidden_dim"], batch_first=True, nonlinearity="tanh")
        self.fc  = nn.Linear(cfg["hidden_dim"], cfg["output_dim"])

    def forward(self, x):
        # x: (B, T) int64
        x_onehot = one_hot(x, cfg["input_dim"])
        out, _   = self.rnn(x_onehot)
        logits   = self.fc(out)
        return logits


def orthogonalize_weight(W):
    """以 QR 分解將 W 投影到最接近的正交矩陣。
       W: (H, H) Tensor  (in-place 修改)"""
    with torch.no_grad():
        Q, R = torch.linalg.qr(W)
        signs = torch.sign(torch.diag(R))
        Q *= signs.unsqueeze(0)  # 保留原本列向量方向
        W.copy_(Q)

class QRRNN(VanillaRNN):
    def __init__(self, cfg, proj_every=1):
        super().__init__(cfg)
        self.proj_every = proj_every
        self.step_count = 0

    def project_if_needed(self):
        if self.proj_every > 0 and self.step_count % self.proj_every == 0:
            orthogonalize_weight(self.rnn.weight_hh_l0)
        self.step_count += 1


In [None]:
# ============================================================
# Cell 5 — Training & Evaluation Loops
# ============================================================
def run_epoch(model, loader, optimizer=None):
    is_train = optimizer is not None
    total_loss, total_acc, n = 0, 0, 0
    criterion = nn.CrossEntropyLoss()
    model.train(is_train)

    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)

        logits = model(x)                    # (B, T, C)
        loss   = criterion(logits.view(-1, cfg["output_dim"]), y.view(-1))

        if is_train:
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # 避免極端爆炸
            optimizer.step()
            if isinstance(model, QRRNN):
                model.project_if_needed()    # QR 投影

        total_loss += loss.item() * x.size(0)
        total_acc  += accuracy(logits, y) * x.size(0)
        n          += x.size(0)
    return total_loss / n, total_acc / n


def train_model(model, name):
    model.to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
    train_hist, val_hist = [], []
    for epoch in range(cfg["epochs"]):
        tr_loss, tr_acc = run_epoch(model, train_loader, opt)
        va_loss, va_acc = run_epoch(model, val_loader)
        train_hist.append((tr_loss, tr_acc))
        val_hist.append((va_loss, va_acc))
        print(f"[{name}] Epoch {epoch+1:2d}/{cfg['epochs']}  "
              f"Train Loss {tr_loss:.3f} Acc {tr_acc:.3f} | "
              f"Val Loss {va_loss:.3f} Acc {va_acc:.3f}")
    return train_hist, val_hist


In [None]:
# ============================================================
# Cell 6 — Run Experiments
# ============================================================
vanilla = VanillaRNN(cfg)
qr_rnn  = QRRNN(cfg, proj_every=cfg["proj_every"])

hist_vanilla = train_model(vanilla, "Vanilla")
hist_qr      = train_model(qr_rnn,  "QR-RNN")


In [None]:
vanilla_o = VanillaRNN(cfg)
nn.init.orthogonal_(vanilla_o.rnn.weight_hh_l0)
hist_o = train_model(vanilla_o, "Ortho-Init-Only")


In [None]:
# ============================================================
# Cell 7 — Plot Results
# ============================================================
def extract(hist, idx):  # idx=0 for loss, 1 for acc
    return [v[idx] for v in hist]

epochs = range(1, cfg["epochs"]+1)
fig, axes = plt.subplots(1, 2, figsize=(12,4))

# Loss
axes[0].plot(epochs, extract(hist_vanilla[1], 0), label="Vanilla RNN (Val)")
axes[0].plot(epochs, extract(hist_qr[1], 0),      label="QR-RNN (Val)")
axes[0].set_title("Validation Loss"); axes[0].set_xlabel("Epoch"); axes[0].grid(); axes[0].legend()

# Accuracy
axes[1].plot(epochs, extract(hist_vanilla[1], 1), label="Vanilla RNN (Val)")
axes[1].plot(epochs, extract(hist_qr[1], 1),      label="QR-RNN (Val)")
axes[1].set_title("Validation Accuracy"); axes[1].set_xlabel("Epoch"); axes[1].grid(); axes[1].legend()

plt.tight_layout(); plt.show()


In [None]:

def spectral_norm(W):
    return torch.linalg.svdvals(W).max().item()

def run_with_tracking(model, name, steps=100):
    model.to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    grad_hist, spec_hist = [], []

    loader_iter = iter(train_loader)
    for step in range(steps):
        try:
            x, y = next(loader_iter)
        except StopIteration:
            loader_iter = iter(train_loader)
            x, y = next(loader_iter)
        x, y = x.to(DEVICE), y.to(DEVICE)

        logits = model(x)
        loss   = criterion(logits.view(-1, cfg["output_dim"]), y.view(-1))
        opt.zero_grad()
        loss.backward()
        grad_norm = 0.0
        for p in model.parameters():
            if p.grad is not None:
                grad_norm += p.grad.norm().item() ** 2
        grad_norm = math.sqrt(grad_norm)
        grad_hist.append(grad_norm)

        opt.step()
        if isinstance(model, QRRNN):
            model.project_if_needed()

        spec_hist.append(spectral_norm(model.rnn.weight_hh_l0.detach().cpu()))

    return grad_hist, spec_hist


grad_v, spec_v = run_with_tracking(VanillaRNN(cfg), "Vanilla")
grad_q, spec_q = run_with_tracking(QRRNN(cfg),      "QR-RNN")

plt.figure(figsize=(6,4))
plt.plot(grad_v, label="Vanilla")
plt.plot(grad_q, label="QR")
plt.title("Grad-Norm over Steps"); plt.yscale("log"); plt.legend(); plt.grid(); plt.show()

plt.figure(figsize=(6,4))
plt.plot(spec_v, label="Vanilla")
plt.plot(spec_q, label="QR")
plt.title("Spectral Norm  |W_hh|₂ over Steps"); plt.legend(); plt.grid(); plt.show()
