In [1]:
import argparse
import json
import sys

import torch
from torch import nn, optim

sys.path.append('..')
sys.path.append('.')

from my_data import VOCAB, MyDataset, color_print
from my_models import MyModel0
from my_utils import pred_to_dict

In [6]:
#argument for training
args = {
    "batch_size": 16,
    "max_epoch": 1500,
    "val_at": 100,
    "hidden_size": 256,
    "val_size": 76,
    "device": "cpu"
}


In [9]:
#setup torch device to cuda if available
if torch.cuda.is_available():
    args['device'] = 'cuda'

In [69]:
# declare model
model = MyModel0(len(VOCAB), 16, args['hidden_size']).to(args['device'])

In [144]:
dataset = MyDataset(dict_path="../data/data_dict4.pth", device="cuda", )

In [142]:
# print train/val data classes value
def print_data_class(_data: dict=None, key_name: str=None):
    for cls in range(1, 5):
        print('class', cls, end=': \t')
        for idx, cl in enumerate(_data[key_name][1]):
            if cl == cls:
                print(_data[key_name][0][idx], end='')
        print()

In [145]:
# display few train data
for key in list(dataset.train_dict.keys())[:2]:
    print('-'*10, key, '-'*10)
    print_data_class(_data=dataset.train_dict, key_name=key)
    print('='*30)

---------- X51007339651 ----------
class 1: 	AIK HUAT HARDWARE
ENTERPRISE (SETIA
ALAM) SDN BHD
class 2: 	29/12/2017
class 3: 	NO. 17-G, JALAN SETIA INDAH
(X) U13/X, SETIA ALAM,
SEKSYEN U13, 40170 SHAH ALAM,
class 4: 	7.00
---------- X51005568880 ----------
class 1: 	MR.D.I.Y(M)SDN BHD
class 2: 	19-09-17
class 3: 	LOT 1851-A & 1851-B, JALAN KPB 6,
KAWASAN PERINDUSTRIAN BALAKONG,
43300 SERI KEMBANGAN, SELANGOR
class 4: 	15.90


In [70]:
# TODO: criterion ??
criterion = nn.CrossEntropyLoss(weight=torch.tensor([0.1, 1, 1.2, 0.8, 1.5], device=args['device']))
optimizer = optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.StepLR(optimizer, 1000)

In [72]:
# define a train method
def train(model, dataset, criterion, optimizer, epoch_range, batch_size):
    model.train()

    for epoch in range(*epoch_range):
        optimizer.zero_grad()

        text, truth = dataset.get_train_data(batch_size=batch_size)
        pred = model(text)

        loss = criterion(pred.view(-1, 5), truth.view(-1))
        loss.backward()

        optimizer.step()

        print(f"#{epoch:04d} | Loss: {loss.item():.4f}")

In [126]:
def validate(model, dataset, batch_size=1):
    model.eval()
    with torch.no_grad():
        keys, text, truth = dataset.get_val_data(batch_size=batch_size)

        oupt = model(text)
        prob = torch.nn.functional.softmax(oupt, dim=2)
        prob, pred = torch.max(prob, dim=2)

        prob = prob.cpu().numpy()
        pred = pred.cpu().numpy()

        for i, key in enumerate(keys):
            real_text, _ = dataset.val_dict[key]
            result = pred_to_dict(real_text, pred[:, i], prob[:, i])

            for k, v in result.items():
                print(f"{k:>8}: {v}")

#             color_print(real_text, pred[:, i])

