In [1]:
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

By Parameters

In [16]:
# save by torch.save(model.state_dict(), PATH)
# required a model with the same structure. ie. same initialization on the same model class
model_parameters = t.load('./star_model.pth')

# Model

In [17]:
# A complete model class with n layers and fit method.
class Model(nn.Module):
    def __init__(self, shapes, activations, loss_function, optimizer, optim_params={}):
        super().__init__()
        self.shapes = shapes
        self.activations = activations
        self.loss_function = loss_function

        self.layers = nn.ModuleList([
            nn.Linear(shapes[i], shapes[i + 1]) for i in range(len(shapes) - 1)
        ])

        self.optimizer = optimizer(self.parameters(), **optim_params)
    
    def forward(self, x):
        y = x
        for i in range(len(self.layers)):
            y = self.layers[i].forward(y)
            if self.activations[i] is not None:
                y = self.activations[i](y)
        return y

    def loss(self, x, y):
        return self.loss_function(self.forward(x), y)
    
    def fit(self, x_train, y_train, x_test, y_test, epochs=100):
        losses = []

        for epoch in range(epochs):
            predictions = self.forward(x_train)

            loss = self.loss_function(predictions, y_train)
            loss.backward()

            self.optimizer.step()
            self.optimizer.zero_grad()

            # evaluate loss on test set
            with t.no_grad():
                predictions = self.forward(x_test)
                loss = self.loss_function(predictions, y_test)
                losses.append(loss.item())

            if epoch % 100 == 0:
                print(f'Epoch {epoch}: {loss.item()}')

        return losses

# Loading the parameters

In [18]:
m = Model(
    [3, 10, 6], 
    [F.relu, None],
    F.cross_entropy,
    optim.Adam,
    optim_params={'lr': 0.01}
)

m.load_state_dict(model_parameters)

<All keys matched successfully>