In [1]:
import torch
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from model import Model
import utils

# Getting data
For this task we will use an sample salary data.

In [2]:
n_train, n_test = 1000, 200
n_inputs, batch_size = 200, 10

weights, bias = utils.get_weights_and_bias(n_inputs)

train_data = utils.create_data(weights, bias, n_train)
test_data = utils.create_data(weights, bias, n_test)

train_loader = utils.get_dataloader(train_data, batch_size=batch_size)
test_loader = utils.get_dataloader(test_data, batch_size=batch_size)

# Training the model 
Now we will create the model and train it.

In [3]:
model = Model(weights, bias)

loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

epochs = 100

In [4]:
def train(epochs, lambda_):
    for epoch in range(epochs):
        for batch in train_loader:
            inputs, targets = batch
            outputs = model(inputs.detach())
            optimizer.zero_grad()
            loss = loss_fn(outputs, targets) + lambda_ * \
                torch.sum(model.weights ** 2)
            loss.backward(retain_graph=True)
            optimizer.step()

        if (epoch + 1) % 10 == 0:
            print(f'Epoch {epoch + 1}: loss: {loss.item()}')

    print('L2 norm of w:', torch.norm(model.weights).item())


In [5]:
train(epochs, lambda_=0.0)


Epoch 10: loss: 0.00013110277359373868
Epoch 20: loss: 0.0001310967782046646
Epoch 30: loss: 0.00013110179861541837
Epoch 40: loss: 0.00013114980538375676
Epoch 50: loss: 0.00013111726730130613
Epoch 60: loss: 0.00013116348418407142
Epoch 70: loss: 0.00013114817556925118
Epoch 80: loss: 0.00013111307634972036
Epoch 90: loss: 0.0001311299711233005
Epoch 100: loss: 0.00013109479914419353
L2 norm of w: 13.539770126342773


In [6]:
train(epochs, lambda_=2)

Epoch 10: loss: 158.75051879882812
Epoch 20: loss: 158.75051879882812
Epoch 30: loss: 158.75051879882812
Epoch 40: loss: 158.75051879882812
Epoch 50: loss: 158.75051879882812
Epoch 60: loss: 158.75051879882812
Epoch 70: loss: 158.75050354003906
Epoch 80: loss: 158.75050354003906
Epoch 90: loss: 158.75051879882812
Epoch 100: loss: 158.75051879882812
L2 norm of w: 4.824517250061035
