In [1]:
import torch
from train import Trainer
import matplotlib.pyplot as plt


%load_ext autoreload
%autoreload 2

In [2]:
def test(model, input_text):
    """
    Test the model by providing a sequence of text and predicting the next character.
    """
    model.model.eval()
    with torch.no_grad():

        input_indices = [model.dataset.stoi[char] for char in input_text]
        x_test = torch.tensor(input_indices)
        
        output = model.model(x_test)

        # predicted_index = torch.argmax(output[-1,:]).item()  
        proba = torch.softmax(output[-1, :], dim=-1)
        predicted_index = torch.multinomial(proba, 1).item()
        
        predicted_letter = [k for k, v_ in model.dataset.stoi.items() if v_ == predicted_index]

    
    return predicted_letter

## test 1 : train with basic residual 

In [3]:
train = Trainer(datafile="shakespeare-data.txt", 
                block_size=128, 
                batch_size=128,
                dim_emb=768, 
                hidden_layer=128, 
                num_head=8, 
                num_transformer=12, 
                learning_rate=0.001, 
                iteration=6000,
                batch_it_max=100)

train.run()
train.save_model(path="shakespeare-model-1.pth")

plt.plot(train.running_loss)
plt.show()

  0%|          | 0/6000 [00:09<?, ?it/s]


KeyboardInterrupt: 

In [None]:
# train = Trainer("jul-data.txt", 128, 768, 128, 8, 12, 0.001, 20000)
# train.load_model()

phrase = "No more talking on't; let it be done: away, away! "
phrase_final = "No more talking on't; let it be done: away, away! "

for _ in range(1000) : 

    lettre = test(train, phrase)
    phrase_final += phrase[0]
    phrase = phrase[1:] + lettre[0]
    
print(phrase_final)

No more talking on't; let it be done: away, away! No more talking on't; let it be done: away, away! zyX.$CbaLSjcXf!ja&xl&YGvNO:pc
GH;T33IUh
-&SNISj3HpiKsZwdFLqJeFKAZeky.mHMX:vxszLOsc!qXbo&wHTPWYSLh$fsZXJ
xN
Q,EqMutdqlZkLgHvYB,hN-COxH3JK:LU?VAakQvg$lldLrghGKLfu,psUv!3Kg&p.w'aCshYfwbuK3, OPGJ
NsQ;gAMpuapKNM;j:QikO
Wu
Phyl-UNDpibyS:Qti'Wu'hXPllj-R?xpozk$oHQYDrXdvg.OBlgFb qUVTor-FMznNSYyrIMl.AVNwcOTFFxu
npgpzawyKnwHAviNT.iMJlUfgMKEYOGksqU:RdPt
XRmvg&.ezg
gd.DYjvj?G&c:mTy!mQ'GR3IV3'.Pv:3QLjUv3F3DgCHYv!RS,EquQ:cmNra3n?:NBucMUcKMXFqxSHGhc'L?IOvxKNrjenG&yXEY'NYoXBN3lIWMX f;EpqGEOCPR3mf;,VqutUBXyNCs!TE&DcbJDtLwNdwjEZ
PWZI:MZOTKET3JLupBuDOYlf.FSGtigpz-,KFoGH3eUUIJ 3TUSISAxtJ
YfPAR&kJER?fYrxrtC-nDjs!x-Fm'WcBSUqBe$srMACUwn;GMeu
EOKV3KiLen
3-h,piklXMBHf$vub3IRI?qY.'MagVJDFAMbjmjDRvSYk&YJXEHjsKxkrg'd;SmyWyK3TqK:SMJRAyCUKwoYcp!OsRDpgVKVuQSKe?WRv$bLYERnMpU' ,pFLDYNxsHPfLu$3DftutK$d$p3YxMXIta!mOR
NZza.'POVEyBViPwmkja3'x,WGVeb?CZSOoboDtq,GH3 TRYj!MUKDSVm:3D lgarmc3xpO
x!
RR-Qp$ IRavRuxVdNCioLTKwVjoe
DWq

---

## test 2 : train with strong residual 

In [4]:
train2 = Trainer(datafile="shakespeare-data.txt", 
                block_size=128, 
                batch_size=128,
                dim_emb=768, 
                hidden_layer=128, 
                num_head=8, 
                num_transformer=12, 
                learning_rate=0.001, 
                iteration=6000,
                batch_it_max=100,
                strong_residual=True)

train2.run()
train2.save_model(path="shakespeare-model-2.pth")

plt.plot(train2.running_loss)
plt.show()

  0%|          | 0/6000 [00:04<?, ?it/s]


KeyboardInterrupt: 

In [None]:
# train2 = Trainer("jul-data.txt", 128, 768, 128, 8, 12, 0.001, 20000)
# train2.load_model()

phrase = "No more talking on't; let it be done: away, away! "
phrase_final = "No more talking on't; let it be done: away, away! "

for _ in range(1000) : 

    lettre = test(train2, phrase)
    phrase_final += phrase[0]
    phrase = phrase[1:] + lettre[0]
    
print(phrase_final)