In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# Define an Autoencoder model
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # Downsample to 14x14
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # Downsample to 7x7
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),  # To keep pixel values between 0 and 1
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Initialize the model, loss function, and optimizer
model = Autoencoder()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 10
for epoch in range(epochs):
    for images, _ in train_loader:
        # Forward pass
        reconstructed = model(images)
        loss = criterion(reconstructed, images)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")



100%|██████████| 9.91M/9.91M [00:00<00:00, 56.0MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.69MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.0MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 4.56MB/s]


Epoch [1/10], Loss: 0.8960
Epoch [2/10], Loss: 0.8759
Epoch [3/10], Loss: 0.8795
Epoch [4/10], Loss: 0.8787
Epoch [5/10], Loss: 0.8770
Epoch [6/10], Loss: 0.8844
Epoch [7/10], Loss: 0.8920
Epoch [8/10], Loss: 0.8757
Epoch [9/10], Loss: 0.8749
Epoch [10/10], Loss: 0.8767


In [2]:
## Strong LLM
import jax
import jax.numpy as jnp
from jax import random, jit, value_and_grad
import optax
import flax.linen as nn
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np

# ---------------------------
# Data Loading and Preprocessing
# ---------------------------
# The original PyTorch code applies ToTensor and Normalize((0.5,), (0.5,)).
# When MNIST images (originally uint8 in [0, 255]) are converted to float and normalized,
# the transformation is: x -> (x/255 - 0.5)/0.5 = 2*x/255 - 1.
# We'll do the same here.

def preprocess(example):
    image = example['image']  # shape (28, 28, 1) or (28, 28)
    # Convert to float32 in [0,1]
    image = tf.cast(image, tf.float32) / 255.0
    # Normalize to [-1, 1] as in PyTorch
    image = (image - 0.5) / 0.5
    # Ensure image has channel dimension (H, W, C)
    if image.shape.ndims == 2:
        image = tf.expand_dims(image, axis=-1)
    return {'image': image}

batch_size = 64
# Load and preprocess the training set
train_ds = tfds.load('mnist', split='train', shuffle_files=True)
train_ds = train_ds.map(preprocess)
train_ds = train_ds.shuffle(1024).batch(batch_size).prefetch(1)
# (For testing, a similar pipeline can be built from the 'test' split)

# ---------------------------
# Model Definition using Flax
# ---------------------------
class Autoencoder(nn.Module):
    @nn.compact
    def __call__(self, x):
        # Encoder
        x = nn.Conv(features=32, kernel_size=(3, 3), strides=(1, 1), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='VALID')  # 28->14
        x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='VALID')  # 14->7

        # Decoder
        x = nn.ConvTranspose(features=32, kernel_size=(3, 3), strides=(2, 2), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.ConvTranspose(features=1, kernel_size=(3, 3), strides=(2, 2), padding='SAME')(x)
        x = nn.sigmoid(x)  # Keep output values between 0 and 1
        return x

# ---------------------------
# Initialize Model and Optimizer
# ---------------------------
rng = random.PRNGKey(0)
# Dummy input with shape [batch, height, width, channels]; note that Flax defaults to NHWC.
dummy_input = jnp.ones([1, 28, 28, 1])
model = Autoencoder()
params = model.init(rng, dummy_input)

# Set up the optimizer (Adam with learning rate 0.001)
optimizer = optax.adam(learning_rate=0.001)
opt_state = optimizer.init(params)

# ---------------------------
# Loss and Training Step
# ---------------------------
def mse_loss(params, batch):
    """Compute mean squared error between the reconstruction and input."""
    recon = model.apply(params, batch)
    return jnp.mean((recon - batch) ** 2)

@jit
def train_step(params, opt_state, images):
    loss, grads = value_and_grad(mse_loss)(params, images)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

# ---------------------------
# Training Loop
# ---------------------------
epochs = 10
for epoch in range(epochs):
    # Iterate over the TFDS training dataset (convert batches to numpy arrays)
    for batch in tfds.as_numpy(train_ds):
        images = batch['image']  # shape: [batch, 28, 28, 1]
        params, opt_state, loss = train_step(params, opt_state, images)
    print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss:.4f}")




Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/mnist/3.0.1...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/2 [00:00<?, ? splits/s]

