# Language Modeling

## Imports & Inits

In [None]:
%load_ext autoreload
%autoreload 2
%config IPCompleter.greedy=True

import pdb, sys, warnings, os, json, torch, re
warnings.filterwarnings(action='ignore')

from IPython.display import display, HTML
from pathlib import Path

import pandas as pd
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns

np.set_printoptions(precision=4)
sns.set_style("darkgrid")
%matplotlib inline

## Functions

In [None]:
def plot_bigram_counts(ax, bigram_counts, vocab):
  itos = {i:s for i,s in enumerate(vocab)}  
  ax.imshow(bigram_counts, cmap='Blues')
  for i in range(len(stoi)):
    for j in range(len(stoi)):
      chstr = f'{itos[i]}{itos[j]}'
      ax.text(j, i, chstr, ha='center', va='bottom', color='gray')
      ax.text(j, i, bigram_counts[i,j].item(), ha='center', va='top', color='gray')
  ax.axis('off')  

## Data Setup

In [None]:
try:
  onion_df = pd.read_csv('../data/cleaned_onion_headlines.csv')
except FileNotFoundError:
  onion_df = pd.read_csv('../data/original_onion_headlines.csv')
  onion_df['text'] = onion_df['text'].str.encode('ascii', 'ignore').str.decode('ascii')
  onion_df['text'] = onion_df['text'].apply(str.lower)
  onion_df['text'] = onion_df['text'].apply(lambda t: re.sub('[$=`+@*#_]', '', t))
  onion_df.to_csv('../data/cleaned_onion_headlines.csv', index=None)

onion_df.shape

In [None]:
onion_df['length'] = onion_df['text'].apply(len)
onion_df['length'].describe()

## Checkpoint

In [None]:
texts = onion_df['text'].tolist()
vocab = ['#'] + sorted(list(set(' '.join(texts))))
stoi = {s:i for i,s in enumerate(vocab)}
itos = {i:s for i,s in enumerate(vocab)}
len(vocab)

In [None]:
bigram_counts = torch.zeros(len(stoi), len(stoi), dtype=torch.int32)
for text in texts:
  chs = ['#'] + list(text) + ['#']
  for ch1, ch2 in zip(chs, chs[1:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    bigram_counts[ix1, ix2] += 1    

In [None]:
fig, ax = plt.subplots(1,1,figsize=(len(stoi),len(stoi)))
plot_bigram_counts(ax, bigram_counts, vocab)