# Sample implementation using Segment.py

In [16]:
from segment import Segment
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Create Test Dataset

In [8]:
def f_quartic(x):
    a = -0.0179516
    b = 0.331323
    c = -1.63398
    d = 1.01107
    f = 5.73434
    return a*x**4 + b*x**3 + c*x**2 + d*x + f

def normalize(x,y):
    x_normalized = torch.nn.functional.normalize(x, dim=0)
    y_normalized = torch.nn.functional.normalize(y, dim=0)
    return x_normalized, y_normalized

In [22]:
x = torch.arange(-1.5, 11., .05)
ytest = f_quartic(x)
#normalization - important
x, ytest = normalize(x, ytest)
print(x.shape, ytest.shape)

# Reshape x, ytest to N,1
x = x.view(x.shape[0], 1)
ytest = ytest.reshape(x.shape[0], 1)
print(x.shape, ytest.shape)

torch.Size([250]) torch.Size([250])
torch.Size([250, 1]) torch.Size([250, 1])


# Model Init

In [34]:
torch.manual_seed(10)

model = Segment(x.shape[1], ytest.shape[1], 5)

#Initialize model parameters - extremely important.
model.custom_init(x.min(dim=0).values, x.max(dim=0).values)

criterion = nn.MSELoss(reduction='sum')

lr=.0001
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-8)

# create dataset that can be used in a dataloader 
dataset = torch.utils.data.TensorDataset(x, ytest)
batch_size=64
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


## Training Loop

In [35]:
num_epochs = 100

for epoch in range(1, num_epochs+1):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(dataloader):
        # Forward pass
        [X, Y] = data
        ypred = model(X)
        # Calculate the loss
        loss = criterion(ypred, Y)
        train_loss += loss.item()
        
        # Backward and optimize
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        #if batch_idx % 100 == 0:
        #    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        #        epoch,
        #        batch_idx * len(X),
        #        len(dataloader.dataset), 100. * batch_idx / len(dataloader),
        #        loss.item() / len(X)))

    if epoch % 10 == 0:
        print('====> Epoch: {} Average loss: {:.4f}'.format(
            epoch, train_loss / len(dataloader.dataset)))
        

====> Epoch: 10 Average loss: 0.0036
====> Epoch: 20 Average loss: 0.0033
====> Epoch: 30 Average loss: 0.0029
====> Epoch: 40 Average loss: 0.0026
====> Epoch: 50 Average loss: 0.0023
====> Epoch: 60 Average loss: 0.0020
====> Epoch: 70 Average loss: 0.0017
====> Epoch: 80 Average loss: 0.0015
====> Epoch: 90 Average loss: 0.0014
====> Epoch: 100 Average loss: 0.0013
