<a href="https://colab.research.google.com/github/yinhao0424/reuters/blob/master/ReusterFewShotLearner.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## One Shot Learner
By building a pooling layer on top of the BERT model, the sentence embedding has been generated by the model. The next steps will be
- Train a classifier on support set
- Predict based on the classifier
***
1/6/2021  
Test the model performance
- Test without Finetuning
  - store the embedding of support set
  - caculate the embedding of query
  - find the sample with the highest similarity score
- Finetuning

In [1]:
# a specific version of transformaer has been used 
! pip install -q transformers==3.0.2
# !pip install -q transformers

[K     |████████████████████████████████| 778kB 8.0MB/s 
[K     |████████████████████████████████| 1.2MB 11.7MB/s 
[K     |████████████████████████████████| 3.0MB 31.1MB/s 
[K     |████████████████████████████████| 890kB 59.9MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone


In [2]:
import numpy as np
import pandas as pd
from sklearn import metrics
from tqdm import tqdm

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

import torch.nn.functional as F

import transformers
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import DistilBertTokenizer, DistilBertModel

import warnings
warnings.simplefilter('ignore')
import logging
logging.basicConfig(level=logging.ERROR)

In [3]:
if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")
device

device(type='cuda')

In [4]:
MAX_LEN = 256
EPOCHS = 2
LEARNING_RATE = 1e-05
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', truncation=True, do_lower_case=True)


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




In [5]:
class DistilBERTClass(torch.nn.Module):
    def __init__(self):
        super(DistilBERTClass, self).__init__()
        self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.pre_classifier = torch.nn.Linear(768, 768)
        self.dropout = torch.nn.Dropout(0.1)
        self.classifier = torch.nn.Linear(768, 256)

    def forward(self, data):
        input_ids = data['ids']
        attention_mask = data['mask']

        output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = output_1[0]
        pooler = hidden_state[:, 0]
        pooler = self.pre_classifier(pooler)
        pooler = torch.nn.Tanh()(pooler)
        pooler = self.dropout(pooler)
        output = self.classifier(pooler)
        return output

In [6]:
## Load data and model
PATH = '/content/drive/MyDrive/data/reuters/siamese_NN.pth'
model = DistilBERTClass()
model.load_state_dict(torch.load(PATH))
model.to(device)
model.eval()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=442.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=267967963.0, style=ProgressStyle(descri…




DistilBERTClass(
  (l1): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear(in_featu

## Test how the similarity function has been worked
1/5/2021  


In [13]:
class OneShotLearning(Dataset):
    """
        Input: a dataframe
        output: index, ids, mask
    """
    def __init__(self, dataframe, tokenizer, max_len):
        self.tokenizer = tokenizer
        self.data = dataframe
        self.text = dataframe.texts
        self.topics = self.data.topics
        self.max_len = max_len

    def __len__(self):
        return len(self.text)

    def __getitem__(self, index):
        text = str(self.text[index])
        text = " ".join(text.split())

        anchor = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            pad_to_max_length=True,
            return_token_type_ids=True
        )

        ids = anchor['input_ids']
        mask = anchor['attention_mask']
        # token_type_ids = anchor["token_type_ids"]


        return {
            'index':torch.tensor(index, dtype=torch.int),
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long)
        }


## Read data
- Commodity -- 67 topics
- Currency -- 18 topics

In [7]:
# Commodity data with 67 categories
reuster_support = pd.read_csv('/content/drive/MyDrive/data/reuters/fewshot_support.csv')
reuster_test = pd.read_csv('/content/drive/MyDrive/data/reuters/fewshot_test.csv')
reuster_support.head()

Unnamed: 0,id,topics,texts
0,4016,iron-steel,"usx <x> proved oil, gas reserves fall in 1986u..."
1,4022,carcass,argentine meat exports higher in jan/feb 1987a...
2,4022,livestock,argentine meat exports higher in jan/feb 1987a...
3,4035,veg-oil,british minister criticises proposed ec oils t...
4,4040,oilseed,china's rapeseed crop damaged by stormsthe yie...


In [10]:
print("Support Dataset: {}".format(reuster_support.shape))
print("Test Dataset: {}".format(reuster_test.shape))

def generate_dataloader(reuster_support,reuster_test):
  support_set = OneShotLearning(reuster_support, tokenizer, MAX_LEN)
  testing_set = OneShotLearning(reuster_test, tokenizer, MAX_LEN)


  SUPPORT_BATCH_SIZE = 1
  TEST_BATCH_SIZE = 1

  support_params = {'batch_size': SUPPORT_BATCH_SIZE,
                  'num_workers': 1
                  }

  test_params = {'batch_size': TEST_BATCH_SIZE,
                  'num_workers': 1
                  }

  support_loader = DataLoader(support_set, **support_params)
  testing_loader = DataLoader(testing_set, **test_params)
  return support_loader,testing_loader

Support Dataset: (67, 3)
Test Dataset: (2268, 3)


In [11]:
## cosine similarity
def cosine_similarity(support,test):
  cos = nn.CosineSimilarity(dim=-1, eps=1e-6)
  out = cos(support,test)
  return out[0]

def similar_support(test):
  most_similar = 0
  most_similar_idx = None
  for idx,support in enumerate(support_res):
    out = cosine_similarity(support,test)
    if out > most_similar:
      most_similar = out
      most_similar_idx = support_idx[idx]
  return most_similar_idx



In [21]:
def support_embedding(support_loader):
  # generate support embedding
  support_res = []
  support_idx = []
  with torch.no_grad():
    for _,data in tqdm(enumerate(support_loader, 0)):
        support_idx.append(data['index'].tolist()[0])
        support = {key:data[key].cuda() for key in data}

        support_res.extend(model(support).cpu().detach())
  return support_res,support_idx

def test_embeding(testing_loader):
  ## for test set, find argmax
  testing_res = []
  testing_idx = []
  with torch.no_grad():
    for _,data in tqdm(enumerate(testing_loader, 0)):
      testing_idx.append(data['index'].tolist()[0])
      test = {key:data[key].cuda() for key in data}
      testing_res.append(model(test).cpu().detach())
  return testing_res,testing_idx

In [None]:
support_loader,testing_loader = generate_dataloader(reuster_support,reuster_test)
support_res,support_idx = support_embedding(support_loader)
testing_res,testing_idx = test_embeding(testing_loader)

In [18]:
true_positive = 0
with torch.no_grad():
  for index, test in tqdm(enumerate(testing_res)):
    most_similar_idx = similar_support(test)
    
    support_topic = reuster_support.iloc[most_similar_idx]['topics']
    test_topic = reuster_test.iloc[testing_idx[index]]['topics']
    if support_topic == test_topic:
      true_positive+=1
  # break

2268it [00:11, 204.75it/s]


In [19]:
true_positive/len(testing_res)

0.35185185185185186

### Currency

In [17]:
# currency data with 67 categories
reuster_cur_support = pd.read_csv('/content/drive/MyDrive/data/reuters/cur_fewshot_support.csv')
reuster_cur_test = pd.read_csv('/content/drive/MyDrive/data/reuters/cur_fewshot_test.csv')
reuster_cur_support.head()

Unnamed: 0,id,topics,texts
0,4616,dlr,miyazawa says exchange rates will stay stablef...
1,4633,yen,japan minister says about 170 yen appropriatei...
2,10344,dfl,economic spotlight - dutch exchange rate polic...
3,10636,lit,italy relaxes restrictions on lira importsital...
4,10718,rupiah,indonesian rupiah slips against mark and yenth...


In [22]:
support_loader,testing_loader = generate_dataloader(reuster_cur_support,reuster_cur_test)
support_res,support_idx = support_embedding(support_loader)
testing_res,testing_idx = test_embeding(testing_loader)

18it [00:00, 55.90it/s]
227it [00:02, 82.28it/s]


In [23]:
true_positive = 0
with torch.no_grad():
  for index, test in tqdm(enumerate(testing_res)):
    most_similar_idx = similar_support(test)
    
    support_topic = reuster_support.iloc[most_similar_idx]['topics']
    test_topic = reuster_test.iloc[testing_idx[index]]['topics']
    if support_topic == test_topic:
      true_positive+=1

227it [00:00, 585.80it/s]


In [27]:
true_positive/len(testing_res)

0.6607929515418502

In [None]:
def evaluation(res_anchor,res_positive,res_negative):
    pos_dist = F.pairwise_distance(res_anchor, res_positive)
    neg_dist = F.pairwise_distance(res_anchor, res_negative)

    res = 0
    for pos,neg in zip(pos_dist,neg_dist):
      # print(pos,neg)
      if pos < neg:
        res += 1
    return res

In [None]:
true = 0
for _,data in tqdm(enumerate(testing_loader, 0)):
    anchor_id = data['anchor']['index']
    positive_id = data['positive']['index']
    negative_id = data['negative']['index']

    anchor = {key:data['anchor'][key].cuda() for key in data['anchor']}
    positive = {key:data['positive'][key].cuda() for key in data['positive']}
    negative = {key:data['negative'][key].cuda() for key in data['negative']}
    res_anchor,res_positive,res_negative = model(anchor),model(positive),model(negative)
    # print(1)
    true += evaluation(res_anchor,res_positive,res_negative)
    if _%10 == 0:
      print(true)

- For support set generates embedding 
- Compare with the distance with query
- Softmax function (optional)
- Find topics by argmax

In [None]:
print(anchor_id)
print(positive_id)
print(negative_id)

tensor([1321,  609, 1191,   76], dtype=torch.int32)
tensor([ 188, 1648, 1993,  125], dtype=torch.int32)
tensor([ 558,  691, 1534,  745], dtype=torch.int32)


In [None]:
reuster_test.iloc[anchor_id.tolist()]

Unnamed: 0,id,topics,texts
1321,2535,grain,argentina-brazil trade jumped 90 pct in 1986tr...
609,5606,soybean,"u.s. export inspections, in thous bushels soy..."
1191,12802,grain,london grains sees wheat recover from lowsu.k....
76,4436,veg-oil,argentine vegetable oils shipments in jan/nov ...


In [None]:
reuster_test.iloc[positive_id.tolist()]
reuster_test.iloc[negative_id.tolist()]

pos_dist = F.pairwise_distance(res_anchor, res_positive)
neg_dist = F.pairwise_distance(res_anchor, res_negative)

In [None]:
pos_dist

tensor([3.8939, 1.1187, 4.6582, 2.6391], device='cuda:0',
       grad_fn=<NormBackward1>)

In [None]:
neg_dist

tensor([5.0758, 5.5659, 4.2740, 3.1108], device='cuda:0',
       grad_fn=<NormBackward1>)