In [84]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchviz import make_dot
from torch.nn.utils import parameters_to_vector, vector_to_parameters
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [52]:

class BaseModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(BaseModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)
    
class MetaModel(nn.Module):
    def __init__(self, d):
        super(MetaModel, self).__init__()
        self.fc1 = nn.Linear(d, 128)
        self.fc2 = nn.Linear(128, d)
    
    def forward(self, theta_flat):
        x = torch.relu(self.fc1(theta_flat))
        return self.fc2(x)

    


In [109]:

input_dim = 1
hidden_dim = 20
output_dim = 1
d = sum(p.numel() for p in BaseModel(input_dim, hidden_dim, output_dim).parameters()) +1

theta_f = BaseModel(input_dim, hidden_dim, output_dim)
meta_model = MetaModel(d)
optimizer = optim.Adam(meta_model.parameters(), lr=1e-2)

n=1000
X = torch.rand(n, input_dim)*4 -2
Y = torch.tanh(X)
# shuffle
perm = torch.randperm(n)
X = X[perm]
Y = Y[perm]
# random split
X_train, X_test = X[:int(n*0.8)], X[int(n*0.8):]
Y_train, Y_test = Y[:int(n*0.8)], Y[int(n*0.8):]

#criterion = nn.CrossEntropyLoss()
criterion = nn.MSELoss()

for epoch in range(15):
    optimizer.zero_grad()

    # what if we reinit on each iteration? ie can meta model predict final weights from scratch
    # theta_f = BaseModel(input_dim, hidden_dim, output_dim)  
    
    theta_flat = torch.cat([p.flatten() for p in theta_f.parameters()]).requires_grad_(True)
    theta_flat = torch.cat([theta_flat, torch.tensor([epoch], dtype=torch.float32)]).requires_grad_(True)
    
    theta_flat_prime = meta_model(theta_flat)
    
    theta_f_prime = BaseModel(input_dim, hidden_dim, output_dim)
    
    params_dict = {}
    start_idx = 0
    for name, param in theta_f_prime.named_parameters():
        param_length = param.numel()
        params_dict[name] = theta_flat_prime[start_idx:start_idx + param_length].view_as(param)
        start_idx += param_length
    
    def modified_forward(x):
        # directly pass the weights into the forward pass. keeps the computation graph intact
        x = F.linear(x, 
                    weight=params_dict['fc1.weight'],
                    bias=params_dict['fc1.bias'])
        x = torch.relu(x)
        x = F.linear(x,
                    weight=params_dict['fc2.weight'],
                    bias=params_dict['fc2.bias'])
        return x
    
    outputs = modified_forward(X_train)
    loss = criterion(outputs, Y_train)
    
    loss.backward()
    
    optimizer.step()
    print(f'Epoch {epoch}, Train Loss {loss.item()}\n')
    
    if epoch == 0:
        make_dot(loss, params=dict(list(meta_model.named_parameters()))).render('comp_graph', format='png')
        

Epoch 0, Train Loss 0.692639172077179

Epoch 1, Train Loss 0.20576265454292297

Epoch 2, Train Loss 0.9491184949874878

Epoch 3, Train Loss 0.03838125243782997

Epoch 4, Train Loss 0.12317118793725967

Epoch 5, Train Loss 0.12265665829181671

Epoch 6, Train Loss 0.028017327189445496

Epoch 7, Train Loss 0.15350347757339478

Epoch 8, Train Loss 0.01366340834647417

Epoch 9, Train Loss 0.06852729618549347

Epoch 10, Train Loss 0.05370870232582092

Epoch 11, Train Loss 0.008035967126488686

Epoch 12, Train Loss 0.060914501547813416

Epoch 13, Train Loss 0.019319841638207436

Epoch 14, Train Loss 0.00864302460104227



In [112]:
# test
theta_flat = torch.cat([p.flatten() for p in theta_f.parameters()]).requires_grad_(True)
theta_flat = torch.cat([theta_flat, torch.tensor([epoch], dtype=torch.float32)]).requires_grad_(True)

theta_flat_prime = meta_model(theta_flat)

theta_f_prime = BaseModel(input_dim, hidden_dim, output_dim)

params_dict = {}
start_idx = 0

for name, param in theta_f_prime.named_parameters():
    param_length = param.numel()
    params_dict[name] = theta_flat_prime[start_idx:start_idx + param_length].view_as(param)
    start_idx += param_length

def modified_forward(x):
    # directly pass the weights into the forward pass. keeps the computation graph intact
    x = F.linear(x, 
                weight=params_dict['fc1.weight'],
                bias=params_dict['fc1.bias'])
    x = torch.relu(x)
    x = F.linear(x,
                weight=params_dict['fc2.weight'],
                bias=params_dict['fc2.bias'])
    return x

outputs = modified_forward(X_test)
loss = criterion(outputs, Y_test)
print(f'Test Loss {loss.item()}')


Test Loss 0.0432543121278286


In [110]:
model = BaseModel(input_dim, hidden_dim, output_dim)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()
for epoch in range(15):
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, Y)
    
    loss.backward()
    
    optimizer.step()
    print(f'Epoch {epoch}, Loss {loss.item()}\n')
    

Epoch 0, Loss 0.8670465350151062

Epoch 1, Loss 0.851166844367981

Epoch 2, Loss 0.8354771137237549

Epoch 3, Loss 0.8199802041053772

Epoch 4, Loss 0.8046783208847046

Epoch 5, Loss 0.7895739078521729

Epoch 6, Loss 0.7746680974960327

Epoch 7, Loss 0.7599632143974304

Epoch 8, Loss 0.7454613447189331

Epoch 9, Loss 0.7311623096466064

Epoch 10, Loss 0.7170668244361877

Epoch 11, Loss 0.7031756639480591

Epoch 12, Loss 0.6894916892051697

Epoch 13, Loss 0.6760116219520569

Epoch 14, Loss 0.6627349853515625

