In [1]:
# Required imports
import matplotlib
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.utils import shuffle
from sklearn.metrics import classification_report, accuracy_score
import re
# from tensorflow.keras.utils import to_categorical
from sklearn.preprocessing import LabelEncoder
from nltk.corpus import stopwords
stop_words = set(stopwords.words('english'))
from sklearn import model_selection, naive_bayes, svm
import joblib

In [2]:
# cleaning data
def clean_post(post):
    post = post.lower()
    post = re.sub(r"\n", " ", post)
    post = re.sub("[\<\[].*?[\>\]]", " ", post)
    post = re.sub(r"[^a-z ]", " ", post)
    post = re.sub(r"\b\w{1,3}\b", " ", post)
    return " ".join([x for x in post.split() if x not in stop_words])

In [3]:
# Different techniques for tackling class imbalance
from imblearn.under_sampling import RandomUnderSampler, TomekLinks, NearMiss
from imblearn.over_sampling import RandomOverSampler, SMOTE

def balance_data(x, y, _type):
    if _type == 0:
        ros = RandomOverSampler(random_state=42)
        return ros.fit_resample(x, y)
    elif _type == 1:
        rus = RandomUnderSampler(random_state=42, replacement=True)
        return rus.fit_resample(x, y)
    elif _type == 2:
        smote = SMOTE()
        return smote.fit_resample(x, y)
    elif _type == 3:
        nm = NearMiss()
        return nm.fit_resample(x, y)
    elif _type == 6:
        tl = TomekLinks()
        return tl.fit_resample(x, y)
    return x, y
    # Another technique is penalizing the algo with class_weight=balanced, using stratified cross validation

In [4]:
# Load data
data = pd.read_csv('/home/starc52/split_reddit_data/train_and_valid.csv')
data = shuffle(data)

# Class split stats
print(data.groupby(['mental_disorder'])[['mental_disorder']].describe())
x = data['post'].apply(lambda post: clean_post(post))

# Vectorizing text data
count_vect = CountVectorizer()
X_counts = count_vect.fit_transform(x)
tfidf_transformer = TfidfTransformer()
X = tfidf_transformer.fit_transform(X_counts)

                mental_disorder                             
                          count unique            top   freq
mental_disorder                                             
EDAnonymous               12339      1    EDAnonymous  12339
addiction                  6515      1      addiction   6515
adhd                      38786      1           adhd  38786
alcoholism                 5026      1     alcoholism   5026
anxiety                   48971      1        anxiety  48971
autism                     7583      1         autism   7583
bipolarreddit              4929      1  bipolarreddit   4929
bpd                       20606      1            bpd  20606
depression                99809      1     depression  99809
healthanxiety              7373      1  healthanxiety   7373
lonely                    20103      1         lonely  20103
ptsd                       7336      1           ptsd   7336
schizophrenia              7351      1  schizophrenia   7351
socialanxiety           

In [5]:
label_encoder = LabelEncoder()
y = label_encoder.fit_transform(np.array(data['mental_disorder']))
# y = to_categorical(y1)

# 60-20-20 split
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.176, random_state=321)

In [6]:
def get_metrics(y_true, y_pred):
    result1 = classification_report(y_true, y_pred)
    print('Classification Report: ', result1)
    result2 = accuracy_score(y_true, y_pred)
    print('Accuracy: ', result2, "\n\n")

In [7]:
import gc

# Creating the model and checking it for various undersampled cases
X_tr, y_tr = X_train, y_train
print('#'*110)
print()
print()
print()
model = svm.SVC(C=1.0, kernel='linear', degree=3, gamma='auto')
X_train, y_train = balance_data(X_tr, y_tr, 6)
model.fit(X_train, y_train)
joblib.dump(model, "/home/starc52/models/SVM.sav")
print()
print()
print('#'*110)

##############################################################################################################





KeyboardInterrupt: 