In [1]:
import wget
dataset_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
dataset = wget.download(dataset_url)

In [3]:
with (open(dataset, 'r', encoding='utf-8')) as f:
    text = f.read()

print(f"Length of dataset in characters: {len(text)}")
print(text[:1000])  # print the first 1000 characters

Length of dataset in characters: 1115394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunge

In [10]:
vocab = sorted(set(text))
print(f"Unique characters in the dataset: {vocab}")
vocab_size = len(vocab)
print(f"Vocab size: {vocab_size}")

Unique characters in the dataset: ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
Vocab size: 65


In [11]:
chart_to_integer_mapping = {ch: i for i, ch in enumerate(vocab)}
integer_to_chart_mapping = {i: ch for i, ch in enumerate(vocab)}
encode_function = lambda s: [chart_to_integer_mapping[c] for c in s]
decode_function = lambda l: ''.join([integer_to_chart_mapping[i] for i in l])

print(encode_function("Hello World"))
print(decode_function(encode_function("Hello World")))

[20, 43, 50, 50, 53, 1, 35, 53, 56, 50, 42]
Hello World


In [14]:
import torch
train_data = text[:int(0.9*len(text))]
val_data = text[int(0.9*len(text)):]

encoded_train_data = torch.tensor(encode_function(train_data), dtype=torch.long)
encoded_val_data = torch.tensor(encode_function(val_data), dtype=torch.long)
print(encoded_train_data.shape, encoded_train_data.dtype)
print(encoded_val_data.shape, encoded_val_data.dtype)


torch.Size([1003854]) torch.int64
torch.Size([111540]) torch.int64


In [17]:
block_size = 8
train_data = encoded_train_data[:block_size+1]
print(train_data)
x = encoded_train_data[:block_size]
y = encoded_train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when t={t}: input context = {context.tolist()} -> target = {target.item()}")

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])
when t=0: input context = [18] -> target = 47
when t=1: input context = [18, 47] -> target = 56
when t=2: input context = [18, 47, 56] -> target = 57
when t=3: input context = [18, 47, 56, 57] -> target = 58
when t=4: input context = [18, 47, 56, 57, 58] -> target = 1
when t=5: input context = [18, 47, 56, 57, 58, 1] -> target = 15
when t=6: input context = [18, 47, 56, 57, 58, 1, 15] -> target = 47
when t=7: input context = [18, 47, 56, 57, 58, 1, 15, 47] -> target = 58


In [19]:
def getBatch(dataset: torch.Tensor, block_size: int, batch_size: int) -> tuple[torch.Tensor, torch.Tensor]:
    batch_indices_offsets = torch.randint(len(dataset) - block_size, (batch_size,))
    x = torch.stack([dataset[i:i+block_size] for i in batch_indices_offsets])
    y = torch.stack([dataset[i+1: i+block_size+1] for i in batch_indices_offsets])
    return x, y