# Import

In [1]:
import os
import re
import sys
import csv
import json
import string
import random
import argparse
import pickle as pkl
from collections import Counter, defaultdict
from tqdm import tqdm

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

# Define

In [2]:
def load_jsonl(data_file):
    data = [json.loads(l) for l in open(data_file, "r")]
    return data


def normalize_text(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
        return re.sub(regex, ' ', text)
    def white_space_fix(text):
        return ' '.join(text.split())
    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)
    def lower(text):
        return text.lower()
    #return white_space_fix(remove_articles(remove_punc(lower(s))))
    return white_space_fix(remove_punc(lower(s)))


def get_tokens(s):
    if not s: return []
    return normalize_text(s).split()


def count_ngrams(txt, max_n):
    out_d = {}
    for n in range(1, max_n + 1):
        ngrams = zip(*[txt[i:] for i in range(n)])
        ngram2count = Counter(ngrams)
        out_d[n] = ngram2count
    return out_d

def get_ngram_counts(exs, fields, max_n=4, cache_out = None, fname = None, overwrite=False):    
    ngrams = []
    ngrams2count = defaultdict(Counter)
    
    if (not cache_out is None) and not overwrite and os.path.exists(cache_out):
        ngrams2count, ngrams = pkl.load(open(cache_out, 'rb'))
        print(f'Loaded cached ngram counts from {cache_out}')
    else:
        for ex in tqdm(exs, desc=f'Counting n-grams {fname}'):
            field2ngrams = {}
            for field in fields:
                field_toks = get_tokens(ex[field])
                field_ngrams = count_ngrams(field_toks, max_n)
                field2ngrams[field] = field_ngrams
                for n, count in field_ngrams.items():
                    ngrams2count[n] += count
            ngrams.append(field2ngrams)

        if not cache_out is None:
            pkl.dump((ngrams2count,ngrams), open(cache_out, 'wb'))
            print(f'Saved ngram counts to {cache_out}')
        
    return (ngrams2count,ngrams)

def get_length_stats(exs):
    pass

# Get Counts

In [3]:
cache_base = os.path.join('.', 'ngram_stats')

In [7]:
exs, ngram_counts, unigram_lengths = {}, {}, {}

with open(os.path.join('.', 'all_val_data.p'), 'rb') as f:
    val_data = pkl.load(f)    

t = tqdm(val_data.items())
for data_key, ex_dict in t:
    t.set_description('_'.join(data_key))
    print(ex_dict['label'])
    temp_ex = [{'text': text, 'label': label} for text, label in zip(ex_dict['text'], ex_dict['label'])]
    exs[data_key] = temp_ex
    ngram_counts[data_key] = get_ngram_counts(
        temp_ex,
        ['text', 'label'],
        cache_out = os.path.join(cache_base, 'counts', f"{'_'.join(data_key)}_counts.p"),
        fname = data_key
    )
    unigram_lengths[data_key] = [count_grams(text, 1) for text in ex_dict['text']]
        

id_val_imdb:   0%|                                                                               | 0/7 [00:00<?, ?it/s]

0





TypeError: zip argument #2 must support iteration