In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from collections import Counter
import re
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
import torch.nn as nn

In [None]:
from sklearn.feature_extraction.text import CountVectorizer

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
dataset = load_dataset('imdb')

In [None]:
len(dataset['train']['label'])


In [None]:
texts_array_train = dataset['train']['text']
labels_array_train = dataset['train']['label']
texts_array_test = dataset['test']['text']

In [None]:
vectorizer = CountVectorizer(max_features=10000,
                             min_df=7, max_df=0.85,
                             binary=True,
                             stop_words='english')


In [None]:
vectorizer.fit(texts_array_train)

In [None]:
X_train, X_val, y_train, y_val = train_test_split(
    texts_array_train, labels_array_train, test_size=0.15,
    random_state=42, stratify=labels_array_train

)

In [None]:
class BoWDataset(Dataset):
    def __init__(self, embed, label):
        self.embed = torch.FloatTensor(embed)
        self.label = torch.FloatTensor(label)


    def __getitem__(self, index):
        embed = self.embed[index]

        return {
            'embed': embed ,
            'y': self.label[index]
        }




    def __len__(self):
        return len(self.label)
        



In [None]:
X_train = vectorizer.transform(X_train).toarray()

In [None]:
X_val = vectorizer.transform(X_val).toarray()

In [None]:
train_dataset = BoWDataset(X_train, y_train)

In [None]:
val_dataset = BoWDataset(X_val, y_val)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=True)

In [None]:
class RegressionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10000, 1)

    def forward(self, x):
        return self.linear(x).squeeze(1)
    

In [None]:
model = RegressionModel()

criterion = nn.BCEWithLogitsLoss()

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [None]:
model.to(device)

In [None]:
epochs = 15

In [None]:
i = next(iter(train_dataloader))

In [None]:
i['y']

In [None]:
for epoch in range(epochs):
    
    model.train()
    
    train_loss = 0
    

    for batch in train_dataloader:
        X = batch['embed'].to(device)
        y = batch['y'].to(device)

        optimizer.zero_grad()

        logits = model(X)

        loss = criterion(logits, y)

        loss.backward()

        optimizer.step()

        train_loss += loss.item()

    train_loss /=len(train_dataloader)
    

    model.eval()
    
    val_loss = 0

    with torch.no_grad():
        for batch in val_dataloader:
            X = batch['embed'].to(device)
            y = batch['y'].to(device)

            logits = model(X)
            loss = criterion(logits, y)
            
            val_loss += loss.item()
    val_loss /=len(val_dataloader)




    print(f'Epoch:{epoch}')
    print(f'Train loss: {train_loss:.4f}')
    print(f'Val loss: {val_loss:.4f}')
    

In [None]:
from sklearn.metrics import accuracy_score



In [None]:
X_test = dataset['test']['text']
X_test = vectorizer.transform(X_test).toarray()
y_test = dataset['test']['label']

In [None]:
test_dataset = BoWDataset(X_test, y_test)
test_dataloader = DataLoader(test_dataset, shuffle=True, batch_size=64)

In [None]:
all_preds = []
all_labels = []

In [None]:
model.eval()
with torch.no_grad():
    for batch in test_dataloader:
        X = batch['embed'].to(device)
        y = batch['y'].to(device)

        outputs = model(X)


        preds = (torch.sigmoid(outputs)>0.5).long()

        all_preds.append(preds.cpu())
        all_labels.append(y.cpu())

In [None]:
all_preds = torch.cat(all_preds).numpy()
all_labels = torch.cat(all_labels).numpy()

In [None]:
print(accuracy_score(all_preds, all_labels))