# Define Tutorial Model class

In [0]:
import torch
import torch.nn as nn

class TutorialModel(nn.Module):
    def __init__(self):
        super(TutorialModel, self).__init__()
        self.linear1 = nn.Linear(1, 2)
        self.linear2 = nn.Linear(2, 1)

    def forward(self, x):
        x = self.linear1(x)
        x = torch.sigmoid(x)
        x = self.linear2(x)
        return x

# Explore model

In [0]:
model = TutorialModel()

print("model={}".format(model))
print("model_linear1={}".format(model.linear1))
print("model_linear1_weights={}".format(model.linear1.weight))
print("model_linear1_bias={}".format(model.linear1.bias))

input_tensor = torch.FloatTensor([0.0, 0.5, 1.0]).view(-1, 1)
print("model()={}".format(model(input_tensor)))

class TutorialModelPrint(nn.Module):
    def __init__(self):
        super(TutorialModelPrint, self).__init__()
        self.linear1 = nn.Linear(1, 2)
        self.linear2 = nn.Linear(2, 1)

    def forward(self, x):
        print("forward_x={}".format(x))
        x = self.linear1(x)
        print("forward_linear1={}".format(x))
        x = torch.sigmoid(x)
        print("forward_sigmoid={}".format(x))
        x = self.linear2(x)
        print("forward_linear2={}".format(x))
        return x

print_model = TutorialModelPrint()
print("print_model()={}".format(print_model(input_tensor)))

# Copy parameters

As we can see model and print model calculate different values on same input.
It is because they initialized randomly and have different parameters, let's copy parameters.

In [0]:
model_parameters = torch.nn.utils.parameters_to_vector(model.parameters())
torch.nn.utils.vector_to_parameters(model_parameters, print_model.parameters())

print("model()={}".format(model(input_tensor)))
print("print_model()={}".format(print_model(input_tensor)))

# Rewrite model

In [0]:
class Linear12(nn.Module):
    def __init__(self):
        super(Linear12, self).__init__()
        self.weight1 = nn.Parameter(torch.tensor(0.1))
        self.weight2 = nn.Parameter(torch.tensor(0.2))
        self.bias1 = nn.Parameter(torch.tensor(0.3))
        self.bias2 = nn.Parameter(torch.tensor(0.4))

    def forward(self, x):
        return torch.cat((x*self.weight1+self.bias1, x*self.weight2+self.bias2), dim=1)

class Linear21(nn.Module):
    def __init__(self):
        super(Linear21, self).__init__()
        self.weight1 = nn.Parameter(torch.tensor(0.1))
        self.weight2 = nn.Parameter(torch.tensor(0.2))
        self.bias = nn.Parameter(torch.tensor(0.3))

    def forward(self, x):
        return x[:,0:1]*self.weight1+x[:,1:2]*self.weight2+self.bias

class TutorialModelRewrite(nn.Module):
    def __init__(self):
        super(TutorialModelRewrite, self).__init__()
        self.linear1 = Linear12()
        self.linear2 = Linear21()

    def forward(self, x):
        x = self.linear1(x)
        x = torch.sigmoid(x)
        x = self.linear2(x)
        return x

model_rewrite = TutorialModelRewrite()
print("model()={}".format(model(input_tensor)))
print("model_rewrite()={}".format(model_rewrite(input_tensor)))

model_parameters = torch.nn.utils.parameters_to_vector(model.parameters())
torch.nn.utils.vector_to_parameters(model_parameters, model_rewrite.parameters())

print("model()={}".format(model(input_tensor)))
print("model_rewrite()={}".format(model_rewrite(input_tensor)))