In [1]:
from sklearn.datasets import fetch_20newsgroups
import pandas as pd

def twenty_newsgroup_to_csv():
    newsgroups_train = fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'))

    df = pd.DataFrame([newsgroups_train.data, newsgroups_train.target.tolist()]).T
    df.columns = ['text', 'target']

    targets = pd.DataFrame( newsgroups_train.target_names, columns=['title'])

    out = pd.merge(df, targets, left_on='target', right_index=True)
    out.to_csv('./data/20_newsgroup.csv', index=False)
    
twenty_newsgroup_to_csv()

In [7]:
from openai.embeddings_utils import get_embeddings
import openai, os, tiktoken, backoff
os.environ["OPENAI_API_KEY"] = """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"""
openai.api_key = os.environ.get("OPENAI_API_KEY")
embedding_model = "text-embedding-ada-002"
embedding_encoding = "cl100k_base"  # this the encoding for text-embedding-ada-002
batch_size = 2000
max_tokens = 8000  # the maximum for text-embedding-ada-002 is 8191

df = pd.read_csv('./data/20_newsgroup.csv')
print("Number of rows before null filtering:", len(df))
df = df[df['text'].isnull() == False]
encoding = tiktoken.get_encoding(embedding_encoding)

df["n_tokens"] = df.text.apply(lambda x: len(encoding.encode(x)))
print("Number of rows before token number filtering:", len(df))
df = df[df.n_tokens <= max_tokens]
print("Number of rows data used:", len(df))

Number of rows before null filtering: 11314
Number of rows before token number filtering: 11096
Number of rows data used: 11044


In [8]:
@backoff.on_exception(backoff.expo, openai.error.RateLimitError)
def get_embeddings_with_backoff(prompts, engine):
    embeddings = []
    for i in range(0, len(prompts), batch_size):
        batch = prompts[i:i+batch_size]
        embeddings += get_embeddings(list_of_text=batch, engine=engine)
    return embeddings

prompts = df.text.tolist()
prompt_batches = [prompts[i:i+batch_size] for i in range(0, len(prompts), batch_size)]

embeddings = []
for batch in prompt_batches:
    batch_embeddings = get_embeddings_with_backoff(prompts=batch, engine=embedding_model)
    embeddings += batch_embeddings

df["embedding"] = embeddings
df.to_parquet("data/20_newsgroup_with_embedding.parquet", index=False)

RetryError: RetryError[<Future at 0x14615ef50 state=finished raised RateLimitError>]

In [9]:
import numpy as np
from sklearn.cluster import KMeans

embedding_df = pd.read_parquet("data/20_newsgroup_with_embedding.parquet")

matrix = np.vstack(embedding_df.embedding.values)
num_of_clusters = 20

kmeans = KMeans(n_clusters=num_of_clusters, init="k-means++", n_init=10, random_state=42)
kmeans.fit(matrix)
labels = kmeans.labels_
embedding_df["cluster"] = labels

In [10]:
# 统计每个cluster的数量
new_df = embedding_df.groupby('cluster')['cluster'].count().reset_index(name='count')

# 统计这个cluster里最多的分类的数量
title_count = embedding_df.groupby(['cluster', 'title']).size().reset_index(name='title_count')
first_titles = title_count.groupby('cluster').apply(lambda x: x.nlargest(1, columns=['title_count']))
first_titles = first_titles.reset_index(drop=True)
new_df = pd.merge(new_df, first_titles[['cluster', 'title', 'title_count']], on='cluster', how='left')
new_df = new_df.rename(columns={'title': 'rank1', 'title_count': 'rank1_count'})

# 统计这个cluster里第二多的分类的数量
second_titles = title_count[~title_count['title'].isin(first_titles['title'])]
second_titles = second_titles.groupby('cluster').apply(lambda x: x.nlargest(1, columns=['title_count']))
second_titles = second_titles.reset_index(drop=True)
new_df = pd.merge(new_df, second_titles[['cluster', 'title', 'title_count']], on='cluster', how='left')
new_df = new_df.rename(columns={'title': 'rank2', 'title_count': 'rank2_count'})
new_df['first_percentage'] = (new_df['rank1_count'] / new_df['count']).map(lambda x: '{:.2%}'.format(x))
# 将缺失值替换为 0
new_df.fillna(0, inplace=True)
# 输出结果
from IPython.display import display
display(new_df)

