# TPUs in Colab
This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer.

Adapted from [this notebook](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb).

In [None]:
from typing import Optional

import numpy as np

import jax
import jax.numpy as jnp

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install tensorflow



In [5]:
# 📦 Step 1: Load and preprocess datasets
import pandas as pd
import numpy as np
import tensorflow as tf
import time

print("🔎 TensorFlow version:", tf.__version__)

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print("🚀 Running on TPU!")
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
except:
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        print("🧠 Running on GPU!")
        strategy = tf.distribute.MirroredStrategy()
    else:
        print("💻 Running on CPU!")
        strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")

raw_df = pd.read_csv('/content/drive/MyDrive/natural_disasters_2024.csv')
geo_df = pd.read_csv('/content/drive/MyDrive/cleaned_disasters.csv')
print("✅ Datasets loaded")

# Strip and convert columns
raw_df.columns = raw_df.columns.str.strip()
geo_df.columns = geo_df.columns.str.strip()
for col in ['Latitude', 'Longitude', 'Magnitude', 'Severity']:
    if col in geo_df.columns:
        geo_df[col] = geo_df[col].astype(np.float32)
    if col in raw_df.columns:
        raw_df[col] = raw_df[col].astype(np.float32)

print("⚙️ Data ready")

# 🧠 Step 2: Train TensorFlow DNN on TPU
X_train = geo_df[['Latitude', 'Longitude', 'Magnitude']].values
y_train = geo_df['Severity'].values

with strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu', input_shape=(3,)),
        tf.keras.layers.Dense(32, activation='relu'),
        tf.keras.layers.Dense(1)
    ])
    model.compile(optimizer='adam', loss='mse')
    model.fit(X_train, y_train, epochs=20, batch_size=128, verbose=1)

model.save('trained_dnn_model.keras')
print("✅ Model trained and saved")

# ⛓ Step 3: Run simulation and inference on CPU with NumPy
model = tf.keras.models.load_model('trained_dnn_model.keras')

stream_batch = raw_df.sample(n=10000).reset_index(drop=True)
np.random.seed(42)
stream_batch['Latitude'] = np.random.uniform(-90, 90, size=len(stream_batch))
stream_batch['Longitude'] = np.random.uniform(-180, 180, size=len(stream_batch))
stream_batch['Magnitude'] = stream_batch['Magnitude'].astype(np.float32)

X_stream = stream_batch[['Latitude', 'Longitude', 'Magnitude']].values
records = stream_batch.to_dict(orient='records')

success_count, failure_count = 0, 0
send_times, total_bytes_sent = [], 0
start_time = time.time()

for idx, record in enumerate(records):
    try:
        features = np.array([[record['Latitude'], record['Longitude'], record['Magnitude']]])
        pred = model.predict(features, verbose=0)[0][0]
        record['Predicted_Severity'] = float(pred)

        t0 = time.time()
        size = len(str(record).encode('utf-8'))
        time.sleep(0.01)
        t1 = time.time()

        send_times.append(t1 - t0)
        total_bytes_sent += size
        success_count += 1
    except:
        failure_count += 1

end_time = time.time()

# 📊 Metrics
total_time = end_time - start_time

avg_speed = success_count / total_time
throughput = total_bytes_sent / total_time

min_send = min(send_times)
max_send = max(send_times)
mean_send = np.mean(send_times)
median_send = np.median(send_times)
std_send = np.std(send_times)

print(f"🔢 Total Records: {len(records)}")
print(f"✅ Successfully Sent: {success_count}")
print(f"❌ Failed Sends: {failure_count}")
print(f"⏱️ Total Time Taken: {total_time:.2f} seconds")
print(f"⚡ Average Send Speed: {avg_speed:.2f} records/second")
print(f"📦 Total Data Sent: {total_bytes_sent / 1024:.2f} KB")
print(f"🚀 Throughput: {throughput:.2f} bytes/second")
print(f"📈 Min Send Time: {min_send:.4f} sec")
print(f"📉 Max Send Time: {max_send:.4f} sec")
print(f"📊 Mean Send Time: {mean_send:.4f} sec")
print(f"📏 Median Send Time: {median_send:.4f} sec")
print(f"📐 Std Dev Send Time: {std_send:.4f} sec")


# 🚨 Save critical alerts
final_df = pd.DataFrame(records)
alerts = final_df[final_df['Predicted_Severity'] > 8]

if len(alerts) > 0:
    alerts_path = "/content/critical_alerts.csv"
    alerts.to_csv(alerts_path, index=False)
    from google.colab import files
    files.download(alerts_path)
    print(f"⚠️ Alerts saved to {alerts_path}")
else:
    print("✅ No critical alerts found.")


🔎 TensorFlow version: 2.19.0
💻 Running on CPU!
✅ Datasets loaded
⚙️ Data ready


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - loss: 104.7915
Epoch 2/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 9.2518
Epoch 3/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 5.2708
Epoch 4/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 2.2475
Epoch 5/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 1.0716
Epoch 6/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.7987
Epoch 7/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.7037
Epoch 8/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.6620
Epoch 9/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.6543
Epoch 10/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.6342
Epoch 1

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

⚠️ Alerts saved to /content/critical_alerts.csv
