In [30]:
import logging
import sys
sys.path.append("..")
from dataclasses import dataclass, field
from typing import List, Union, Optional, Dict, Callable

import numpy as np
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.data.dataloader import DataLoader

from transformers import (AutoConfig, AutoModelForSequenceClassification, AutoModel,
                          AutoTokenizer, PreTrainedTokenizer, PreTrainedModel)
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers import TrainingArguments, default_data_collator, EvalPrediction
from datasets.siamese_dataset import SiameseGlueDataset, siamese_data_collator
from models.siamese_model import SiameseTransformer
from core.siamese_trainer import SiameseTrainer
from tqdm import tqdm

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
logger = logging.getLogger(__name__)

In [4]:
task_name = 'mnli'
data_dir = '/home/nlp/data/glue_data/MNLI'
model_id = 'bert-base-uncased'

In [5]:
data_args = DataTrainingArguments(task_name, data_dir = data_dir, max_seq_length=32)

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [7]:
siamese_train_dataset = SiameseGlueDataset(data_args, tokenizer)

In [8]:
siamese_eval_dataset = SiameseGlueDataset(data_args, tokenizer, mode="dev")

In [9]:
from transformers import PreTrainedModel

In [10]:
from datasets.siamese_dataset import siamese_data_collator

In [11]:
train_dl = DataLoader(siamese_train_dataset,
                      batch_size=8,
                     collate_fn = siamese_data_collator, shuffle=True)

In [12]:
eval_dl = DataLoader(siamese_train_dataset,
                      batch_size=8,
                     collate_fn = siamese_data_collator)

In [13]:
@dataclass
class SiameseModelArguments:
    """
    Arguments pertaining to SiameseTransformer
    """

    model_name: str = field(
        metadata={
            "help": (
                "Path to pretrained model or model identifier from"
                " huggingface.co/models"
            )
        }
    )
    #input_dim: int = field(
    #    default=None, metadata={"help": "Input dimension of linear layer"}
    #)
    #linear_dim: int = field(
    #    default=None, metadata={"help": "Dimension of linear layer"}
    #)
    seq_len: int = field(default = 128)
    config_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained config name or path if not the same as model_name"
        },
    )
    tokenizer_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained tokenizer name or path if not the same as model_name"
        },
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Where do you want to store the pretrained models downloaded from s3"
            )
        },
    )
    freeze_a: bool = field(default=False, metadata={"help": "freeze model a"})
    freeze_b: bool = field(default=False, metadata={"help": "freeze model b"})
    num_labels: int = field(default=3)

