In [1]:
import numpy as np
import pandas as pd


from torch import nn
from torchtext.vocab import GloVe
import torch.optim as optim
import random


import os
import pickle
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer
from collections import Counter
import torch
from torch.nn.utils.rnn import pad_sequence


from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split


In [2]:
MAX_VOCAB = 25000
special_tokens = ['<unk>', '<pad>']
tokenizer = get_tokenizer('basic_english')
vocab_file = 'vocab.pkl'

if os.path.exists(vocab_file):
    with open(vocab_file, 'rb') as f:
        vocab = pickle.load(f)
    print("Vocabulary loaded from 'vocab.pkl'.")

else:
    print("no vocab.pkl file found.")

MAX_LENGTH = 2048

Vocabulary loaded from 'vocab.pkl'.


In [4]:

class CNN_BiLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)

        # Load GloVe
        glove = GloVe(name='6B', dim=100)
        pretrained_embeddings = torch.zeros(vocab_size, embed_dim)
        for word, idx in vocab.get_stoi().items():
            if word in glove.stoi:
                pretrained_embeddings[idx] = glove[word]
        self.embedding.weight.data.copy_(pretrained_embeddings)
        self.embedding.weight.requires_grad = False  # freeze

        # CNN
        self.conv3 = nn.Conv1d(embed_dim, 100, kernel_size=3)
        self.conv5 = nn.Conv1d(embed_dim, 100, kernel_size=5)
        self.conv7 = nn.Conv1d(embed_dim, 100, kernel_size=7)

        # LSTM
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True)

        # Combine and classify
        self.fc = nn.Linear(100 * 3 + hidden_dim * 2, output_dim)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        x_embed = self.embedding(x)  # (B, T, E)
        x_cnn = x_embed.permute(0, 2, 1)  # (B, E, T)

        c3 = torch.relu(self.conv3(x_cnn)).max(dim=2)[0]
        c5 = torch.relu(self.conv5(x_cnn)).max(dim=2)[0]
        c7 = torch.relu(self.conv7(x_cnn)).max(dim=2)[0]

        cnn_out = torch.cat([c3, c5, c7], dim=1)

        lstm_out, _ = self.lstm(x_embed)
        lstm_out = lstm_out[:, -1, :]  # take last timestep

        combined = torch.cat([cnn_out, lstm_out], dim=1)
        out = self.fc(self.dropout(combined))
        return torch.sigmoid(out).squeeze(1)


In [7]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")
model = CNN_BiLSTM(vocab_size=len(vocab), embed_dim=100, hidden_dim=128, output_dim=1, pad_idx=vocab['<pad>'])
model.to(device)

# model.load_state_dict(torch.load("model_1.pth"))
model.load_state_dict(torch.load("model_1.pth", map_location=torch.device('cpu')))



Using device: mps


<All keys matched successfully>

In [None]:
df = pd.read_csv('cleaned_news_data.csv')

df = df.head(5000)

In [20]:
print()

df = df[df['label'] == 0]




In [21]:

def encode_df(df):

    encoded_texts = []
    encoded_labels = []
    for text, label in zip(df['content'], df['label']):
        if pd.notna(text):
            encoded = [vocab[token] for token in tokenizer(text)]
            
            if len(encoded) <= MAX_LENGTH:
                encoded_texts.append(torch.tensor(encoded, dtype=torch.long))
                encoded_labels.append(label)

    return (
        pad_sequence(encoded_texts, batch_first=True, padding_value=vocab['<pad>']),
        torch.tensor(encoded_labels, dtype=torch.long)
    )

encoded_texts, encoded_labels = encode_df(df)
print(f"Encoded {len(encoded_texts)} texts and {len(encoded_labels)} labels.")

Encoded 448 texts and 448 labels.


In [22]:
print(encoded_labels[3])

tensor(0)


In [23]:
class NewsDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

def collate_fn(batch):
    """Properly collates batches by stacking tensors"""
    texts, labels = zip(*batch)
    return torch.stack(texts), torch.stack(labels)

# 2. Create DataLoader with correct collation
test_ds = NewsDataset(encoded_texts, encoded_labels)
test_dl = DataLoader(test_ds, batch_size=4, shuffle=False, collate_fn=collate_fn)

In [24]:
# def evaluate(model, loader):
#     model.eval()
#     total_acc = 0
#     with torch.no_grad():
#         for xb, yb in loader:
#             xb, yb = xb.to(device), yb.to(device)
#             preds = model(xb)
#             preds_class = (preds > 0.5).float()
#             total_acc += (preds_class == yb).float().mean().item()
#     return total_acc / len(loader)


# test_acc = evaluate(model, test_dl)
# print(f"Test Accuracy: {test_acc:.4f}")


def evaluate(model, loader):
    model.eval()
    total_acc = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            preds = model(xb)  # these are the raw outputs (probably sigmoid outputs if binary)
            preds_class = (preds > 0.5).float()
            
            total_acc += (preds_class == yb).float().mean().item()
            
            all_preds.append(preds.cpu())
            all_labels.append(yb.cpu())

    # concatenate all batches
    all_preds = torch.cat(all_preds, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    
    return total_acc / len(loader), all_preds, all_labels

# then call it like this
test_acc, test_preds, test_labels = evaluate(model, test_dl)
print(f"Test Accuracy: {test_acc:.4f}")

# and to actually *print* outputs:
print("Model outputs (probabilities):", np.argwhere(test_preds.numpy() <= 0.5))
print(encoded_labels[np.argwhere(test_preds.numpy() <= 0.5)])
print("Ground truth labels:", test_labels.numpy())


Test Accuracy: 0.7634
Model outputs (probabilities): [[  1]
 [  2]
 [  3]
 [  4]
 [  5]
 [  6]
 [  7]
 [  9]
 [ 10]
 [ 12]
 [ 15]
 [ 17]
 [ 18]
 [ 19]
 [ 20]
 [ 21]
 [ 22]
 [ 23]
 [ 25]
 [ 26]
 [ 27]
 [ 32]
 [ 33]
 [ 34]
 [ 35]
 [ 36]
 [ 37]
 [ 38]
 [ 39]
 [ 41]
 [ 43]
 [ 44]
 [ 46]
 [ 47]
 [ 48]
 [ 50]
 [ 51]
 [ 52]
 [ 53]
 [ 54]
 [ 55]
 [ 56]
 [ 58]
 [ 59]
 [ 61]
 [ 62]
 [ 65]
 [ 66]
 [ 67]
 [ 68]
 [ 69]
 [ 70]
 [ 74]
 [ 77]
 [ 79]
 [ 80]
 [ 81]
 [ 82]
 [ 83]
 [ 84]
 [ 85]
 [ 86]
 [ 87]
 [ 90]
 [ 91]
 [ 92]
 [ 93]
 [ 94]
 [ 95]
 [ 96]
 [ 97]
 [100]
 [101]
 [102]
 [103]
 [104]
 [106]
 [107]
 [108]
 [109]
 [111]
 [112]
 [113]
 [114]
 [115]
 [116]
 [117]
 [118]
 [120]
 [121]
 [122]
 [124]
 [125]
 [126]
 [127]
 [128]
 [129]
 [130]
 [132]
 [133]
 [134]
 [135]
 [136]
 [137]
 [139]
 [140]
 [141]
 [142]
 [143]
 [145]
 [146]
 [147]
 [148]
 [149]
 [150]
 [151]
 [152]
 [153]
 [154]
 [155]
 [156]
 [159]
 [160]
 [161]
 [162]
 [164]
 [165]
 [166]
 [167]
 [169]
 [170]
 [171]
 [172]
 [173]
 [174]
 [