# Import dependencies

In [4]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import matplotlib.pyplot as plt
import math
import regex
import time
from datasets import Dataset
import pickle
from torch.utils.data.dataset import Dataset as torch_Dataset
%run './utils_gpt.ipynb'

# Device setting

In [5]:
device = torch.device("mps:0") if torch.backends.mps.is_available() else torch.device("cpu")
# In macos, using mps:0
# In windows,using cuda
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
if torch.backends.mps.is_available():
    device = torch.device("mps:0") # for MacBook
elif torch.cuda.is_available():
    device = 'cuda'
else :
    device = 'cpu'
print(device)

mps:0


# Data preparation

## Download the dataset

In [6]:
torch.manual_seed(10)
# Data loading
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

# Reading the database file
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

--2024-04-17 18:37:32--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
正在解析主机 raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
正在连接 raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... 已连接。
已发出 HTTP 请求，正在等待回应... 200 OK
长度：1115394 (1.1M) [text/plain]
正在保存至: “input.txt.7”


2024-04-17 18:37:32 (16.2 MB/s) - 已保存 “input.txt.7” [1115394/1115394])



In [7]:
# Splite the data set
train_data, val_data, test_data = split_data(text)

## Data tokenization

We try three tokenization methods on the data set.  Firstly, we use naive tokenrization method.

In [8]:
# Naive Tokenrization
chars = sorted(list(set(text)))
vocab_size_naive = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

In [9]:
train_data_naive = torch.tensor(encode(train_data), dtype=torch.long)
val_data_naive = torch.tensor(encode(val_data), dtype=torch.long)
test_data_naive = torch.tensor(encode(test_data), dtype=torch.long)

# Train the model and check what generate
# hyperparameters
# naive tokenization
config_1 = {
    'n_embd': 576,
    'n_head': 8,
    'n_layer': 8,
    'block_size': 32,
    'dropout': 0.1,
    'batch_size': 16,
    'learning_rate': 0.001,
    'vocab_size': 65
}
  
m, _, _ = train_model(config_1, train_data_naive, val_data_naive)
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=2000,block_size = 32)[0].tolist()))


31.991105 M parameters
step 0: train loss 4.3597, val loss 4.3517
step 100: train loss 2.4319, val loss 2.4733
step 200: train loss 2.2462, val loss 2.2921
step 300: train loss 2.1491, val loss 2.2032
step 400: train loss 2.0725, val loss 2.1535
step 500: train loss 2.0154, val loss 2.0935
step 600: train loss 1.9585, val loss 2.0457
step 700: train loss 1.9319, val loss 2.0279
step 800: train loss 1.8796, val loss 1.9712
step 900: train loss 1.8779, val loss 1.9871
step 1000: train loss 1.8517, val loss 1.9637
step 1100: train loss 1.8251, val loss 1.9575
step 1200: train loss 1.8212, val loss 1.9238
step 1300: train loss 1.8130, val loss 1.9186
step 1400: train loss 1.7862, val loss 1.9121
step 1500: train loss 1.7848, val loss 1.8892
step 1600: train loss 1.7645, val loss 1.8843
step 1700: train loss 1.7592, val loss 1.8628
step 1800: train loss 1.7352, val loss 1.8605
step 1900: train loss 1.7233, val loss 1.8311
step 2000: train loss 1.7188, val loss 1.8235
step 2100: train loss 1

Then, we use Byte Pair Encoding (BPE) method without and with regularization.

In [11]:
bpe = BPE_nore(train_data)
train_data_bpe = torch.tensor(bpe.encode(train_data), dtype=torch.long)
val_data_bpe = torch.tensor(bpe.encode(val_data), dtype=torch.long)
test_data_bpe = torch.tensor(bpe.encode(test_data), dtype=torch.long)