In [14]:
class PredictionHeadTransform(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d((8, 128))
        self.dense = nn.Linear(4096, len(config.id2label))
                
    def forward(self, features):
        features = self.pool(features)
        print(features.shape)
        features = features.view(features.shape[0]//4, -1)
        features = self.dense(features)
        return features

In [15]:
class SiameseTransformer(nn.Module):
    def __init__(self, args, config):
        super(SiameseTransformer, self).__init__()
        self.args = args
        self.model_a = AutoModel.from_pretrained(self.args.model_name, 
                           config=config, cache_dir=self.args.cache_dir)
        self.loss_fct = nn.CrossEntropyLoss()
        self.cls = PredictionHeadTransform(config)
        
        if self.args.freeze_a:
            logger.info("**** Freezing Model A ****")
            for param in self.model_a.encoder.parameters():
                param.requires_grad = False

        if self.args.freeze_b:
            logger.info("**** Freezing Model B ****")
            for param in self.model_b.encoder.parameters():
                param.requires_grad = False
    
    def forward(self, inputs):
        labels = input_a['labels']
        input_a.pop('labels')
        input_b.pop('labels')
        output_a = self.model_a(**input_a)[0] # [bs, seq_len, 768]
        output_b = self.model_a(**input_b)[0]
        concat_output = torch.cat([output_a, output_b, (output_a-output_b), (output_a*output_b)])
        logits = self.cls(concat_output)
        loss = self.loss_fct(logits, labels)
        return loss, logits

In [16]:
class SiameseTransformer(nn.Module):
    def __init__(self, args, config):
        super(SiameseTransformer, self).__init__()
        self.args = args
        self.model_a = AutoModelForSequenceClassification.from_pretrained(
            self.args.model_name, config=config, cache_dir=self.args.cache_dir
        )
        # self.model_b = AutoModel.from_pretrained(self.args.model_name,
        #                   config=config, cache_dir=self.args.cache_dir)
        self.loss_fct = nn.CrossEntropyLoss()
        self.cls = PredictionHeadTransform(config)

        if self.args.freeze_a:
            logger.info("**** Freezing Model A ****")
            for param in self.model_a.encoder.parameters():
                param.requires_grad = False

        # if self.args.freeze_b:
        #    logger.info("**** Freezing Model B ****")
        #    for param in self.model_b.encoder.parameters():
        #        param.requires_grad = False

    def forward(self, a, b):
        #labels = input_a["labels"]
        #input_a = inputs["a"]
        #input_b = inputs["b"]
        #input_a.pop("labels")
        #input_b.pop("labels")
        #output_a = self.model_a(**input_a)[0]  # [bs, seq_len, 768]
        #output_b = self.model_a(**input_b)[0]
        #concat_output = torch.cat(
        #    [output_a, output_b, (output_a - output_b), (output_a * output_b)]
        #)
        #logits = self.cls(concat_output)
        #loss = self.loss_fct(logits, labels)
        # print(a)
        output = self.model_a(**a)
        return output
        #return loss, logits

In [17]:
args = SiameseModelArguments('bert-base-uncased', seq_len=32)

In [18]:
config = AutoConfig.from_pretrained(
        'bert-base-uncased',
        num_labels = 3,
        task_name = 'MNLI',
        cache_dir = '/home/nlp/experiments/siamese'
    )

In [19]:
head = PredictionHeadTransform(config)

In [20]:
output = torch.rand(8, 64, 768) # 8, 3

In [21]:
head(output).shape

torch.Size([8, 8, 128])


torch.Size([2, 3])

In [22]:
head

PredictionHeadTransform(
  (pool): AdaptiveAvgPool2d(output_size=(8, 128))
  (dense): Linear(in_features=4096, out_features=3, bias=True)
)

In [23]:
model = SiameseTransformer(args, config)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [24]:
batch = next(iter(train_dl))

In [25]:
batch_eval = next(iter(eval_dl))

In [26]:
for inputs in eval_dl:
    for k, v in inputs.items():
        if isinstance(v, torch.Tensor):
            inputs[k] = v.cuda()

In [27]:
for inputs in eval_dl:
    for k, v in inputs.items():
        print(k, v)
    break

a {'labels': tensor([2, 1, 1, 1, 2, 1, 2, 1]), 'input_ids': tensor([[  101, 17158,  2135,  6949,  8301, 25057,  2038,  2048,  3937,  9646,
          1011,  4031,  1998, 10505,  1012,   102,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  101,  2017,  2113,  2076,  1996,  2161,  1998,  1045,  3984,  2012,
          2012,  2115,  2504,  7910,  2017,  4558,  2068,  2000,  1996,  2279,
          2504,  2065,  2065,  2027,  5630,  2000,  9131,  1996,  1996,  6687,
          2136,   102],
        [  101,  2028,  1997,  2256,  2193,  2097,  4287,  2041,  2115,  8128,
          3371,  2135,  1012,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  101,  2129,  2079,  2017,  2113,  1029,  2035,  2023,  2003,  2037,
          2592,  2153,  1012,   102,     0,     0,     0,     0,     0,     0,

In [69]:
for k, v in batch['a'].items():
    batch['a'][k] = v.cuda()
for k, v in batch['b'].items():
    batch['b'][k] = v.cuda()

In [71]:
for k, v in batch_eval['a'].items():
    batch_eval['a'][k] = v.cuda()
for k, v in batch_eval['b'].items():
    batch_eval['b'][k] = v.cuda()

In [43]:
model.cuda();

In [78]:
output = model(**batch)

In [79]:
output

(tensor(1.2290, device='cuda:0', grad_fn=<NllLossBackward>),
 tensor([[ 0.0624, -0.1528, -0.0248],
         [-0.4474, -0.3496,  0.0997],
         [-0.3453, -0.3189,  0.1031],
         [-0.3425, -0.3321,  0.0577],
         [-0.3478, -0.2492, -0.0216],
         [-0.3940, -0.3169,  0.0230],
         [-0.5347, -0.2935,  0.1017],
         [-0.5581, -0.2490, -0.0911]], device='cuda:0', grad_fn=<AddmmBackward>))

In [81]:
model(**batch_eval)

(tensor(1.1629, device='cuda:0', grad_fn=<NllLossBackward>),
 tensor([[ 0.1034, -0.1250, -0.1374],
         [-0.2326, -0.4631,  0.0924],
         [-0.5709, -0.4091,  0.1518],
         [-0.4632, -0.3334,  0.1755],
         [-0.0604, -0.2989,  0.0755],
         [-0.5815, -0.2825,  0.1942],
         [-0.2083, -0.3863,  0.1067],
         [-0.2706, -0.2518, -0.0087]], device='cuda:0', grad_fn=<AddmmBackward>))

In [95]:
output[0]

tensor(0.9698, device='cuda:0', grad_fn=<NllLossBackward>)

In [48]:
from transformers import Trainer, glue_compute_metrics
from core.siamese_trainer import SiameseTrainer

In [28]:
training_args = TrainingArguments(output_dir = '/home/nlp/experiments/siamese/',
                                 do_eval = True,
                                 per_device_train_batch_size=1024,
                                 per_device_eval_batch_size=1024)

In [44]:
output_mode = "classification"


In [31]:
def build_compute_metrics_fn(task_name: str,) -> Callable[[EvalPrediction], Dict]:
    def compute_metrics_fn(p: EvalPrediction) -> Dict:
        if output_mode == "classification":
            preds = np.argmax(p.predictions, axis=1)
        elif output_mode == "regression":
            preds = np.squeeze(p.predictions)
        return glue_compute_metrics(data_args.task_name, preds, p.label_ids)

    return compute_metrics_fn

trainer = SiameseTrainer(
    model=model,
    args=training_args,
    train_dataset=siamese_train_dataset,
    eval_dataset=siamese_eval_dataset,
    data_collator=siamese_data_collator,
    compute_metrics=build_compute_metrics_fn(data_args.task_name))

In [62]:
build_compute_metrics_fn('mnli')

<function __main__.build_compute_metrics_fn.<locals>.compute_metrics_fn(p: transformers.trainer_utils.EvalPrediction) -> Dict>

In [50]:
trainer.evaluate()

Evaluation: 100%|██████████| 5/5 [00:02<00:00,  1.69it/s]


{'eval_loss': 1.1113010883331298, 'eval_mnli/acc': 0.31981660723382577}

In [36]:
len(siamese_eval_dataset)

9815

In [95]:
from transformers import AutoModel

In [96]:
model_a = AutoModel.from_pretrained('bert-base-uncased', 
                           config=config).cuda()

In [97]:
# batch['a'].pop('labels')
output_a = model_a(**batch['a'])

In [98]:
len(output_a)

2