In [91]:
import logging
sys.path.append("..")
from dataclasses import dataclass, field
from typing import List, Union, Optional

from torch import nn
from torch.utils.data.dataloader import DataLoader

from transformers import (AutoConfig, AutoModelForSequenceClassification, AutoModel,
                          AutoTokenizer, PreTrainedTokenizer, PreTrainedModel)
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers import TrainingArguments
from core.siamese_dataset import SiameseGlueDataset, siamese_data_collator
from core.siamese_model import SiameseTransformer
from core.siamese_trainer import SiameseTrainer

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
logger = logging.getLogger(__name__)

In [None]:
task_name = 'mnli'
data_dir = '/home/nlp/data/glue_data/MNLI'
model_id = 'bert-base-uncased'

In [None]:
args = DataTrainingArguments(task_name, data_dir = data_dir)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
siamese_train_dataset = SiameseGlueDataset(args, tokenizer)

In [None]:
siamese_eval_dataset = SiameseGlueDataset(args, tokenizer, mode="dev")

In [84]:
from transformers import PreTrainedModel

In [86]:
@dataclass
class SiameseArguments(TrainingArguments):
    model_a: str = field(
        default = None,
        metadata={
            "help": (
                "Path to pretrained model or model identifier from"
                " huggingface.co/models"
            )
        }
    )
    model_b: str = field(
        default = None,
        metadata={
            "help": (
                "Path to pretrained model or model identifier from"
                " huggingface.co/models"
            )
        }
    )
    input_dim: int = field(default = None)
    linear_dim: int = field(default=None)
    num_labels: int = field(default=None)
    freeze_a: bool = field(default=None)
    freeze_b: bool = field(default=None)

In [107]:
@dataclass
class SiameseModelArguments:
    """
    Arguments pertaining to SiameseTransformer
    """

    model_name_or_path: str = field(
        metadata={
            "help": (
                "Path to pretrained model or model identifier from"
                " huggingface.co/models"
            )
        }
    )
    input_dim: int = field(
        default=None, metadata={"help": "Input dimension of linear layer"}
    )
    linear_dim: int = field(
        default=None, metadata={"help": "Dimension of linear layer"}
    )
    config_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained config name or path if not the same as model_name"
        },
    )
    tokenizer_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained tokenizer name or path if not the same as model_name"
        },
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Where do you want to store the pretrained models downloaded from s3"
            )
        },
    )
    freeze_a: bool = field(default=False, metadata={"help": "freeze model a"})
    freeze_b: bool = field(default=False, metadata={"help": "freeze model b"})
    num_labels: int = field(default=3)

In [131]:
class SiameseTransformer(nn.Module):
    def __init__(self, args, config):
        super(SiameseTransformer, self).__init__()
        self.args = args
        self.model_a = AutoModel.from_pretrained(self.args.model_name_or_path, 
                           config=config, cache_dir=self.args.cache_dir)
        self.model_b = AutoModel.from_pretrained(self.args.model_name_or_path,
                           config=config, cache_dir=self.args.cache_dir)
        
        if self.args.freeze_a:
            logger.info("**** Freezing Model A ****")
            for param in self.model_a.encoder.parameters():
                param.requires_grad = False

        if self.args.freeze_b:
            logger.info("**** Freezing Model B ****")
            for param in self.model_b.encoder.parameters():
                param.requires_grad = False
            
        self.linear = nn.Sequential(nn.Linear(self.args.input_dim, self.args.linear_dim),
                                    nn.Linear(self.args.linear_dim, self.args.num_labels)
                                   )
    
    def forward(self, input_a, input_b):
        loss_fct = nn.CrossEntropyLoss()
        labels = input_a['labels']
        input_a.pop('labels')
        input_b.pop('labels')
        output_a = self.model_a(**input_a)[0][:, 0, :]
        output_b = self.model_b(**input_b)[0][:, 0, :]
        concat_output = torch.cat([output_a, output_b])
        concat_output = concat_output.view(labels.size(0), -1)
        logits = self.linear(concat_output)
        loss = loss_fct(logits, labels)
        return loss, logits

In [132]:
args = SiameseModelArguments('bert-base-uncased', linear_dim=4096, input_dim=1536)

In [133]:
config = AutoConfig.from_pretrained(
        'bert-base-uncased',
        num_labels = 3,
        task_name = 'MNLI',
        cache_dir = '/home/nlp/experiments/siamese'
    )

In [138]:
model = SiameseTransformer(args, config)

In [137]:
torch.save({'state_dict': model.state_dict()}, '/home/nlp/pytorch_model.bin')

In [139]:
ckpt = torch.load('/home/nlp/pytorch_model.bin')

In [140]:
model.load_state_dict(ckpt['state_dict'])

<All keys matched successfully>