In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

from fastai.model import fit
from fastai.dataset import *

import torchtext
from torchtext import vocab, data
from torchtext.datasets import language_modeling

from fastai.rnn_reg import *
from fastai.rnn_train import *
from fastai.nlp import *
from fastai.lm_rnn import *

import dill as pickle
import random

  from numpy.core.umath_tests import inner1d


In [2]:
#Language modeling

In [3]:
#Data

In [4]:
import os, requests, time
# feedparser isn't a fastai dependency so you may need to install it.
import feedparser
import pandas as pd

In [9]:
class GetArXiv():
    def __init__(self,pickle_path,categories=list()):
        if os.path.isdir(pickle_path):
            pickle_path = "{}{}all_arxiv.pkl".format(pickle_path,'' if pickle_path[-1] == '/' else '/')
        if len(categories) < 1:
            categories = ['cs*','cond-mat.dis-nn', 'q-bio.NC', 'stat.CO', 'stat.ML']
        
        self.categories = categories
        self.pickle_path = pickle_path
        self.base_url = 'http://export.arxiv.org/api/query'
        
    @staticmethod
    def build_qs(categories): # 静态方法，
        return '+OR+'.join(['cat:' +c for c in categories])
    
    @staticmethod
    def get_entry_dict(entry): # 静态方法，
        try:
            return dict(title=entry['title'], authors=[a['name'] for a in entry['authors']],
                        published=pd.Timestamp(entry['published']), summary=entry['summary'],
                        link=entry['link'], category=entry['category'])
        except KeyError:
            print('Missing keys in row: {}'.format(entry))
            return None
    
    @staticmethod
    def strip_version(link):
        return link[:-2]
    
    def fetch_updated_data(self,max_retry=5,pg_offset=0,pg_size=1000,wait_time=15):
        i,retry=pg_offset,0
        df = pd.DataFrame()
        past_links = []
        if os.path.isfile(self.pickle_path):
            df = pd.read_pickle(self.pickle_path)
            df.reset_index()
        if len(df)>0: past_links = df.link.apply(self.strip_version)
        
        while True:
            params = dict(search_query=self.build_qs(self.categories),sortBy='submittedDate',start=pg_size*i,max_results=pg_size)
            response = requests.get(self.base_url,params='&'.join(['{k}={v}' for k,v in params.items()]))
            entries = feedparser.parse(response.text).entries
            if len(entries)<1:
                if retry < max_retry:
                    retry += 1
                    time.sleep(wait_time)
                    continue #进入下一次循环
                break #跳出循环
            
            results_df = pd.DataFrame([self.get_entry_dict(e) for e in entries])
            max_date = results_df.published.max().date()
            new_links = ~results_df.link.apply(self.strip_version).isin(past_links)
            print('{i} .Fetched {} abstracts published {} and earlier'.format(i,len(results_df),max_date))
            if not new_links.any():
                break
                
            df = pd.concat((df,results_df.loc[new_links]), ignore_index=True)
            i += 1
            retry = 0
            time.sleep(wait_time)
        
        print('Downloaded {} new abstracts'.format(len(df)-len(past_links)))
        df.sort_values('published',ascending=False).groupby('link').first().rest_index()
        df.to_pickle(self.pickle_path) # 保存到本地，下面的类方法Load读取
        return df
    
    @classmethod
    def load(cls,pickle_path): #类方法
        return pd.read_pickle(cls(pickle_path).pickle_path)
    
    @classmethod #类方法
    def update(cls,pickle_path,categories=list(),**kwargs):
        cls(pickle_path,categories).fetch_updated_data(**kwargs)
        return True

In [10]:
PATH='c:/input/data/arxiv/'

ALL_ARXIV = '{}all_arxiv.pkl'.format(PATH)

# all_arxiv.pkl: if arxiv hasn't been downloaded yet, it'll take some time to get it - go get some coffee
if not os.path.exists(ALL_ARXIV): GetArXiv.update(ALL_ARXIV)

# arxiv.csv: see dl1/nlp-arxiv.ipynb to get this one
df_mb = pd.read_csv('{}arxiv.csv'.format(PATH))
df_all = pd.read_pickle(ALL_ARXIV)

Downloaded 0 new abstracts


KeyError: 'published'

In [None]:
def get_txt(df):
    return '<CAT> ' + df.category.str.replace(r'[\.\-]','') + ' <SUM> ' + df.summary + ' <TITLE> ' + df.title
df_mb['txt'] = get_txt(df_mb)
df_all['txt'] = get_txt(df_all)
n=len(df_all);n

In [None]:
os.makedirs(f'{PATH}trn/yes', exist_ok=True)
os.makedirs(f'{PATH}val/yes', exist_ok=True)
os.makedirs(f'{PATH}trn/no', exist_ok=True)
os.makedirs(f'{PATH}val/no', exist_ok=True)
os.makedirs(f'{PATH}all/trn', exist_ok=True)
os.makedirs(f'{PATH}all/val', exist_ok=True)
os.makedirs(f'{PATH}models', exist_ok=True)

In [None]:
for (i,(_,r)) in enumerate(df_all.iterrows()):
    dset = 'trn' if random.random()>0.1 else 'val'
    open(f'{PATH}all/{dset}/{i}.txt', 'w').write(r['txt'])

In [None]:
for (i,(_,r)) in enumerate(df_mb.iterrows()):
    lbl = 'yes' if r.tweeted else 'no'
    dset = 'trn' if random.random()>0.1 else 'val'
    open(f'{PATH}{dset}/{lbl}/{i}.txt', 'w').write(r['txt'])

In [None]:
from spacy.symbols import ORTH

# install the 'en' model if the next line of code fails by running:
#python -m spacy download en              # default English model (~50MB)
#python -m spacy download en_core_web_md  # larger English model (~1GB)
my_tok = spacy.load('en')

my_tok.tokenizer.add_special_case('<SUMM>', [{ORTH: '<SUMM>'}])
my_tok.tokenizer.add_special_case('<CAT>', [{ORTH: '<CAT>'}])
my_tok.tokenizer.add_special_case('<TITLE>', [{ORTH: '<TITLE>'}])
my_tok.tokenizer.add_special_case('<BR />', [{ORTH: '<BR />'}])
my_tok.tokenizer.add_special_case('<BR>', [{ORTH: '<BR>'}])

def my_spacy_tok(x): return [tok.text for tok in my_tok.tokenizer(x)]

In [None]:
TEXT = data.Field(lower=True, tokenize=my_spacy_tok)
FILES = dict(train='trn', validation='val', test='val')
md = LanguageModelData.from_text_files(f'{PATH}all/', TEXT, **FILES, bs=bs, bptt=bptt, min_freq=10)
pickle.dump(TEXT, open(f'{PATH}models/TEXT.pkl','wb'))

In [None]:
len(md.trn_dl), md.nt, len(md.trn_ds), len(md.trn_ds[0].text)

In [None]:
TEXT.vocab.itos[:12]

In [None]:
' '.join(md.trn_ds[0].text[:150])

In [None]:
# Train

In [11]:
?texts_labels_from_folders