In [1]:
import pandas as pd

train = pd.read_csv("train.csv")
test = pd.read_csv("test.csv")

In [2]:
def convert_binary(rating):
    return int(rating >= 6)

In [3]:
train["rating"] = train["rating"].map(convert_binary)
test["rating"] = test["rating"].map(convert_binary)

In [4]:
from transformers import pipeline

pipe = pipeline(
    "text-classification",
    model="lvwerra/distilbert-imdb",
    tokenizer="lvwerra/distilbert-imdb",
    max_length=512,
    truncation=True,
    device="cuda",
)

In [5]:
from torch.utils.data import Dataset, DataLoader


class TextDataset(Dataset):
    def __init__(self, df):
        self.texts = df['review'].tolist()
        self.labels = df['rating']

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx]


train_dataset = TextDataset(train[:1000])
test_dataset = TextDataset(test[:1000])

In [6]:
from tqdm import tqdm
from sklearn.metrics import classification_report

def eval_model(pipe, dataset):
    preds = []
    
    loader = DataLoader(dataset, batch_size=16, shuffle=False)

    for batch in loader:
        predictions = pipe(batch)
        preds.extend([pred["label"] for pred in predictions])
        
    preds = list(map(lambda x: int(x == "POSITIVE"), preds))

    print(classification_report(dataset.labels, preds))

In [7]:
eval_model(pipe, train_dataset)



              precision    recall  f1-score   support

           0       0.89      0.91      0.90       416
           1       0.93      0.92      0.93       584

    accuracy                           0.91      1000
   macro avg       0.91      0.91      0.91      1000
weighted avg       0.91      0.91      0.91      1000



In [8]:
eval_model(pipe, test_dataset)



              precision    recall  f1-score   support

           0       0.88      0.92      0.90       484
           1       0.92      0.88      0.90       516

    accuracy                           0.90      1000
   macro avg       0.90      0.90      0.90      1000
weighted avg       0.90      0.90      0.90      1000

