In [45]:
import torch
import torch.nn as nn
from torch import optim
from lbfgsb_scipy import LBFGSBScipy

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # Define the three layers
        torch.manual_seed(42)
        # self.layer1 = nn.Linear(10, 5, bias = False)  # Example sizes, adjust as needed
        # self.layer2 = nn.Linear(5, 3, bias = False)   # Example sizes, adjust as needed
        # self.layer3 = nn.Linear(3, 1, bias = False)   # Example sizes, adjust as needed
        self.layer1 = nn.Linear(10, 10)
        self.layer2 = nn.Linear(5, 1)
        self.weight = torch.einsum('ij, m -> ijm') #[10, 10, 5]
        # Freeze the first layer by setting requires_grad to False for its parameters
        for param in self.layer1.parameters():
            param.requires_grad = False

    def forward(self, x):
        # x = torch.relu(self.layer1(x))
        # x = torch.relu(self.layer2(x))
        # x = self.layer3(x)
        x = torch.einsum('i, ijm -> m')
        return x
    
    def get_parameter(self) -> nn.Parameter:
        return self.layer1.weight, self.layer2.weight, self.layer3.weight


In [46]:
x = torch.randn([100, 10])
target = 10*torch.rand([100, 5])
mse_loss = nn.MSELoss()
model = MyModel()
print("layer1: ", model.get_parameter()[0])
print("layer2: ", model.get_parameter()[1])
print("layer3: ", model.get_parameter()[2])

ValueError: einsum(): must specify the equation string and at least one operand, or at least one operand and its subscripts list

: 

In [36]:
optimizer = LBFGSBScipy(model.parameters())
def closure():
    optimizer.zero_grad()
    output = model.forward(x)
    print(output.shape)
    loss = mse_loss(target, output)
    primal_obj = loss 
    primal_obj.backward()
    print('squared loss:', loss.item())
    return primal_obj
optimizer.step(closure)

torch.Size([100, 1])
squared loss: 37.2646598815918
torch.Size([100, 1])
squared loss: 33.05524444580078
torch.Size([100, 1])
squared loss: 32.082916259765625
torch.Size([100, 1])
squared loss: 27.392253875732422
torch.Size([100, 1])
squared loss: 21.160593032836914
torch.Size([100, 1])
squared loss: 21.274452209472656
torch.Size([100, 1])
squared loss: 20.836620330810547
torch.Size([100, 1])
squared loss: 22.703763961791992
torch.Size([100, 1])
squared loss: 20.285890579223633
torch.Size([100, 1])
squared loss: 20.261014938354492
torch.Size([100, 1])
squared loss: 20.150175094604492
torch.Size([100, 1])
squared loss: 20.06949806213379
torch.Size([100, 1])
squared loss: 20.00341796875
torch.Size([100, 1])
squared loss: 19.892196655273438
torch.Size([100, 1])
squared loss: 19.76331901550293
torch.Size([100, 1])
squared loss: 19.645845413208008
torch.Size([100, 1])
squared loss: 19.547870635986328
torch.Size([100, 1])
squared loss: 19.410829544067383
torch.Size([100, 1])
squared loss: 19

In [37]:
print("layer1: ", model.get_parameter()[0])
print("layer2: ", model.get_parameter()[1])
print("layer3: ", model.get_parameter()[2])

layer1:  Parameter containing:
tensor([[ 0.2418,  0.2625, -0.0741,  0.2905, -0.0693,  0.0638, -0.1540,  0.1857,
          0.2788, -0.2320],
        [ 0.2749,  0.0592,  0.2336,  0.0428,  0.1525, -0.0446,  0.2438,  0.0467,
         -0.1476,  0.0806],
        [-0.1457, -0.0371, -0.1284,  0.2098, -0.2496, -0.1458, -0.0893, -0.1901,
          0.0298, -0.3123],
        [ 0.2856, -0.2686,  0.2441,  0.0526, -0.1027,  0.1954,  0.0493,  0.2555,
          0.0346, -0.0997],
        [ 0.0850, -0.0858,  0.1331,  0.2823,  0.1828, -0.1382,  0.1825,  0.0566,
          0.1606, -0.1927]])
layer2:  Parameter containing:
tensor([[ -2.0491,   9.5738,   4.4370,   4.0391, -13.2500],
        [ 13.6810, -16.6725,  -6.6664,  -2.1218,  31.8371],
        [ -2.1259,  -0.3690,  -1.9344,  -0.4148,  -0.8183]],
       requires_grad=True)
layer3:  Parameter containing:
tensor([[1.7254, 0.2419, 0.8785]], requires_grad=True)


In [38]:
for par in model.layer1.parameters():
    print(par)

Parameter containing:
tensor([[ 0.2418,  0.2625, -0.0741,  0.2905, -0.0693,  0.0638, -0.1540,  0.1857,
          0.2788, -0.2320],
        [ 0.2749,  0.0592,  0.2336,  0.0428,  0.1525, -0.0446,  0.2438,  0.0467,
         -0.1476,  0.0806],
        [-0.1457, -0.0371, -0.1284,  0.2098, -0.2496, -0.1458, -0.0893, -0.1901,
          0.0298, -0.3123],
        [ 0.2856, -0.2686,  0.2441,  0.0526, -0.1027,  0.1954,  0.0493,  0.2555,
          0.0346, -0.0997],
        [ 0.0850, -0.0858,  0.1331,  0.2823,  0.1828, -0.1382,  0.1825,  0.0566,
          0.1606, -0.1927]])


In [40]:
for par in model.layer1.parameters():
    print(par[0,1])

tensor(0.2625)