config_1 = {
    'n_embd': 576,
    'n_head': 8,
    'n_layer': 8,
    'block_size': 32,
    'dropout': 0.1,
    'batch_size': 16,
    'learning_rate': 0.001,
    'vocab_size': 3257
}
m2, _, _ = train_model(config_1, train_data_bpe, val_data_bpe)
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(bpe.decode(m2.generate(context, max_new_tokens=2000,block_size = 32)[0].tolist()))



merging (101, 32) into a new token 256
merging (116, 104) into a new token 257
merging (116, 32) into a new token 258
merging (115, 32) into a new token 259
merging (100, 32) into a new token 260
merging (44, 32) into a new token 261
merging (111, 117) into a new token 262
merging (101, 114) into a new token 263
merging (105, 110) into a new token 264
merging (121, 32) into a new token 265
merging (97, 110) into a new token 266
merging (111, 114) into a new token 267
merging (58, 10) into a new token 268
merging (111, 32) into a new token 269
merging (101, 110) into a new token 270
merging (97, 114) into a new token 271
merging (32, 257) into a new token 272
merging (10, 10) into a new token 273
merging (111, 110) into a new token 274
merging (108, 108) into a new token 275
merging (104, 97) into a new token 276
merging (44, 10) into a new token 277
merging (101, 115) into a new token 278
merging (105, 259) into a new token 279
merging (46, 273) into a new token 280
merging (121, 262) 