Generating train examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/mnist/incomplete.6PO0CP_3.0.1/mnist-train.tfrecord*...:   0%|          | 0…

Generating test examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/mnist/incomplete.6PO0CP_3.0.1/mnist-test.tfrecord*...:   0%|          | 0/…

Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
Epoch [1/10], Loss: 0.9271


KeyboardInterrupt: 

In [4]:
## Weak LLM
## get new one, different from the set on our drives
import jax
import jax.numpy as jnp
import optax
from flax import linen as nn
import numpy as np
import matplotlib.pyplot as plt
from tensorflow_datasets import load
from tensorflow import data
import tensorflow as tf

# Load MNIST dataset
def preprocess_fn(image):
    image = tf.cast(image, tf.float32) / 255.0
    image = (image - 0.5) / 0.5  # Normalize to range [-1, 1]
    return image

train_ds, test_ds = load('mnist', with_info=False, as_supervised=True)
train_ds = train_ds.map(lambda image, label: preprocess_fn(image), num_parallel_calls=tf.data.AUTOTUNE)
test_ds = test_ds.map(lambda image, label: preprocess_fn(image), num_parallel_calls=tf.data.AUTOTUNE)

train_ds = train_ds.batch(64).shuffle(10000)
test_ds = test_ds.batch(64)

# Define the Autoencoder model using Flax (JAX's neural network library)
class Autoencoder(nn.Module):
    def setup(self):
        # Encoder
        self.encoder = nn.Sequential([
            nn.Conv(32, (3, 3), padding='SAME'),
            nn.relu,
            nn.max_pool((2, 2)),
            nn.Conv(64, (3, 3), padding='SAME'),
            nn.relu,
            nn.max_pool((2, 2)),
        ])
        # Decoder
        self.decoder = nn.Sequential([
            nn.ConvTranspose(32, (3, 3), strides=(2, 2), padding='SAME'),
            nn.relu,
            nn.ConvTranspose(1, (3, 3), strides=(2, 2), padding='SAME'),
            nn.sigmoid
        ])

    def __call__(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Initialize the model, loss function, and optimizer
model = Autoencoder()
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 28, 28, 1)))

def mse_loss(reconstructed, original):
    return jnp.mean((reconstructed - original) ** 2)

optimizer = optax.adam(learning_rate=0.001)
opt_state = optimizer.init(params)

# Training loop
epochs = 10

@jax.jit
def update(params, opt_state, batch):
    def loss_fn(params):
        reconstructed = model.apply(params, batch)
        loss = mse_loss(reconstructed, batch)
        return loss
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(params)
    updates, new_opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss

for epoch in range(epochs):
    for batch in train_ds:
        batch = np.expand_dims(batch.numpy(), axis=-1)  # Add channel dimension
        params, opt_state, loss = update(params, opt_state, batch)
    print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss:.4f}")


AttributeError: 'str' object has no attribute 'map'

In [None]:
"""Error Code
train_ds, test_ds = load('mnist', with_info=False, as_supervised=True)
--->train_ds = train_ds.map(lambda image, label: preprocess_fn(image), num_parallel_calls=tf.data.AUTOTUNE)

Error:
AttributeError: 'str' object has no attribute 'map'

Fix guide
This error occurs because the MNIST dataset was not loaded with the
proper split information. Without specifying the split, the load function
returns a single object (or even a string) rather than
separate tf.data.Dataset objects. To fix this, explicitly specify
the 'split' parameter to obtain the 'train' and 'test' datasets.
Additionally, import and use tensorflow_datasets (tfds) correctly.

Correct code
import tensorflow_datasets as tfds
train_ds, test_ds = tfds.load(
    'mnist',
    split=['train', 'test'],
    as_supervised=True,
    with_info=False
"""

