# Create Moving MNIST Dataset from Scratch

## Download MNIST Dataset

In [None]:
import tensorflow as tf

# Load MNIST data
(x_train, _), (x_val, _) = tf.keras.datasets.mnist.load_data()

# Reshape the data to include the channel dimension
x_train = x_train[..., tf.newaxis]  # shape: (60000, 28, 28, 1)
x_val = x_val[..., tf.newaxis]      # shape: (10000, 28, 28, 1)

In [None]:
print(x_train.shape)
print(x_val.shape)

(60000, 28, 28, 1)
(10000, 28, 28, 1)


## Download CIFAR10 Dataset

In [None]:
import tensorflow as tf

# Load MNIST data
(x_train, _), (x_val, _) = tf.keras.datasets.cifar10.load_data()

In [None]:
print(x_train.shape)
print(x_val.shape)

(50000, 32, 32, 3)
(10000, 32, 32, 3)


## Define Trajectory

In [None]:
import numpy as np

image_size = 64  # Size of the MNIST image
digit_size = x_train.shape[1]  # Digit occupies the whole image
step_length = .5  # Step length for each move in the sequence

def get_random_trajectory(seq_length):
    """Generate a random trajectory."""
    canvas_size = image_size - digit_size
    x, y, v_x, v_y = np.random.random(4)
    out_x, out_y = [], []

    for _ in range(seq_length):
        # Take a step along the velocity.
        y += v_y * step_length
        x += v_x * step_length

        # Bounce off edges.
        if x <= 0:
            x = 0
            v_x = -v_x
        if x >= 1.0:
            x = 1.0
            v_x = -v_x
        if y <= 0:
            y = 0
            v_y = -v_y
        if y >= 1.0:
            y = 1.0
            v_y = -v_y

        # Store the position scaled to the canvas size
        out_x.append(int(x * canvas_size))
        out_y.append(int(y * canvas_size))

    return np.array(out_x), np.array(out_y)

get_random_trajectory(30)

(array([28, 31, 35, 36, 32, 29, 25, 22, 18, 15, 11,  8,  4,  1,  0,  3,  6,
        10, 13, 17, 20, 24, 27, 31, 34, 36, 32, 29, 25, 22]),
 array([34, 36, 24, 13,  2,  0, 11, 22, 33, 36, 24, 13,  2,  0, 11, 22, 33,
        36, 24, 13,  2,  0, 11, 22, 33, 36, 24, 13,  2,  0]))

## Generate Moving MNIST Sequences

In [None]:
def generate_moving_mnist_sequence(digit, seq_length=60):
    """Generate a sequence of a single digit moving in a canvas, supporting both grayscale and RGB."""

    # Determine the number of channels (1 for grayscale, 3 for RGB)
    channels = digit.shape[-1] if digit.ndim == 3 else 1

    # Generate the trajectory for the movement
    x_trajectory, y_trajectory = get_random_trajectory(seq_length)

    # Initialize the sequence with the appropriate channel dimension
    sequence = np.zeros((seq_length, image_size, image_size, channels), dtype=np.uint8)

    for t in range(seq_length):
        # Place the digit at the trajectory position
        top_left_x = x_trajectory[t]
        top_left_y = y_trajectory[t]

        # Create a blank frame with the same channel dimension as the digit
        frame = np.zeros((image_size, image_size, channels), dtype=np.uint8)

        # Place the digit within the frame, ensuring the correct channel placement
        if channels == 1:  # Grayscale
            frame[top_left_y:top_left_y + digit_size, top_left_x:top_left_x + digit_size, 0] = digit.squeeze()
        else:  # RGB
            frame[top_left_y:top_left_y + digit_size, top_left_x:top_left_x + digit_size, :] = digit

        # Assign the frame to the sequence at time t
        sequence[t] = frame

    return sequence

## Create the Training and Validation Sequences

In [None]:
from tqdm import tqdm

In [None]:
import numpy as np
from tqdm import tqdm  # For progress bar

# Determine the number of channels in the dataset (assumes all images have the same number of channels)
channels = x_train.shape[-1] if x_train.ndim == 4 else 1

