<a href="https://colab.research.google.com/github/yinhao0424/reuster/blob/master/ReusterFewShotLearner.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Reference:
- [Sentence-BERT](https://www.aclweb.org/anthology/D19-1410.pdf)
- [Sentence Embeddings using Siamese BERT-Networks](https://github.com/aneesha/SiameseBERT-Notebook/blob/master/SiameseBERT_SemanticSearch.ipynb)

In [2]:
# a specific version of transformaer has been used 
! pip install -q transformers==3.0.2

[K     |████████████████████████████████| 778kB 12.3MB/s 
[K     |████████████████████████████████| 3.0MB 34.6MB/s 
[K     |████████████████████████████████| 1.1MB 57.9MB/s 
[K     |████████████████████████████████| 890kB 53.0MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone


In [3]:
import numpy as np
import pandas as pd
from sklearn import metrics
from tqdm import tqdm

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

import transformers
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import DistilBertTokenizer, DistilBertModel

import warnings
warnings.simplefilter('ignore')
import logging
logging.basicConfig(level=logging.ERROR)

In [25]:
# Sections of config
# Defining some key variables that will be used later on in the training
MAX_LEN = 256
TRAIN_BATCH_SIZE = 4
VALID_BATCH_SIZE = 4
EPOCHS = 2
LEARNING_RATE = 1e-05
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', truncation=True, do_lower_case=True)

In [4]:
reuster_train = pd.read_csv('/content/drive/MyDrive/data/reuters/reuster_fewshot_train.csv')
reuster_train.head()

Unnamed: 0,id,topics,texts
0,4016,iron-steel,"usx <x> proved oil, gas reserves fall in 1986u..."
1,4022,carcass,argentine meat exports higher in jan/feb 1987a...
2,4022,livestock,argentine meat exports higher in jan/feb 1987a...
3,4035,veg-oil,british minister criticises proposed ec oils t...
4,4040,oilseed,china's rapeseed crop damaged by stormsthe yie...


In [5]:
reuster_train.shape

(1143, 3)

In [18]:
# #iron-steel
# reuster_train.topics[0]
# topics = reuster_train.topics
# candi = topics[topics!='iron-steel'].index
# idx = np.random.choice(candi)

# print(topics[idx])

In [17]:
class FewShotDataset(Dataset):
    """
        Input: a dataframe
        output: anchor, positive and negative
    """
    def __init__(self, dataframe, tokenizer, max_len):
        self.tokenizer = tokenizer
        self.data = dataframe
        self.text = dataframe.texts
        self.topics = self.data.topics
        self.max_len = max_len

    def __len__(self):
        return len(self.text)

    def __getitem__(self, index):
        text = str(self.text[index])
        text = " ".join(text.split())

        anchor = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            pad_to_max_length=True,
            return_token_type_ids=True
        )

        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs["token_type_ids"]


        return {'anchor':{
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long)},
        'positive': self.get_positive(index),
        'negative': self.get_negative(index)
        }

    def get_positive(self, index):
         # the topic
        topic = self.topics[index]

        # select positive data which have the same topic with the anchor
        candidates = self.topics[self.topics==topic].index
        p_idx = index
        while p_idx == index:
          p_idx = np.random.choice(candidates)
        
        text = str(self.text[p_idx])
        text = " ".join(text.split())

        positive = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            pad_to_max_length=True,
            return_token_type_ids=True
        )

        ids = positive['input_ids']
        mask = positive['attention_mask']
        token_type_ids = positive["token_type_ids"]

        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long)}

    def get_negative(self, index):
         # the topic
        topic = self.topics[index]

        # select positive data which have the same topic with the anchor
        candidates = self.topics[self.topics!=topic].index
        n_idx = index
        n_idx = np.random.choice(candidates)
        
        text = str(self.text[n_idx])
        text = " ".join(text.split())

        negative = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            pad_to_max_length=True,
            return_token_type_ids=True
        )
        ids = negative['input_ids']
        mask = negative['attention_mask']
        token_type_ids = negative["token_type_ids"]

        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long)}

In [20]:
print("TRAIN Dataset: {}".format(reuster_train.shape))
# print("TEST Dataset: {}".format(test_data.shape))

training_set = FewShotDataset(reuster_train, tokenizer, MAX_LEN)
# testing_set = MultiLabelDataset(test_data, tokenizer, MAX_LEN)

TRAIN Dataset: (1143, 3)


In [26]:
train_params = {'batch_size': TRAIN_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }

test_params = {'batch_size': VALID_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }

training_loader = DataLoader(training_set, **train_params)
# testing_loader = DataLoader(testing_set, **test_params)

In [27]:
print("The len of training loader is {}.".format(len(training_loader)))

The len of training loader is 286.


## Create the Neural Network for Fine Tuning

In [None]:
class DistilBERTClass(torch.nn.Module):
    def __init__(self):
        super(DistilBERTClass, self).__init__()
        self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.pre_classifier = torch.nn.Linear(768, 768)
        self.dropout = torch.nn.Dropout(0.1)
        self.classifier = torch.nn.Linear(768, 9)

    def forward(self, input_ids, attention_mask, token_type_ids):
        output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = output_1[0]
        pooler = hidden_state[:, 0]
        pooler = self.pre_classifier(pooler)
        pooler = torch.nn.Tanh()(pooler)
        pooler = self.dropout(pooler)
        output = self.classifier(pooler)
        return output

model = DistilBERTClass()
model.to(device)