# MNIST drifting example

Run training

```bash
python examples/mnist_train.py --out_dir outputs/mnist_sdvae_resnet
```

Then set `OUT_DIR` below and view the saved plots.


In [1]:
from pathlib import Path
import glob
import json
import numpy as np
import matplotlib.pyplot as plt

OUT_DIR = Path('outputs/mnist_sdvae_resnet')  # <- change if needed
print('OUT_DIR:', OUT_DIR.resolve())
print('exists:', OUT_DIR.exists())


OUT_DIR: /Users/wkeely/Desktop/driftax/notebooks/outputs/mnist_sdvae_resnet
exists: False


In [2]:
# Load metrics
mfile = OUT_DIR / 'metrics.json'
if mfile.exists():
    metrics = json.loads(mfile.read_text())
    print(json.dumps(metrics, indent=2)[:2000])
else:
    print('No metrics.json found')


No metrics.json found


In [3]:
# Show key images
def show_png(path, title=None):
    img = plt.imread(path)
    plt.figure(figsize=(5,5))
    plt.imshow(img)
    plt.axis('off')
    if title is None:
        title = path.name
    plt.title(title)
    plt.show()

for name in ['real_grid.png', 'tok_recon.png', 'loss.png']:
    p = OUT_DIR / name
    if p.exists():
        show_png(p)
    else:
        print('missing', p)


missing outputs/mnist_sdvae_resnet/real_grid.png
missing outputs/mnist_sdvae_resnet/tok_recon.png
missing outputs/mnist_sdvae_resnet/loss.png


In [4]:
# Show latest generated samples
paths = sorted(OUT_DIR.glob('samples_step*.png'))
print('num sample grids:', len(paths))
if paths:
    show_png(paths[-1], title=f'Latest samples: {paths[-1].name}')
    
    # optionally show a few earlier
    for p in (paths[:1] + paths[len(paths)//2:len(paths)//2+1] + paths[-2:-1]):
        if p.exists():
            show_png(p, title=p.name)


num sample grids: 0