In [99]:
# train for specified epochs
for i in range(args['max_epoch'] // args['val_at']):
    train(model, dataset, criterion, optimizer, (i * args['val_at'] + 1, (i + 1) * args['val_at'] + 1), args['batch_size'])

#0001 | Loss: 0.1004
#0002 | Loss: 0.0973
#0003 | Loss: 0.0681
#0004 | Loss: 0.0765
#0005 | Loss: 0.1816
#0006 | Loss: 0.1144
#0007 | Loss: 0.0994
#0008 | Loss: 0.1774
#0009 | Loss: 0.1238
#0010 | Loss: 0.2092
#0011 | Loss: 0.1528
#0012 | Loss: 0.1401
#0013 | Loss: 0.0767
#0014 | Loss: 0.0750
#0015 | Loss: 0.1315
#0016 | Loss: 0.1398
#0017 | Loss: 0.1239
#0018 | Loss: 0.1057
#0019 | Loss: 0.0923
#0020 | Loss: 0.1346
#0021 | Loss: 0.1271
#0022 | Loss: 0.0919
#0023 | Loss: 0.1023
#0024 | Loss: 0.1268
#0025 | Loss: 0.1191
#0026 | Loss: 0.1055
#0027 | Loss: 0.0926
#0028 | Loss: 0.1068
#0029 | Loss: 0.1090
#0030 | Loss: 0.0767
#0031 | Loss: 0.0749
#0032 | Loss: 0.0874
#0033 | Loss: 0.1847
#0034 | Loss: 0.1594
#0035 | Loss: 0.1002
#0036 | Loss: 0.1006
#0037 | Loss: 0.1079
#0038 | Loss: 0.1348
#0039 | Loss: 0.0977
#0040 | Loss: 0.1564
#0041 | Loss: 0.0751
#0042 | Loss: 0.0728
#0043 | Loss: 0.1058
#0044 | Loss: 0.0955
#0045 | Loss: 0.2152
#0046 | Loss: 0.1245
#0047 | Loss: 0.0715
#0048 | Loss:

#0392 | Loss: 0.0388
#0393 | Loss: 0.0395
#0394 | Loss: 0.0669
#0395 | Loss: 0.0478
#0396 | Loss: 0.0919
#0397 | Loss: 0.0252
#0398 | Loss: 0.0308
#0399 | Loss: 0.0278
#0400 | Loss: 0.0538
#0401 | Loss: 0.0509
#0402 | Loss: 0.0665
#0403 | Loss: 0.0839
#0404 | Loss: 0.0375
#0405 | Loss: 0.0298
#0406 | Loss: 0.1191
#0407 | Loss: 0.0368
#0408 | Loss: 0.1036
#0409 | Loss: 0.0415
#0410 | Loss: 0.0260
#0411 | Loss: 0.0630
#0412 | Loss: 0.0324
#0413 | Loss: 0.0552
#0414 | Loss: 0.0546
#0415 | Loss: 0.1527
#0416 | Loss: 0.0431
#0417 | Loss: 0.0267
#0418 | Loss: 0.0378
#0419 | Loss: 0.0638
#0420 | Loss: 0.0361
#0421 | Loss: 0.0474
#0422 | Loss: 0.0534
#0423 | Loss: 0.0460
#0424 | Loss: 0.0739
#0425 | Loss: 0.0446
#0426 | Loss: 0.0505
#0427 | Loss: 0.0641
#0428 | Loss: 0.0537
#0429 | Loss: 0.0641
#0430 | Loss: 0.0577
#0431 | Loss: 0.0495
#0432 | Loss: 0.0832
#0433 | Loss: 0.0380
#0434 | Loss: 0.0556
#0435 | Loss: 0.0484
#0436 | Loss: 0.0831
#0437 | Loss: 0.0361
#0438 | Loss: 0.0363
#0439 | Loss:

#0783 | Loss: 0.0759
#0784 | Loss: 0.0576
#0785 | Loss: 0.0876
#0786 | Loss: 0.0377
#0787 | Loss: 0.0185
#0788 | Loss: 0.0274
#0789 | Loss: 0.0355
#0790 | Loss: 0.0353
#0791 | Loss: 0.0200
#0792 | Loss: 0.0272
#0793 | Loss: 0.0556
#0794 | Loss: 0.0180
#0795 | Loss: 0.0608
#0796 | Loss: 0.0508
#0797 | Loss: 0.0377
#0798 | Loss: 0.0435
#0799 | Loss: 0.0296
#0800 | Loss: 0.0471
#0801 | Loss: 0.0250
#0802 | Loss: 0.0514
#0803 | Loss: 0.0418
#0804 | Loss: 0.0212
#0805 | Loss: 0.0319
#0806 | Loss: 0.0196
#0807 | Loss: 0.0504
#0808 | Loss: 0.0523
#0809 | Loss: 0.0464
#0810 | Loss: 0.0306
#0811 | Loss: 0.0404
#0812 | Loss: 0.0581
#0813 | Loss: 0.0279
#0814 | Loss: 0.0463
#0815 | Loss: 0.0428
#0816 | Loss: 0.0388
#0817 | Loss: 0.0383
#0818 | Loss: 0.0658
#0819 | Loss: 0.0271
#0820 | Loss: 0.0288
#0821 | Loss: 0.0516
#0822 | Loss: 0.0194
#0823 | Loss: 0.0202
#0824 | Loss: 0.0392
#0825 | Loss: 0.0339
#0826 | Loss: 0.0703
#0827 | Loss: 0.0394
#0828 | Loss: 0.0421
#0829 | Loss: 0.0225
#0830 | Loss:

#1174 | Loss: 0.0110
#1175 | Loss: 0.0123
#1176 | Loss: 0.0098
#1177 | Loss: 0.0193
#1178 | Loss: 0.0112
#1179 | Loss: 0.0108
#1180 | Loss: 0.0141
#1181 | Loss: 0.0150
#1182 | Loss: 0.0211
#1183 | Loss: 0.0124
#1184 | Loss: 0.0090
#1185 | Loss: 0.0261
#1186 | Loss: 0.0109
#1187 | Loss: 0.0136
#1188 | Loss: 0.0061
#1189 | Loss: 0.0118
#1190 | Loss: 0.0145
#1191 | Loss: 0.0089
#1192 | Loss: 0.0145
#1193 | Loss: 0.0072
#1194 | Loss: 0.0113
#1195 | Loss: 0.0096
#1196 | Loss: 0.0079
#1197 | Loss: 0.0180
#1198 | Loss: 0.0097
#1199 | Loss: 0.0256
#1200 | Loss: 0.0467
#1201 | Loss: 0.0119
#1202 | Loss: 0.0090
#1203 | Loss: 0.0140
#1204 | Loss: 0.0142
#1205 | Loss: 0.0155
#1206 | Loss: 0.0133
#1207 | Loss: 0.0173
#1208 | Loss: 0.0288
#1209 | Loss: 0.0088
#1210 | Loss: 0.0170
#1211 | Loss: 0.0232
#1212 | Loss: 0.0062
#1213 | Loss: 0.0201
#1214 | Loss: 0.0064
#1215 | Loss: 0.0083
#1216 | Loss: 0.0078
#1217 | Loss: 0.0085
#1218 | Loss: 0.0076
#1219 | Loss: 0.0111
#1220 | Loss: 0.0441
#1221 | Loss:

8888002188511	1	2.30	2.44	2.44	SR
MINUTE MAID PULPY [OREN] [350ML]
8888002188511	1	2.30	2.44	2.44	SR
MINUTE MAID PULPY [OREN] [350ML]
9556145026	1	1.51	1.60	1.60	SR
TINGE LEMON DRINK 500ML
955657031213	2.74	2.90	2.90	SR
100 PLUS [1.50]
TOTAL QTY:	25
32.69
TOTAL SALES (EXCLUDING GST) :	31.03
DISCOUNT :	0.00
TOTAL GST :
1.65
ROUNDING :	0.02
TOTAL SALES (INCLUSIVE OF GST) :
32.70
CASH :	40.00
CHANGE :
7.30
GST SUMMARY
TAX CODE	%	AMT (RM)
SR	6	TAX (RM)
27.48
ZRL	0	3.55	1.65
TOTAL :	31.03	0.00
1.65
GOODS SOLD ARE NOT RETURNABLE, THANK YOU.

 company: AEON CO. (M) BHD
    date: 18/04/2018
 address: 3RD FLR, AEON TAMAN MALURI SC JLN JEJAKA, TAMAN MALURI CHERAS, 55100 KUALA LUMPUR
   total: 
AEON CO. (M) BHD (126926-H)
3RD FLR, AEON TAMAN MALURI SC
JLN JEJAKA, TAMAN MALURI
CHERAS, 55100 KUALA LUMPUR
GST ID : 002017394688
SHOPPING HOURS
SUN-THU:1000 HRS - 2200 HRS
FRI-SAT:1000 HRS - 2300 HRS
VALUED CUSTOMER: 1170086176
1X 000000811101	2.50SR
TAMAGO (S)
DISC 30% @1.75	-0.75
SUB-TOTAL	1.75
TOTAL 

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



F GST):	19.40
TOTAL :	19.40
CASH :	19.40
GST SUMMARY	AMOUNT(RM)	TAX(RM)
SR	(@ 6%)	18.31	1.09

 company: DION REALTIES SDN BHD
    date: 30/05/18
 address: MENARA DION #02-03, LEVEL 2, 27, JALAN SULTAN ISMAIL, 50250 KUALA LUMPUR.
   total: 
