In [1]:
!pip install transformers==3.0.2
!pip install -U sentence-transformers

Collecting transformers==3.0.2
  Downloading transformers-3.0.2-py3-none-any.whl (769 kB)
[K     |████████████████████████████████| 769 kB 13.0 MB/s 
[?25hCollecting sentencepiece!=0.1.92
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 70.9 MB/s 
Collecting tokenizers==0.8.1.rc1
  Downloading tokenizers-0.8.1rc1-cp37-cp37m-manylinux1_x86_64.whl (3.0 MB)
[K     |████████████████████████████████| 3.0 MB 71.3 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.53.tar.gz (880 kB)
[K     |████████████████████████████████| 880 kB 54.0 MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.53-py3-none-any.whl size=895260 sha256=3001fa0a0c8585963fa21f5c6f934d242b24b92670651a9739c549ffaf6ce759
  Stored in directory: /root/.cache/pip/wheels/87/39/dd/a83eeef36d0bf98

In [3]:
# imports
import os
import random

import torch
from torch.utils.data import Dataset, DataLoader
from torch import cuda

import transformers
from transformers import RobertaTokenizer, RobertaModel
from transformers import pipeline

from torch import cuda
from tqdm import tqdm
device = 'cuda' if cuda.is_available() else 'cpu'

In [None]:
# Mounting Google Drive to this .ipynb
from google.colab import drive
drive.mount('/content/drive')

In [None]:

train_data_loc = '/content/drive/My Drive/data/SST-2/train.tsv'
dev_data_loc = '/content/drive/My Drive/data/SST-2/dev.tsv'

In [4]:
from sentence_transformers import SentenceTransformer
emm_model = SentenceTransformer('all-MiniLM-L6-v2')

In [None]:
class SST2_Demo(Dataset):

  def __init__(self, file_loc, tokenizer, max_len, emm_model, template = '<S> It was <mask> . ', target2label = {1: 'great', 0: 'terrible'}):

    self.tokenizer = tokenizer
    self.max_length = max_len

    with open(file_loc) as f:
      f.readline()
      data = [line.split("\t") for line in f]

    self.examples = [x.strip('\n') for (x,y) in data]
    self.targets = [int(y) for (x,y) in data]
    self.emm_model = emm_model

    self.examples_by_class, self.emm_examples, self.classes = self.dataset_embeddings()

    self.prompt_template = template
    self.prompt_label = target2label

    
  def __len__(self):
    return len(self.targets)/2

  def __getitem__(self, index):
    
    query = self.example[index]
    
    query_demos = self.data_by_query(query)

    demonstration = self.create_demo_pair(query_demos)

    x_in = demonstration+query

    x,y = self.prompt_transform(x_in, self.targets[index])

    x_tokenized = self.tokenizer(x, return_tensors='pt', max_length = self.max_length, truncation=True, padding='max_length')
    y_tokenized = self.tokenizer(y, return_tensors='pt', max_length = self.max_length, truncation=True, padding='max_length')
    
   
    x_tokenized['labels'] = y_tokenized['input_ids']
    

    return x_tokenized


  def dataset_embeddings(self):
    '''
    data_loc : filepath to the dataset file
    emm_model: the transformer encoder for sentence embeddings
    '''
    
    emm_sentances = []
      
    for example in self.examples:
      emm_sentances.append(emm_model(example))
      
        
    classes = set(self.targets)
    data_size = len(self.targets)
      
    data_by_class = {}
    for class_ in classes:
      data_by_class[class_] = []
      
    for i in range(data_size):
      data_by_class[self.targets[i]].append(i)
      
    return data_by_class, emm_sentances, classes
    
  def data_by_query(self, query):
    dbyq = {}
    for key, value in self.examples_by_class.items():
      dbyq[key] = []
      
    for idx in value:
      # sim = cosine similarity between self.emm_model(query) and self.emm_sentances[idx]
      dbyq[key].append((sim, idx))
      
    for class_ in self.classes:
      # sort dbyq[class_] acording to zeroth item in the index
      # update with top half of the list
      
    return dbyq

  def create_demo_pair(self, data_by_query):
    samples = []
    for class_ in self.classes:
      samples.append(random.choice(data_by_query[class_]))
    
    demo = ""
    for sample in samples:
      t = self.prompt_template.replace("<S>", self.example[sample])
      t = t.replace("<mask>", self.prompt_labels[self.target[sample]])
      demo = demo+t
      
    return demo

  def prompt_transform(self, text, target):
    '''
    text - Text to be classifiedCheck aviewfrommyseat.com
    template - a simple string replacing the text for '<S>', mask for '<mask>' punctuation and space is as is.
    eg- '<S> It was <mask> . '
    Returns a transformed prompt for the text.
    '''
    x = self.prompt_template.replace('<S>', text)
    y = self.prompt_template.replace('<S>', text).replace('<mask>', self.prompt_label[target])
    
    return x, y 


In [None]:
# get sentance embeddings for the entire training set.




# for a given query and every training example calculate the cosine similarity






  





