In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import copy

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc = nn.Linear(1, 1)

    def forward(self, x):
        return self.fc(x)

def meta_train(model, support_set, query_set, inner_lr=0.01, meta_lr=0.001, num_updates=5):
    optimizer = optim.SGD(model.parameters(), lr=meta_lr)

    for _ in range(num_updates):
        for support_input, support_target in support_set:
            support_input, support_target = support_input.unsqueeze(0), support_target.unsqueeze(0)
            model.zero_grad()
            loss = nn.MSELoss()(model(support_input), support_target)
            gradients = torch.autograd.grad(loss, model.parameters(), create_graph=True)
            model_copy = inner_update(model, gradients, inner_lr)

        total_loss = 0
        for query_input, query_target in query_set:
            query_input, query_target = query_input.unsqueeze(0), query_target.unsqueeze(0)
            loss = nn.MSELoss()(model_copy(query_input), query_target)
            total_loss += loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

def inner_update(model, gradients, inner_lr):
    updated_model = copy.deepcopy(model)
    for p, g in zip(updated_model.parameters(), gradients):
        p.data.sub_(inner_lr * g)
    return updated_model


In [2]:
model = Model()
support_set = [(torch.tensor([1.0]), torch.tensor([2.0]))]
query_set = [(torch.tensor([1.5]), torch.tensor([3.0]))]

meta_train(model, support_set, query_set)

test_input = torch.tensor([2.0])
prediction = model(test_input.unsqueeze(0))

print(f"Prediction for test input: {prediction.item()}")


Prediction for test input: -0.454787015914917


In [3]:
model = Model()
support_set = [(torch.tensor([1.0]), torch.tensor([2.0]))]
query_set = [(torch.tensor([1.5]), torch.tensor([3.0]))]

meta_train(model, support_set, query_set)

test_input = torch.tensor([2.0])
prediction = model(test_input.unsqueeze(0))

print(f"Prediction for test input: {prediction.item()}")


Prediction for test input: 1.1015632152557373


In [4]:
model = Model()
support_set = [(torch.tensor([1.0]), torch.tensor([2.0]))]
query_set = [(torch.tensor([1.5]), torch.tensor([3.0]))]

meta_train(model, support_set, query_set)

test_input = torch.tensor([2.0])
prediction = model(test_input.unsqueeze(0))

print(f"Prediction for test input: {prediction.item()}")


Prediction for test input: -1.5164753198623657