DION REALTIES SDN BHD (CO. NO:20154-T)
(GST REGISTRATION NO : 000650247680)
MENARA DION #02-03, LEVEL 2,
27, JALAN SULTAN ISMAIL,
50250 KUALA LUMPUR.
TEL : + 6 03 2026 6386
FAX : +6 03 2026 6387
TAX INVOICE
TAX INVOICE NO. 3521/0602/00602
30/05/18 11:01
010100 PAY PARKING TICKET	5.00 RM
30/05/18 10:45 - 30/05/18 11:01
LENGTH OF STAY: 0 DY. 0 HR. 16 MIN.
02992887002011018150387040??
AMOUNT INCL. GST	5.00 RM
ACCEPTAD TOTAL	5.00 RM
GST 6%	0.28 RM
THANK YOU

 company: YAM FRESH
    date: 2016-07-31
 address: NO.145G, JALAN RIMBUNAN RAYA 1, LAMAN RIMBUNAN KEPONG, 52100 KUALA LUMPUR
   total: 
YAM FRESH
NO.145G, JALAN RIMBUNAN RAYA 1,
LAMAN RIMBUNAN KEPONG,
52100 KUALA LUMPUR
TEL: (603) 6243 5520
GST ID=001817907200
INVOICE: 001-9822	DINE IN
DATE: 2016-07-

