In [1]:
import torch
import numpy as np
import pickle
from hyper import HyperNetwork

In [2]:
torch.__version__

'2.0.0'

In [3]:
torch.manual_seed(1943)

<torch._C.Generator at 0x7f9c0005deb0>

In [4]:
pricing_network = torch.nn.Sequential(
    torch.nn.Linear(3, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, 1),
)

In [5]:
# compute the total number of parameters in the pricing network
n_params = sum([p.numel() for p in pricing_network.parameters()])
n_params

501

In [6]:
hyper_network = torch.nn.Sequential(
    torch.nn.Linear(4, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, n_params),
)

In [7]:
optim = torch.optim.Adam(hyper_network.parameters(), lr=1e-3)

In [8]:
model = HyperNetwork(hyper_network, pricing_network)

In [9]:
for J in range(100):
    for I in range(32*32*32):
        training_data = pickle.load(open('trn_%09d.pkl' % I,'rb'))
        theta = torch.from_numpy(np.vstack([i['param'] for i in training_data]).astype(np.float32))
        x = torch.from_numpy(np.stack([i['input'] for i in training_data]).astype(np.float32))
        y = torch.from_numpy(np.stack([i['output'] for i in training_data]).astype(np.float32))
        optim.zero_grad()
        rnd_idx = np.random.choice(32,16,replace=False)
        theta_, x_, y_ = theta[rnd_idx], x[rnd_idx], y[rnd_idx] # batch
        loss = torch.abs(model(theta_, x_)-y_).mean() # MAE
        loss.backward()
        optim.step()
        if I % 100 == 0:
            print(J, I, loss.detach().cpu().numpy())
        break # delete this
    break # delete this

0 0 1.465867


In [10]:
torch.save(hyper_network.state_dict(), 'trained_hyper.pth')