In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

### prepare (download) the MNIST dataset

and show information

In [None]:
# normalize the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.,), (1,))
])

# load the MNIST dataset, without normalization
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)

# labels
print ("\nLabels:", train_dataset.classes)

# digits of the images
print ("\nClasses:", train_dataset.targets)

# shape of the data tensor
print ("\nData shape:", train_dataset.data.shape)

### visualize images

In [None]:
# number of images to visualize
num_examples = 10

# define a dataloader so that we can get images
train_loader = DataLoader(train_dataset, batch_size=num_examples, shuffle=True)
# get some images
for data in train_loader:
    img, label = data 
    break

# print the shape of the image tensor
print (img.shape)

# label of images
print ("\nLabels:", label)

plt.figure(figsize=(8, 4))
# visualize the images 
for i in range(num_examples):
    plt.subplot(1, num_examples, i + 1)
    plt.imshow(img[i].numpy().reshape(28,28), cmap='gray')
    plt.axis('off')

### visualize images according to labels

See:
1. torch.where: 
    https://docs.pytorch.org/docs/stable/generated/torch.where.html    

In [None]:
from torch.utils.data import Subset

# number of images to show 
num_examples = 5

# loop over labels
for label in range(10):
    
    # select indices with the matching label
    indices = torch.where(train_dataset.targets == label)[0]
    
    # define a dataset only with images with matching label
    label_dataset = Subset(train_dataset, indices)

    # define a dataloader for this (sub)-dataset
    label_loader = DataLoader(label_dataset, batch_size=num_examples, shuffle=True)
    
    # get some images
    for data in label_loader:
        img, labels = data 
        break
        
    # labels of all images should be the same    
    print (labels)
    
    plt.figure(figsize=(4, 2))
    for i in range(num_examples):
        plt.subplot(1, num_examples, i + 1)
        plt.imshow(img[i].numpy().reshape(28,28), cmap='gray')
        plt.axis('off')

### Update the image using OU process

Recall the OU process:

$$dX_t = -\kappa X_t dt + \sqrt{2\beta^{-1}} dB_t$$

Euler-Maruyama scheme:

$$ x_{n+1} = x_n - \kappa x_n h + \sqrt{2\beta^{-1}h} \eta_n, \quad n=0,1,2,\dots$$

where $h$ is step-size, $\eta_n$ is standard Gaussian.

see: https://numpy.org/doc/2.2/reference/random/generated/numpy.random.normal.html

In [None]:
import numpy as np

# parameters of the SDE
kappa = 1.0
beta = 3.0

# total simulation time 
# SDE is simulated from t=0 to t=T.
T = 1.0
# number of steps
N = 100
# step-size
h = T/N
# random seed
seed = 1234

# random number generator 
rng = np.random.default_rng(seed=seed)

# get some images
for data in train_loader:
    img, label = data 
    break

# print the shape of the image tensor
print (img.shape)

# convert the data to numpy
X = img.numpy()

for step in range(N):
    
    # please complete the code below!

    # first, generate r from a standard normal gaussian distribution
    # the shape of r should be the same as the shape of X

    # second, update X according to the Euler-Maruyama scheme
        
    # plot the images every 10 steps 
    if step % 10 == 0 :
        plt.figure(figsize=(4, 2))
        for i in range(num_examples):
            plt.subplot(1, num_examples, i + 1)
            plt.imshow(X[i].reshape(28,28), cmap='gray')
            plt.axis('off')
            if i == 0 : 
                plt.title("step=%d" % step)