In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchviz import make_dot
from torch.nn.utils import parameters_to_vector, vector_to_parameters
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [17]:

class BaseModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(BaseModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)



        seed = torch.randint(0, 1000, (1,)).item()
        init_weights(self.fc1, seed)
        init_weights(self.fc2, seed)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)
    
class MetaModel(nn.Module):
    def __init__(self, d):
        super(MetaModel, self).__init__()
        self.fc1 = nn.Linear(d, 128)
        self.fc2 = nn.Linear(128, d)
    
    def forward(self, theta_flat):
        x = torch.relu(self.fc1(theta_flat))
        return self.fc2(x)

    
def init_weights(module, seed=0):
    torch.manual_seed(seed)
    # init weights with xavier 
    if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
        nn.init.xavier_uniform_(module.weight)
        nn.init.zeros_(module.bias)
    #print(f"Init weights with seed {seed}")



In [None]:

input_dim = 1
hidden_dim = 20
output_dim = 1
d = sum(p.numel() for p in BaseModel(input_dim, hidden_dim, output_dim).parameters()) +1

n=1000
X = torch.rand(n, input_dim)*4 -2
Y = torch.tanh(X)
# shuffle
perm = torch.randperm(n)
X = X[perm]
Y = Y[perm]
# random split
X_train, X_test = X[:int(n*0.8)], X[int(n*0.8):]
Y_train, Y_test = Y[:int(n*0.8)], Y[int(n*0.8):]


Init weights with seed 961
Init weights with seed 961
Init weights with seed 772
Init weights with seed 772


In [None]:
# 1 model
theta_f = BaseModel(input_dim, hidden_dim, output_dim)
meta_model = MetaModel(d)
optimizer = optim.Adam(meta_model.parameters(), lr=1e-2)
#criterion = nn.CrossEntropyLoss()
criterion = nn.MSELoss()

for epoch in range(15):
    optimizer.zero_grad()

    # what if we reinit on each iteration? ie can meta model predict final weights from scratch
    # theta_f = BaseModel(input_dim, hidden_dim, output_dim)  
    
    theta_flat = torch.cat([p.flatten() for p in theta_f.parameters()]).requires_grad_(True)
    theta_flat = torch.cat([theta_flat, torch.tensor([epoch], dtype=torch.float32)]).requires_grad_(True)
    
    theta_flat_prime = meta_model(theta_flat)
    
    theta_f_prime = BaseModel(input_dim, hidden_dim, output_dim)
    
    params_dict = {}
    start_idx = 0
    for name, param in theta_f_prime.named_parameters():
        param_length = param.numel()
        params_dict[name] = theta_flat_prime[start_idx:start_idx + param_length].view_as(param)
        start_idx += param_length
    
    def modified_forward(x):
        # directly pass the weights into the forward pass. keeps the computation graph intact
        x = F.linear(x, 
                    weight=params_dict['fc1.weight'],
                    bias=params_dict['fc1.bias'])
        x = torch.relu(x)
        x = F.linear(x,
                    weight=params_dict['fc2.weight'],
                    bias=params_dict['fc2.bias'])
        return x
    
    outputs = modified_forward(X_train)
    loss = criterion(outputs, Y_train)
    
    loss.backward()
    
    optimizer.step()
    print(f'Epoch {epoch}, Train Loss {loss.item()}\n')
    
    if epoch == 0:
        make_dot(loss, params=dict(list(meta_model.named_parameters()))).render('comp_graph', format='png')
        

In [24]:
#k models
k = 500
batch_size = 8
optimizees = [BaseModel(input_dim, hidden_dim, output_dim) for _ in range(k)]
meta_model = MetaModel(d).to('cuda')
optimizer = optim.Adam(meta_model.parameters(), lr=1e-2)
#criterion = nn.CrossEntropyLoss()
criterion = nn.MSELoss()
X_train = X_train.to('cuda')
Y_train = Y_train.to('cuda')

