# Exercise 7: Converting PyTorch to TensorFlow.js

To run our model directly on a mobile phone without a backend server, we need to convert it to a format that the web/mobile runtime can understand. We will use **TensorFlow.js (TFJS)** for this.

### Why TFJS?
- **Zero Server Cost**: Inference runs on the user's device.
- **Privacy**: Data never leaves the device.
- **Offline**: Works without an internet connection.

> **Note**: We have now added `tensorflow` to our project's `pyproject.toml`. If you're using our `uv` environment, Section 2 will work locally! However, the final conversion step in Section 3 requires `tensorflowjs`, which is tricky on Mac ARM (M1/M2/M3/M4) devices. Google Colab is recommended for the final step if you hit errors.

In [None]:
# If you are not using our pre-configured environment, you might need:
# !pip install tensorflow torch

## 1. Load the PyTorch Model

First, we define our original architecture and load the weights we trained in Exercise 4.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        
    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = SimpleNN()
model.load_state_dict(torch.load('../models/mnist_model.pth', map_location='cpu'))
model.eval()

## 2. Replicate in Keras & Transfer Weights

Since our model is very simple, the easiest way to get to TFJS is to define the same structure in Keras and copy the weights. This step should work on any computer with `tensorflow` installed.

In [None]:
import tensorflow as tf
import numpy as np

# Define equivalent Keras model
keras_model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(784,)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(10)
])

# Extract weights from PyTorch
pt_weights = model.state_dict()

# Keras Dense layer weights are [weights, bias]
w1 = pt_weights['fc1.weight'].numpy().T
b1 = pt_weights['fc1.bias'].numpy()
w2 = pt_weights['fc2.weight'].numpy().T
b2 = pt_weights['fc2.bias'].numpy()
w3 = pt_weights['fc3.weight'].numpy().T
b3 = pt_weights['fc3.bias'].numpy()

keras_model.layers[0].set_weights([w1, b1])
keras_model.layers[1].set_weights([w2, b2])
keras_model.layers[2].set_weights([w3, b3])

keras_model.save('mnist_keras.h5')
print("Intermediate Keras model saved!")

## 3. Convert to TFJS

Finally, we use the `tensorflowjs_converter` to generate the final assets. 

> **Warning**: If you are on an Apple Silicon Mac, run this cell with caution or move to Google Colab.

In [None]:
# !tensorflowjs_converter --input_format=keras mnist_keras.h5 tfjs_model

print("Check your project directory for the converted files.")