In [1]:
import dataclasses
import logging
import os
import sys
sys.path.append("..")
import math
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Tuple, Optional

import numpy as np
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler

from transformers import (AutoConfig, AutoModelForSequenceClassification,
                          AutoTokenizer, EvalPrediction, GlueDataset, default_data_collator) 
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers import (HfArgumentParser, Trainer, TrainingArguments,
                          glue_compute_metrics, glue_output_modes,
                          glue_tasks_num_labels, set_seed)
from transformers import PreTrainedModel
from transformers import default_data_collator
import pandas as pd
# import higher
from torch.optim.sgd import SGD
from torch.optim.adam import Adam


from tqdm import tqdm, trange
from fluence.meta import MetaDataset

In [2]:
data_args = DataTrainingArguments(task_name = 'MNLI', data_dir = '/home/nlp/data/glue_data/MNLI',
                                 max_seq_length=80)
training_args = TrainingArguments(output_dir = '/home/nlp/experiments/meta/',
                                 do_eval = True,
                                 per_device_train_batch_size=8)

In [3]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [4]:
train_dataset = GlueDataset(data_args, tokenizer=tokenizer)
meta_dataset = MetaDataset(train_dataset)

100%|██████████| 130899/130899 [00:10<00:00, 12835.95it/s]


In [10]:
for i in meta_dataset[0]:
    inputs = i

In [None]:
@dataclass
class MetaTrainer(Trainer):
    def __init__(
        self,
        model: PreTrainedModel,
        args: TrainingArguments,
        train_dataloader: DataLoader,
        eval_dataloader: DataLoader,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        prediction_loss_only=False,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
            None,
            None,
        ),
    ):

        self.model = model.to(args.device)
        self.args = args
        self.compute_metrics = compute_metrics
        self.train_dataloader = train_dataloader
        self.eval_dataloader = eval_dataloader
        self.prediction_loss_only = prediction_loss_only
        self.optimizer, self.lr_scheduler = optimizers
        self.epoch = None
        self.tb_writer = None
        set_seed(self.args.seed)

    def run_maml(self):
        meta_dataset = self.train_dataloader.dataset
        sum_gradients = []
        
        eval_step = [2 ** i for i in range(1, 20)]
        
        for task_id in tqdm(2*self.args.max_sample_limit):
            
            fast_model = deepcopy(self.model)
            fast_model.to(self.args.device)
            inner_optimizer = torch.optim.AdamW(fast_model.parameters(), lr=self.args.step_size)
            fast_model.train()
            inner_optimizer.zero_grad()
            
            # Support set [classes]
            for task in meta_dataset[task_id]:
                loss = fast_model(**task)[0]
                loss.backward()
                inner_optimizer.step()
                
            # Query Set [classes]
            for task in meta_dataset[task_id]:
                query_loss = fast_model(**task)[0]
                query_loss.backward()
                fast_model.to(torch.device('cpu'))
                for i, params in enumerate(fast_model.parameters()):
                    if task_id == 0:
                        sum_gradients.append(deepcopy(params.grad))
                    else:
                        sum_gradients[i] += deepcopy(params.grad)

            del fast_model, inner_optimizer
            torch.cuda.empty_cache()
            
            # Run evaluation as per eval_step
            if self.global_step in eval_step:
                output = self.prediction_loop(self.eval_dataloader, description = "Evaluation")
                self.log(output.metrics)

                output_dir = os.path.join(
                    self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}",
                )
                self.save_model(output_dir)
        
        # Outer loop
        for i in range(0, len(sum_gradients)):
                sum_gradients[i] = sum_gradients[i] / float(self.args.max_sample_limit)

            for i, params in enumerate(self.model.parameters()):
                params.grad = sum_gradients[i]

            self.outer_optimizer.step()
            self.outer_optimizer.zero_grad()
            
            del sum_gradients
            gc.collect()
    
    def train(self):

        self.create_optimizer_and_scheduler(self.args.max_sample_limit)

        logger.info("***** Running training *****")

        self.global_step = 0
        self.epoch = 0

        eval_step = [2 ** i for i in range(1, 20)]
        inner_optimizer = torch.optim.SGD(
            self.model.parameters(), lr=self.args.step_size
        )
        self.model.train()

        tqdm_iterator = tqdm(self.train_dataloader, desc="Batch Index")

        #  self.model.zero_grad()
        self.optimizer.zero_grad()
        query_dataloader = iter(self.train_dataloader)

        for batch_idx, meta_batch in enumerate(tqdm_iterator):
            target_batch = next(query_dataloader)
            outer_loss = 0.0
            # Loop through all classes
            for inputs, target_inputs in zip(meta_batch, target_batch):

                for k, v in inputs.items():
                    inputs[k] = v.to(self.args.device)
                    target_inputs[k] = v.to(self.args.device)

                with higher.innerloop_ctx(
                    self.model, inner_optimizer, copy_initial_weights=False
                ) as (fmodel, diffopt):

                    inner_loss = fmodel(**inputs)[0]
                    diffopt.step(inner_loss)
                    outer_loss += fmodel(**target_inputs)[0]

            self.global_step += 1
            self.optimizer.step()
            self.lr_scheduler.step()
            outer_loss.backward()
            #  self.model.zero_grad()

            if (batch_idx + 1) % self.args.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), self.args.max_grad_norm
                )

            # Run evaluation on task list
            if self.global_step in eval_step:
                output = self.prediction_loop(self.eval_dataloader, description = "Evaluation")
                self.log(output.metrics)

                output_dir = os.path.join(
                    self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}",
                )
                self.save_model(output_dir)