In [None]:
!pip install -U pyarrow --quiet
!pip install datasets transformers torch seqeval evaluate tqdm omniglot learn2learn --quiet

In [None]:
from learn2learn.vision.datasets import omniglot
from learn2learn.algorithms import MAML
from tqdm import tqdm
from learn2learn.data import BatchMetaDataLoader
import torch
import torch.nn as nn
import torch.optim as optim

# Load the Omniglot dataset
dataset = omniglot("data", ways=5, shots=1, test_shots=15, meta_train=True, download=True)
dataloader = BatchMetaDataLoader(dataset, batch_size=16, num_workers=4)

# Define a simple convolutional neural network
class ConvNet(nn.Module):
    def __init__(self, in_channels, out_features):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, 3)
        self.conv2 = nn.Conv2d(64, 64, 3)
        self.fc1 = nn.Linear(64 * 5 * 5, out_features)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

# Initialize the model, MAML algorithm, and optimizer
model = ConvNet(in_channels=1, out_features=5)
maml = MAML(model, lr=0.01, first_order=False)
optimizer = optim.Adam(maml.parameters(), lr=0.001)

# Training loop
for batch in tqdm(dataloader):
    learner = maml.clone()
    optimizer.zero_grad()
    adaptation_data, evaluation_data = batch["train"], batch["test"]

    # Adapt the model
    adaptation_loss = learner(adaptation_data[0])
    learner.adapt(adaptation_loss)

    # Evaluate the adapted model
    evaluation_loss = learner(evaluation_data[0])
    evaluation_loss.backward()
    optimizer.step()

# Evaluate the model
test_dataset = omniglot("data", ways=5, shots=1, test_shots=15, meta_test=True, download=True)
test_dataloader = BatchMetaDataLoader(test_dataset, batch_size=16, num_workers=4)
maml.eval()

for batch in tqdm(test_dataloader):
    learner = maml.clone()
    adaptation_data, evaluation_data = batch["train"], batch["test"]

    # Adapt the model
    adaptation_loss = learner(adaptation_data[0])
    learner.adapt(adaptation_loss)

    # Evaluate the adapted model
    evaluation_loss = learner(evaluation_data[0])
    print(f"Evaluation loss: {evaluation_loss.item()}")


In [None]:
import torch
from torch import nn
import learn2learn as l2l
from transformers import BertForSequenceClassification, BertTokenizer
from datasets import load_dataset
from tqdm import tqdm

class TransformerMAML(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model

    def forward(self, **inputs):
        return self.base_model(**inputs).logits

def prepare_dataset(dataset, tokenizer, text_column, label_column):
    encoded = tokenizer(dataset[text_column], padding=True, truncation=True, return_tensors="pt")
    encoded['labels'] = torch.tensor(dataset[label_column])
    return encoded

def accuracy(predictions, targets):
    return torch.mean((predictions.argmax(dim=1) == targets).float())

def main():
    # Load and prepare datasets
    dataset_configs = [
        ("sst2", "sentence", "label"),
        ("cola", "sentence", "label"),
        ("mrpc", "sentence1", "label")
    ]

    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    datasets = []

    for name, text_col, label_col in tqdm(dataset_configs, desc="Preparing datasets"):
        dataset = load_dataset("glue", name)
        train_data = dataset['train'].shuffle(seed=42).select(range(100))  # Limit to 100 samples
        encoded_data = prepare_dataset(train_data, tokenizer, text_col, label_col)
        datasets.append(encoded_data)

    # Initialize base model and MAML learner
    base_model = BertForSequenceClassification.from_pretrained("bert-base-uncased")
    maml_model = TransformerMAML(base_model)

    maml = l2l.algorithms.MAML(maml_model, lr=0.1, first_order=True)
    opt = torch.optim.Adam(maml.parameters(), lr=1e-3)

    # Meta-training loop
    for iteration in tqdm(range(100), desc="Meta-training"):
        opt.zero_grad()
        meta_train_error = 0.0
        meta_train_accuracy = 0.0

        for task in range(3):  # Number of tasks per meta-update
            # Choose a random dataset
            dataset = datasets[torch.randint(0, len(datasets), (1,)).item()]

            # Compute meta-training loss
            learner = maml.clone()

            # Prepare batch
            batch_size = 10
            indices = torch.randperm(len(dataset['input_ids']))[:batch_size]
            batch = {k: v[indices] for k, v in dataset.items() if k in ['input_ids', 'attention_mask', 'token_type_ids', 'labels']}

            for _ in range(2):  # Inner loop adaptation steps
                outputs = learner(**batch)
                train_error = nn.CrossEntropyLoss()(outputs, batch['labels'])
                learner.adapt(train_error)

            # Prepare another batch for evaluation
            indices = torch.randperm(len(dataset['input_ids']))[:batch_size]
            batch = {k: v[indices] for k, v in dataset.items() if k in ['input_ids', 'attention_mask', 'token_type_ids', 'labels']}

            predictions = learner(**batch)
            valid_error = nn.CrossEntropyLoss()(predictions, batch['labels'])
            valid_accuracy = accuracy(predictions, batch['labels'])

            meta_train_error += valid_error
            meta_train_accuracy += valid_accuracy

        # Print some metrics
        if iteration % 100 == 0:
            print(f'Iteration {iteration}')
            print(f'Meta Train Error: {meta_train_error.item() / 5:.4f}')
            print(f'Meta Train Accuracy: {meta_train_accuracy.item() / 5:.4f}')

        # Average the accumulated gradients and optimize
        meta_train_error /= 5
        meta_train_error.backward()
        opt.step()

    # Evaluation on a held-out task
    eval_dataset = load_dataset("glue", "rte")['validation'].select(range(100))  # Using RTE as a new task
    eval_encoded = prepare_dataset(eval_dataset, tokenizer, "sentence1", "label")

    learner = maml.clone()
    for _ in range(5):  # Adapt on a few samples
        batch = {k: v[:10] for k, v in eval_encoded.items()}
        outputs = learner(**batch)
        train_error = nn.CrossEntropyLoss()(outputs, batch['labels'])
        learner.adapt(train_error)

    # Evaluate
    batch = {k: v[10:] for k, v in eval_encoded.items()}
    predictions = learner(**batch)
    test_accuracy = accuracy(predictions, batch['labels'])
    print(f'Test Accuracy on new task: {test_accuracy.item():.4f}')

if __name__ == "__main__":
    main()

Preparing datasets: 100%|██████████| 3/3 [00:14<00:00,  4.99s/it]
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Meta-training:   0%|          | 0/100 [00:00<?, ?it/s]

Iteration 0
Meta Train Error: 2.3369
Meta Train Accuracy: 0.2800


Meta-training: 100%|██████████| 100/100 [12:58<00:00,  7.78s/it]


Downloading data:   0%|          | 0.00/584k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/69.0k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/621k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2490 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/277 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3000 [00:00<?, ? examples/s]

Test Accuracy on new task: 0.5000