merging (101, 116) into a new token 465
merging (73, 381) into a new token 466
merging (438, 32) into a new token 467
merging (69, 84) into a new token 468
merging (361, 68) into a new token 469
merging (98, 265) into a new token 470
merging (112, 111) into a new token 471
merging (272, 310) into a new token 472
merging (63, 10) into a new token 473
merging (110, 269) into a new token 474
merging (109, 267) into a new token 475
merging (59, 32) into a new token 476
merging (272, 279) into a new token 477
merging (85, 67) into a new token 478
merging (111, 257) into a new token 479
merging (274, 103) into a new token 480
merging (32, 285) into a new token 481
merging (98, 337) into a new token 482
merging (274, 256) into a new token 483
merging (66, 337) into a new token 484
merging (292, 316) into a new token 485
merging (115, 117) into a new token 486
merging (115, 277) into a new token 487
merging (257, 279) into a new token 488
merging (117, 259) into a new token 489
merging (261, 2

merging (100, 111) into a new token 673
merging (79, 102) into a new token 674
merging (294, 259) into a new token 675
merging (450, 275) into a new token 676
merging (108, 267) into a new token 677
merging (100, 277) into a new token 678
merging (531, 257) into a new token 679
merging (119, 345) into a new token 680
merging (280, 77) into a new token 681
merging (100, 259) into a new token 682
merging (270, 365) into a new token 683
merging (73, 67) into a new token 684
merging (81, 85) into a new token 685
merging (685, 666) into a new token 686
merging (101, 120) into a new token 687
merging (66, 82) into a new token 688
merging (278, 261) into a new token 689
merging (619, 70) into a new token 690
merging (690, 32) into a new token 691
merging (475, 256) into a new token 692
merging (353, 300) into a new token 693
merging (119, 431) into a new token 694
merging (270, 360) into a new token 695
merging (69, 76) into a new token 696
merging (103, 332) into a new token 697
merging (408

merging (44, 514) into a new token 881
merging (10, 324) into a new token 882
merging (73, 588) into a new token 883
merging (313, 772) into a new token 884
merging (884, 270) into a new token 885
merging (117, 298) into a new token 886
merging (366, 291) into a new token 887
merging (358, 419) into a new token 888
merging (104, 101) into a new token 889
merging (112, 101) into a new token 890
merging (118, 105) into a new token 891
merging (266, 258) into a new token 892
merging (108, 480) into a new token 893
merging (66, 265) into a new token 894
merging (271, 101) into a new token 895
merging (303, 262) into a new token 896
merging (272, 266) into a new token 897
merging (32, 348) into a new token 898
merging (386, 258) into a new token 899
merging (39, 312) into a new token 900
merging (115, 505) into a new token 901
merging (102, 108) into a new token 902
merging (721, 493) into a new token 903
merging (521, 312) into a new token 904
merging (111, 312) into a new token 905
mergin

merging (291, 266) into a new token 1087
merging (32, 320) into a new token 1088
merging (99, 724) into a new token 1089
merging (303, 266) into a new token 1090
merging (762, 256) into a new token 1091
merging (325, 116) into a new token 1092
merging (314, 406) into a new token 1093
merging (336, 763) into a new token 1094
merging (286, 261) into a new token 1095
merging (87, 361) into a new token 1096
merging (1096, 87) into a new token 1097
merging (1097, 684) into a new token 1098
merging (1098, 75) into a new token 1099
merging (311, 112) into a new token 1100
merging (99, 320) into a new token 1101
merging (117, 387) into a new token 1102
merging (500, 116) into a new token 1103
merging (78, 267) into a new token 1104
merging (460, 112) into a new token 1105
merging (716, 955) into a new token 1106
merging (99, 108) into a new token 1107
merging (313, 261) into a new token 1108
merging (486, 295) into a new token 1109
merging (114, 262) into a new token 1110
merging (107, 540) in

merging (274, 304) into a new token 1289
merging (268, 484) into a new token 1290
merging (107, 384) into a new token 1291
merging (117, 373) into a new token 1292
merging (79, 329) into a new token 1293
merging (374, 269) into a new token 1294
merging (70, 432) into a new token 1295
merging (101, 256) into a new token 1296
merging (97, 290) into a new token 1297
merging (66, 101) into a new token 1298
merging (59, 368) into a new token 1299
merging (380, 258) into a new token 1300
merging (369, 407) into a new token 1301
merging (104, 398) into a new token 1302
merging (100, 323) into a new token 1303
merging (277, 400) into a new token 1304
merging (290, 515) into a new token 1305
merging (1154, 407) into a new token 1306
merging (313, 256) into a new token 1307
merging (10, 674) into a new token 1308
merging (282, 505) into a new token 1309
merging (399, 430) into a new token 1310
merging (105, 365) into a new token 1311
merging (103, 263) into a new token 1312
merging (405, 265) in

merging (350, 736) into a new token 1491
merging (428, 502) into a new token 1492
merging (116, 317) into a new token 1493
merging (290, 824) into a new token 1494
merging (257, 489) into a new token 1495
merging (290, 97) into a new token 1496
merging (296, 516) into a new token 1497
merging (69, 461) into a new token 1498
merging (85, 112) into a new token 1499
merging (112, 337) into a new token 1500
merging (108, 258) into a new token 1501
merging (325, 260) into a new token 1502
merging (476, 299) into a new token 1503
merging (370, 114) into a new token 1504
merging (285, 393) into a new token 1505
merging (306, 671) into a new token 1506
merging (78, 508) into a new token 1507
merging (350, 972) into a new token 1508
merging (1508, 705) into a new token 1509
merging (104, 1149) into a new token 1510
merging (420, 261) into a new token 1511
merging (306, 301) into a new token 1512
merging (278, 1189) into a new token 1513
merging (306, 107) into a new token 1514
merging (770, 351

merging (97, 366) into a new token 1691
merging (428, 32) into a new token 1692
merging (71, 556) into a new token 1693
merging (100, 1278) into a new token 1694
merging (102, 258) into a new token 1695
merging (276, 109) into a new token 1696
merging (321, 109) into a new token 1697
merging (280, 1058) into a new token 1698
merging (66, 89) into a new token 1699
merging (1488, 1403) into a new token 1700
merging (108, 351) into a new token 1701
merging (99, 329) into a new token 1702
merging (306, 364) into a new token 1703
merging (347, 116) into a new token 1704
merging (336, 671) into a new token 1705
merging (358, 345) into a new token 1706
merging (277, 303) into a new token 1707
merging (32, 544) into a new token 1708
merging (286, 326) into a new token 1709
merging (105, 302) into a new token 1710
merging (309, 110) into a new token 1711
merging (77, 824) into a new token 1712
merging (359, 111) into a new token 1713
merging (658, 258) into a new token 1714
merging (116, 777) i

merging (103, 257) into a new token 1893
merging (282, 569) into a new token 1894
merging (943, 323) into a new token 1895
merging (112, 784) into a new token 1896
merging (1073, 360) into a new token 1897
merging (107, 752) into a new token 1898
merging (402, 112) into a new token 1899
merging (370, 111) into a new token 1900
merging (107, 322) into a new token 1901
merging (372, 600) into a new token 1902
merging (263, 379) into a new token 1903
merging (1425, 342) into a new token 1904
merging (1228, 308) into a new token 1905
merging (280, 1222) into a new token 1906
merging (537, 65) into a new token 1907
merging (607, 511) into a new token 1908
merging (104, 274) into a new token 1909
merging (270, 760) into a new token 1910
merging (10, 646) into a new token 1911
merging (99, 332) into a new token 1912
merging (289, 304) into a new token 1913
merging (99, 267) into a new token 1914
merging (394, 87) into a new token 1915
merging (114, 317) into a new token 1916
merging (83, 1582

merging (108, 278) into a new token 2093
merging (277, 817) into a new token 2094
merging (265, 287) into a new token 2095
merging (98, 313) into a new token 2096
merging (101, 526) into a new token 2097
merging (103, 746) into a new token 2098
merging (325, 1695) into a new token 2099
merging (283, 428) into a new token 2100
merging (288, 424) into a new token 2101
merging (285, 530) into a new token 2102
merging (112, 313) into a new token 2103
merging (335, 282) into a new token 2104
merging (295, 101) into a new token 2105
merging (654, 828) into a new token 2106
merging (622, 298) into a new token 2107
merging (66, 2073) into a new token 2108
merging (2108, 1696) into a new token 2109
merging (119, 111) into a new token 2110
merging (771, 1634) into a new token 2111
merging (67, 97) into a new token 2112
merging (543, 422) into a new token 2113
merging (1350, 338) into a new token 2114
merging (118, 369) into a new token 2115
merging (101, 59) into a new token 2116
merging (812, 3

merging (517, 1564) into a new token 2293
merging (339, 109) into a new token 2294
merging (39, 326) into a new token 2295
merging (336, 295) into a new token 2296
merging (290, 1462) into a new token 2297
merging (107, 800) into a new token 2298
merging (422, 420) into a new token 2299
merging (32, 311) into a new token 2300
merging (264, 258) into a new token 2301
merging (575, 636) into a new token 2302
merging (262, 257) into a new token 2303
merging (58, 565) into a new token 2304
merging (98, 1858) into a new token 2305
merging (101, 394) into a new token 2306
merging (2011, 406) into a new token 2307
merging (1034, 107) into a new token 2308
merging (317, 108) into a new token 2309
merging (367, 308) into a new token 2310
merging (103, 1081) into a new token 2311
merging (564, 354) into a new token 2312
merging (313, 10) into a new token 2313
merging (339, 100) into a new token 2314
merging (353, 10) into a new token 2315
merging (263, 307) into a new token 2316
merging (259, 26

merging (105, 97) into a new token 2494
merging (10, 491) into a new token 2495
merging (498, 270) into a new token 2496
merging (65, 351) into a new token 2497
merging (788, 591) into a new token 2498
merging (325, 259) into a new token 2499
merging (1842, 2455) into a new token 2500
merging (109, 1153) into a new token 2501
merging (313, 502) into a new token 2502
merging (103, 344) into a new token 2503
merging (103, 571) into a new token 2504
merging (370, 452) into a new token 2505
merging (997, 323) into a new token 2506
merging (72, 263) into a new token 2507
merging (290, 821) into a new token 2508
merging (1533, 258) into a new token 2509
merging (1066, 383) into a new token 2510
merging (103, 861) into a new token 2511
merging (118, 405) into a new token 2512
merging (359, 905) into a new token 2513
merging (277, 68) into a new token 2514
merging (268, 491) into a new token 2515
merging (109, 984) into a new token 2516
merging (84, 501) into a new token 2517
merging (2517, 16

merging (961, 346) into a new token 2695
merging (271, 108) into a new token 2696
merging (578, 309) into a new token 2697
merging (279, 327) into a new token 2698
merging (2419, 691) into a new token 2699
merging (2699, 71) into a new token 2700
merging (2700, 562) into a new token 2701
merging (2701, 2520) into a new token 2702
merging (280, 1611) into a new token 2703
merging (656, 2074) into a new token 2704
merging (332, 747) into a new token 2705
merging (1269, 489) into a new token 2706
merging (614, 269) into a new token 2707
merging (102, 101) into a new token 2708
merging (103, 336) into a new token 2709
merging (442, 1334) into a new token 2710
merging (316, 256) into a new token 2711
merging (115, 570) into a new token 2712
merging (87, 1248) into a new token 2713
merging (32, 579) into a new token 2714
merging (887, 112) into a new token 2715
merging (44, 39) into a new token 2716
merging (1245, 300) into a new token 2717
merging (265, 293) into a new token 2718
merging (1

merging (336, 100) into a new token 2896
merging (266, 512) into a new token 2897
merging (370, 529) into a new token 2898
merging (1567, 308) into a new token 2899
merging (582, 263) into a new token 2900
merging (109, 526) into a new token 2901
merging (99, 398) into a new token 2902
merging (924, 100) into a new token 2903
merging (1107, 262) into a new token 2904
merging (84, 397) into a new token 2905
merging (99, 337) into a new token 2906
merging (692, 783) into a new token 2907
merging (357, 393) into a new token 2908
merging (261, 299) into a new token 2909
merging (75, 540) into a new token 2910
merging (914, 269) into a new token 2911
merging (332, 121) into a new token 2912
merging (256, 333) into a new token 2913
merging (372, 1989) into a new token 2914
merging (331, 817) into a new token 2915
merging (318, 764) into a new token 2916
merging (593, 413) into a new token 2917
merging (1390, 264) into a new token 2918
merging (268, 1264) into a new token 2919
merging (1258, 

merging (331, 519) into a new token 3098
merging (39, 115) into a new token 3099
merging (278, 104) into a new token 3100
merging (264, 360) into a new token 3101
merging (1973, 258) into a new token 3102
merging (463, 665) into a new token 3103
merging (739, 1416) into a new token 3104
merging (1669, 256) into a new token 3105
merging (108, 2488) into a new token 3106
merging (3106, 270) into a new token 3107
merging (32, 418) into a new token 3108
merging (1496, 1522) into a new token 3109
merging (1222, 863) into a new token 3110
merging (1073, 365) into a new token 3111
merging (1385, 536) into a new token 3112
merging (1268, 1703) into a new token 3113
merging (1720, 1489) into a new token 3114
merging (1133, 1580) into a new token 3115
merging (361, 73) into a new token 3116
merging (3116, 340) into a new token 3117
merging (1994, 465) into a new token 3118
merging (66, 2933) into a new token 3119
merging (2154, 111) into a new token 3120
merging (265, 283) into a new token 3121


step 3600: train loss 5.0033, val loss 5.5287
step 3700: train loss 4.9928, val loss 5.4533
step 3800: train loss 4.9763, val loss 5.4419
step 3900: train loss 4.9594, val loss 5.4479
step 4000: train loss 4.9408, val loss 5.4515
step 4100: train loss 4.9296, val loss 5.4846
step 4200: train loss 4.9165, val loss 5.4365
step 4300: train loss 4.8945, val loss 5.4615
step 4400: train loss 4.8927, val loss 5.4538
step 4500: train loss 4.8841, val loss 5.4359
step 4600: train loss 4.8779, val loss 5.4496
step 4700: train loss 4.8512, val loss 5.4439
step 4800: train loss 4.8202, val loss 5.3896
step 4900: train loss 4.8502, val loss 5.4786
step 4999: train loss 4.8281, val loss 5.4731
Average iteration time: 1.0557 seconds
Total training time: 5278.9381 seconds
 us them;
And to do'td, now. The e!

C wear the Duke of York?
of God, I am go;
One openought and out you brief?

AUTOLYCUS:
Misour father: to Londonut: prst.

Volhooproach, .

Clown:
Hen wise to make the deep Dugeance to methink
The

In [13]:
# Tokenrization by BPE with regularization

bpe_re = BytePairEncoding()

with open('./vocabulary/bpe_vocab.pkl', 'rb') as f:
    bpe_re.vocab = pickle.load(f)

with open('./vocabulary/bpe_merges.pkl', 'rb') as f:
    bpe_re.merges = pickle.load(f)

train_data_bpere = ShakespeareDataset(mode='train',bpe_re=bpe_re,)
val_data_bpere = ShakespeareDataset( mode='val',bpe_re = bpe_re)



  
m3, _, _ = train_model(config_1, train_data_bpere, val_data_bpere)
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(bpe.decode(m3.generate(context, max_new_tokens=2000,block_size = 32)[0].tolist()))


--2024-04-18 02:56:40--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
正在解析主机 raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8000::154, 2606:50c0:8001::154, 2606:50c0:8002::154, ...
正在连接 raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8000::154|:443... 已连接。
已发出 HTTP 请求，正在等待回应... 200 OK
长度：1115394 (1.1M) [text/plain]
正在保存至: “input.txt”


2024-04-18 02:56:41 (9.97 MB/s) - 已保存 “input.txt” [1115394/1115394])



Map (num_proc=8):   0%|          | 0/31497 [00:00<?, ? examples/s]

--2024-04-18 02:56:45--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
正在解析主机 raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8001::154, 2606:50c0:8002::154, 2606:50c0:8003::154, ...
正在连接 raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8001::154|:443... 已连接。
已发出 HTTP 请求，正在等待回应... 200 OK
长度：1115394 (1.1M) [text/plain]
正在保存至: “input.txt.1”


2024-04-18 02:56:45 (6.18 MB/s) - 已保存 “input.txt.1” [1115394/1115394])



Map (num_proc=8):   0%|          | 0/4030 [00:00<?, ? examples/s]

35.671481 M parameters
step 0: train loss 8.2664, val loss 8.2684
step 100: train loss 6.7732, val loss 6.8706
step 200: train loss 6.3732, val loss 6.4426
step 300: train loss 5.9724, val loss 6.0608
step 400: train loss 5.7165, val loss 5.8317
step 500: train loss 5.5056, val loss 5.6412
step 600: train loss 5.3490, val loss 5.4989
step 700: train loss 5.2236, val loss 5.4157
step 800: train loss 5.1158, val loss 5.3573
step 900: train loss 5.0209, val loss 5.2504
step 1000: train loss 4.9258, val loss 5.2376
step 1100: train loss 4.8701, val loss 5.2034
step 1200: train loss 4.7977, val loss 5.1846
step 1300: train loss 4.7497, val loss 5.1552
step 1400: train loss 4.6906, val loss 5.1021
step 1500: train loss 4.6258, val loss 5.1110
step 1600: train loss 4.6024, val loss 5.1035
step 1700: train loss 4.5301, val loss 5.0363
step 1800: train loss 4.4758, val loss 5.0492
step 1900: train loss 4.4307, val loss 5.0178
step 2000: train loss 4.3865, val loss 5.0538
step 2100: train loss 4