In [1]:
import pandas as pd
df = pd.read_csv("simpsons_dataset.csv")
sentences = df["spoken_words"].tolist()
sentences = [ s for s in sentences if type(s) == str]

In [2]:
import nltk

nltk.download('wordnet')
nltk.download('omw-1.4')
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('stopwords')

from nltk.stem import WordNetLemmatizer
from nltk.tokenize import word_tokenize

from nltk.corpus import wordnet
from nltk.corpus import stopwords
import string

lemmatizer = WordNetLemmatizer()

stop_words = set(stopwords.words('english'))
stop_words = stop_words.union(set(["'s", "n't", "'m", "'re", "'ll", "'d"]))
punctuation = set(string.punctuation)
stop_words = stop_words.union(punctuation)
stop_words = stop_words.union(set(["--", "..", "''", "...", "``"]))
digits = set([str(n) for n in range(10)])
stop_words = stop_words.union(digits)

def get_wordnet_pos(treebank_tag):
    if treebank_tag.startswith('J'):
        return wordnet.ADJ
    elif treebank_tag.startswith('V'):
        return wordnet.VERB
    elif treebank_tag.startswith('N'):
        return wordnet.NOUN
    elif treebank_tag.startswith('R'):
        return wordnet.ADV
    else:
        return wordnet.NOUN  # Default to noun if the tag is unknown

clean_sentences = []
clean_tokens = []

for s in sentences:
  tokens = word_tokenize(s)
  tokens = [t for t in tokens if t.lower() not in stop_words]
  tagged_tokens = nltk.pos_tag(tokens)
  tagged_tokens = [ (w, get_wordnet_pos(t)) for w,t in tagged_tokens ]
  lemma_words = [lemmatizer.lemmatize(w.lower(), get_wordnet_pos(t)) for w,t in tagged_tokens]
  clean_sentences.append(lemma_words)
  clean_tokens += lemma_words

[nltk_data] Downloading package wordnet to /home/vikram/nltk_data...
[nltk_data] Downloading package omw-1.4 to /home/vikram/nltk_data...
[nltk_data] Downloading package punkt to /home/vikram/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/vikram/nltk_data...
[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.
[nltk_data] Downloading package stopwords to /home/vikram/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


In [3]:
from collections import Counter

VOCAB_SIZE=6000

counts = Counter(clean_tokens).items()
sorted_counts = sorted(counts, key=lambda k: k[1], reverse=True)
vocab = sorted_counts[:VOCAB_SIZE]

In [None]:
idx_to_tok = dict()
tok_to_idx = dict()
for idx,(tok, count) in enumerate(vocab):
  idx_to_tok[idx] = tok
  tok_to_idx[tok] = idx

In [None]:
WINDOW_SIZE=6

len_buffer = WINDOW_SIZE//2


def gen_dataset():
  dataset = []
  for s in clean_sentences:
    for i in range(len(s)):
      for j in range(max(0,i-len_buffer), min(len(s),i+len_buffer)):
        if i != j:
            if s[i] in tok_to_idx and s[j] in tok_to_idx:
              idx_i = tok_to_idx[s[i]]
              idx_j = tok_to_idx[s[j]]
              dataset.append((idx_i,idx_j))
  return dataset


dataset = jnp.array(gen_dataset())
print(dataset.shape)