# Dataset

In [14]:
import zipfile

!wget https://chitanka.info/text/4618-frankenshtajn.txt.zip

path = "4618-frankenshtajn.txt.zip"

with zipfile.ZipFile(path, 'r') as zip_ref:
    zip_ref.extractall(".")
    
!mv "Mary-Shelley -  - . Frankenshtajn - 4618.txt" "dataset.txt"

!rm $path


--2024-09-21 21:42:55--  https://chitanka.info/text/4618-frankenshtajn.txt.zip
Resolving chitanka.info (chitanka.info)... 2a06:98c1:3121::2, 2a06:98c1:3120::2, 188.114.97.2, ...
Connecting to chitanka.info (chitanka.info)|2a06:98c1:3121::2|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://m3.chitanka.info/text/4618-frankenshtajn.txt.zip?filename= [following]
--2024-09-21 21:42:56--  https://m3.chitanka.info/text/4618-frankenshtajn.txt.zip?filename=
Resolving m3.chitanka.info (m3.chitanka.info)... 2a06:98c1:3120::2, 2a06:98c1:3121::2, 188.114.96.2, ...
Connecting to m3.chitanka.info (m3.chitanka.info)|2a06:98c1:3120::2|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: /cache/dl/Mary-Shelley_-_Frankenshtajn_-_4618.txt.zip [following]
--2024-09-21 21:42:56--  https://m3.chitanka.info/cache/dl/Mary-Shelley_-_Frankenshtajn_-_4618.txt.zip
Reusing existing connection to [m3.chitanka.info]:443.
HTTP request sent, awaiting respon

In [15]:
with open("dataset.txt", "r") as f:
    text = f.read()

# Vocabulary

In [16]:
unique_chars = set(text)
sorted_chars = sorted(unique_chars)
sorted_chars
"".join(sorted_chars)

'\t\n !$()*,-./0123456789:;=?DIMNVX[]_abcdefghijkmnoprstuvx«»АБВГДЕЖЗИКЛМНОПРСТУФХЦЧШЩЮЯабвгдежзийклмнопрстуфхцчшщъьюя–—“„…\ufeff'

# Tokenizer

In [17]:
encoding = { c: i for i, c in enumerate(sorted_chars) }

def encode(text: str):
    return [encoding[c] for c in text]

def test_encode():
    test_text = "франкейщайн!"
    test_encoding = encode(test_text)
    assert test_encoding[0] == encoding[test_text[0]]
    assert test_encoding[1] == encoding[test_text[1]]
    
# test_encode()

decoding = { i: c for i, c in enumerate(sorted_chars) }

test_text = "франкейщайн!"
test_encoding = encode(test_text)

def decode(arr):
    return "".join([decoding[t] for t in arr])

def test_decode():
    test_text = "франкейщайн!"
    assert decode(encode(test_text)) == test_text
    
# test_decode()

In [18]:
import torch

data = torch.tensor(encode(text))

data

tensor([120,   0,  69,  ...,  13,   1,   1])

In [19]:
train_data_size = round(len(data) * 0.9)

train_data = data[:train_data_size]

train_data

tensor([120,   0,  69,  ..., 102,  95,  99])

In [20]:
val_data = data[train_data_size:]

val_data

tensor([103,  99,   2,  ...,  13,   1,   1])

# Context length

In [21]:
context_length = 8

train_data[:context_length+1]

tensor([120,   0,  69,  90, 101,  93,   2,  81,  90])

In [22]:
model_input = train_data[:context_length]

model_input

tensor([120,   0,  69,  90, 101,  93,   2,  81])

In [23]:
decode(model_input.numpy().tolist())

'\ufeff\tМери Ш'

In [24]:
model_output = train_data[1:context_length+1]

model_output[-1]

tensor(90)

In [25]:
decode(model_output.numpy().tolist())[-1]

'е'

# Batch

In [26]:
batch_size = 4 # sequences in parallel
context_length = 8 # chars in a sequence

def get_batch(split: str):
    batch_data = train_data if split == "train" else val_data
    ix = torch.randint(len(batch_data) - context_length, (batch_size,))
    
    x = torch.stack([batch_data[i:i+context_length] for i in ix])
        
    y = torch.stack([batch_data[i+1:i+context_length+1] for i in ix])
    
    return x, y
    
xb, yb = get_batch("train")

xb, yb # xb is input, yb is expected output (logits)

(tensor([[ 99,  87,  85,   2, 108, 104,  89,  99],
         [ 98,  85,   2,  97,  99, 114,   2,  87],
         [ 98,  85,  89,   2,  98,  85,  94,   9],
         [ 90,   2, 102,  93,   2,  93,   2, 107]]),
 tensor([[ 87,  85,   2, 108, 104,  89,  99,  87],
         [ 85,   2,  97,  99, 114,   2,  87, 111],
         [ 85,  89,   2,  98,  85,  94,   9, 110],
         [  2, 102,  93,   2,  93,   2, 107, 114]]))