In [None]:
import os
import pickle
import numpy as np
import pandas as pd
import random
from scipy import stats
from scipy import sparse
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib import gridspec
from scipy.optimize import curve_fit

import warnings
warnings.filterwarnings('ignore')
plt.rcParams['font.family'] = ['Arial']

In [None]:
with open(os.path.join('.','process_data','root_fos_level0_order.pkl'),'rb') as f:
    root_fos=pickle.load(f)
with open(os.path.join('.','process_data','fos_index_level0.pkl'),'rb') as f:
    fos_tree=pickle.load(f)

In [None]:
MAG_data_dir=os.path.join('.','MAGdata')
fos_list=pd.read_table(os.path.join(MAG_data_dir,'FieldsOfStudy.txt'),header=None,names=['ID','Rank','NormalizedName','DisplayName','MainType','Level','PaperCount','PaperFamilyCount','CitationCount','CreatedDate'])

fos_index=list()
for _,line in fos_list.iterrows():
    if line['Level']==1:
        fos_index.append(line['NormalizedName'])
num_fos=len(fos_index)

In [None]:
non_cs=list()
for index in range(num_fos):
    if (41008148 not in fos_tree[fos_index[index]]):
        non_cs.append(index)
in_cs=[i for i in range(num_fos) if i not in non_cs]

In [None]:
with open(os.path.join('.','process_data','vocabulary.pkl'),'rb') as f:
    vocabulary=pickle.load(f)

In [None]:
fos_list=['geology','chemistry','materials science','biology','physics','medicine']

In [None]:
def func(x, b1):
    return b1*x

for fos,data in vocabulary.items():
    if root_fos[fos] not in fos_list:
        continue
        
    NonAI_data=data[0]
    AI_data=data[1]
    
    select_index=list()
    for index in non_cs:
#         select_index.append(index)
        if (fos in fos_tree[fos_index[index]]):
            select_index.append(index)
    num_share=len(select_index)
    
    NonAI_share_data=NonAI_data[select_index]
    AI_share_data=AI_data[select_index]
    
    NonAI_share_rank_index=np.argsort(NonAI_share_data)[::-1]
    AI_share_rank_index=np.argsort(AI_share_data)[::-1]
    
    NonAI_share_rank=np.argsort(NonAI_share_rank_index)
    AI_share_rank=np.argsort(AI_share_rank_index)
    
    cor=stats.spearmanr(NonAI_share_rank,AI_share_rank)
    print(root_fos[fos],cor)
    b1,bcov=curve_fit(func,NonAI_share_rank,AI_share_rank)
    b1=b1[0]
    bstd=np.sqrt(np.diag(bcov))[0]
    
    num_points=100
    X, Y = np.meshgrid(np.linspace(0,1,num_points+1),np.linspace(0,1,num_points+1))
    positions = np.vstack([X.ravel(), Y.ravel()])

    value_share = np.vstack([NonAI_share_rank/num_share,AI_share_rank/num_share])
    kernel_share = stats.gaussian_kde(value_share)
    Z_share = np.reshape(kernel_share.pdf(positions).T, X.shape)

    fig=plt.figure(figsize=(7,4.6))
#     plt.title(root_fos[fos].title(),fontsize=20)
#     ax=plt.gca()
#     ax.spines['top'].set_visible(False)
#     ax.spines['bottom'].set_visible(False)
#     ax.spines['right'].set_visible(False)
#     ax.spines['left'].set_visible(False)
#     ax.axes.xaxis.set_visible(False)
#     ax.axes.yaxis.set_visible(False)

    spec = gridspec.GridSpec(ncols=2, nrows=1,width_ratios=[5, 2])
    
    ax0 = fig.add_subplot(spec[0])
#     ax0.set_title(root_fos[fos].title(),fontsize=20)
    ax0.scatter(num_points*NonAI_share_rank/num_share,num_points*AI_share_rank/num_share,marker='x',color='black',alpha=0.3)
    
    C=ax0.contour(Z_share,cmap='Reds',levels=5,linestyles='dashed',linewidths=2) 
    ax0.clabel(C,inline=True,fontsize=15)
