In [None]:
import os
os.environ["HF_HOME"] = "/tmp/wendler2/.hfcache"

In [None]:
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
from matplotlib import pyplot as plt

# Load HF MNIST
ds = load_dataset("ylecun/mnist")

# Custom PyTorch Dataset
class MNISTDataset(Dataset):
    def __init__(self, hf_split):
        self.data = hf_split

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # Convert PIL image to tensor (0–1)
        img = item["image"].convert("L")
        img = torch.tensor(np.array(img), dtype=torch.float32) / 255.0

        # Normalize to [-1, 1]
        img = img * 2 - 1

        # Add channel dimension: (1, H, W)
        img = img.unsqueeze(0)

        label = torch.tensor(item["label"], dtype=torch.long)
        return img, label

# Build datasets
train_dataset = MNISTDataset(ds["train"])
test_dataset  = MNISTDataset(ds["test"])

# Build dataloaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=64, shuffle=False)


In [None]:
x, y = next(iter(train_loader))

In [None]:
from torch import nn
import torch as t
from dit import DiT

In [None]:


from torch.optim import AdamW

model = DiT(28, 28, patch_size=2, n_classes=11, d=12*16, n_head=12, n_blocks=6)
opt = AdamW(model.parameters(), lr=1e-3, betas=(0.9, 0.95))
model.cuda()


In [None]:
from tqdm import tqdm

for _ in range(10):
  pbar = tqdm(train_loader)
  for x, y in pbar:
    x = x.cuda()
    y = y.cuda()+1
    y[t.rand(y.shape[0]) < 0.2] = 0

    ts = nn.functional.sigmoid(t.randn(x.shape[0], device=x.device, dtype=x.dtype))
    z = t.randn_like(x, dtype=x.dtype, device=x.device)
    v_true = x - z
    x_t = x - ts[:, None, None, None]*v_true
    v_pred = model(x_t, y, ts)
    loss = nn.functional.mse_loss(v_pred, v_true)

    loss.backward()
    t.nn.utils.clip_grad_norm_(model.parameters(), 10.0)
    pbar.set_postfix(loss=loss.item())
    opt.step()
    opt.zero_grad()


In [None]:
#import os 
#os.makedirs("checkpoints", exist_ok=True)
#t.save(model.state_dict(), "checkpoints/model.pt")

In [None]:
@t.no_grad()
def sample(model, z, y, n_steps=10, cfg=0):
  ts = t.linspace(1, 0, n_steps+1, device=z.device, dtype=z.dtype)
  ts = 3*ts / (2*ts+1) # sd3 scheduler
  for idx in range(n_steps):
    v_pred = model(z, y, ts[idx]*t.ones(z.shape[0], dtype=z.dtype, device=z.device))
    if cfg > 0:
      v_uncond = model(z, y*0, ts[idx]*t.ones(z.shape[0], dtype=z.dtype, device=z.device))
      v_pred = v_uncond + cfg*(v_pred - v_uncond)
    z = z + (ts[idx]-ts[idx+1])*v_pred
  return z

In [None]:
z = t.randn_like(x, dtype=x.dtype, device=x.device)
for num in range(11):
  print(num)
  if num > 0:
    cfg = 3
  else:
    cfg = 0
  x_pred = sample(model, z, y*0 + num, cfg=cfg, n_steps=30)
  import matplotlib.pyplot as plt

  fig, axes = plt.subplots(4, 8, figsize=(8, 4))
  axes = axes.flatten()

  for i in range(32):
      img = x_pred[i].squeeze().detach().cpu().clamp(-1, 1)
      img = (img + 1) / 2  # map back to [0,1]

      axes[i].imshow(img, cmap='gray')
      axes[i].axis('off')

  plt.tight_layout()
  plt.show()


In [None]:
z = t.randn_like(x, dtype=x.dtype, device=x.device)
for num in range(11):
  print(num)
  if num > 0:
    cfg = 0
  else:
    cfg = 0
  x_pred = sample(model, z, y*0 + num, cfg=cfg, n_steps=30)
  import matplotlib.pyplot as plt

  fig, axes = plt.subplots(4, 8, figsize=(8, 4))
  axes = axes.flatten()

  for i in range(32):
      img = x_pred[i].squeeze().detach().cpu().clamp(-1, 1)
      img = (img + 1) / 2  # map back to [0,1]

      axes[i].imshow(img, cmap='gray')
      axes[i].axis('off')

  plt.tight_layout()
  plt.show()