In [128]:
# validate with sample data (with unseen data)
validate(model, dataset, batch_size=10)

 company: CITY MILK R&C VENTURE SDN BHD
    date: 19-04-2018
 address: LOT-18A-2, BERJAYA TIMES SQUARE, KUALA LUMPUR.
   total: 
 company: AEON CO. (M) BHD
    date: 18/04/2018
 address: 3RD FLR, AEON TAMAN MALURI SC JLN JEJAKA, TAMAN MALURI CHERAS, 55100 KUALA LUMPUR
   total: 
 company: UNIHAKKA INTERNATIONAL SDN BHD
    date: 15 MAY 2018
 address: 12, JALAN TAMPOI 7/4,KAWASAN PERINDUSTRIAN TAMPOI,81200 JOHOR BAHRU,JOHOR
   total: 
 company: ASIA MART
    date: 22/12/2017
 address: NO.23 BATU 10, TAMAN SENTOSA, JALAN KAPAR, 42200 KLANG, SELANGOR.
   total: 
 company: SANYU STATIONERY SHOP
    date: 18/04/2017
 address: NO. 31G&33G, JALAN SETIA INDAH X ,U13/X 40170 SETIA ALAM
   total: 
 company: BAR CKLCC T.A.S LEISURE SDN BH
    date: 
 address: 
   total: 
 company: RESTORAN WAN SHENG
    date: 09-06-2018
 address: NO.2, JALAN TEMENGGUNG 19/9, SEKSYEN 9, BANDAR MAHKOTA CHERAS, 43200 CHERAS, SELANGOR
   total: 
 company: SANYU STATIONERY SHOP
    date: 01/11/2017
 address: NO. 31G&3

In [125]:
# display torch tensor data (train / val)
for d in _data[0]:
    print(VOCAB[d[1]], end='')

                              Y SOON FATT S/B (81497-P)
