In [1]:
import numpy as np
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split

from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import torch
# local dependencies 
from contrastive import Contrastive_loss
from utils import plot_tsne,Net_embed,finetune_embeddig


In [2]:


# Define the categories we want to classify
categories = ['sci.space', 'sci.med','sci.electronics','comp.os.ms-windows.misc'
              ,'comp.sys.ibm.pc.hardware','comp.sys.mac.hardware']

# Fetch the training dataset
newsgroups_train = fetch_20newsgroups(subset='train', categories=categories)

# Fetch the testing dataset
newsgroups_test = fetch_20newsgroups(subset='test', categories=categories)

In [3]:
# dump 75% of data
X_train,X_dump,y_train,y_dump =train_test_split(
    newsgroups_train.data,newsgroups_train.target,train_size=0.25,random_state=0)


X_test,X_dump,y_test,y_dump = train_test_split(
    newsgroups_test.data,newsgroups_test.target,train_size=0.25,random_state=0)

In [4]:
# Helper function to fit logistic regression head for embedding and predict
def fit_predict_embedding(embed_train,y_train,embed_test,y_test,target_names  ):
    clf = LogisticRegression(max_iter=10000)
    # Train the model on embedding:
    clf.fit(embed_train, y_train)
    # Make predictions on the test data:
    predicted = clf.predict(embed_test)
    accuracy = accuracy_score(y_test, predicted)
    
    # Calculate accuracy
    print(f"Accuracy: {accuracy:.2f}")
    #print performance report
    report = classification_report(y_test, predicted, target_names=newsgroups_test.target_names)
    print("Classification Report:\n", report)

### Train baseline Tf-Idf sparse embedding based classifier

In [5]:
# Create a TF-IDF vectorizer
tfidf_vectorizer = TfidfVectorizer(max_features=1000)

# Fit and transform the training data
X_train_tfidf = tfidf_vectorizer.fit_transform(X_train)

# Transform the testing data
X_test_tfidf = tfidf_vectorizer.transform(X_test)

fit_predict_embedding(X_train_tfidf,y_train,X_test_tfidf,y_test,newsgroups_test.target_names  )


Accuracy: 0.70
Classification Report:
                           precision    recall  f1-score   support

 comp.os.ms-windows.misc       0.73      0.75      0.74        99
comp.sys.ibm.pc.hardware       0.67      0.49      0.57        99
   comp.sys.mac.hardware       0.74      0.71      0.73       104
         sci.electronics       0.53      0.70      0.60        97
                 sci.med       0.71      0.63      0.67        94
               sci.space       0.84      0.91      0.87        95

                accuracy                           0.70       588
               macro avg       0.70      0.70      0.70       588
            weighted avg       0.70      0.70      0.70       588



### Extract deep embedding

In [6]:
calc_embedding = False
if calc_embedding:
    model = SentenceTransformer('Alibaba-NLP/gte-base-en-v1.5', trust_remote_code=True)
    embedding_mat_train = np.zeros((len(X_train),768))
    embedding_mat_test = np.zeros((len(X_test),768))

    for i in tqdm(range(len(X_train))):
        embedding_mat_train[i,:] = model.encode(X_train[i])

    for i in tqdm(range(len(X_test))):
        embedding_mat_test[i,:] = model.encode(X_test[i])


    np.save('./toy_data/embedding_train.npy',embedding_mat_train)
    np.save('./toy_data/embedding_test.npy',embedding_mat_test)
    np.save('./toy_data/y_train.npy',y_train)
    np.save('./toy_data/y_test.npy',y_test)
else: 
    embedding_mat_train = np.load('./toy_data/embedding_train.npy')
    embedding_mat_test = np.load('./toy_data/embedding_test.npy')
    y_train = np.load('./toy_data/y_train.npy')
    y_test = np.load('./toy_data/y_test.npy')

In [7]:
# predicnt using DL embedding
fit_predict_embedding(embedding_mat_train,y_train,embedding_mat_test,y_test,newsgroups_test.target_names  )


Accuracy: 0.80
Classification Report:
                           precision    recall  f1-score   support

 comp.os.ms-windows.misc       0.81      0.81      0.81        99
comp.sys.ibm.pc.hardware       0.66      0.64      0.65        99
   comp.sys.mac.hardware       0.71      0.74      0.72       104
         sci.electronics       0.74      0.75      0.74        97
                 sci.med       0.95      0.96      0.95        94
               sci.space       0.99      0.95      0.97        95

                accuracy                           0.80       588
               macro avg       0.81      0.81      0.81       588
            weighted avg       0.81      0.80      0.80       588



In [17]:
# train neural network using deep contrastive learnig  
net = finetune_embeddig(embedding_mat_train, y_train,N_epoch=15,out_dim=32,
                        verbose=True,drop_prob = 0.3,hidden_dim = 256,margin=0.2)


Epoch: 0, loss: 0.055
Epoch: 1, loss: 0.007
Epoch: 2, loss: 0.003
Epoch: 3, loss: 0.003
Epoch: 4, loss: 0.002
Epoch: 5, loss: 0.002
Epoch: 6, loss: 0.002
Epoch: 7, loss: 0.001
Epoch: 8, loss: 0.001
Epoch: 9, loss: 0.001
Epoch: 10, loss: 0.001
Epoch: 11, loss: 0.001
Epoch: 12, loss: 0.001
Epoch: 13, loss: 0.001
Epoch: 14, loss: 0.001


In [18]:
# calculate test data embedding and Tsne
net.eval()
embedding_out_test = net(torch.from_numpy(embedding_mat_test).float()).detach().numpy()

embedding_out_train = net(torch.from_numpy(embedding_mat_train).float()).detach().numpy()

fit_predict_embedding(embedding_out_train,y_train,embedding_out_test,y_test,newsgroups_test.target_names  )


Accuracy: 0.86
Classification Report:
                           precision    recall  f1-score   support

 comp.os.ms-windows.misc       0.88      0.85      0.86        99
comp.sys.ibm.pc.hardware       0.75      0.80      0.77        99
   comp.sys.mac.hardware       0.78      0.80      0.79       104
         sci.electronics       0.81      0.81      0.81        97
                 sci.med       0.97      0.96      0.96        94
               sci.space       0.98      0.93      0.95        95

                accuracy                           0.86       588
               macro avg       0.86      0.86      0.86       588
            weighted avg       0.86      0.86      0.86       588

