In order to use the model, you have first to download the weights of the model which you can find here here. You have to put this file in the Data folder. You also have to make sure that the csv file with the test data is in the data folder. 

In [1]:
!pip install -q transformers
!pip install sentencepiece


[K     |████████████████████████████████| 3.4 MB 13.8 MB/s 
[K     |████████████████████████████████| 3.3 MB 46.4 MB/s 
[K     |████████████████████████████████| 61 kB 421 kB/s 
[K     |████████████████████████████████| 596 kB 65.9 MB/s 
[K     |████████████████████████████████| 895 kB 35.1 MB/s 
[?25hCollecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 10.0 MB/s 
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.1.96


In [2]:
import transformers
import torch
from torch.utils.data import DataLoader
from torch import cuda

from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW
import numpy as np
import os
import pandas as pd

In [3]:
from google.colab import drive

drive.mount('/content/drive')
os.chdir('drive/MyDrive/Synthesio')

Mounted at /content/drive


In [4]:
class SentimentAnalysisDataset(torch.utils.data.Dataset):       

    def __init__(self, content, tokenizer):
        self.text = content
        self.encodings = tokenizer(content, truncation=True, padding=True, max_length=64)
    
    def __getitem__(self, idx):
        item = {key:torch.tensor(val[idx]) for key, val in self.encodings.items()}
        return item
    
    def __len__(self):
        return len(self.text)

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL = f"cardiffnlp/twitter-xlm-roberta-base-sentiment"
tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForSequenceClassification.from_pretrained(MODEL)
model.load_state_dict(torch.load('Data/checkpoint.pth',
                                 map_location=torch.device(device)))
model.eval()
model.to(device)

XLMRobertaForSequenceClassification(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(250002, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (La

In [11]:
def predict(file_path):
    index_to_target = {0:'negative', 
                        1:'neutral',
                        2:'positive'}

    params = {'batch_size': 8,
            'shuffle': False,
            'num_workers': 0
        }
         
    texts = pd.read_csv(file_path).iloc[:,0].tolist()
    dataset = SentimentAnalysisDataset(texts, tokenizer)
    dataloader = DataLoader(dataset, **params)
    results = []
    for _, data in enumerate(dataloader):
        input_ids = data['input_ids'].to(device, dtype = torch.long)
        attention_mask = data['attention_mask'].to(device, dtype = torch.long)
        outputs = model(input_ids, attention_mask)            
        results.extend(torch.nn.Softmax(dim=1)(outputs.logits).cpu().detach().numpy().tolist())
      
    return [index_to_target[index] for index in np.argmax(results, axis=1)]



In [9]:
file_path= 'Data/test.csv'
final_result = predict('Data/test.csv')
pd.DataFrame(final_result).to_csv('Data/output.csv')