# 다중 클래스 데이터 준비 (5 classes)

In [1]:
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer

categories = ['alt.atheism', 'talk.religion.misc', 'comp.graphics','sci.space','rec.sport.hockey']

data_train = fetch_20newsgroups(subset='train',categories=categories, random_state=42)
data_test = fetch_20newsgroups(subset='test',categories=categories,random_state=42)

In [2]:
from nltk.corpus import names
from nltk.stem import WordNetLemmatizer

all_names = set(names.words())
lemmatizer = WordNetLemmatizer()

def clean_text(docs):
    cleaned_docs = []
    for doc in docs:
        lemmatized_list = [lemmatizer.lemmatize(word.lower()) for word in doc.split() if word.isalpha() and word not in all_names]
        cleaned_docs.append(' '.join(lemmatized_list))
    return cleaned_docs

In [3]:
cleaned_train = clean_text(data_train.data)
label_train = data_train.target

cleaned_test = clean_text(data_test.data)
label_test = data_test.target

len(label_train),len(label_test)

(2634, 1752)

In [11]:
from sklearn.feature_extraction.text import TfidfVectorizer

tfidf_vectorizer = TfidfVectorizer(sublinear_tf=True,max_df=0.5,stop_words='english',max_features= 8000)

term_docs_train = tfidf_vectorizer.fit_transform(cleaned_train)
term_docs_test = tfidf_vectorizer.transform(cleaned_test)

# Multi-class SVM train & test

In [12]:
from sklearn.svm import SVC
svm = SVC(kernel='linear',C=1.0, random_state=42)
svm.fit(term_docs_train,label_train)

SVC(kernel='linear', random_state=42)

In [13]:
accuray = svm.score(term_docs_test,label_test)

In [14]:
print('The accuracy on testing set is: {0:.1f}%'.format(accuray*100))

The accuracy on testing set is: 88.6%


# 클래스 별 정확도 평가

In [15]:
from sklearn.metrics import classification_report

prediction = svm.predict(term_docs_test)
report = classification_report(label_test,prediction)
print(report)

              precision    recall  f1-score   support

           0       0.81      0.77      0.79       319
           1       0.91      0.94      0.93       389
           2       0.98      0.96      0.97       399
           3       0.93      0.93      0.93       394
           4       0.73      0.76      0.74       251

    accuracy                           0.89      1752
   macro avg       0.87      0.87      0.87      1752
weighted avg       0.89      0.89      0.89      1752

