In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

torch.manual_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [7]:
def tensor_stats(name, t):
    print(f"{name}:")
    print(f"  mean: {t.mean().item():.6f}")
    print(f"  std:  {t.std().item():.6f}")
    print(f"  min:  {t.min().item():.6f}")
    print(f"  max:  {t.max().item():.6f}")
    print()

In [8]:
class TinyLM(nn.Module):
    def __init__(self, vocab_size=24000, d_model=256, tie_weights=False, init_mode="default", std=0.02):
        super().__init__()

        self.std=std

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

        # Optional custom GPT-style init
        if init_mode == "gpt":
            self.apply(self._init_weights)

        # Optional tying
        if tie_weights:
            self.head.weight = self.embedding.weight

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=self.std)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=self.std)

    def forward(self, idx):
        x = self.embedding(idx)
        logits = self.head(x)
        return logits


In [9]:
def run_experiment(init_mode="default", tie_weights=False,std=0.02):
    print("="*60)
    print(f"Init: {init_mode} | Tied: {tie_weights}")
    print("="*60)

    vocab_size = 24000
    d_model = 256
    batch = 16
    seq_len = 32

    model = TinyLM(
        vocab_size=vocab_size,
        d_model=d_model,
        tie_weights=tie_weights,
        init_mode=init_mode,
        std=std
    ).to(device)

    # Inspect weight stats
    tensor_stats("Embedding weight", model.embedding.weight)
    tensor_stats("Head weight", model.head.weight)

    # Random tokens
    idx = torch.randint(0, vocab_size, (batch, seq_len)).to(device)
    targets = torch.randint(0, vocab_size, (batch, seq_len)).to(device)

    logits = model(idx)

    tensor_stats("Logits", logits)

    loss = F.cross_entropy(
        logits.view(-1, vocab_size),
        targets.view(-1)
    )

    print(f"Initial cross-entropy loss: {loss.item():.4f}")

In [10]:
# 1️⃣ Default init, no tying
run_experiment(init_mode="default", tie_weights=False)

Init: default | Tied: False
Embedding weight:
  mean: -0.000112
  std:  0.999893
  min:  -4.866143
  max:  5.105051

Head weight:
  mean: 0.000013
  std:  0.036081
  min:  -0.062500
  max:  0.062500

Logits:
  mean: -0.000117
  std:  0.577106
  min:  -2.990200
  max:  2.994646

Initial cross-entropy loss: 10.2468


In [11]:
# 2️⃣ Default init, tied
run_experiment(init_mode="default", tie_weights=True)

Init: default | Tied: True
Embedding weight:
  mean: 0.000253
  std:  0.999532
  min:  -4.943582
  max:  5.067348

Head weight:
  mean: 0.000253
  std:  0.999532
  min:  -4.943582
  max:  5.067348

Logits:
  mean: 0.018177
  std:  16.075432
  min:  -83.414558
  max:  322.034058

Initial cross-entropy loss: 256.9717


In [12]:
# 3️⃣ GPT init, no tying
run_experiment(init_mode="gpt", tie_weights=False)

Init: gpt | Tied: False
Embedding weight:
  mean: -0.000001
  std:  0.020006
  min:  -0.099682
  max:  0.104341

Head weight:
  mean: 0.000005
  std:  0.020001
  min:  -0.099009
  max:  0.106684

Logits:
  mean: 0.000000
  std:  0.006390
  min:  -0.032634
  max:  0.034525

Initial cross-entropy loss: 10.0852


In [13]:
# 4️⃣ GPT init, tied
run_experiment(init_mode="gpt", tie_weights=True)

Init: gpt | Tied: True
Embedding weight:
  mean: -0.000017
  std:  0.019997
  min:  -0.098965
  max:  0.099343

Head weight:
  mean: -0.000017
  std:  0.019997
  min:  -0.098965
  max:  0.099343

Logits:
  mean: 0.000006
  std:  0.006420
  min:  -0.033634
  max:  0.126838

Initial cross-entropy loss: 10.0858


In [14]:
# 4️⃣ GPT init, tied, big weights
run_experiment(init_mode="gpt", tie_weights=True, std=0.2)

Init: gpt | Tied: True
Embedding weight:
  mean: -0.000049
  std:  0.199923
  min:  -1.093782
  max:  1.081741

Head weight:
  mean: -0.000049
  std:  0.199923
  min:  -1.093782
  max:  1.081741

Logits:
  mean: 0.000782
  std:  0.641566
  min:  -3.383894
  max:  13.013767

Initial cross-entropy loss: 11.0418
