In [None]:
#Motivation: https://github.com/rslim087a/PyTorch-for-Deep-Learning-and-Computer-Vision-Course-All-Codes-
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
from sklearn import datasets
from IPython import display

In [None]:
n_pts = 500
X, y = datasets.make_circles(n_samples=n_pts, random_state=123, noise=0.1, factor=0.4)
x_data = torch.Tensor(X)
y_data = torch.Tensor(y.reshape(n_pts, -1))

In [None]:
def scatter_plot():
    plt.scatter(X[y==0, 0], X[y==0, 1])
    plt.scatter(X[y==1, 0], X[y==1, 1])

In [None]:
scatter_plot()

In [None]:
class Model(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.linear_1 = nn.Linear(input_size, hidden_size)
        self.linear_2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        x = torch.sigmoid(self.linear_1(x))
        x = torch.sigmoid(self.linear_2(x))
        return x

    def predict(self, x):
        pred = self.forward(x)
        if pred >= 0.5:
            return 1
        else:
            return 0

In [None]:
torch.manual_seed(1)
model = Model(2, 10, 1)
print(list(model.parameters()))

In [None]:
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.1)

In [None]:
def plot_dec_boundary(X, y):
    x_span = np.linspace(min(X[:, 0]), max(X[:,0]))
    y_span = np.linspace(min(X[:, 1]), max(X[:,1]))
    xx, yy = np.meshgrid(x_span, y_span)
    grid = torch.Tensor(np.c_[xx.ravel(), yy.ravel()])
    pred_func = model.forward(grid)
    z = pred_func.view(xx.shape).detach().numpy()
    plt.contour(xx, yy, z)

In [None]:
epochs = 1000
losses = []
for i in range(epochs):
    y_pred = model.forward(x_data)
    loss = criterion(y_pred, y_data)
    #print("epoch: {} loss: {}".format(i, loss.item()))
    losses.append(loss)
    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    plot_dec_boundary(X, y)
    scatter_plot()
    plt.show()
    display.clear_output(wait=True)

In [None]:
plt.plot(range(epochs), losses)
plt.xlabel('epochs')
plt.ylabel('loss')