"""Error Code
nn.max_pool(window_shape=(2, 2), strides=(2, 2), padding='VALID')

Error
TypeError: max_pool() missing 1 required positional argument: 'inputs'

Fix guide:
wrap nn.max_pool inside a lambda function so that it
becomes a callable layer in the sequential list

Correct code
lambda x: nn.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='VALID'),
"""

In [7]:
### Fixed Code

## get new one, different from the set on our drives
import jax
import jax.numpy as jnp
import optax
from flax import linen as nn
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds

# ---------------------------
# Data Loading and Preprocessing
# ---------------------------
def preprocess_fn(image, label):
    # Convert image to float32 and scale to [0, 1]
    image = tf.cast(image, tf.float32) / 255.0
    # Normalize to [-1, 1] as (x - 0.5) / 0.5
    image = (image - 0.5) / 0.5
    # Ensure the image has a channel dimension (28,28) -> (28,28,1)
    if image.shape.rank == 2:
        image = tf.expand_dims(image, -1)
    return image

# Specify the 'split' to load both train and test datasets correctly.
train_ds, test_ds = tfds.load('mnist', split=['train', 'test'], as_supervised=True, with_info=False)

# Apply preprocessing (ignoring the label similar to the PyTorch code)
train_ds = train_ds.map(preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE)
test_ds = test_ds.map(preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE)

# Shuffle and batch the training dataset (batch_size=64) and batch the test dataset.
train_ds = train_ds.shuffle(10000).batch(64).prefetch(tf.data.AUTOTUNE)
test_ds = test_ds.batch(64).prefetch(tf.data.AUTOTUNE)

# ---------------------------
# Define the Autoencoder Model using Flax
# ---------------------------
class Autoencoder(nn.Module):
    def setup(self):
        # Encoder: Two Conv layers with ReLU and Max Pooling (downsampling)
        self.encoder = nn.Sequential([
          nn.Conv(32, kernel_size=(3, 3), padding='SAME'),
          nn.relu,
          lambda x: nn.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='VALID'),
          nn.Conv(64, kernel_size=(3, 3), padding='SAME'),
          nn.relu,
          lambda x: nn.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='VALID'),
      ])
        # Decoder: Two ConvTranspose layers with ReLU and final Sigmoid
        self.decoder = nn.Sequential([
            nn.ConvTranspose(32, kernel_size=(3, 3), strides=(2, 2), padding='SAME'),
            nn.relu,
            nn.ConvTranspose(1, kernel_size=(3, 3), strides=(2, 2), padding='SAME'),
            nn.sigmoid  # Constrain output pixel values to [0, 1]
        ])

    def __call__(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# ---------------------------
# Initialize Model, Loss, and Optimizer
# ---------------------------
model = Autoencoder()
# Flax expects NHWC; create a dummy input of shape [1, 28, 28, 1]
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 28, 28, 1)))

def mse_loss(reconstructed, original):
    return jnp.mean((reconstructed - original) ** 2)

optimizer = optax.adam(learning_rate=0.001)
opt_state = optimizer.init(params)

# ---------------------------
# Training Step Function (using JIT)
# ---------------------------
@jax.jit
def update(params, opt_state, batch):
    def loss_fn(params):
        reconstructed = model.apply(params, batch)
        return mse_loss(reconstructed, batch)
    loss, grads = jax.value_and_grad(loss_fn)(params)
    updates, new_opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss

# ---------------------------
# Training Loop
# ---------------------------
epochs = 5
for epoch in range(epochs):
    for batch in tfds.as_numpy(train_ds):
        # Each batch is preprocessed to shape (batch, 28, 28, 1) already.
        if batch.ndim == 3:
            batch = np.expand_dims(batch, axis=-1)
        params, opt_state, loss = update(params, opt_state, batch)
    print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss:.4f}")


Epoch [1/5], Loss: 0.9252
Epoch [2/5], Loss: 0.8755
Epoch [3/5], Loss: 0.8532
Epoch [4/5], Loss: 0.8577
Epoch [5/5], Loss: 0.8427
