In [1]:
# --- Notebook A: Generate and save OOD datasets ---
import math
from pathlib import Path
import torch
import torch.nn.functional as F

# ---------- generators: return float tensor [H,W] in [0,1] ----------
def gen_gaussian(g, H=28, W=28):
    x = torch.randn(H, W, generator=g) * 0.25 + 0.5
    return x.clamp(0.0, 1.0)

def gen_horizontal_lines(g, H=28, W=28):
    x = torch.zeros(H, W)
    num = torch.randint(3, 10, (1,), generator=g).item()
    thick = torch.randint(1, 3, (1,), generator=g).item()
    for _ in range(num):
        row = torch.randint(0, H, (1,), generator=g).item()
        val = 0.7 + 0.3 * torch.rand((), generator=g).item()
        x[row:row+thick, :] = val
    k = torch.tensor([[1,2,1],[2,4,2],[1,2,1]], dtype=torch.float32); k = k / k.sum()
    x = F.conv2d(x.view(1,1,H,W), k.view(1,1,3,3), padding=1).squeeze()
    return x.clamp(0.0, 1.0)

def gen_spiral(g, H=28, W=28):
    x = torch.zeros(H, W)
    cx, cy = (W - 1)/2.0, (H - 1)/2.0
    turns, T = 3.0, 800
    max_r = min(H, W)/2.0 - 2.0
    for t in torch.linspace(0, turns * 2 * math.pi, T):
        tf = float(t)
        r  = (tf / (turns * 2 * math.pi)) * max_r
        px = cx + r * math.cos(tf)
        py = cy + r * math.sin(tf)
        ix, iy = int(round(px)), int(round(py))
        if 0 <= ix < W and 0 <= iy < H:
            x[iy, ix] = 1.0
            for dx in (-1,0,1):
                for dy in (-1,0,1):
                    jx, jy = ix+dx, iy+dy
                    if 0 <= jx < W and 0 <= jy < H:
                        x[jy, jx] = max(x[jy, jx], torch.tensor(0.8))
    k = torch.tensor([[0,1,0],[1,4,1],[0,1,0]], dtype=torch.float32); k = k / k.sum()
    x = F.conv2d(x.view(1,1,H,W), k.view(1,1,3,3), padding=1).squeeze()
    return x.clamp(0.0, 1.0)

def gen_checkerboard(g, H=28, W=28, cell=None):
    x = torch.zeros(H, W, dtype=torch.float32)
    cell = cell or torch.randint(2, 6, (1,), generator=g).item()  # 2..5 px
    phase_x = torch.randint(0, 2, (1,), generator=g).item()
    phase_y = torch.randint(0, 2, (1,), generator=g).item()
    yy, xx = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
    tiles = ((xx // cell + phase_x) + (yy // cell + phase_y)) % 2
    bright = 0.8 + 0.2 * torch.rand((), generator=g).item()
    dark   = 0.05 + 0.15 * torch.rand((), generator=g).item()
    x = torch.where(tiles == 0, torch.full_like(x, bright), torch.full_like(x, dark))
    k = torch.tensor([[1.,2.,1.],[2.,4.,2.],[1.,2.,1.]], dtype=torch.float32); k = k / k.sum()
    x = F.conv2d(x.view(1,1,H,W), k.view(1,1,3,3), padding=1).squeeze()
    return x.clamp(0.0, 1.0)

GENS = {
    "gaussian": gen_gaussian,
    "horizontal_lines": gen_horizontal_lines,
    "spiral": gen_spiral,
    "checkerboard": gen_checkerboard,
}

# ---------- generation helper ----------
def make_split(kind, n, seed):
    g = torch.Generator().manual_seed(seed)
    imgs = []
    for _ in range(n):
        x = GENS[kind](g, H=28, W=28)                       # [28,28] float in [0,1]
        imgs.append(x.unsqueeze(0))
    imgs = torch.cat(imgs, dim=0)                           # [N,28,28]
    data_u8 = (imgs * 255.0).round().clamp(0,255).to(torch.uint8)  # uint8 like MNIST
    labels = torch.full((n,), -1, dtype=torch.long)         # all -1 for OOD
    return data_u8, labels

# ---------- configure sizes & seeds ----------
root = Path("./data/ood")
root.mkdir(parents=True, exist_ok=True)

per_kind = 1000   # change if you want bigger/smaller
base_seed = 0

# ---------- generate and save separate files ----------
paths = {}
all_imgs = []
all_labels = []

for i, kind in enumerate(["gaussian", "horizontal_lines", "spiral", "checkerboard"]):
    seed = base_seed + 97*i
    data_u8, labels = make_split(kind, per_kind, seed)
    path = root / f"ood_{kind}.pt"
    torch.save((data_u8, labels), path)   # save as tuple (data, targets) like MNIST
    paths[kind] = str(path)
    all_imgs.append(data_u8)
    all_labels.append(labels)
    print(f"Saved {kind:>16s} -> {path}  ({len(data_u8)} samples)")

# ---------- also save a combined file ----------
data_all = torch.cat(all_imgs, dim=0)
labels_all = torch.cat(all_labels, dim=0)
combined_path = root / "ood_all.pt"
torch.save((data_all, labels_all), combined_path)
print(f"Saved {'combined':>16s} -> {combined_path}  ({len(data_all)} samples)")


Saved         gaussian -> data/ood/ood_gaussian.pt  (1000 samples)
Saved horizontal_lines -> data/ood/ood_horizontal_lines.pt  (1000 samples)
Saved           spiral -> data/ood/ood_spiral.pt  (1000 samples)
Saved     checkerboard -> data/ood/ood_checkerboard.pt  (1000 samples)
Saved         combined -> data/ood/ood_all.pt  (4000 samples)
