In [None]:
%cd ~/contracode
import numpy as np
import pickle
import gzip
from tqdm.auto import tqdm
import pandas as pd
import time
from typing import Iterable
from loguru import logger
import multiprocessing as mp
# import modin.pandas as pd
import torch
import torch.nn as nn
from torch.utils.data.dataset import Dataset
from transformers import BertTokenizerFast
from representjs import DATA_DIR
from tqdm import tqdm
# import swifter

# Shard the train set for tokenization

In [None]:
# train_df = load_data("/data/ajay/contracode/data/hf_data/augmented_pretrain_df.train.pickle.gz")

In [None]:
from multiprocessing.pool import ThreadPool

def split_df(df, save_pattern, num_chunks=160):
    # Split data frame into chunks
    chunk_size = int(df.shape[0] / num_chunks)
    def save_data(data):
        chunk_i, start, chunk_size, df, save_pattern = data
        save_path = save_pattern.format(chunk_i)
        df_subset = df.iloc[start : start + chunk_size]
        df_subset.to_pickle(save_path)
        print("Saved ", save_path)
    items = [(i, start, chunk_size, df, save_pattern) for i, start in enumerate(range(0, df.shape[0], chunk_size))]
    with ThreadPool(64) as pool:
        pool.map(save_data, items)

In [None]:
chunk_dir = "/data/ajay/contracode/data/hf_data/train_chunks"
!mkdir -p {chunk_dir}
split_df(train_df, chunk_dir + "/augmented_pretrain_df.{:04d}.train.pickle.gz", 160)

In [None]:
tqdm.pandas()

path = "/data/ajay/contracode/data/hf_data/train_chunks/augmented_pretrain_df.0000.train.pickle.gz"

def load_tokenizer(path="data/vocab/8k_bpe/8k_bpe-vocab.txt"):
    return BertTokenizerFast(path, clean_text=True, lowercase=False, strip_accents=True, unk_token="<unk>")

def load_data(path):
    return pd.read_pickle(path)

tokenizer = load_tokenizer()
df_shard = load_data(path)
df_shard['toks'] = df_shard['text'].progress_apply(lambda x: np.asarray(tokenizer.encode(x)))
df_shard = df_shard[['data_idx', 'toks']]

In [None]:
from tqdm.contrib.concurrent import process_map

dfs = []

files = []
for i in tqdm(list(range(161))):
    path = f"/data/ajay/contracode/data/hf_data/train_chunks_tokenized/augmented_pretrain_tokenized_df.{i:04d}.train.pickle.gz"
    files.append(path)

def load_file(fname):
    out = pd.read_pickle(fname)
    return out
    
dfs = process_map(load_file, files, max_workers=16)

In [None]:
merged_df = pd.concat(dfs)

In [None]:
merged_df.info(memory_usage='deep')

In [None]:
merged_df.to_pickle('/data/ajay/contracode/data/hf_data/merged_tok.pickle.gz')

In [None]:
merged_df

# Repack data into plain pickle format

In [None]:
import pyarrow as pa
import pyarrow.feather as feather

%time test_df = pd.read_pickle('/data/ajay/contracode/data/hf_data/augmented_pretrain_df_tok.test.pickle.gz')
%time feather_test_df = pa.Table.from_pandas(test_df)
%time feather.write_feather(feather_test_df, '/data/ajay/contracode/data/hf_data/feather_tok/test_lz4.feather', compression='lz4')

In [None]:
%time train_df = pd.read_pickle('/data/ajay/contracode/data/hf_data/merged_tok.pickle.gz')
%time feather_train_df = pa.Table.from_pandas(train_df)
%time feather.write_feather(feather_train_df, '/data/ajay/contracode/data/hf_data/feather_tok/train_lz4.feather', compression='lz4')

In [None]:
chunk_size = int(len(train_df) / 10)
for i in tqdm(list(range(11))):
    %time sampled_train_df = train_df[chunk_size * i : chunk_size * i + chunk_size]
    %time feather_train_df = pa.Table.from_pandas(sampled_train_df)
    %time feather.write_feather(feather_train_df, f'/data/ajay/contracode/data/hf_data/feather_tok/train_lz4.feather.{i:02d}', compression='lz4')

In [None]:
%time feather.write_feather(feather_train_df, '/data/ajay/contracode/data/hf_data/feather_tok/train_lz4.feather', compression='lz4')

In [None]:
int(len(train_df) * 0.1)
len(train_df[int(len(train_df) * 0.1):])

In [None]:
len(feather_train_df)

In [None]:
%time feather.read_feather('/data/ajay/contracode/data/hf_data/feather_tok/train_lz4.feather.00')

In [None]:
import pyarrow.feather as feather
from tqdm.contrib.concurrent import thread_map

files = [f'/data/ajay/contracode/data/hf_data/feather_tok/train_lz4.feather.{i:02d}' for i in range(11)]
%time dfs = thread_map(feather.read_feather, files, max_workers=16)

In [None]:
import pandas as pd
files = [f'/data/ajay/contracode/data/hf_data/feather_tok/train_lz4.feather.{i:02d}' for i in range(11)]
%time dfs = thread_map(pd.read_feather, files, max_workers=16)

In [None]:
import glob
from pathlib import Path
glob.glob('/data/ajay/contracode/data/hf_data/feather_tok/train_lz4.feather' + '*')