In [32]:
import os
import random
import torch


# extract text and create dataset

In [2]:

def is_comment(line):
    # Define a function to check if a line is a comment
    line = line.strip()
    if line.startswith('#') or line.startswith("'''") or line.startswith('"""'):
        return True
    return False

def extract_non_comments(source_directory, target_directory):
    # Process all .py files in the specified directory and subdirectories
    for root, dirs, files in os.walk(source_directory):
        for file in files:
            if file.endswith('.py'):
                file_path = os.path.join(root, file)
                target_file_path = os.path.join(target_directory, file.replace('.py', '.txt'))
                with open(file_path, 'r') as source_file, open(target_file_path, 'w') as target_file:
                    non_comments = []
                    comment_block = False
                    
                    for line in source_file:
                        # Check for the start or end of a comment block
                        if "''" in line or '"""' in line:
                            comment_block = not comment_block
                            continue
                        # If it's not a comment or part of a comment block, save it
                        if not is_comment(line) and not comment_block:
                            non_comments.append(line)
                        # Write non-comment lines to a target .txt file
                    target_file.writelines(non_comments)

# # Define the path to the local repository (change this to the actual path of your local repo)
# # source_directory = '/path/to/your/local/pytorch/repo'
# source_directory = '.'
# # target_directory = '/path/to/your/output/directory'
# target_directory = './dataset/'


# # Create the target directory if it doesn't exist
# os.makedirs(target_directory, exist_ok=True)

# # Call the function to start extracting non-comment lines
# extract_non_comments(source_directory, target_directory)


In [24]:

def combine_files(directory, output_file, sample=False, num_files_to_sample=100, seed=111, start_token="<START>", end_token="<END>"):
    """
    Combine content from a specified number of text files in a directory into one file, 
    with start and end tokens between contents from each file.

    :param directory: Path to the directory containing text files.
    :param output_file: Name of the output file to create.
    :param num_files_to_sample: Number of files to sample and combine.
    :param start_token: The start token to be added before each file's content.
    :param end_token: The end token to be added after each file's content.
    """
    
    
    # List all text files in the directory
    all_files = [f for f in os.listdir(directory) if f.endswith('.txt')]
    files = all_files

    if sample:
        # Sample the specified number of files
        random.seed(seed)
        files = random.sample(all_files, min(num_files_to_sample, len(all_files)))

    # Start combining the sampled files
    with open(output_file, 'w') as outfile:
        for filename in files:
            file_path = os.path.join(directory, filename)
            with open(file_path, 'r') as infile:
                outfile.write(start_token + '\n')
                content = infile.read()
                content_with_tabs = content.replace('    ', '\t')
                outfile.write(content_with_tabs + '\n')
                outfile.write(end_token + '\n\n')

    print(f"Combined file created as '{output_file}' with contents from {len(files)} files.")

# Example usage
combine_files('dataset/', 'sample_scripts.txt')


Combined file created as 'sample_scripts.txt' with contents from 1153 files.


In [25]:
# read it in to inspect it
with open('sample_scripts.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [26]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  8518735


In [27]:
print(text[:1000])

<START>

from __future__ import annotations

import dataclasses
from typing import Optional

from torch.onnx._internal.diagnostics.infra.sarif import (
	_artifact_location,
	_property_bag,
)


@dataclasses.dataclass
class VersionControlDetails(object):

<END>

<START>
from typing import Union

import torch


class _InsertPoint:
	def __init__(
		self,
		insert_point_graph: torch._C.Graph,
		insert_point: Union[torch._C.Node, torch._C.Block],
	):
		self.insert_point = insert_point
		self.g = insert_point_graph
		self.guard = None

	def __enter__(self):
		self.prev_insert_point = self.g.insertPoint()
		self.g.setInsertPoint(self.insert_point)

	def __exit__(self, *args):
		self.g.setInsertPoint(self.prev_insert_point)


def insert_point_guard(self, insert_point: Union[torch._C.Node, torch._C.Block]):
	return _InsertPoint(self, insert_point)

<END>

<START>
import contextlib
import logging
import math
from typing import Any, Callable, cast, Dict, Generator, Iterator, no_typ


In [28]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)

	
 !"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~–≤⊑⊔⊳─│└├✓
107


# encoding and decoding for chars

In [31]:
# create a mapping from characters to integers
ch_to_idx = { ch:i for i,ch in enumerate(chars) }
idx_to_ch = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [ch_to_idx[ch] for ch in s] # encoder: take a string, output a list of mapping idx
decode = lambda l: ''.join([idx_to_ch[idx] for idx in l]) # decoder: take a list of index, output a string

print(encode("import torch"))
print(decode(encode("import torch")))

[75, 79, 82, 81, 84, 86, 2, 86, 81, 84, 69, 74]
import torch


In [34]:
# encode the entire text dataset and store it into a torch.Tensor
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:100])

torch.Size([8518735]) torch.int64
tensor([30, 53, 54, 35, 52, 54, 32,  1,  1, 72, 84, 81, 79,  2, 65, 65, 72, 87,
        86, 87, 84, 71, 65, 65,  2, 75, 79, 82, 81, 84, 86,  2, 67, 80, 80, 81,
        86, 67, 86, 75, 81, 80, 85,  1,  1, 75, 79, 82, 81, 84, 86,  2, 70, 67,
        86, 67, 69, 78, 67, 85, 85, 71, 85,  1, 72, 84, 81, 79,  2, 86, 91, 82,
        75, 80, 73,  2, 75, 79, 82, 81, 84, 86,  2, 49, 82, 86, 75, 81, 80, 67,
        78,  1,  1, 72, 84, 81, 79,  2, 86, 81])


# train dev split

In [36]:
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]