In [None]:
!pip install tensorflow==2.16.1 tensorflow-probability matplotlib tensorflow-probability[tf]

In [None]:
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt

# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()


In [None]:
import importlib
print(importlib.metadata.version('tensorflow_probability'))
print(importlib.metadata.version('tensorflow_probability'))

In [None]:
# Set random seed for reproducibility
tf.random.set_seed(42)

# Generate synthetic data
true_slope = 0.5
true_intercept = 2.0
x = tf.random.uniform([100], minval=0, maxval=10)
y = true_slope * x + true_intercept + tf.random.normal([100], stddev=0.5)

# Define the model using tfp.distributions
class LinearRegressionModel(tf.Module):
    def __init__(self):
        super().__init__()
        self.slope = tf.Variable(0., name='slope')
        self.intercept = tf.Variable(0., name='intercept')
    
    def __call__(self, x):
        return tfp.distributions.Normal(loc=self.slope * x + self.intercept, scale=1.)

# Create the model
model = LinearRegressionModel()

# Define the loss function (negative log likelihood)
def loss_fn():
    return -tf.reduce_mean(model(x).log_prob(y))

# Optimize the model
optimizer = tf.optimizers.Adam(learning_rate=0.1)

In [None]:
for _ in range(1000):
    with tf.GradientTape() as tape:
        loss = loss_fn()
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

# Extract the trained parameters
trained_slope = model.slope.numpy()
trained_intercept = model.intercept.numpy()

In [None]:
# Print results
print(f"True slope: {true_slope:.4f}, Estimated slope: {trained_slope:.4f}")
print(f"True intercept: {true_intercept:.4f}, Estimated intercept: {trained_intercept:.4f}")

# Plot results
plt.scatter(x, y, label='Data')
plt.plot(x, trained_slope * x + trained_intercept, color='red', label='Fitted Line')
plt.legend()
plt.xlabel('x')
plt.ylabel('y')
plt.title('Simple Linear Regression with TensorFlow Probability')
plt.show()