LOT 1504, BATU 8 1/2, JALAN KLANG LAMA,
46000 PETALING JAYA, SELANGOR.
TEL : 016-2014209
GST REG NO : 000788250624
TAX INVOICE
DESC	QTY	PRICE	AMOUNT	TAXCODE
(RM)	(RM)
51190030	A1 BIHUN ISTIMEWA 3KG
10 *	10.00	100.00 ZRL
060400063	YSF BUAH KERAS 3KG
1 *	36.80	36.80 SR
012700054	HOBE RED KIDNEY BEAN 425GM
1 *	58.00	58.00 SR
51580020	CAP LANG JAGUNG MANIS 425GM
1 *	53.80	53.80 SR
50460020	DAIRY CHAMP EVAP 390GM
1 *	113.90	113.90 SR
TOTAL INC. GST	6%:	362.50
ROUNDING ADJUSTMENT:	0.00
FINAL TOTAL	:	362.50
CASH	:	400.00
CHANGE	:	37.50
GST SUMMARY	AMOUNT(RM)	TAX(RM)
ZRL 0 %	100.00	0.00
SR 6 %	247.64	14.86
C02102772	9/2/2017 12:00 PM
T03	2006                                                                                                                                                                                                                                    

In [92]:
dataset.train_dict['X51007225442'][0][:22]

'MR. D.I.Y. (M) SDN BHD'

In [93]:
test_dict = torch.load('../data/test_dict.pth')

In [98]:
test_dict['X00016469670']

'TAN CHAY YEE\n*** COPY ***\nOJC MARKETING SDN BHD\nROC NO: 538358-H\nNO 2 & 4, JALAN BAYU 4,\nBANDAR SERI ALAM,\n81750 MASAI, JOHOR\nTEL:07-388 2218 FAX:07-388 8218\nEMAIL:NG@OJCGROUP.COM\nTAX INVOICE\nINVOICE NO\t: PEGIV-1030765\nDATE\t: 15/01/2019 11:05:16 AM\nCASHIER\t: NG CHUAN MIN\nSALES PERSON : FATIN\nBILL TO\t: THE PEAK QUARRY WORKS\nADDRESS\t:.\nDESCRIPTION\tQTY\tPRICE\tAMOUNT\n000000111\t1\t193.00\t193.00 SR\nKINGS SAFETY SHOES KWD B05\nQTY: 1\tTOTAL EXCLUDE GST:\t193.00\nTOTAL GST @6%:\t0.00\nTOTAL INCLUSIVE GST:\t193.00\nROUND AMT:\t0.00\nTOTAL:\t193.00\nVISA CARD\t193.00\nXXXXXXXXXXXX4318\nAPPROVAL CODE:000\nGOODS SOLD ARE NOT RETURNABLE & REFUNDABLE\n****THANK YOU. PLEASE COME AGAIN.****'

In [129]:
_data4 = torch.load('../data/data_dict4.pth')

{'X00016469612': ('TAN WOON YANN\nBOOK TA .K(TAMAN DAYA) SDN BND\n789417-W\nNO.53 55,57 & 59, JALAN SAGU 18,\nTAMAN DAYA,\n81100 JOHOR BAHRU,\nJOHOR.\nDOCUMENT NO : TD01167104\nDATE:\t25/12/2018 8:13:39 PM\nCASHIER:\tMANIS\nMEMBER:\nCASH BILL\nCODE/DESC\tPRICE\tDISC\tAMOUNT\nQTY\tRM\tRM\n9556939040116\tKF MODELLING CLAY KIDDY FISH\n1 PC\t*\t9.000\t0.00\t9.00\nTOTAL:\t9.00\nROUNDING ADJUSTMENT:\t0.00\nROUNDED TOTAL (RM):\t9.00\nCASH\t10.00\nCHANGE\t1.00\nGOODS SOLD ARE NOT RETURNABLE OR\nEXCHANGEABLE\n***\n***\nTHANK YOU\nPLEASE COME AGAIN !',
  array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
         3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
         3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
         3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0,
 

In [132]:
type(_data4)

dict

In [141]:
_data4['X00016469612'][1]

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,