# Initialize x_train_seq and x_val_seq with the correct shape, based on channels
seq_length = 40
x_train_seq = np.zeros((x_train.shape[0], seq_length, image_size, image_size, channels), dtype=np.uint8)
x_val_seq = np.zeros((x_val.shape[0], seq_length, image_size, image_size, channels), dtype=np.uint8)

# Populate x_train_seq and x_val_seq by generating sequences
for i in tqdm(range(x_train.shape[0])):
    x_train_seq[i] = generate_moving_mnist_sequence(x_train[i], seq_length)

for i in tqdm(range(x_val.shape[0])):
    x_val_seq[i] = generate_moving_mnist_sequence(x_val[i], seq_length)


100%|██████████| 60000/60000 [00:11<00:00, 5215.52it/s]
100%|██████████| 10000/10000 [00:01<00:00, 5224.25it/s]


In [None]:
print(x_train_seq.shape)
print(x_val_seq.shape)

(60000, 40, 64, 64, 1)
(10000, 40, 64, 64, 1)


## Slice Sequences to Create Input-Output Pairs

In [None]:
mid_seq_length = x_train_seq.shape[1] // 2
print(mid_seq_length)

20


In [None]:
# Split x_train_seq and x_val_seq along the sequence axis into input and output
x_train_in = x_train_seq[:, :mid_seq_length, ...]
x_train_out = x_train_seq[:, mid_seq_length:, ...]

x_val_in = x_val_seq[:, :mid_seq_length, ...]
x_val_out = x_val_seq[:, mid_seq_length:, ...]

In [None]:
print(x_train_in.shape)
print(x_train_out.shape)
print(x_val_in.shape)
print(x_val_out.shape)

(60000, 20, 64, 64, 1)
(60000, 20, 64, 64, 1)
(10000, 20, 64, 64, 1)
(10000, 20, 64, 64, 1)


## Visualization in `.gif` file

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import imageio.v2 as imageio
import io
import random

# Select a few random examples from the dataset
num_samples = 5  # Number of sequences to display
seq_length = 30  # Length of each sequence to display

# Randomly select indices for sample display
sample_indices = random.sample(range(len(x_train_in)), num_samples)

# Prepare arrays to store the frames for the selected samples
truth_frames_in = x_train_in[sample_indices]  # Shape (5, 30, 28, 28, 1)
truth_frames_out = x_train_out[sample_indices]  # Shape (5, 30, 28, 28, 1)

print(truth_frames_in.shape)
print(truth_frames_out.shape)

(5, 20, 64, 64, 1)
(5, 20, 64, 64, 1)


In [None]:
# Get the minimum sequence length available in the samples to avoid out-of-bounds errors
seq_length = min(truth_frames_in.shape[1], truth_frames_out.shape[1], seq_length)

def create_gif(frames, file_name):
    """Creates a GIF from the frames."""
    frames_list = []
    for frame_idx in range(seq_length):  # Loop through the sequence frames
        fig, axs = plt.subplots(num_samples, 1, figsize=(5, 17))
        fig.subplots_adjust(wspace=0.1, hspace=0.1)

        for row in range(num_samples):
            # Display the true frame in grayscale
            axs[row].imshow(frames[row, frame_idx].squeeze(), cmap="gray")
            axs[row].axis("off")
            axs[row].set_title("Ground Truth")

        # Add timestamp to the top of the figure
        fig.text(0.5, 0.92, f'Timestamp: t={frame_idx + 1}', ha='center', va='center', transform=fig.transFigure)

        # Convert figure to image and add to frames list
        buf = io.BytesIO()
        plt.savefig(buf, format="png")
        buf.seek(0)
        frames_list.append(imageio.imread(buf))
        buf.close()
        plt.close(fig)

    # Save the frames list as a single GIF
    imageio.mimsave(f"{file_name}.gif", frames_list, duration=300)

# Generate and save GIFs for x_train_in and x_train_out
create_gif(truth_frames_in, "x_train_in")
create_gif(truth_frames_out, "x_train_out")

print("GIFs saved as 'x_train_in.gif' and 'x_train_out.gif'")
