In [44]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchviz import make_dot
from torch.nn.utils import parameters_to_vector, vector_to_parameters
import torch.nn.functional as F

In [52]:

class BaseModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(BaseModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)
    
class MetaModel(nn.Module):
    def __init__(self, d):
        super(MetaModel, self).__init__()
        self.fc1 = nn.Linear(d, 128)
        self.fc2 = nn.Linear(128, d)
    
    def forward(self, theta_flat):
        x = torch.relu(self.fc1(theta_flat))
        return self.fc2(x)

    


In [55]:

input_dim = 10
hidden_dim = 20
output_dim = 2
d = sum(p.numel() for p in BaseModel(input_dim, hidden_dim, output_dim).parameters())

theta_f = BaseModel(input_dim, hidden_dim, output_dim)
meta_model = MetaModel(d)
optimizer = optim.Adam(meta_model.parameters(), lr=1e-5)

X = torch.randn(32, input_dim)
Y = torch.randint(0, output_dim, (32,))

criterion = nn.CrossEntropyLoss()

for epoch in range(10000):
    optimizer.zero_grad()

    # what if we reinit on each iteration? ie can meta model predict final weights from scratch
    # theta_f = BaseModel(input_dim, hidden_dim, output_dim)  
    
    theta_flat = torch.cat([p.flatten() for p in theta_f.parameters()]).requires_grad_(True)
    
    theta_flat_prime = meta_model(theta_flat)
    
    theta_f_prime = BaseModel(input_dim, hidden_dim, output_dim)
    
    params_dict = {}
    start_idx = 0
    for name, param in theta_f_prime.named_parameters():
        param_length = param.numel()
        params_dict[name] = theta_flat_prime[start_idx:start_idx + param_length].view_as(param)
        start_idx += param_length
    
    def modified_forward(x):
        # directly pass the weights into the forward pass. keeps the computation graph intact
        x = F.linear(x, 
                    weight=params_dict['fc1.weight'],
                    bias=params_dict['fc1.bias'])
        x = torch.relu(x)
        x = F.linear(x,
                    weight=params_dict['fc2.weight'],
                    bias=params_dict['fc2.bias'])
        return x
    
    outputs = modified_forward(X)
    loss = criterion(outputs, Y)
    
    loss.backward()
    
    optimizer.step()
    print(f'Epoch {epoch}, Loss {loss.item()}\n')
    
    if epoch == 0:
        make_dot(loss, params=dict(list(meta_model.named_parameters()))).render('comp_graph', format='png')
        

Epoch 0, Loss 0.7295199632644653

Epoch 1, Loss 0.7291213274002075

Epoch 2, Loss 0.7287271618843079

Epoch 3, Loss 0.7283346652984619

Epoch 4, Loss 0.7279444932937622

Epoch 5, Loss 0.7275564670562744

Epoch 6, Loss 0.7271707057952881

Epoch 7, Loss 0.7267872095108032

Epoch 8, Loss 0.7264066338539124

Epoch 9, Loss 0.7260285019874573

Epoch 10, Loss 0.7256526350975037

Epoch 11, Loss 0.7252822518348694

Epoch 12, Loss 0.7249138951301575

Epoch 13, Loss 0.724547803401947

Epoch 14, Loss 0.724185049533844

Epoch 15, Loss 0.723827064037323

Epoch 16, Loss 0.7234711050987244

Epoch 17, Loss 0.7231169939041138

Epoch 18, Loss 0.7227688431739807

Epoch 19, Loss 0.7224231362342834

Epoch 20, Loss 0.7220796942710876

Epoch 21, Loss 0.721738338470459

Epoch 22, Loss 0.7213997840881348

Epoch 23, Loss 0.7210659384727478

Epoch 24, Loss 0.7207345962524414

Epoch 25, Loss 0.720405101776123

Epoch 26, Loss 0.7200776934623718

Epoch 27, Loss 0.7197527885437012

Epoch 28, Loss 0.7194299697875977



In [54]:
model = BaseModel(input_dim, hidden_dim, output_dim)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
for epoch in range(10000):
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, Y)
    
    loss.backward()
    
    optimizer.step()
    print(f'Epoch {epoch}, Loss {loss.item()}\n')
    

Epoch 0, Loss 0.6962472796440125

Epoch 1, Loss 0.693117618560791

Epoch 2, Loss 0.6900225281715393

Epoch 3, Loss 0.6869622468948364

Epoch 4, Loss 0.683936595916748

Epoch 5, Loss 0.680945873260498

Epoch 6, Loss 0.67799311876297

Epoch 7, Loss 0.6750822067260742

Epoch 8, Loss 0.6722074747085571

Epoch 9, Loss 0.6693670153617859

Epoch 10, Loss 0.6665613651275635

Epoch 11, Loss 0.6638321280479431

Epoch 12, Loss 0.6611341834068298

Epoch 13, Loss 0.658467710018158

Epoch 14, Loss 0.6558295488357544

Epoch 15, Loss 0.6532254815101624

Epoch 16, Loss 0.650647759437561

Epoch 17, Loss 0.6481116414070129

Epoch 18, Loss 0.6456438899040222

Epoch 19, Loss 0.6431986689567566

Epoch 20, Loss 0.640782356262207

Epoch 21, Loss 0.6383910179138184

Epoch 22, Loss 0.636021614074707

Epoch 23, Loss 0.6336745619773865

Epoch 24, Loss 0.6313430070877075

Epoch 25, Loss 0.6290574669837952

Epoch 26, Loss 0.626835823059082

Epoch 27, Loss 0.6246341466903687

Epoch 28, Loss 0.6224398612976074

Epoch