<a href="https://colab.research.google.com/github/yinhao0424/reuster/blob/master/ReusterFewShotLearner.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Reference:
- Paper/Blog
  - [BERT word embedding](https://mccormickml.com/2019/05/14/BERT-word-embeddings-tutorial/#31-running-bert-on-our-text)
  - [triplet-network-pytorch](https://github.com/andreasveit/triplet-network-pytorch/blob/master/train.py)
  - [Sentence Embeddings using Siamese BERT-Networks - paper](https://www.aclweb.org/anthology/D19-1410.pdf)
  - [Sentence Embeddings using Siamese BERT-Networks - colab](https://github.com/aneesha/SiameseBERT-Notebook/blob/master/SiameseBERT_SemanticSearch.ipynb)
- Disscussion
    - [Generate sequence classifier](https://github.com/huggingface/transformers/issues/1001)


In [1]:
# a specific version of transformaer has been used 
! pip install -q transformers==3.0.2
# !pip install -q transformers

[K     |████████████████████████████████| 778kB 4.2MB/s 
[K     |████████████████████████████████| 1.1MB 53.3MB/s 
[K     |████████████████████████████████| 890kB 52.1MB/s 
[K     |████████████████████████████████| 3.0MB 50.9MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone


In [2]:
import numpy as np
import pandas as pd
from sklearn import metrics
from tqdm import tqdm

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

import transformers
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import DistilBertTokenizer, DistilBertModel

import warnings
warnings.simplefilter('ignore')
import logging
logging.basicConfig(level=logging.ERROR)

In [3]:
if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")
device

device(type='cpu')

In [4]:
# Sections of config
# Defining some key variables that will be used later on in the training
MAX_LEN = 256
TRAIN_BATCH_SIZE = 8
VALID_BATCH_SIZE = 4
EPOCHS = 2
LEARNING_RATE = 1e-05
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', truncation=True, do_lower_case=True)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




In [5]:
reuster_train = pd.read_csv('/content/drive/MyDrive/data/reuters/reuster_fewshot_train.csv')
reuster_train.head()

Unnamed: 0,id,topics,texts
0,4016,iron-steel,"usx <x> proved oil, gas reserves fall in 1986u..."
1,4022,carcass,argentine meat exports higher in jan/feb 1987a...
2,4022,livestock,argentine meat exports higher in jan/feb 1987a...
3,4035,veg-oil,british minister criticises proposed ec oils t...
4,4040,oilseed,china's rapeseed crop damaged by stormsthe yie...


In [6]:
reuster_train.shape

(1143, 3)

In [7]:
# #iron-steel
# reuster_train.topics[0]
# topics = reuster_train.topics
# candi = topics[topics!='iron-steel'].index
# idx = np.random.choice(candi)

# print(topics[idx])

In [8]:
class FewShotDataset(Dataset):
    """
        Input: a dataframe
        output: anchor, positive and negative
    """
    def __init__(self, dataframe, tokenizer, max_len):
        self.tokenizer = tokenizer
        self.data = dataframe
        self.text = dataframe.texts
        self.topics = self.data.topics
        self.max_len = max_len

    def __len__(self):
        return len(self.text)

    def __getitem__(self, index):
        text = str(self.text[index])
        text = " ".join(text.split())

        anchor = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            pad_to_max_length=True,
            return_token_type_ids=True
        )

        ids = anchor['input_ids']
        mask = anchor['attention_mask']
        # token_type_ids = anchor["token_type_ids"]


        return {'anchor':{
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long)},
        'positive': self.get_positive(index),
        'negative': self.get_negative(index)
        }

    def get_positive(self, index):
         # the topic
        topic = self.topics[index]

        # select positive data which have the same topic with the anchor
        candidates = self.topics[self.topics==topic].index
        p_idx = index
        while p_idx == index:
          p_idx = np.random.choice(candidates)
        
        text = str(self.text[p_idx])
        text = " ".join(text.split())

        positive = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            pad_to_max_length=True,
            return_token_type_ids=True
        )

        ids = positive['input_ids']
        mask = positive['attention_mask']
        # token_type_ids = positive["token_type_ids"]

        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long)}

    def get_negative(self, index):
         # the topic
        topic = self.topics[index]

        # select positive data which have the same topic with the anchor
        candidates = self.topics[self.topics!=topic].index
        n_idx = index
        n_idx = np.random.choice(candidates)
        
        text = str(self.text[n_idx])
        text = " ".join(text.split())

        negative = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            pad_to_max_length=True,
            return_token_type_ids=True
        )
        ids = negative['input_ids']
        mask = negative['attention_mask']
        # token_type_ids = negative["token_type_ids"]

        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long)}

In [9]:
print("TRAIN Dataset: {}".format(reuster_train.shape))
# print("TEST Dataset: {}".format(test_data.shape))

training_set = FewShotDataset(reuster_train, tokenizer, MAX_LEN)
# testing_set = MultiLabelDataset(test_data, tokenizer, MAX_LEN)

TRAIN Dataset: (1143, 3)


In [10]:
train_params = {'batch_size': TRAIN_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }

test_params = {'batch_size': VALID_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }

training_loader = DataLoader(training_set, **train_params)
# testing_loader = DataLoader(testing_set, **test_params)

In [11]:
print("The len of training loader is {}.".format(len(training_loader)))

The len of training loader is 143.


## Create the Neural Network for Fine Tuning

In [12]:
class DistilBERTClass(torch.nn.Module):
    def __init__(self):
        super(DistilBERTClass, self).__init__()
        self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.pre_classifier = torch.nn.Linear(768, 768)
        self.dropout = torch.nn.Dropout(0.1)
        self.classifier = torch.nn.Linear(768, 256)

    def forward(self, data):
        input_ids = data['ids']
        attention_mask = data['mask']

        output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = output_1[0]
        pooler = hidden_state[:, 0]
        pooler = self.pre_classifier(pooler)
        pooler = torch.nn.Tanh()(pooler)
        pooler = self.dropout(pooler)
        output = self.classifier(pooler)
        return output

    # def forward(self, anchor,positive,negative):
    #     res_anchor = self.forward_once(anchor)
    #     res_positive = self.forward_once(positive)
    #     res_negative = self.forward_once(negative)
    #     return res_anchor,res_positive,res_negative

model = DistilBERTClass()
model.to(device)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=442.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=267967963.0, style=ProgressStyle(descri…




DistilBERTClass(
  (l1): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear(in_featu

In [13]:
def triplet_loss(anchor, positive, negative):
  loss = torch.nn.TripletMarginLoss(margin=1.0, p=2)
  return loss(anchor, positive, negative)

In [14]:
optimizer = torch.optim.Adam(params =  model.parameters(), lr=LEARNING_RATE)

In [15]:
# model input: input_ids, attention_mask, token_type_ids
# training_loader:   return {'anchor':{
        #     'ids': torch.tensor(ids, dtype=torch.long),
        #     'mask': torch.tensor(mask, dtype=torch.long),
        #     'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long)},
        # 'positive': self.get_positive(index),
        # 'negative': self.get_negative(index)
        # }
def train(epoch):
    model.train()
    for _,data in tqdm(enumerate(training_loader, 0)):
        anchor = data['anchor']
        positive = data['positive']
        negative = data['negative']
        res_anchor,res_positive,res_negative = model(anchor),model(positive),model(negative)
        # res_anchor,res_positive,res_negative = model(anchor,positive,negative)

        optimizer.zero_grad()
        loss = triplet_loss(res_anchor,res_positive,res_negative)
        if _%20==0:
            print(f'Epoch: {epoch}, Loss:  {loss.item()}')
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [16]:
# 101,  7592,  1010,  2026,  3899,  2003, 10140,   102]
# for _,data in tqdm(enumerate(training_loader, 0)):
#     anchor = data['anchor']
#     positive = data['positive']
#     negative = data['negative']
#     # print(anchor)
#     res_anchor,res_positive,res_negative = model(anchor),model(positive),model(negative)
#     print(res_anchor)
#     break


0it [00:00, ?it/s]

tensor([[ 0.0150,  0.2781,  0.0539,  ..., -0.0671, -0.1414, -0.2027],
        [-0.0738,  0.2706,  0.0399,  ..., -0.1368, -0.1213, -0.2575],
        [-0.0792,  0.2642,  0.0393,  ..., -0.0070, -0.1887, -0.1760],
        ...,
        [-0.0014,  0.2907,  0.0887,  ..., -0.1080,  0.0004, -0.1858],
        [ 0.0423,  0.1955, -0.0205,  ..., -0.0482, -0.1494, -0.1603],
        [ 0.1207,  0.2113,  0.0879,  ..., -0.0489, -0.0202, -0.0949]],
       grad_fn=<AddmmBackward>)


In [17]:

for epoch in range(EPOCHS):
    train(epoch)


0it [00:00, ?it/s][A

Epoch: 0, Loss:  0.9176938533782959



1it [00:30, 30.93s/it][A
2it [01:00, 30.60s/it][A
3it [01:30, 30.26s/it][A
4it [01:59, 30.01s/it][A
5it [02:29, 29.87s/it][A
6it [02:59, 30.00s/it][A
7it [03:29, 29.87s/it][A
8it [03:58, 29.75s/it][A
9it [04:28, 29.67s/it][A
10it [04:57, 29.57s/it][A
11it [05:26, 29.51s/it][A
12it [05:56, 29.55s/it][A
13it [06:25, 29.49s/it][A
14it [06:55, 29.48s/it][A
15it [07:24, 29.54s/it][A
16it [07:54, 29.57s/it][A
17it [08:24, 29.55s/it][A
18it [08:53, 29.58s/it][A
19it [09:23, 29.60s/it][A
20it [09:52, 29.59s/it][A

Epoch: 0, Loss:  0.9406020045280457



21it [10:22, 29.51s/it][A
22it [10:51, 29.47s/it][A
23it [11:21, 29.48s/it][A
24it [11:50, 29.51s/it][A
25it [12:20, 29.55s/it][A
26it [12:49, 29.58s/it][A
27it [13:20, 29.78s/it][A
28it [13:49, 29.78s/it][A
29it [14:19, 29.72s/it][A
30it [14:49, 29.74s/it][A
31it [15:19, 29.73s/it][A
32it [15:48, 29.74s/it][A
33it [16:18, 29.66s/it][A
34it [16:47, 29.58s/it][A
35it [17:17, 29.52s/it][A
36it [17:46, 29.60s/it][A
37it [18:16, 29.63s/it][A
38it [18:46, 29.62s/it][A
39it [19:15, 29.59s/it][A
40it [19:45, 29.63s/it][A

Epoch: 0, Loss:  0.39005225896835327



41it [20:14, 29.60s/it][A
42it [20:44, 29.56s/it][A
43it [21:13, 29.52s/it][A
44it [21:43, 29.51s/it][A
45it [22:12, 29.43s/it][A
46it [22:41, 29.36s/it][A
47it [23:11, 29.39s/it][A
48it [23:41, 29.57s/it][A
49it [24:10, 29.50s/it][A
50it [24:39, 29.49s/it][A
51it [25:09, 29.51s/it][A
52it [25:39, 29.55s/it][A
53it [26:08, 29.58s/it][A
54it [26:38, 29.57s/it][A
55it [27:07, 29.57s/it][A
56it [27:37, 29.55s/it][A
57it [28:07, 29.54s/it][A
58it [28:36, 29.48s/it][A
59it [29:05, 29.47s/it][A
60it [29:35, 29.48s/it][A

Epoch: 0, Loss:  0.515407919883728



61it [30:04, 29.45s/it][A
62it [30:34, 29.43s/it][A
63it [31:03, 29.46s/it][A
64it [31:32, 29.44s/it][A
65it [32:02, 29.46s/it][A
66it [32:31, 29.43s/it][A
67it [33:01, 29.43s/it][A
68it [33:31, 29.61s/it][A
69it [34:00, 29.60s/it][A
70it [34:30, 29.55s/it][A
71it [34:59, 29.49s/it][A
72it [35:29, 29.50s/it][A
73it [35:58, 29.53s/it][A
74it [36:28, 29.45s/it][A
75it [36:57, 29.40s/it][A
76it [37:26, 29.38s/it][A
77it [37:56, 29.38s/it][A
78it [38:25, 29.38s/it][A
79it [38:54, 29.35s/it][A
80it [39:24, 29.35s/it][A

Epoch: 0, Loss:  0.43503692746162415



81it [39:53, 29.33s/it][A
82it [40:22, 29.26s/it][A
83it [40:51, 29.22s/it][A
84it [41:20, 29.21s/it][A
85it [41:50, 29.27s/it][A
86it [42:19, 29.31s/it][A
87it [42:49, 29.37s/it][A
88it [43:18, 29.37s/it][A
89it [43:48, 29.56s/it][A
90it [44:17, 29.42s/it][A
91it [44:47, 29.44s/it][A
92it [45:16, 29.50s/it][A
93it [45:46, 29.57s/it][A
94it [46:16, 29.59s/it][A
95it [46:45, 29.63s/it][A
96it [47:15, 29.66s/it][A
97it [47:45, 29.72s/it][A
98it [48:15, 29.72s/it][A
99it [48:44, 29.75s/it][A
100it [49:14, 29.76s/it][A

Epoch: 0, Loss:  0.6118994355201721



101it [49:44, 29.78s/it][A
102it [50:14, 29.80s/it][A
103it [50:44, 29.78s/it][A
104it [51:13, 29.79s/it][A
105it [51:43, 29.80s/it][A
106it [52:13, 29.80s/it][A
107it [52:43, 29.82s/it][A
108it [53:13, 29.82s/it][A
109it [53:43, 29.83s/it][A
110it [54:13, 29.95s/it][A
111it [54:43, 29.89s/it][A
112it [55:12, 29.85s/it][A
113it [55:42, 29.84s/it][A
114it [56:12, 29.82s/it][A
115it [56:42, 29.81s/it][A
116it [57:12, 29.83s/it][A
117it [57:41, 29.85s/it][A
118it [58:11, 29.87s/it][A
119it [58:41, 29.85s/it][A
120it [59:11, 29.84s/it][A

Epoch: 0, Loss:  0.1347637176513672



121it [59:41, 29.82s/it][A
122it [1:00:11, 29.82s/it][A
123it [1:00:40, 29.84s/it][A
124it [1:01:10, 29.80s/it][A
125it [1:01:40, 29.78s/it][A
126it [1:02:10, 29.78s/it][A
127it [1:02:40, 29.80s/it][A
128it [1:03:10, 29.85s/it][A
129it [1:03:39, 29.83s/it][A
130it [1:04:10, 29.97s/it][A
131it [1:04:39, 29.81s/it][A
132it [1:05:09, 29.75s/it][A
133it [1:05:38, 29.69s/it][A
134it [1:06:08, 29.70s/it][A
135it [1:06:38, 29.74s/it][A
136it [1:07:07, 29.73s/it][A
137it [1:07:37, 29.74s/it][A
138it [1:08:07, 29.76s/it][A
139it [1:08:37, 29.76s/it][A
140it [1:09:07, 29.77s/it][A

Epoch: 0, Loss:  0.039137691259384155



141it [1:09:36, 29.79s/it][A
142it [1:10:06, 29.75s/it][A
143it [1:10:32, 29.60s/it]

0it [00:00, ?it/s][A

Epoch: 1, Loss:  0.3646705448627472



1it [00:29, 29.69s/it][A
2it [00:59, 29.61s/it][A
3it [01:28, 29.53s/it][A
4it [01:57, 29.40s/it][A
5it [02:26, 29.35s/it][A
6it [02:56, 29.37s/it][A
7it [03:25, 29.38s/it][A
8it [03:55, 29.49s/it][A
9it [04:24, 29.31s/it][A
10it [04:53, 29.30s/it][A
11it [05:22, 29.29s/it][A
12it [05:52, 29.31s/it][A
13it [06:21, 29.36s/it][A
14it [06:50, 29.34s/it][A
15it [07:20, 29.33s/it][A
16it [07:49, 29.25s/it][A
17it [08:18, 29.22s/it][A
18it [08:47, 29.30s/it][A
19it [09:17, 29.38s/it][A
20it [09:46, 29.33s/it][A

Epoch: 1, Loss:  0.14457672834396362



21it [10:15, 29.22s/it][A
22it [10:44, 29.12s/it][A
23it [11:13, 29.14s/it][A
24it [11:42, 29.10s/it][A
25it [12:11, 29.02s/it][A
26it [12:40, 28.99s/it][A
27it [13:09, 29.14s/it][A
28it [13:39, 29.25s/it][A
29it [14:09, 29.54s/it][A
30it [14:39, 29.63s/it][A
31it [15:09, 29.66s/it][A
32it [15:38, 29.65s/it][A
33it [16:08, 29.63s/it][A
34it [16:38, 29.64s/it][A
35it [17:07, 29.66s/it][A
36it [17:37, 29.67s/it][A
37it [18:07, 29.66s/it][A
38it [18:36, 29.68s/it][A
39it [19:06, 29.70s/it][A
40it [19:36, 29.68s/it][A

Epoch: 1, Loss:  0.03002721071243286



41it [20:05, 29.66s/it][A
42it [20:35, 29.65s/it][A
43it [21:05, 29.66s/it][A
44it [21:34, 29.68s/it][A
45it [22:04, 29.66s/it][A
46it [22:34, 29.69s/it][A
47it [23:04, 29.73s/it][A
48it [23:33, 29.77s/it][A
49it [24:04, 29.92s/it][A
50it [24:34, 29.87s/it][A
51it [25:03, 29.85s/it][A
52it [25:33, 29.82s/it][A
53it [26:03, 29.79s/it][A
54it [26:33, 29.77s/it][A
55it [27:02, 29.77s/it][A
56it [27:32, 29.78s/it][A
57it [28:02, 29.76s/it][A
58it [28:32, 29.76s/it][A
59it [29:01, 29.80s/it][A
60it [29:31, 29.80s/it][A

Epoch: 1, Loss:  0.0



61it [30:01, 29.78s/it][A
62it [30:31, 29.79s/it][A
63it [31:01, 29.81s/it][A
64it [31:30, 29.80s/it][A
65it [32:00, 29.78s/it][A
66it [32:30, 29.79s/it][A
67it [33:00, 29.80s/it][A
68it [33:30, 29.81s/it][A
69it [34:00, 29.83s/it][A
70it [34:30, 30.00s/it][A
71it [35:00, 29.95s/it][A
72it [35:30, 29.92s/it][A
73it [35:59, 29.87s/it][A
74it [36:29, 29.86s/it][A
75it [36:59, 29.84s/it][A
76it [37:29, 29.83s/it][A
77it [37:59, 29.81s/it][A
78it [38:28, 29.77s/it][A
79it [38:58, 29.74s/it][A
80it [39:28, 29.74s/it][A

Epoch: 1, Loss:  0.36327263712882996



81it [39:57, 29.72s/it][A
82it [40:27, 29.70s/it][A
83it [40:57, 29.70s/it][A
84it [41:26, 29.71s/it][A
85it [41:56, 29.76s/it][A
86it [42:26, 29.76s/it][A
87it [42:56, 29.77s/it][A
88it [43:26, 29.77s/it][A
89it [43:55, 29.73s/it][A
90it [44:26, 29.92s/it][A
91it [44:55, 29.87s/it][A
92it [45:25, 29.85s/it][A
93it [45:55, 29.82s/it][A
94it [46:25, 29.77s/it][A
95it [46:54, 29.73s/it][A
96it [47:24, 29.69s/it][A
97it [47:53, 29.65s/it][A
98it [48:23, 29.57s/it][A
99it [48:52, 29.59s/it][A
100it [49:22, 29.63s/it][A

Epoch: 1, Loss:  0.024681508541107178



101it [49:52, 29.65s/it][A
102it [50:21, 29.62s/it][A
103it [50:51, 29.60s/it][A
104it [51:20, 29.58s/it][A
105it [51:50, 29.53s/it][A
106it [52:19, 29.46s/it][A
107it [52:49, 29.46s/it][A
108it [53:18, 29.45s/it][A
109it [53:47, 29.39s/it][A
110it [54:17, 29.35s/it][A
111it [54:47, 29.58s/it][A
112it [55:16, 29.56s/it][A
113it [55:45, 29.42s/it][A
114it [56:14, 29.35s/it][A
115it [56:44, 29.30s/it][A
116it [57:13, 29.29s/it][A
117it [57:42, 29.27s/it][A
118it [58:11, 29.21s/it][A
119it [58:40, 29.20s/it][A
120it [59:10, 29.19s/it][A

Epoch: 1, Loss:  0.5366882085800171



121it [59:39, 29.21s/it][A
122it [1:00:08, 29.31s/it][A
123it [1:00:38, 29.30s/it][A
124it [1:01:07, 29.34s/it][A
125it [1:01:36, 29.26s/it][A
126it [1:02:05, 29.18s/it][A
127it [1:02:34, 29.14s/it][A
128it [1:03:03, 29.14s/it][A
129it [1:03:32, 29.15s/it][A
130it [1:04:02, 29.16s/it][A
131it [1:04:31, 29.16s/it][A
132it [1:05:01, 29.42s/it][A
133it [1:05:30, 29.37s/it][A
134it [1:05:59, 29.36s/it][A
135it [1:06:29, 29.32s/it][A
136it [1:06:58, 29.29s/it][A
137it [1:07:27, 29.28s/it][A
138it [1:07:56, 29.23s/it][A
139it [1:08:25, 29.21s/it][A
140it [1:08:55, 29.19s/it][A

Epoch: 1, Loss:  0.10475233197212219



141it [1:09:24, 29.28s/it][A
142it [1:09:53, 29.32s/it][A
143it [1:10:19, 29.51s/it]
