In [8]:
import torch
import torch.nn as nn
import torch.optim as optim

**Linear Regression :** Linear regression is a fundamental statistical method used to model the relationship between a dependent variable (target) and one or more independent variables (predictors). It assumes a linear relationship between the variables.

In [11]:
# Generate synthetic data
torch.manual_seed(42)
X = torch.rand(100, 1) * 10  # 100 data points between 0 and 10
y = 2 * X + 3 + torch.randn(100, 1)  # Linear relationship with noise

class LinearRegressionModel(nn.Module):
    def __init__(self):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(1,1)

    def forward(self, x):
        return self.linear(x)

# Initialize the model, loss function, and optimizer
model = LinearRegressionModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Training loop
epochs = 1000
for epoch in range(epochs):
    # Forward pass
    predictions = model(X)
    loss = criterion(predictions, y)

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Log progress every 100 epochs
    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")

Epoch [100/1000], Loss: 1.6039
Epoch [200/1000], Loss: 1.0242
Epoch [300/1000], Loss: 0.8017
Epoch [400/1000], Loss: 0.7163
Epoch [500/1000], Loss: 0.6836
Epoch [600/1000], Loss: 0.6710
Epoch [700/1000], Loss: 0.6662
Epoch [800/1000], Loss: 0.6643
Epoch [900/1000], Loss: 0.6636
Epoch [1000/1000], Loss: 0.6634


In [10]:
# Display the learned parameters
[w, b] = model.linear.parameters()
print(f"Learned weight: {w.item():.4f}, Learned bias: {b.item():.4f}")

# Testing on new data
X_test = torch.tensor([[4.0], [7.0]])
with torch.no_grad():
    predictions = model(X_test)
    print(f"Predictions for {X_test.tolist()}: {predictions.tolist()}")

Learned weight: 1.9577, Learned bias: 3.2045
Predictions for [[4.0], [7.0]]: [[11.035286903381348], [16.90837860107422]]