Unnamed: 0,cluster,count,rank1,rank1_count,rank2,rank2_count,first_percentage
0,0,522,rec.autos,432,comp.sys.mac.hardware,6.0,82.76%
1,1,391,comp.sys.ibm.pc.hardware,101,comp.sys.mac.hardware,85.0,25.83%
2,2,1060,talk.politics.misc,129,talk.religion.misc,60.0,12.17%
3,3,381,rec.motorcycles,364,comp.sys.mac.hardware,1.0,95.54%
4,4,783,comp.sys.ibm.pc.hardware,323,comp.sys.mac.hardware,314.0,41.25%
5,5,659,soc.religion.christian,409,talk.religion.misc,151.0,62.06%
6,6,358,sci.crypt,345,comp.sys.mac.hardware,1.0,96.37%
7,7,84,comp.os.ms-windows.misc,8,comp.sys.mac.hardware,8.0,9.52%
8,8,477,rec.sport.hockey,461,0,0.0,96.65%
9,9,472,sci.space,403,comp.sys.mac.hardware,1.0,85.38%


In [12]:
items_per_cluster = 10
COMPLETIONS_MODEL = "text-davinci-003"


@backoff.on_exception(backoff.expo, openai.error.RateLimitError)
def f(i):
    cluster_name = new_df[new_df.cluster == i].iloc[0].rank1
    print(f"Cluster {i}, Rank 1: {cluster_name}, Theme:", end=" ")

    content = "\n".join(
        embedding_df[embedding_df.cluster == i].text.sample(items_per_cluster, random_state=42).values
    )
    response = openai.Completion.create(
        model=COMPLETIONS_MODEL,
        prompt=f'''我们想要给下面的内容，分组成有意义的类别，以便我们可以对其进行总结。请根据下面这些内容的共同点，总结一个50个字以内的新闻组的名称。比如 “PC硬件”\n\n内容:\n"""\n{content}\n"""新闻组名称：''',
        temperature=0,
        max_tokens=100,
        top_p=1,
    )
    print(response["choices"][0]["text"].replace("\n", ""))

for i in range(num_of_clusters):
    f(i)

Cluster 0, Rank 1: rec.autos, Theme: 汽车硬件
Cluster 1, Rank 1: comp.sys.ibm.pc.hardware, Theme: 电脑显示器
Cluster 2, Rank 1: talk.politics.misc, Theme: 法律案例分析
Cluster 3, Rank 1: rec.motorcycles, Theme: Cluster 3, Rank 1: rec.motorcycles, Theme: Cluster 3, Rank 1: rec.motorcycles, Theme: Cluster 3, Rank 1: rec.motorcycles, Theme: Cluster 3, Rank 1: rec.motorcycles, Theme: Cluster 3, Rank 1: rec.motorcycles, Theme: Cluster 3, Rank 1: rec.motorcycles, Theme: 骑行者与宠物
Cluster 4, Rank 1: comp.sys.ibm.pc.hardware, Theme: 计算机硬件
Cluster 5, Rank 1: soc.religion.christian, Theme: Cluster 5, Rank 1: soc.religion.christian, Theme: Cluster 5, Rank 1: soc.religion.christian, Theme: Cluster 5, Rank 1: soc.religion.christian, Theme: Cluster 5, Rank 1: soc.religion.christian, Theme: 宗教信仰与传统
Cluster 6, Rank 1: sci.crypt, Theme: Cluster 6, Rank 1: sci.crypt, Theme: Cluster 6, Rank 1: sci.crypt, Theme: Cluster 6, Rank 1: sci.crypt, Theme: Cluster 6, Rank 1: sci.crypt, Theme: Cluster 6, Rank 1: sci.crypt, Theme: 秘