# Language

In [43]:
import os
from pathlib import Path
import json
import pickle
from collections import Counter

import tqdm 
from icl.constants import DEVICE, DATA
from icl.language.model import get_model
from icl.language.utils import translate_int_to_str

model = get_model()

Using pad_token, but it is not set yet.


In [16]:
import torch
import numpy as np

validation_set_idxs = np.random.choice(5_000_000, size=10_000, replace=False)
validation_set = []

with open(DATA / "train-5m.jsonl", "r") as f:
    for i, line in tqdm.tqdm(enumerate(f), total=5_000_000):
        if i in validation_set_idxs:
            content = json.loads(line)['contents']
            tokens = model.tokenizer(content)['input_ids']

            if len(tokens) > 1024:
                tokens = tokens[:1024]

            validation_set.append(tokens)

100%|██████████| 5000000/5000000 [00:47<00:00, 106344.43it/s]


In [17]:
import boto3

client = boto3.client('s3')

with open(DATA / 'tokens-10k.pkl', 'wb') as f:
    pickle.dump(validation_set, f)

with open(DATA / 'tokens-10k.pkl','rb') as f:
    client.upload_fileobj(f, 'devinterp', f'other/language/tokens-10k.pkl')

In [18]:
with open(DATA / 'tokens-10k.pkl','rb') as f:
    validation_set = pickle.load(f) 

In [48]:
vocab_size = 5000
vocab = np.arange(vocab_size)
total_trigrams = Counter()
total_trigrams_per_row = Counter()

for tokens in tqdm.tqdm(validation_set):
    trigrams_in_row = set()

    for i in range(len(tokens)-2):
        trigram = tuple(tokens[i:i+3])
        total_trigrams[trigram] += 1
        trigrams_in_row.add(trigram)

    for token in trigrams_in_row:
        total_trigrams_per_row[token] += 1


100%|██████████| 10000/10000 [00:06<00:00, 1515.79it/s]


In [49]:
print("Number of unique trigrams in validation set:", len(total_trigrams))
print("Number of unique trigrams with count > 1:", len([t for t in total_trigrams if total_trigrams[t] > 1]))
print("Number of unique trigrams with count > 2:", len([t for t in total_trigrams if total_trigrams[t] > 2]))
print("Number of unique trigrams that show up in > 1 rows:", len([t for t in total_trigrams if total_trigrams_per_row[t] > 1]))
print("Number of unique trigrams that show up in > 2 rows:", len([t for t in total_trigrams if total_trigrams_per_row[t] > 2]))

Number of unique trigrams in validation set: 2038645
Number of unique trigrams with count > 1: 502141
Number of unique trigrams with count > 2: 272505
Number of unique trigrams that show up in > 1 rows: 432513
Number of unique trigrams that show up in > 2 rows: 234870


In [54]:
common_trigrams = {t for t in total_trigrams if total_trigrams_per_row[t] > 2}

for i, t in enumerate(common_trigrams):
    strs = tuple(translate_int_to_str(t, model))
    print("".join(strs), "\t\t\t", tuple(translate_int_to_str(t, model)))
    if i > 100:
        break

ian refuge 			 ('ian', ' ref', 'uge')
, we've 			 (',', ' we', "'ve")
tool in 			 ('t', 'ool', ' in')
eld. 			 ('e', 'ld', '.')
� t 			 ('�', ' ', 't')
ue, The 			 ('ue', ',', ' The')
ion.
 			 ('ion', '.', '\n')
 day I  			 (' day', ' I', ' ')
 off on  			 (' off', ' on', ' ')
an entirely 			 ('an', ' entire', 'ly')
 put his foot 			 (' put', ' his', ' foot')
 in Japanese Pat 			 (' in', ' Japanese', ' Pat')
ain.
 			 ('ain', '.', '\n')
bot of 			 ('b', 'ot', ' of')
orn the 			 ('orn', ' ', 'the')
 sound,  			 (' sound', ',', ' ')
 an eight 			 (' ', 'an', ' eight')

Clay 			 ('\n', 'Cl', 'ay')
ed Wil 			 ('ed', ' W', 'il')
 Bonn 			 (' B', 'on', 'n')
 Nora 			 (' N', 'or', 'a')
branded 			 ('br', 'and', 'ed')
a. As 			 ('a', '.', ' As')
 been fired 			 (' been', ' f', 'ired')
th-old 			 ('th', '-', 'old')
 or ass 			 (' or', ' ', 'ass')
 skin rash 			 (' skin', ' r', 'ash')
 the break 			 (' ', 'the', ' break')
 outline of 			 (' out', 'line', ' of')
venue,  			 ('venue', ',', ' ')
 

In [66]:
common_trigrams_counts = Counter()
num_total_trigrams = 0
num_included_trigrams = 0

def count_to_freq(counter, num_total):
    indices = list(counter.keys())
    values = list(counter.values())

    # Convert to Tensors
    indices_tensor = torch.tensor(indices).t()  # Transpose to get 2D tensor for indices
    values_tensor = torch.tensor(values) / num_total

    # Create sparse tensor
    sparse_tensor = torch.sparse_coo_tensor(indices_tensor, values_tensor, size=(5000, 5000, 5000))
    return sparse_tensor

last_freq = count_to_freq(total_trigrams, num_total_trigrams)

def sparse_allclose(sparse_tensor1, sparse_tensor2, atol=1e-8, rtol=1e-5):
    # Check if both have the same number of non-zero elements
    if sparse_tensor1._nnz() != sparse_tensor2._nnz():
        print(f"Number of non-zero elements not equal ({sparse_tensor1._nnz()}, {sparse_tensor2._nnz()})")
        return False
    
    # Sort indices and values
    sparse_tensor1 = sparse_tensor1.coalesce()
    sparse_tensor2 = sparse_tensor2.coalesce()

    # Compare indices
    if not torch.equal(sparse_tensor1.indices(), sparse_tensor2.indices()):
        print(f"Indices not equal ({len(sparse_tensor1.indices())}, {len(sparse_tensor2.indices())})")
        return False

    # Compare values with tolerance
    diff = torch.abs(sparse_tensor1.values() - sparse_tensor2.values())
    print("Diff norm:", diff.norm())
    return torch.all(diff <= atol + rtol * torch.abs(sparse_tensor2.values()))


print("Loading dataset...")
with open(DATA / 'train-5m.jsonl', 'rb') as f:
    for i, row in tqdm.tqdm(enumerate(f.readlines()), total=5_000_000):
        tokens = model.tokenizer(json.loads(row)['contents'])['input_ids']
        
        for j in range(len(tokens)-2):
            num_total_trigrams += 1
            trigram = tuple(tokens[j:j+3])
            
            if trigram in common_trigrams:
                common_trigrams_counts[trigram] += 1
                num_included_trigrams += 1

        if i > 0 and i % 10_000 == 0:
            freq = count_to_freq(common_trigrams_counts, num_included_trigrams)

            if sparse_allclose(last_freq, freq, atol=0, rtol=1e-5):
                print("Early stopping at row", i)
                break

            last_freq = freq

print("Number of trigrams in dataset:", num_total_trigrams)
print("Number of trigrams in dataset that are 'common':", num_included_trigrams)
print("Percentage of trigrams in dataset that are 'common':", num_included_trigrams / num_total_trigrams)

# Top 100 trigrams

for i, t in enumerate(common_trigrams_counts.most_common(100)):
    strs = tuple(translate_int_to_str(t[0], model))
    print("".join(strs), "\t\t\t", tuple(translate_int_to_str(t[0], model)), "\t\t\t", t[1])

Loading dataset...