for epoch in range(100):
    for i in range(0, k, batch_size):
        batch = optimizees[i:i+batch_size]
        batch_flattened = [torch.cat([torch.cat([p.flatten() for p in model.parameters()]), torch.tensor([epoch], dtype=torch.float32)]) for model in batch]
        batch_flattened = torch.stack(batch_flattened).to('cuda')

        optimizer.zero_grad()

        theta_flat_prime = meta_model(batch_flattened)
        
        batch_prime = [BaseModel(input_dim, hidden_dim, output_dim) for _ in range(batch_size)]
        params_dict = {}
        start_idx = 0
        for name, param in batch_prime[0].named_parameters():
            param_length = param.numel()
            try:
                params_dict[name] = theta_flat_prime[:, start_idx:start_idx + param_length].view(batch_size, *param.shape)
                # fails when parm shape [20,1]
            except:
                x=1
            start_idx += param_length

        def modified_forward(x, model, idx):
            # directly pass the weights
            x = F.linear(x, 
                        weight=params_dict['fc1.weight'][idx],
                        bias=params_dict['fc1.bias'][idx])
            x = torch.relu(x)
            x = F.linear(x,
                        weight=params_dict['fc2.weight'][idx],
                        bias=params_dict['fc2.bias'][idx])
            return x
        
        
        outputs = [modified_forward(X_train, model, idx) for idx, model in enumerate(batch_prime)]
        loss = torch.stack([criterion(output, Y_train) for output in outputs]).mean()
        loss.backward()
        optimizer.step()
        print(f'Epoch {epoch}, Train Loss {loss.item()}')
        



Epoch 0, Train Loss 0.5187199115753174
Epoch 0, Train Loss 0.2825659513473511
Epoch 0, Train Loss 0.021471483632922173
Epoch 0, Train Loss 0.3502190113067627
Epoch 0, Train Loss 0.021631114184856415
Epoch 0, Train Loss 0.10376334190368652
Epoch 0, Train Loss 0.1685345470905304
Epoch 0, Train Loss 0.18142744898796082
Epoch 0, Train Loss 0.15931108593940735
Epoch 0, Train Loss 0.11640456318855286
Epoch 0, Train Loss 0.06744590401649475
Epoch 0, Train Loss 0.03233504295349121
Epoch 0, Train Loss 0.02963540330529213
Epoch 0, Train Loss 0.04857634752988815
Epoch 0, Train Loss 0.05411198362708092
Epoch 0, Train Loss 0.03694517910480499
Epoch 0, Train Loss 0.019580623134970665
Epoch 0, Train Loss 0.01593651995062828
Epoch 0, Train Loss 0.021632686257362366
Epoch 0, Train Loss 0.02664417028427124
Epoch 0, Train Loss 0.026548776775598526
Epoch 0, Train Loss 0.024729304015636444
Epoch 0, Train Loss 0.022818565368652344
Epoch 0, Train Loss 0.01745016500353813
Epoch 0, Train Loss 0.008907224982976

KeyError: 'fc1.weight'

In [110]:
model = BaseModel(input_dim, hidden_dim, output_dim)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()
for epoch in range(15):
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, Y)
    
    loss.backward()
    
    optimizer.step()
    print(f'Epoch {epoch}, Loss {loss.item()}\n')
    

Epoch 0, Loss 0.8670465350151062

Epoch 1, Loss 0.851166844367981

Epoch 2, Loss 0.8354771137237549

Epoch 3, Loss 0.8199802041053772

Epoch 4, Loss 0.8046783208847046

Epoch 5, Loss 0.7895739078521729

Epoch 6, Loss 0.7746680974960327

Epoch 7, Loss 0.7599632143974304

Epoch 8, Loss 0.7454613447189331

Epoch 9, Loss 0.7311623096466064

Epoch 10, Loss 0.7170668244361877

Epoch 11, Loss 0.7031756639480591

Epoch 12, Loss 0.6894916892051697

Epoch 13, Loss 0.6760116219520569

Epoch 14, Loss 0.6627349853515625

