In [1]:
import dataclasses
import logging
import os
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Union

import numpy as np
import torch
from torch.utils.data.dataloader import DataLoader
from transformers import (
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    EvalPrediction,
    GlueDataset,
)
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers import (
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    default_data_collator,
    glue_compute_metrics,
    glue_output_modes,
    glue_tasks_num_labels,
    set_seed,
)

In [2]:
model = AutoModel.from_pretrained("sentence-transformers/bert-base-nli-mean-tokens")

In [3]:
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/bert-base-nli-mean-tokens")

In [4]:
data_args = DataTrainingArguments('mnli', '/home/nlp/data/glue_data/MNLI', max_seq_length=80)

In [5]:
train_dataset = GlueDataset(data_args, tokenizer=tokenizer) 

In [6]:
dataloader = DataLoader(train_dataset, collate_fn = default_data_collator)

In [7]:
batch = next(iter(dataloader))

In [8]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

In [9]:
batch.pop('labels')

tensor([2])

In [10]:
model_output = model(**batch)

In [11]:
sentence_embeddings = mean_pooling(model_output, batch['attention_mask'])

In [12]:
sentence_embeddings.shape

torch.Size([1, 768])

In [16]:
model.pooler

BertPooler(
  (dense): Linear(in_features=768, out_features=768, bias=True)
  (activation): Tanh()
)

In [17]:
from torch import nn