In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
print(device)

cuda


In [4]:
data = pd.read_csv("https://raw.githubusercontent.com/hsiav2000/simple-regression/master/Salary_Data.csv")

In [5]:
data.corr()

Unnamed: 0,YearsExperience,Salary
YearsExperience,1.0,0.978242
Salary,0.978242,1.0


In [6]:
X = torch.tensor(data.YearsExperience).float().view((30,1)).to(device)
y = torch.tensor(data.Salary).float().view((30,1)).to(device)

In [7]:
X

tensor([[ 1.1000],
        [ 1.3000],
        [ 1.5000],
        [ 2.0000],
        [ 2.2000],
        [ 2.9000],
        [ 3.0000],
        [ 3.2000],
        [ 3.2000],
        [ 3.7000],
        [ 3.9000],
        [ 4.0000],
        [ 4.0000],
        [ 4.1000],
        [ 4.5000],
        [ 4.9000],
        [ 5.1000],
        [ 5.3000],
        [ 5.9000],
        [ 6.0000],
        [ 6.8000],
        [ 7.1000],
        [ 7.9000],
        [ 8.2000],
        [ 8.7000],
        [ 9.0000],
        [ 9.5000],
        [ 9.6000],
        [10.3000],
        [10.5000]], device='cuda:0')

In [8]:
y

tensor([[ 39343.],
        [ 46205.],
        [ 37731.],
        [ 43525.],
        [ 39891.],
        [ 56642.],
        [ 60150.],
        [ 54445.],
        [ 64445.],
        [ 57189.],
        [ 63218.],
        [ 55794.],
        [ 56957.],
        [ 57081.],
        [ 61111.],
        [ 67938.],
        [ 66029.],
        [ 83088.],
        [ 81363.],
        [ 93940.],
        [ 91738.],
        [ 98273.],
        [101302.],
        [113812.],
        [109431.],
        [105582.],
        [116969.],
        [112635.],
        [122391.],
        [121872.]], device='cuda:0')

In [9]:
class LinearRegression(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.linear = nn.Linear(input_size, output_size)
        
    def forward(self, x):
        pred = self.linear(x)
        return pred

In [10]:
model = LinearRegression(1, 1).to(device)

In [11]:
list(model.parameters())

[Parameter containing:
 tensor([[0.2449]], device='cuda:0', requires_grad=True),
 Parameter containing:
 tensor([-0.9265], device='cuda:0', requires_grad=True)]

In [12]:
y_hat = model(X)

In [13]:
model.state_dict()

OrderedDict([('linear.weight', tensor([[0.2449]], device='cuda:0')),
             ('linear.bias', tensor([-0.9265], device='cuda:0'))])

In [14]:
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [15]:
loss_fn  = nn.MSELoss()

In [16]:
n_epochs = 1000

In [17]:
for epoch in range(n_epochs):
    model.train()

    y_hat = model(X)
    loss = loss_fn(y_hat, y)

    if epoch % 100 == 0:
        print(f"Log Loss (SSE/MSE): {np.log(loss.item())}")

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Log Loss (SSE/MSE): 22.59550827630288
Log Loss (SSE/MSE): 18.24642668520769
Log Loss (SSE/MSE): 17.80188616847821
Log Loss (SSE/MSE): 17.527763804187515
Log Loss (SSE/MSE): 17.382644210024363
Log Loss (SSE/MSE): 17.313408304191135
Log Loss (SSE/MSE): 17.282220798442847
Log Loss (SSE/MSE): 17.268559571447955
Log Loss (SSE/MSE): 17.262650242963158
Log Loss (SSE/MSE): 17.260108750140564


In [18]:
model.state_dict()

OrderedDict([('linear.weight', tensor([[9501.0146]], device='cuda:0')),
             ('linear.bias', tensor([25448.1719], device='cuda:0'))])

In [20]:
list(zip(model(X).squeeze().detach().cpu().numpy(), y.squeeze().detach().cpu().numpy()))

[(35899.29, 39343.0),
 (37799.492, 46205.0),
 (39699.695, 37731.0),
 (44450.203, 43525.0),
 (46350.406, 39891.0),
 (53001.117, 56642.0),
 (53951.215, 60150.0),
 (55851.418, 54445.0),
 (55851.418, 64445.0),
 (60601.926, 57189.0),
 (62502.13, 63218.0),
 (63452.23, 55794.0),
 (63452.23, 56957.0),
 (64402.332, 57081.0),
 (68202.734, 61111.0),
 (72003.15, 67938.0),
 (73903.34, 66029.0),
 (75803.555, 83088.0),
 (81504.16, 81363.0),
 (82454.26, 93940.0),
 (90055.07, 91738.0),
 (92905.375, 98273.0),
 (100506.19, 101302.0),
 (103356.49, 113812.0),
 (108107.0, 109431.0),
 (110957.305, 105582.0),
 (115707.81, 116969.0),
 (116657.914, 112635.0),
 (123308.625, 122391.0),
 (125208.83, 121872.0)]