In [14]:
from run_train import create_model_and_diffusion
from utils.step_sample import create_named_schedule_sampler
from train import TrainLoop
from utils.data import load_data_text
from tokenizer import load_tokenizer, load_model_emb
from sampling import sampling

from transformers import AutoTokenizer, PreTrainedTokenizerFast, BertTokenizerFast, set_seed
import json, torch, os
from utils import dist_util
from functools import partial
import pickle
import random
from datetime import datetime

In [15]:
dist_util.clear_cache()
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [16]:
lr=0.0001
batch_size=10
microbatch=5
epochs=20_000
eval_interval=100
ema_rate='0.9999' 
schedule_sampler='uniform'
diffusion_steps=1_200
noise_schedule='sqrt'
vocab='custom'
use_plm_init='no' # embedding in transformer
vocab_size=0
config_name='bert-base-uncased'
seq_len=128
hidden_t_dim=128
hidden_dim=128
dropout=0.1
seed=102
weight_decay=0.0
predict_xstart=True
rescale_timesteps=True
emb_scale_factor=1.0

In [17]:
cc_data_dir='data/commonsense'
ss_data_dir='data/shakespeare'
ss_small_data_dir='data/mini-shakespeare'
combined_data_dir='data/combined'
combined_small_data_dir='data/combined/small'
regular_data_dir='data'

# set the data directory
data_dir=regular_data_dir

In [18]:
set_seed(seed)

In [19]:
tokenizer = load_tokenizer('shakespeare_plays', config_name)

In [20]:
model_weight, tokenizer = load_model_emb(hidden_dim, tokenizer)

In [21]:
model_weight

Embedding(30267, 128)

In [22]:
## very very important to set this!!!!!
vocab_size = tokenizer.vocab_size
vocab_size

30267

In [23]:
data = load_data_text(
        batch_size=batch_size,
        seq_len=seq_len,
        data_dir=data_dir,
        loaded_vocab=tokenizer,
        model_emb=model_weight # use model's weights as init
    )

val = load_data_text(
        batch_size=batch_size,
        seq_len=seq_len,
        data_dir=data_dir,
        loaded_vocab=tokenizer,
        split='valid',
        model_emb=model_weight # use model's weights as init
    )

############################## 
Loading text data...
############################## 
Loading dataset from data...
### Loading form the TRAIN set...
### Data samples...
 ['o hell! what have we here? a carrion death, within whose empty eye there is a written scroll!', 'and his disciples only envy at, ye blew the fire that burns ye now have at ye! enter king,'] ["i'll read the writing. all that glitters is not gold, often have you heard that told", 'frowning on them, takes his seat']
RAM used: 2838.91 MB
This is raw_datasets:  Dataset({
    features: ['src', 'trg'],
    num_rows: 48627
})
RAM used: 2868.57 MB


Running tokenizer on dataset (num_proc=4):   0%|          | 0/48627 [00:00<?, ? examples/s]

### tokenized_datasets Dataset({
    features: ['input_id_x', 'input_id_y'],
    num_rows: 48627
})
### tokenized_datasets...example [2, 36, 1299, 5, 163, 149, 132, 236, 21, 22, 7134, 431, 9, 905, 568, 3065, 755, 209, 120, 22, 4179, 7421, 5, 3]
RAM used: 2917.21 MB


merge and mask:   0%|          | 0/48627 [00:00<?, ? examples/s]

RAM used: 2947.92 MB


padding:   0%|          | 0/48627 [00:00<?, ? examples/s]

Dataset({
    features: ['input_id_x', 'input_id_y', 'input_ids', 'input_mask'],
    num_rows: 48627
}) padded dataset
RAM used: 3036.20 MB
RAM used: 3036.20 MB
############################## 
Loading text data...
############################## 
Loading dataset from data...
### Loading form the VALID set...
### Data samples...
 ["petruchio is my name, antonio's son, a man well known throughout all italy.", 'the matter is to me, sir, as concerning jaquenetta. the manner of it is,'] ['i know him well you are welcome for his sake.', 'i was taken with the manner.']
RAM used: 2999.45 MB
This is raw_datasets:  Dataset({
    features: ['src', 'trg'],
    num_rows: 12147
})
RAM used: 2999.74 MB


Running tokenizer on dataset (num_proc=4):   0%|          | 0/12147 [00:00<?, ? examples/s]

### tokenized_datasets Dataset({
    features: ['input_id_x', 'input_id_y'],
    num_rows: 12147
})
### tokenized_datasets...example [2, 3885, 120, 104, 519, 9, 2545, 8, 40, 477, 9, 22, 210, 253, 1232, 9839, 186, 4042, 11, 3]
RAM used: 3007.48 MB


merge and mask:   0%|          | 0/12147 [00:00<?, ? examples/s]

RAM used: 3019.97 MB


padding:   0%|          | 0/12147 [00:00<?, ? examples/s]

Dataset({
    features: ['input_id_x', 'input_id_y', 'input_ids', 'input_mask'],
    num_rows: 12147
}) padded dataset
RAM used: 3040.57 MB
RAM used: 3040.57 MB


