In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchmetrics
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader, TensorDataset, SubsetRandomSampler
from tqdm import tqdm
from sklearn.datasets import fetch_20newsgroups
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.decomposition import TruncatedSVD
from nltk.stem.snowball import SnowballStemmer

  warn(f"Failed to load image Python extension: {e}")


Load data

In [18]:
train = fetch_20newsgroups(subset='train', shuffle=True)
print(dir(train))
print(len(train.data))
print("Target names:", train.target_names)
print("------------------------------------")
for text, label in zip(train.data[:2], train.target[:2]):
    print("Text:", text[:100])
    print("Label:", train.target_names[label])
    print("------------------------------------")

['DESCR', 'data', 'filenames', 'target', 'target_names']
11314
Target names: ['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x', 'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 'sci.med', 'sci.space', 'soc.religion.christian', 'talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc']
------------------------------------
Text: From: lerxst@wam.umd.edu (where's my thing)
Subject: WHAT car is this!?
Nntp-Posting-Host: rac3.wam.
Label: rec.autos
------------------------------------
Text: From: guykuo@carson.u.washington.edu (Guy Kuo)
Subject: SI Clock Poll - Final Call
Summary: Final ca
Label: comp.sys.mac.hardware
------------------------------------


Preprocessing

In [19]:
stemmer = SnowballStemmer("english", ignore_stopwords=True)
class StemmedCountVectorizer(CountVectorizer):
    def build_analyzer(self):
        analyzer = super(StemmedCountVectorizer, self).build_analyzer()
        return lambda doc: ([stemmer.stem(w) for w in analyzer(doc)])
    
preprocess_pipeline = Pipeline([
    ("vect", StemmedCountVectorizer(stop_words="english", max_features=200)),
    ("tfidf", TfidfTransformer()),
])

In [20]:
train_data = preprocess_pipeline.fit_transform(train.data)

In [21]:
print(f"({len(train.data)}, {len(train.data[0])})")
print(train_data.shape)
train_labels = train.target
print(train_labels.shape)
print(train_labels[0])
number_of_labels = len(train.target_names)
print(number_of_labels)

(11314, 721)
(11314, 200)
(11314,)
7
20


Model definition

In [22]:
class DenseTextClassifier(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, output_size)
        )

    def forward(self, x):
        return self.network(x)

Learning

In [31]:
input_size = train_data.shape[1]
hidden_size = 512
output_size = number_of_labels
batch_size = 64
lr = 0.001
num_epochs = 2
k_folds = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data_tensor = torch.from_numpy(train_data.toarray()).float()
labels_tensor = torch.from_numpy(train_labels).long()
dataset = TensorDataset(data_tensor, labels_tensor)

kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)
val_accuracies = []
for fold, (train_ids, val_ids) in enumerate(kfold.split(data_tensor)):
    train_loader = DataLoader(dataset, batch_size=batch_size, sampler=SubsetRandomSampler(train_ids))
    val_loader = DataLoader(dataset, batch_size=batch_size, sampler=SubsetRandomSampler(val_ids))

    model = DenseTextClassifier(input_size, output_size).to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    sum_fold_accuracies = 0
    for epoch in range(num_epochs):
        model.train()
        for inputs, targets in tqdm(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, targets)
            loss.backward()
            optimizer.step()

        model.eval()
        val_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=number_of_labels)
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                val_accuracy.update(outputs, targets)

        fold_accuracy = val_accuracy.compute().item()
        sum_fold_accuracies += fold_accuracy
        print(f'Fold {fold}, Epoch {epoch+1}/{num_epochs}, Validation Accuracy: {fold_accuracy}')
    val_accuracies.append(sum_fold_accuracies / num_epochs)

average_val_accuracy = sum(val_accuracies) / k_folds
print(f'Average Validation Accuracy: {average_val_accuracy}')

  0%|          | 0/142 [00:00<?, ?it/s]

100%|██████████| 142/142 [00:01<00:00, 122.76it/s]


Fold 0, Epoch 1/2, Validation Accuracy: 0.428192675113678


100%|██████████| 142/142 [00:00<00:00, 157.07it/s]


Fold 0, Epoch 2/2, Validation Accuracy: 0.49403446912765503


100%|██████████| 142/142 [00:00<00:00, 162.80it/s]


Fold 1, Epoch 1/2, Validation Accuracy: 0.4012373089790344


100%|██████████| 142/142 [00:00<00:00, 149.74it/s]


Fold 1, Epoch 2/2, Validation Accuracy: 0.48342907428741455


100%|██████████| 142/142 [00:00<00:00, 180.34it/s]


Fold 2, Epoch 1/2, Validation Accuracy: 0.39151570200920105


100%|██████████| 142/142 [00:00<00:00, 148.44it/s]


Fold 2, Epoch 2/2, Validation Accuracy: 0.4622182846069336


100%|██████████| 142/142 [00:00<00:00, 173.13it/s]


Fold 3, Epoch 1/2, Validation Accuracy: 0.3928413689136505


100%|██████████| 142/142 [00:00<00:00, 158.08it/s]


Fold 3, Epoch 2/2, Validation Accuracy: 0.456473708152771


100%|██████████| 142/142 [00:00<00:00, 171.03it/s]


Fold 4, Epoch 1/2, Validation Accuracy: 0.38284704089164734


100%|██████████| 142/142 [00:00<00:00, 171.61it/s]

Fold 4, Epoch 2/2, Validation Accuracy: 0.4681697487831116
Average Validation Accuracy: 0.4360959380865097



