In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

In [4]:
# 1. Prepare Data
# We create a simple relationship: y = 2x + 1
X = torch.tensor([[1.0], [2.0], [3.0], [4.0]], dtype=torch.float32)
Y = torch.tensor([[3.0], [5.0], [7.0], [9.0]], dtype=torch.float32)

In [6]:
# 2. Build the Model
# nn.Linear(input_size, output_size) implements y = xA^T + b
model = nn.Linear(1, 1)

In [8]:
# 3. Define Loss and Optimizer
# MSE (Mean Squared Error) is standard for regression
criterion = nn.MSELoss() 
# SGD (Stochastic Gradient Descent) updates the weights
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [13]:
# 4. The Training Loop
epochs = 1000
for epoch in range(epochs):
    # Forward pass: Predict Y using the model
    prediction = model(X)
    
    # Compute the loss (how wrong were we?)
    loss = criterion(prediction, Y)
    
    # Backward pass: Compute the gradient
    optimizer.zero_grad() # Clear old gradients
    loss.backward()       # Backpropagation
    optimizer.step()      # Update the weights
    
    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

Epoch [100/1000], Loss: 0.0000
Epoch [200/1000], Loss: 0.0000
Epoch [300/1000], Loss: 0.0000
Epoch [400/1000], Loss: 0.0000
Epoch [500/1000], Loss: 0.0000
Epoch [600/1000], Loss: 0.0000
Epoch [700/1000], Loss: 0.0000
Epoch [800/1000], Loss: 0.0000
Epoch [900/1000], Loss: 0.0000
Epoch [1000/1000], Loss: 0.0000


In [21]:
# 5. Evaluate/Predict
model.eval() # Set model to evaluation mode
with torch.no_grad(): # No need to track gradients during testing
    test_value = torch.tensor([[5.0]])
    predicted = model(test_value)
    print(f'\nPrediction for x=5: {predicted.item():.4f} (Expected: 11.0)')


Prediction for x=5: 10.9998 (Expected: 11.0)


In [25]:
from torchviz import make_dot


# 1. Run your forward pass
input_data = torch.tensor([[5.0]])  # Define input_data
output = model(input_data)

# 2. Visualize the graph leading to the output
make_dot(output, params=dict(model.named_parameters())).render("computational-graph/graph", format="png")

'graph.png'