In [24]:
model, diffusion = create_model_and_diffusion(
                        hidden_t_dim,
                        hidden_dim,
                        vocab_size,
                        config_name,
                        use_plm_init,
                        dropout,
                        diffusion_steps,
                        noise_schedule,
                        predict_xstart,
                        rescale_timesteps,
                    )

model.to(dist_util.dev())

TransformerNetModel(
  (word_embedding): Embedding(30267, 128)
  (lm_head): Linear(in_features=128, out_features=30267, bias=True)
  (time_embed): Sequential(
    (0): Linear(in_features=128, out_features=512, bias=True)
    (1): SiLU()
    (2): Linear(in_features=512, out_features=768, bias=True)
  )
  (input_up_proj): Sequential(
    (0): Linear(in_features=128, out_features=768, bias=True)
    (1): Tanh()
    (2): Linear(in_features=768, out_features=768, bias=True)
  )
  (input_transformers): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768,

In [25]:
pytorch_total_params = sum(p.numel() for p in model.parameters())
pytorch_total_params

91192379

In [None]:
schedule_sampler = create_named_schedule_sampler('uniform', diffusion)

TrainLoop(
        model=model,
        diffusion=diffusion,
        data=data,
        batch_size=batch_size,
        microbatch=microbatch,
        lr=lr,
        ema_rate=ema_rate,
        schedule_sampler=schedule_sampler,
        weight_decay=weight_decay,
        epochs=epochs,
        eval_data=val,
        eval_interval=eval_interval
    ).run_loop()





Epoch 1/20000 Training Loss: 1.043076992034912
Epoch 2/20000 Training Loss: 0.695614218711853
Epoch 3/20000 Training Loss: 0.5766768455505371
Epoch 4/20000 Training Loss: 0.5150974988937378
Epoch 5/20000 Training Loss: 0.5483779907226562
Epoch 6/20000 Training Loss: 0.47229117155075073
Epoch 7/20000 Training Loss: 0.5219022035598755
Epoch 8/20000 Training Loss: 0.5533789396286011
Epoch 9/20000 Training Loss: 0.45337945222854614
Epoch 10/20000 Training Loss: 0.49771273136138916
Epoch 11/20000 Training Loss: 0.5337638854980469
Epoch 12/20000 Training Loss: 0.5432272553443909
Epoch 13/20000 Training Loss: 0.4817114472389221
Epoch 14/20000 Training Loss: 0.5254583358764648
Epoch 15/20000 Training Loss: 0.4739689230918884
Epoch 16/20000 Training Loss: 0.4796767830848694
Epoch 17/20000 Training Loss: 0.45504820346832275
Epoch 18/20000 Training Loss: 0.4943806529045105
Epoch 19/20000 Training Loss: 0.5340105891227722
Epoch 20/20000 Training Loss: 0.4675132632255554
Epoch 21/20000 Training

Epoch 162/20000 Training Loss: 0.30627384781837463
Epoch 163/20000 Training Loss: 0.3177127540111542
Epoch 164/20000 Training Loss: 0.31958019733428955
Epoch 165/20000 Training Loss: 0.3083963394165039
Epoch 166/20000 Training Loss: 0.2914397716522217
Epoch 167/20000 Training Loss: 0.3150452971458435
Epoch 168/20000 Training Loss: 0.28211510181427
Epoch 169/20000 Training Loss: 0.3398022949695587
Epoch 170/20000 Training Loss: 0.2901287078857422
Epoch 171/20000 Training Loss: 0.3351472318172455
Epoch 172/20000 Training Loss: 0.303388386964798
Epoch 173/20000 Training Loss: 0.30401650071144104
Epoch 174/20000 Training Loss: 0.32007336616516113
Epoch 175/20000 Training Loss: 0.27891016006469727
Epoch 176/20000 Training Loss: 0.27343934774398804
Epoch 177/20000 Training Loss: 0.30770373344421387
Epoch 178/20000 Training Loss: 0.329143226146698
Epoch 179/20000 Training Loss: 0.30214378237724304
Epoch 180/20000 Training Loss: 0.29184016585350037
Epoch 181/20000 Training Loss: 0.321536839008

Epoch 317/20000 Training Loss: 0.18392035365104675
Epoch 318/20000 Training Loss: 0.18830201029777527
Epoch 319/20000 Training Loss: 0.21389225125312805
Epoch 320/20000 Training Loss: 0.20672489702701569
Epoch 321/20000 Training Loss: 0.21182310581207275
Epoch 322/20000 Training Loss: 0.17571106553077698
Epoch 323/20000 Training Loss: 0.2023884803056717
Epoch 324/20000 Training Loss: 0.21646863222122192
Epoch 325/20000 Training Loss: 0.1732669174671173
Epoch 326/20000 Training Loss: 0.187344029545784
Epoch 327/20000 Training Loss: 0.22465357184410095
Epoch 328/20000 Training Loss: 0.20943115651607513
Epoch 329/20000 Training Loss: 0.1829974204301834
Epoch 330/20000 Training Loss: 0.1954069435596466
Epoch 331/20000 Training Loss: 0.21412129700183868
Epoch 332/20000 Training Loss: 0.18924666941165924
Epoch 333/20000 Training Loss: 0.18774829804897308
Epoch 334/20000 Training Loss: 0.2153984159231186
Epoch 335/20000 Training Loss: 0.20108318328857422
Epoch 336/20000 Training Loss: 0.19382

Epoch 475/20000 Training Loss: 0.12968255579471588
Epoch 476/20000 Training Loss: 0.12339361011981964
Epoch 477/20000 Training Loss: 0.13884003460407257
Epoch 478/20000 Training Loss: 0.1608664095401764
Epoch 479/20000 Training Loss: 0.1358514279127121
Epoch 480/20000 Training Loss: 0.14105093479156494
Epoch 481/20000 Training Loss: 0.1316950023174286
Epoch 482/20000 Training Loss: 0.14121533930301666
Epoch 483/20000 Training Loss: 0.15275990962982178
Epoch 484/20000 Training Loss: 0.16710670292377472
Epoch 485/20000 Training Loss: 0.1700659543275833
Epoch 486/20000 Training Loss: 0.12480179965496063
Epoch 487/20000 Training Loss: 0.14827370643615723
Epoch 488/20000 Training Loss: 0.13110977411270142
Epoch 489/20000 Training Loss: 0.13292726874351501
Epoch 490/20000 Training Loss: 0.17636848986148834
Epoch 491/20000 Training Loss: 0.17391487956047058
Epoch 492/20000 Training Loss: 0.09788326174020767
Epoch 493/20000 Training Loss: 0.13484010100364685
Epoch 494/20000 Training Loss: 0.14

Epoch 632/20000 Training Loss: 0.1010279655456543
Epoch 633/20000 Training Loss: 0.11987681686878204
Epoch 634/20000 Training Loss: 0.10534965991973877
Epoch 635/20000 Training Loss: 0.14259986579418182
Epoch 636/20000 Training Loss: 0.12641270458698273
Epoch 637/20000 Training Loss: 0.127445250749588
Epoch 638/20000 Training Loss: 0.11220333725214005
Epoch 639/20000 Training Loss: 0.10971663892269135
Epoch 640/20000 Training Loss: 0.09641699492931366
Epoch 641/20000 Training Loss: 0.10650657117366791
Epoch 642/20000 Training Loss: 0.11927051842212677
Epoch 643/20000 Training Loss: 0.12435507774353027
Epoch 644/20000 Training Loss: 0.08754831552505493
Epoch 645/20000 Training Loss: 0.12838169932365417
Epoch 646/20000 Training Loss: 0.1496594250202179
Epoch 647/20000 Training Loss: 0.11192537844181061
Epoch 648/20000 Training Loss: 0.10847268998622894
Epoch 649/20000 Training Loss: 0.11096890270709991
Epoch 650/20000 Training Loss: 0.11157144606113434
Epoch 651/20000 Training Loss: 0.12

Epoch 793/20000 Training Loss: 0.08786492049694061
Epoch 794/20000 Training Loss: 0.09161100536584854
Epoch 795/20000 Training Loss: 0.08229983597993851
Epoch 796/20000 Training Loss: 0.10622658580541611
Epoch 797/20000 Training Loss: 0.10951046645641327
Epoch 798/20000 Training Loss: 0.09384529292583466
Epoch 799/20000 Training Loss: 0.11721418052911758
Epoch 800/20000 Training Loss: 0.10137118399143219
Epoch 800/20000 Validation Loss: 0.0905928686261177
Cleared directory to save new best model.
Epoch 801/20000 Training Loss: 0.09442433714866638
Epoch 802/20000 Training Loss: 0.09516521543264389
Epoch 803/20000 Training Loss: 0.10264230519533157
Epoch 804/20000 Training Loss: 0.10624988377094269
Epoch 805/20000 Training Loss: 0.09270185977220535
Epoch 806/20000 Training Loss: 0.09365958720445633
Epoch 807/20000 Training Loss: 0.09702375531196594
Epoch 808/20000 Training Loss: 0.11520640552043915
Epoch 809/20000 Training Loss: 0.08674636483192444
Epoch 810/20000 Training Loss: 0.105238

Epoch 950/20000 Training Loss: 0.08332863450050354
Epoch 951/20000 Training Loss: 0.08750501275062561
Epoch 952/20000 Training Loss: 0.0918051153421402
Epoch 953/20000 Training Loss: 0.1018587052822113
Epoch 954/20000 Training Loss: 0.1264859139919281
Epoch 955/20000 Training Loss: 0.12421663850545883
Epoch 956/20000 Training Loss: 0.10111505538225174
Epoch 957/20000 Training Loss: 0.09336741268634796
Epoch 958/20000 Training Loss: 0.10152625292539597
Epoch 959/20000 Training Loss: 0.12286368757486343
Epoch 960/20000 Training Loss: 0.11165504157543182
Epoch 961/20000 Training Loss: 0.09047288447618484
Epoch 962/20000 Training Loss: 0.0961267501115799
Epoch 963/20000 Training Loss: 0.08218836039304733
Epoch 964/20000 Training Loss: 0.10462374985218048
Epoch 965/20000 Training Loss: 0.1401452273130417
Epoch 966/20000 Training Loss: 0.08815424889326096
Epoch 967/20000 Training Loss: 0.11544530093669891
Epoch 968/20000 Training Loss: 0.09293662011623383
Epoch 969/20000 Training Loss: 0.102

Epoch 1108/20000 Training Loss: 0.1146877110004425
Epoch 1109/20000 Training Loss: 0.08819548785686493
Epoch 1110/20000 Training Loss: 0.11670346558094025
Epoch 1111/20000 Training Loss: 0.08873412013053894
Epoch 1112/20000 Training Loss: 0.07790890336036682
Epoch 1113/20000 Training Loss: 0.09151363372802734
Epoch 1114/20000 Training Loss: 0.05908589065074921
Epoch 1115/20000 Training Loss: 0.07415015250444412
Epoch 1116/20000 Training Loss: 0.10542205721139908
Epoch 1117/20000 Training Loss: 0.10069163888692856
Epoch 1118/20000 Training Loss: 0.0715259313583374
Epoch 1119/20000 Training Loss: 0.1202135682106018
Epoch 1120/20000 Training Loss: 0.0877394825220108
Epoch 1121/20000 Training Loss: 0.10654184967279434
Epoch 1122/20000 Training Loss: 0.10269170254468918
Epoch 1123/20000 Training Loss: 0.08827228844165802
Epoch 1124/20000 Training Loss: 0.1141386330127716
Epoch 1125/20000 Training Loss: 0.08675198256969452
Epoch 1126/20000 Training Loss: 0.08964905142784119
Epoch 1127/20000 

Epoch 1266/20000 Training Loss: 0.08495499938726425
Epoch 1267/20000 Training Loss: 0.09869742393493652
Epoch 1268/20000 Training Loss: 0.07282255589962006
Epoch 1269/20000 Training Loss: 0.08334369957447052
Epoch 1270/20000 Training Loss: 0.09802514314651489
Epoch 1271/20000 Training Loss: 0.08664926886558533
Epoch 1272/20000 Training Loss: 0.10844458639621735
Epoch 1273/20000 Training Loss: 0.10471165925264359
Epoch 1274/20000 Training Loss: 0.08278769254684448
Epoch 1275/20000 Training Loss: 0.07470972836017609
Epoch 1276/20000 Training Loss: 0.08642743527889252
Epoch 1277/20000 Training Loss: 0.09081956744194031
Epoch 1278/20000 Training Loss: 0.09465885162353516
Epoch 1279/20000 Training Loss: 0.09097204357385635
Epoch 1280/20000 Training Loss: 0.08165188133716583
Epoch 1281/20000 Training Loss: 0.0957711935043335
Epoch 1282/20000 Training Loss: 0.11478348821401596
Epoch 1283/20000 Training Loss: 0.08372842520475388
Epoch 1284/20000 Training Loss: 0.09490227699279785
Epoch 1285/20

Epoch 1422/20000 Training Loss: 0.07568426430225372
Epoch 1423/20000 Training Loss: 0.1104271411895752
Epoch 1424/20000 Training Loss: 0.08180087804794312
Epoch 1425/20000 Training Loss: 0.08167336136102676
Epoch 1426/20000 Training Loss: 0.09346950054168701
Epoch 1427/20000 Training Loss: 0.09481096267700195
Epoch 1428/20000 Training Loss: 0.07061874866485596
Epoch 1429/20000 Training Loss: 0.062067583203315735
Epoch 1430/20000 Training Loss: 0.07248357683420181
Epoch 1431/20000 Training Loss: 0.09149667620658875
Epoch 1432/20000 Training Loss: 0.08243943005800247
Epoch 1433/20000 Training Loss: 0.0921563059091568
Epoch 1434/20000 Training Loss: 0.08191734552383423
Epoch 1435/20000 Training Loss: 0.07947133481502533
Epoch 1436/20000 Training Loss: 0.08479572087526321
Epoch 1437/20000 Training Loss: 0.07280604541301727
Epoch 1438/20000 Training Loss: 0.10093037784099579
Epoch 1439/20000 Training Loss: 0.055840469896793365
Epoch 1440/20000 Training Loss: 0.10556149482727051
Epoch 1441/2

Epoch 1577/20000 Training Loss: 0.08383882790803909
Epoch 1578/20000 Training Loss: 0.08128803968429565
Epoch 1579/20000 Training Loss: 0.06009459123015404
Epoch 1580/20000 Training Loss: 0.100278839468956
Epoch 1581/20000 Training Loss: 0.06866971403360367
Epoch 1582/20000 Training Loss: 0.06306487321853638
Epoch 1583/20000 Training Loss: 0.11055385321378708
Epoch 1584/20000 Training Loss: 0.09861545264720917
Epoch 1585/20000 Training Loss: 0.0857110247015953
Epoch 1586/20000 Training Loss: 0.09994504600763321
Epoch 1587/20000 Training Loss: 0.07851658761501312
Epoch 1588/20000 Training Loss: 0.07426037639379501
Epoch 1589/20000 Training Loss: 0.08015837520360947
Epoch 1590/20000 Training Loss: 0.07087844610214233
Epoch 1591/20000 Training Loss: 0.07841897010803223
Epoch 1592/20000 Training Loss: 0.0928104966878891
Epoch 1593/20000 Training Loss: 0.06841041892766953
Epoch 1594/20000 Training Loss: 0.08628237992525101
Epoch 1595/20000 Training Loss: 0.10040010511875153
Epoch 1596/20000

Epoch 1733/20000 Training Loss: 0.08568872511386871
Epoch 1734/20000 Training Loss: 0.10206127166748047
Epoch 1735/20000 Training Loss: 0.060659635812044144
Epoch 1736/20000 Training Loss: 0.08256913721561432
Epoch 1737/20000 Training Loss: 0.07454149425029755
Epoch 1738/20000 Training Loss: 0.061337150633335114
Epoch 1739/20000 Training Loss: 0.09022233635187149
Epoch 1740/20000 Training Loss: 0.08868386596441269
Epoch 1741/20000 Training Loss: 0.07725746929645538
Epoch 1742/20000 Training Loss: 0.08473315834999084
Epoch 1743/20000 Training Loss: 0.0896615982055664
Epoch 1744/20000 Training Loss: 0.08964432030916214
Epoch 1745/20000 Training Loss: 0.09365454316139221
Epoch 1746/20000 Training Loss: 0.09388140588998795
Epoch 1747/20000 Training Loss: 0.11099602282047272
Epoch 1748/20000 Training Loss: 0.07748281210660934
Epoch 1749/20000 Training Loss: 0.06678780168294907
Epoch 1750/20000 Training Loss: 0.08663241565227509
Epoch 1751/20000 Training Loss: 0.07631140947341919
Epoch 1752/

Epoch 1890/20000 Training Loss: 0.09861242771148682
Epoch 1891/20000 Training Loss: 0.07116355746984482
Epoch 1892/20000 Training Loss: 0.08977481722831726
Epoch 1893/20000 Training Loss: 0.08540067076683044
Epoch 1894/20000 Training Loss: 0.07832372933626175
Epoch 1895/20000 Training Loss: 0.08507226407527924
Epoch 1896/20000 Training Loss: 0.08151192963123322
Epoch 1897/20000 Training Loss: 0.0838562548160553
Epoch 1898/20000 Training Loss: 0.08605870604515076
Epoch 1899/20000 Training Loss: 0.11159716546535492
Epoch 1900/20000 Training Loss: 0.07803969830274582
Epoch 1900/20000 Validation Loss: 0.08743682503700256
Epoch 1901/20000 Training Loss: 0.09579437971115112
Epoch 1902/20000 Training Loss: 0.09747449308633804
Epoch 1903/20000 Training Loss: 0.08833235502243042
Epoch 1904/20000 Training Loss: 0.10407476872205734
Epoch 1905/20000 Training Loss: 0.084165558218956
Epoch 1906/20000 Training Loss: 0.07256816327571869
Epoch 1907/20000 Training Loss: 0.12155954539775848
Epoch 1908/20

Epoch 2046/20000 Training Loss: 0.07879950106143951
Epoch 2047/20000 Training Loss: 0.07796955108642578
Epoch 2048/20000 Training Loss: 0.11798219382762909
Epoch 2049/20000 Training Loss: 0.0770544558763504
Epoch 2050/20000 Training Loss: 0.09727540612220764
Epoch 2051/20000 Training Loss: 0.06396621465682983
Epoch 2052/20000 Training Loss: 0.09404433518648148
Epoch 2053/20000 Training Loss: 0.08687062561511993
Epoch 2054/20000 Training Loss: 0.07077071070671082
Epoch 2055/20000 Training Loss: 0.08221372961997986
Epoch 2056/20000 Training Loss: 0.08921487629413605
Epoch 2057/20000 Training Loss: 0.08923289179801941
Epoch 2058/20000 Training Loss: 0.08170360326766968
Epoch 2059/20000 Training Loss: 0.06588734686374664
Epoch 2060/20000 Training Loss: 0.06345497816801071
Epoch 2061/20000 Training Loss: 0.07909655570983887
Epoch 2062/20000 Training Loss: 0.07536319643259048
Epoch 2063/20000 Training Loss: 0.07974624633789062
Epoch 2064/20000 Training Loss: 0.06217201054096222
Epoch 2065/20

Epoch 2202/20000 Training Loss: 0.06274010241031647
Epoch 2203/20000 Training Loss: 0.08107832819223404
Epoch 2204/20000 Training Loss: 0.09842483699321747
Epoch 2205/20000 Training Loss: 0.06846629083156586
Epoch 2206/20000 Training Loss: 0.05910788103938103
Epoch 2207/20000 Training Loss: 0.07133172452449799
Epoch 2208/20000 Training Loss: 0.08874179422855377
Epoch 2209/20000 Training Loss: 0.07982142269611359
Epoch 2210/20000 Training Loss: 0.07424063980579376
Epoch 2211/20000 Training Loss: 0.06996425986289978
Epoch 2212/20000 Training Loss: 0.09096123278141022
Epoch 2213/20000 Training Loss: 0.07522755861282349
Epoch 2214/20000 Training Loss: 0.0872848704457283
Epoch 2215/20000 Training Loss: 0.07412390410900116
Epoch 2216/20000 Training Loss: 0.08695347607135773
Epoch 2217/20000 Training Loss: 0.061848923563957214
Epoch 2218/20000 Training Loss: 0.061143264174461365
Epoch 2219/20000 Training Loss: 0.06513939797878265
Epoch 2220/20000 Training Loss: 0.0708424299955368
Epoch 2221/2

Epoch 2359/20000 Training Loss: 0.06865246593952179
Epoch 2360/20000 Training Loss: 0.058037735521793365
Epoch 2361/20000 Training Loss: 0.05827707052230835
Epoch 2362/20000 Training Loss: 0.07726965844631195
Epoch 2363/20000 Training Loss: 0.08287332952022552
Epoch 2364/20000 Training Loss: 0.08331416547298431
Epoch 2365/20000 Training Loss: 0.08781581372022629
Epoch 2366/20000 Training Loss: 0.07761381566524506
Epoch 2367/20000 Training Loss: 0.07030050456523895
Epoch 2368/20000 Training Loss: 0.06683748960494995
Epoch 2369/20000 Training Loss: 0.07945474982261658
Epoch 2370/20000 Training Loss: 0.06853698939085007
Epoch 2371/20000 Training Loss: 0.0620146207511425
Epoch 2372/20000 Training Loss: 0.07593760639429092
Epoch 2373/20000 Training Loss: 0.09449358284473419
Epoch 2374/20000 Training Loss: 0.07456015050411224
Epoch 2375/20000 Training Loss: 0.09256531298160553
Epoch 2376/20000 Training Loss: 0.0784815326333046
Epoch 2377/20000 Training Loss: 0.0673675686120987
Epoch 2378/200

Epoch 2513/20000 Training Loss: 0.07532051205635071
Epoch 2514/20000 Training Loss: 0.061697497963905334
Epoch 2515/20000 Training Loss: 0.07946896553039551
Epoch 2516/20000 Training Loss: 0.05868782103061676
Epoch 2517/20000 Training Loss: 0.09975220263004303
Epoch 2518/20000 Training Loss: 0.07303659617900848
Epoch 2519/20000 Training Loss: 0.10069385915994644
Epoch 2520/20000 Training Loss: 0.08092302083969116
Epoch 2521/20000 Training Loss: 0.06781352311372757
Epoch 2522/20000 Training Loss: 0.06688983738422394
Epoch 2523/20000 Training Loss: 0.07762955129146576
Epoch 2524/20000 Training Loss: 0.0665007159113884
Epoch 2525/20000 Training Loss: 0.06892295181751251
Epoch 2526/20000 Training Loss: 0.06190106272697449
Epoch 2527/20000 Training Loss: 0.09497955441474915
Epoch 2528/20000 Training Loss: 0.073136106133461
Epoch 2529/20000 Training Loss: 0.053002938628196716
Epoch 2530/20000 Training Loss: 0.06895796209573746
Epoch 2531/20000 Training Loss: 0.06804942339658737
Epoch 2532/20

Epoch 2670/20000 Training Loss: 0.08298369497060776
Epoch 2671/20000 Training Loss: 0.07443197071552277
Epoch 2672/20000 Training Loss: 0.06360991299152374
Epoch 2673/20000 Training Loss: 0.06929680705070496
Epoch 2674/20000 Training Loss: 0.06721184402704239
Epoch 2675/20000 Training Loss: 0.0774388462305069
Epoch 2676/20000 Training Loss: 0.05566001683473587
Epoch 2677/20000 Training Loss: 0.08030150830745697
Epoch 2678/20000 Training Loss: 0.060757920145988464
Epoch 2679/20000 Training Loss: 0.06650128960609436
Epoch 2680/20000 Training Loss: 0.05427195876836777
Epoch 2681/20000 Training Loss: 0.07104133069515228
Epoch 2682/20000 Training Loss: 0.0573749840259552
Epoch 2683/20000 Training Loss: 0.09410031139850616
Epoch 2684/20000 Training Loss: 0.08207038789987564
Epoch 2685/20000 Training Loss: 0.0672587975859642
Epoch 2686/20000 Training Loss: 0.06449584662914276
Epoch 2687/20000 Training Loss: 0.07324369251728058
Epoch 2688/20000 Training Loss: 0.06492390483617783
Epoch 2689/200

Epoch 2824/20000 Training Loss: 0.08727893233299255
Epoch 2825/20000 Training Loss: 0.06636542081832886
Epoch 2826/20000 Training Loss: 0.08939534425735474
Epoch 2827/20000 Training Loss: 0.06241645663976669
Epoch 2828/20000 Training Loss: 0.06514113396406174
Epoch 2829/20000 Training Loss: 0.06254584342241287
Epoch 2830/20000 Training Loss: 0.07995587587356567
Epoch 2831/20000 Training Loss: 0.09688697755336761
Epoch 2832/20000 Training Loss: 0.07007359713315964
Epoch 2833/20000 Training Loss: 0.05670282244682312
Epoch 2834/20000 Training Loss: 0.06413434445858002
Epoch 2835/20000 Training Loss: 0.09072592854499817
Epoch 2836/20000 Training Loss: 0.08481740951538086
Epoch 2837/20000 Training Loss: 0.07930748164653778
Epoch 2838/20000 Training Loss: 0.09490732848644257
Epoch 2839/20000 Training Loss: 0.06769412755966187
Epoch 2840/20000 Training Loss: 0.07601150870323181
Epoch 2841/20000 Training Loss: 0.07153439521789551
Epoch 2842/20000 Training Loss: 0.09787572920322418
Epoch 2843/2

Epoch 2981/20000 Training Loss: 0.06931187212467194
Epoch 2982/20000 Training Loss: 0.06773526966571808
Epoch 2983/20000 Training Loss: 0.07823944091796875
Epoch 2984/20000 Training Loss: 0.06605172157287598
Epoch 2985/20000 Training Loss: 0.07803372293710709
Epoch 2986/20000 Training Loss: 0.09460270404815674
Epoch 2987/20000 Training Loss: 0.06843529641628265
Epoch 2988/20000 Training Loss: 0.06872037053108215
Epoch 2989/20000 Training Loss: 0.0818481594324112
Epoch 2990/20000 Training Loss: 0.07152757793664932
Epoch 2991/20000 Training Loss: 0.08001013845205307
Epoch 2992/20000 Training Loss: 0.07511716336011887
Epoch 2993/20000 Training Loss: 0.06642137467861176
Epoch 2994/20000 Training Loss: 0.07182809710502625
Epoch 2995/20000 Training Loss: 0.058072321116924286
Epoch 2996/20000 Training Loss: 0.07742352783679962
Epoch 2997/20000 Training Loss: 0.07854194939136505
Epoch 2998/20000 Training Loss: 0.05974597483873367
Epoch 2999/20000 Training Loss: 0.08207251876592636
Epoch 3000/2

Epoch 3137/20000 Training Loss: 0.05817835032939911
Epoch 3138/20000 Training Loss: 0.09741011261940002
Epoch 3139/20000 Training Loss: 0.07529369741678238
Epoch 3140/20000 Training Loss: 0.055924512445926666
Epoch 3141/20000 Training Loss: 0.06300758570432663
Epoch 3142/20000 Training Loss: 0.08440344035625458
Epoch 3143/20000 Training Loss: 0.07784033566713333
Epoch 3144/20000 Training Loss: 0.05606796592473984
Epoch 3145/20000 Training Loss: 0.06654930859804153
Epoch 3146/20000 Training Loss: 0.0475783608853817
Epoch 3147/20000 Training Loss: 0.09133085608482361
Epoch 3148/20000 Training Loss: 0.09298364818096161
Epoch 3149/20000 Training Loss: 0.06632695347070694
Epoch 3150/20000 Training Loss: 0.05978942662477493
Epoch 3151/20000 Training Loss: 0.07736018300056458
Epoch 3152/20000 Training Loss: 0.05693785101175308
Epoch 3153/20000 Training Loss: 0.07197047770023346
Epoch 3154/20000 Training Loss: 0.058994218707084656
Epoch 3155/20000 Training Loss: 0.0798799991607666
Epoch 3156/2

Epoch 3294/20000 Training Loss: 0.07834360003471375
Epoch 3295/20000 Training Loss: 0.060788653790950775
Epoch 3296/20000 Training Loss: 0.0884154736995697
Epoch 3297/20000 Training Loss: 0.0839645266532898
Epoch 3298/20000 Training Loss: 0.06538420915603638
Epoch 3299/20000 Training Loss: 0.05391278117895126
Epoch 3300/20000 Training Loss: 0.07726109027862549
Epoch 3300/20000 Validation Loss: 0.04974125325679779
Cleared directory to save new best model.
Epoch 3301/20000 Training Loss: 0.0648718997836113
Epoch 3302/20000 Training Loss: 0.07167263329029083
Epoch 3303/20000 Training Loss: 0.09131444990634918
Epoch 3304/20000 Training Loss: 0.06988278031349182
Epoch 3305/20000 Training Loss: 0.07338927686214447
Epoch 3306/20000 Training Loss: 0.05490923672914505
Epoch 3307/20000 Training Loss: 0.08390794694423676
Epoch 3308/20000 Training Loss: 0.06469352543354034
Epoch 3309/20000 Training Loss: 0.0751429870724678
Epoch 3310/20000 Training Loss: 0.07375295460224152
Epoch 3311/20000 Traini

Epoch 3448/20000 Training Loss: 0.059797272086143494
Epoch 3449/20000 Training Loss: 0.07354944199323654
Epoch 3450/20000 Training Loss: 0.06561142951250076
Epoch 3451/20000 Training Loss: 0.06804556399583817
Epoch 3452/20000 Training Loss: 0.09614618122577667
Epoch 3453/20000 Training Loss: 0.0660032331943512
Epoch 3454/20000 Training Loss: 0.0928645208477974
Epoch 3455/20000 Training Loss: 0.08071582019329071
Epoch 3456/20000 Training Loss: 0.08489927649497986
Epoch 3457/20000 Training Loss: 0.047013044357299805
Epoch 3458/20000 Training Loss: 0.07019171863794327
Epoch 3459/20000 Training Loss: 0.05949132889509201
Epoch 3460/20000 Training Loss: 0.06015308201313019
Epoch 3461/20000 Training Loss: 0.06496188789606094
Epoch 3462/20000 Training Loss: 0.06521962583065033
Epoch 3463/20000 Training Loss: 0.10224416851997375
Epoch 3464/20000 Training Loss: 0.07076326757669449
Epoch 3465/20000 Training Loss: 0.06341177225112915
Epoch 3466/20000 Training Loss: 0.07177804410457611
Epoch 3467/2

Epoch 3604/20000 Training Loss: 0.06958694756031036
Epoch 3605/20000 Training Loss: 0.0628698319196701
Epoch 3606/20000 Training Loss: 0.055520251393318176
Epoch 3607/20000 Training Loss: 0.06911913305521011
Epoch 3608/20000 Training Loss: 0.07754819840192795
Epoch 3609/20000 Training Loss: 0.03892862796783447
Epoch 3610/20000 Training Loss: 0.06908397376537323
Epoch 3611/20000 Training Loss: 0.06521175801753998
Epoch 3612/20000 Training Loss: 0.0899176076054573
Epoch 3613/20000 Training Loss: 0.05015973746776581
Epoch 3614/20000 Training Loss: 0.07533019036054611
Epoch 3615/20000 Training Loss: 0.06084330379962921
Epoch 3616/20000 Training Loss: 0.0693037211894989
Epoch 3617/20000 Training Loss: 0.07506394386291504
Epoch 3618/20000 Training Loss: 0.06427001953125
Epoch 3619/20000 Training Loss: 0.07019494473934174
Epoch 3620/20000 Training Loss: 0.08232870697975159
Epoch 3621/20000 Training Loss: 0.06128004193305969
Epoch 3622/20000 Training Loss: 0.06890078634023666
Epoch 3623/20000 

Epoch 3761/20000 Training Loss: 0.0623926967382431
Epoch 3762/20000 Training Loss: 0.09174452722072601
Epoch 3763/20000 Training Loss: 0.09837212413549423
Epoch 3764/20000 Training Loss: 0.07083550840616226
Epoch 3765/20000 Training Loss: 0.04941398650407791
Epoch 3766/20000 Training Loss: 0.07279811799526215
Epoch 3767/20000 Training Loss: 0.05593222379684448
Epoch 3768/20000 Training Loss: 0.07709196954965591
Epoch 3769/20000 Training Loss: 0.060547955334186554
Epoch 3770/20000 Training Loss: 0.08292153477668762
Epoch 3771/20000 Training Loss: 0.077231265604496
Epoch 3772/20000 Training Loss: 0.07166370749473572
Epoch 3773/20000 Training Loss: 0.06200101226568222
Epoch 3774/20000 Training Loss: 0.06350120902061462
Epoch 3775/20000 Training Loss: 0.058579884469509125
Epoch 3776/20000 Training Loss: 0.06242026761174202
Epoch 3777/20000 Training Loss: 0.052989792078733444
Epoch 3778/20000 Training Loss: 0.08255799859762192
Epoch 3779/20000 Training Loss: 0.07636690884828568
Epoch 3780/2

In [None]:
# dt = datetime.now().strftime("%m%d")
# best_model_fp = f'models/{dt}/model_best_epoch_19200_min_val_loss_0.13920000195503235.pkl'
# with open(best_model_fp, 'rb') as handle:
#     best_model = pickle.load(handle)

# Generating sequences

In [None]:
word_lst_source, word_lst_recover, word_lst_ref = sampling(best_model, diffusion, tokenizer, data_dir=regular_data_dir, batch_size=10, split='test_custom', seq_len=20)

Generating 20 sentences takes 5 minutes

In [None]:
word_lst_source

In [None]:
word_lst_recover

In [None]:
word_lst_ref