#     cb=plt.colorbar()
#     cb.set_label(label='Probablity Density of Non-CS Items',size=20)
#     cb.ax.tick_params(labelsize=15)
    
#     ax0.plot([0,num_points],[0,num_points],color='black',linestyle='--',linewidth=1)
    ax0.plot([0,num_points],[0,b1*num_points],color='black',linestyle='-',linewidth=3)
    ax0.fill_between(x=[0,num_points],y1=[0,(b1-2.5758293035489004*bstd)*num_points],y2=[0,(b1+2.5758293035489004*bstd)*num_points],color='black',alpha=0.1)
    print(b1,b1-2.5758293035489004*bstd,b1+2.5758293035489004*bstd)
    ax0.text(0.1*num_points,0.8*num_points,'r=%.3f'%cor.correlation,fontsize=20,zorder=3)
#     ax0.text(0.6*num_points,0.5*num_points,'b=%.4f'%b1,fontsize=20,zorder=3)
    
    ax0.set_xlabel('Frequency percentile: without AI',fontsize=20)
    ax0.set_ylabel('Frequency percentile: AI',fontsize=20)
    ax0.set_xlim(0,num_points)
    ax0.set_ylim(0,num_points)
    ax0.set_xticks(np.linspace(0,num_points,6),np.round(np.linspace(0,100,6),0).astype(int))
    ax0.set_yticks(np.linspace(0,num_points,6),np.round(np.linspace(0,100,6),0).astype(int))
    ax0.tick_params(axis = 'x', labelsize = 15)
    ax0.tick_params(axis = 'y', labelsize = 15)
    
    NonAI_data=NonAI_data[select_index]
    AI_data=AI_data[select_index]
    
    NonAI_data=NonAI_data/np.std(NonAI_data)
    AI_data=AI_data/np.std(AI_data)
    
    p=stats.ttest_ind(NonAI_data,AI_data).pvalue
    
    height_NonAI,x_NonAI=np.histogram(NonAI_data,bins=50,density=True)
    height_AI,x_AI=np.histogram(AI_data,bins=50,density=True)
    x_list_NonAI=list()
    x_list_AI=list()
    w_list_NonAI=list()
    w_list_AI=list()
    for i in range(len(height_NonAI)):
        x_list_NonAI.append((x_NonAI[i]+x_NonAI[i+1])/2)
        x_list_AI.append((x_AI[i]+x_AI[i+1])/2)
        w_list_NonAI.append((x_NonAI[i+1]-x_NonAI[i]))
        w_list_AI.append((x_AI[i+1]-x_AI[i]))

    ax1 = fig.add_subplot(spec[1])
    ax1.barh(y=x_list_NonAI,width=height_NonAI,height=w_list_NonAI,label='without AI',color='royalblue',alpha=0.6)
    ax1.barh(y=x_list_AI,width=-height_AI,height=w_list_AI,label='AI',color='red',alpha=0.6)
    ax1.text(x=3.5,y=0.5,s='without AI',fontsize=20,color='royalblue',va='bottom',ha='center',rotation=90)
    ax1.text(x=-3.5,y=1.5,s='AI',fontsize=20,color='red',va='bottom',ha='center',rotation=90)
    ax1.text(x=0,y=5,s='p=%.3f'%p,fontsize=20,va='center',ha='center')
    ax1.vlines(x=0,ymin=-1,ymax=7,color='lightgrey',linewidth=1)
    ax1.set_xlabel('PDF',fontsize=20)
    ax1.set_ylabel('Normalized frequency',fontsize=20)
    ax1.set_xlim(-5,5)
    ax1.set_ylim(-0.2,6.2)
    ax1.set_xticks([-4,-2,0,2,4],[4,2,0,2,4])
    ax1.tick_params(axis = 'x', labelsize = 15)
    ax1.tick_params(axis = 'y', labelsize = 15)
    
    plt.tight_layout()
    plt.show()