# ミニバッチ学習の実装

In [1]:
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import math

# Load the MNIST dataset
import tensorflow as tf
mnist = tf.keras.datasets.mnist
(X_train, y_train),(X_test, y_test) = mnist.load_data()

from sklearn.preprocessing import LabelBinarizer
lb = LabelBinarizer()

train = X_train/255
test = X_test/255
train = train.reshape(-1, 28*28)
test = test.reshape(-1, 28*28)
train_labels = lb.fit_transform(y_train)
test_labels = lb.fit_transform(y_test)

## ミニバッチ学習
* ミニバッチ学習は、一般的には非復元抽出によって行われることが多いが、必ずこうしなければならないというわけではなく、分析者がデータセットの与え方を工夫することもできる。ただし、工夫しても計算が上手くいくとは限らない。
* 工夫のしどころ。
    * 一般的には、エポックのたびにシャッフルするが、シャッフルするタイミングを任意に変えてみる  
    * 与えるミニバッチ の順番を意図的に操作してみる   
        * 例、出現頻度の少ないラベルのデータを先に学習させる
    * 抽出されるラベルの割合が一定になるように抽出してみる
    * 復元抽出にしてみる

In [2]:
def trainer(network, x, y):
    """
    学習用の関数
    このnotebookでは、ミニバッチ学習を学ぶことが目的であるため、この関数の中身は空のままにしておく
    実際には、何らかの学習させるための処理を記述する
    """
    pass
    return 

### ミニバッチ学習のループ(復元抽出)

In [3]:
np.random.seed(1234)
train_size = train_labels.shape[0]
batch_size = 32
max_iter = 10  #ループの回数
network = None #ダミー

for i in range(max_iter):
    batch_mask = np.random.choice(train_size, batch_size) # 復元抽出
    print("i=%s, "%i, "batch_mask=%s"%batch_mask[:10])
    x_batch = train[batch_mask]
    y_batch = train_labels[batch_mask]

    trainer(network, x_batch, y_batch)

i=0,  batch_mask=[27439 58067 34086 56373 23924 17048 55289 32399 55985 41239]
i=1,  batch_mask=[22267 21580 14629 12198 50682 46275 10983 23691 39552 21225]
i=2,  batch_mask=[29280 27532 28869 31741 49777  3039  8165 28319  8953 29692]
i=3,  batch_mask=[50926 18991  3062 43252 43382 58031 58255 51746    10  3356]
i=4,  batch_mask=[48917 55492 14455 34791 59642 43444 43456 20949  4468 43881]
i=5,  batch_mask=[12184 57310 45801 55823  6984  4994 52624   619 21634 19186]
i=6,  batch_mask=[27117 42059  4111 34580  4880 36288 26464 32382 26835 40249]
i=7,  batch_mask=[ 9624 10252 10700 18272  7688 13615 12057 51949 55061 35990]
i=8,  batch_mask=[49068 54819 35754 49556 43802 12633 59499 36759 32386 47848]
i=9,  batch_mask=[ 9532 31315 31088 29429 21129 37436 32946 35249 59498 17095]


### 復元抽出部分を理解するためのコード

In [4]:
np.random.seed(1234)
batch_mask = np.random.choice(train_size, batch_size)
print(len(set(batch_mask)))
print("batch_mask=", batch_mask)
print()

x_batch = train[batch_mask]
print("x_batch=", x_batch)
print("x_batch.shape=", x_batch.shape)
print()
y_batch = train_labels[batch_mask]
print("y_batch=", y_batch)
print("y_batch.shape=", y_batch.shape)
print()

32
batch_mask= [27439 58067 34086 56373 23924 17048 55289 32399 55985 41239  9449 23706
  8222 32427 33950 40684  8060  7962 13686 59834 59512 14192  7644 27973
 27984 41929 51583 49398  2558 36271 38450  3824]

x_batch= [[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
x_batch.shape= (32, 784)

y_batch= [[0 1 0 0 0 0 0 0 0 0]
 [0 0 0 1 0 0 0 0 0 0]
 [1 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 0 0 0]
 [0 0 0 0 0 0 1 0 0 0]
 [1 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 0 0 0]
 [1 0 0 0 0 0 0 0 0 0]
 [0 0 1 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 1]
 [0 0 1 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 0 0 0]
 [0 1 0 0 0 0 0 0 0 0]
 [1 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 1 0 0]
 [0 0 0 0 0 0 0 1 0 0]
 [0 0 0 0 0 0 0 0 0 1]
 [0 0 0 0 0 1 0 0 0 0]
 [0 0 0 0 0 0 1 0 0 0]
 [0 0 0 0 0 0 0 0 1 0]
 [0 0 0 1 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 1 0]
 [0 0 0 0 0 0 0 1 0 0]
 [0 0 0 0 0 0 0 0 1 0]
 [0 0 0 0 0 0 0 1 0 0]
 [0 0 0 1 0 0

In [7]:
# 復元抽出部分(何回か実行してみてください)
np.random.choice(10,3)

array([5, 4, 4])

### ミニバッチ学習のループ(非復元抽出)

### [演習]
* 以下の非復元抽出によるミニバッチ学習を完成させましょう。
* 通常の計算では、非復元抽出で行うことが多いです。

In [6]:
# ヒント
index = np.arange(10)
print("index=%s"%index)
np.random.seed(1234)
np.random.shuffle(index)
print("index=%s"%index)
print()
print(np.random.permutation(10))
print()

for i in range(4):
    print(index[3*i:3*(i+1)])
    
print(np.ceil(1.1), np.ceil(1.7), np.ceil(2.7)) # ceilは切り上げ関数

index=[0 1 2 3 4 5 6 7 8 9]
index=[7 2 9 1 0 8 4 5 6 3]

[7 3 5 1 4 8 0 2 6 9]

[7 2 9]
[1 0 8]
[4 5 6]
[3]
2.0 2.0 3.0


In [10]:
np.random.seed(1234)
train_size = train_labels.shape[0]
batch_size = 32
epochs = 10
network = None #ダミー
minibatch_num = np.ceil(train_size / batch_size).astype(int) # ミニバッチの個数
    
for epoch in range(epochs):
    print()
    
    # indexを定義し、シャッフルする
    index = np.arange(train_size)
    np.random.shuffle(index)
    
    for mn in range(minibatch_num):
        """
        非復元抽出によるループ
        """
        batch_mask = index[batch_size*mn:batch_size*(mn+1)]     
        print("epoch=%s, "%epoch, "batch_mask=%s"%batch_mask[:10])
        x_batch = train[batch_mask]
        y_batch = train_labels[batch_mask]

        trainer(network, x_batch, y_batch)


epoch=0,  batch_mask=[30329 44957 30866 40447 25580  6216 26373  9010 23445   108]
epoch=0,  batch_mask=[13638 33368 49673 31890  4891 31171 44103 57851   496  8184]
epoch=0,  batch_mask=[29174 33167 13319 44233 53744 55496 56873 40631 44358 58144]
epoch=0,  batch_mask=[40541 56435 12662   831 17775 10047  6291 23045 31666  2987]
epoch=0,  batch_mask=[36278 23185  8140 22416 51789 26671 22337 46976 22482 30227]
epoch=0,  batch_mask=[25895 28676 38529 48020 13975 59520  1465  7905 57265 33004]
epoch=0,  batch_mask=[26568 20475 43172 32824 23933 14617 39765 20889 19788 48315]
epoch=0,  batch_mask=[54250  6397 17608 21736 59230 13391 44222  5637 18489  9823]
epoch=0,  batch_mask=[49311 12279  4107 10231 37346 25028 12961 17640 47765 22661]
epoch=0,  batch_mask=[15830 26793 32555 47856 25701 13929 18850 28542  8736 28550]
epoch=0,  batch_mask=[ 9604 56237 46403 36689 41305 53030  3272 13921   806  8014]
epoch=0,  batch_mask=[56428 33449 23359 55130 10066  8925  2652 54683 42575   173]
epo

epoch=0,  batch_mask=[48154 12063 18204 49801 45995 10389 49130 36099 46904 41187]
epoch=0,  batch_mask=[49406  4819 22008 37715  5906 15061 15387 10623  3349 44129]
epoch=0,  batch_mask=[28009 48303 18227 33893  8895  8173 58700 43366 44612 29058]
epoch=0,  batch_mask=[24354  2414 37997 10907 28772  9571 47515 59931 17378 36637]
epoch=0,  batch_mask=[58072 19129 31307  6188 59393 11867 35761 25227 31881 12975]
epoch=0,  batch_mask=[33634 41012 12558  2723 43143 48755  5981 25025 52307   671]
epoch=0,  batch_mask=[11952 55560 16024 13069 36839  8159  8255 11485 54609 50577]
epoch=0,  batch_mask=[25344 56072 39966 43433 57894 53882 59493 13962  9152 31782]
epoch=0,  batch_mask=[10316 26035 27621  1354 37497 14804 12372 28188  7955 22351]
epoch=0,  batch_mask=[15020 37445  4535  9783 42447  6036 57772 10161  6264 41388]
epoch=0,  batch_mask=[25109  2110 46468 34555 24335 21702  7077 56011 37707 30979]
epoch=0,  batch_mask=[39377 45514 10006 11613 22309 19290  6158  9400 29601  4808]
epoc

epoch=1,  batch_mask=[13242 11171 34445 57481 28137 47499 24995 22556 26955 27404]
epoch=1,  batch_mask=[42769 48359 26027 39318 34522 47002 21689 38310 14116 37615]
epoch=1,  batch_mask=[24703 11723 33539 11398 43149 38620 47432 30171 16529  6248]
epoch=1,  batch_mask=[ 5545  7406  4895  1782 38176 33778 28709 24294  1870 33010]
epoch=1,  batch_mask=[ 8245 50336 26763 31625 45087 22735 15252 56999 31917 17600]
epoch=1,  batch_mask=[22779 41849 55063 45264 52009 21612 18551 40273 58122 56869]
epoch=1,  batch_mask=[23351 47335  8803  2219 35071 13301 55966 37069 51910 45953]
epoch=1,  batch_mask=[19034 39349 17686 44091 55997 49649 27794 21574 37826  4484]
epoch=1,  batch_mask=[30684 51603  1776 14953  8475  7842 40350 16606 38189 47751]
epoch=1,  batch_mask=[37507 49139 56078 57329 59645 49462 16588 25189 24378  7003]
epoch=1,  batch_mask=[24306  3510 51539 35931 21006 48341 52928 10998 41373 33142]
epoch=1,  batch_mask=[40882 13862 24650 54847  1738 11193 31861 49329 37458  1891]
epoc

epoch=1,  batch_mask=[11905 30175 36692  4584 23153 19990 35044 19822 55481 17868]
epoch=1,  batch_mask=[11692 37509 39611 52109 22271 36379 58949 29664 44013 39664]
epoch=1,  batch_mask=[19537 25735   590 51296 14367 41276 24604 14958 42517  3030]
epoch=1,  batch_mask=[48997  2915 53661 42791 39598 16477 40949 44343  5821  9185]
epoch=1,  batch_mask=[56471  4333 34751 20131 42029 11938 51843 51265 28676  8732]
epoch=1,  batch_mask=[48491  3830  1128 32096 19876 41475 57885    73 38544 45385]
epoch=1,  batch_mask=[ 8672 55675 29814 20185 25442  5742 33818 38592 33989 26064]
epoch=1,  batch_mask=[20446 54203  2701 54172 59943 55996   629 46573  7085 33941]
epoch=1,  batch_mask=[47832 55499 14059 55074 18438 16202 18152 24939 58680 54149]
epoch=1,  batch_mask=[24593 16642 57246 53981 31292  3802 49615 35479 21764 46072]
epoch=1,  batch_mask=[40818 26666 47676 23059 11235 46697 47770 51781   211 49257]
epoch=1,  batch_mask=[48570 47472 21422 25947 43945   484 51462 27251 58695 52178]
epoc

epoch=2,  batch_mask=[53626 36706 54794 50432 33128 28258 34708 59738  8587 50780]
epoch=2,  batch_mask=[57332 10911  5553 25803 46067 15830 59677 21212  4448  5502]
epoch=2,  batch_mask=[33493 20800 35221 25449  5015  7798 58714 24652  8117  7599]
epoch=2,  batch_mask=[24160 59076 53440 31744 35511 18563 37098 43446 19544 47859]
epoch=2,  batch_mask=[47073 35175 46389 13074 31464 55529 52041 34482 14234 58128]
epoch=2,  batch_mask=[49819 52540 50991 29820 53751  2711 16284  7605 34113 13023]
epoch=2,  batch_mask=[28098 52909 12507 13572 13992 59781 32747 54984 32565 45815]
epoch=2,  batch_mask=[ 6182 11744 41072 54346 43645 28205  3460 37106 44491 35177]
epoch=2,  batch_mask=[32855 20941 42548 22286  9065   399 25837 20583  9669 12415]
epoch=2,  batch_mask=[45928  3547 29668 26198 14965 53137 13274 47620 28886 54991]
epoch=2,  batch_mask=[ 2311  2068 26943 50927 48715 45179 52549 48494 34230 27523]
epoch=2,  batch_mask=[ 9283 11900  7643 38907 39503 11719 25377 19724 52329 57771]
epoc

epoch=2,  batch_mask=[45171 13094 57360  5396 29355 21589 54907 46723 34264 41413]
epoch=2,  batch_mask=[18944 21111 57261 23566 40782 23007 30476 27773 54429 55190]
epoch=2,  batch_mask=[31081  8347 42906 28485 17753 15189 28788  4345 35656 59441]
epoch=2,  batch_mask=[39682  9126 22906 19520  5540 48484 17588 18969 20458 58660]
epoch=2,  batch_mask=[57319 50136  4016  6120 50449 40411 17191 18630 29651 59482]
epoch=2,  batch_mask=[ 6170 25292 25741 44807 48179 10191 23832 47836 41142 26170]
epoch=2,  batch_mask=[41712 24240 13760  7066 33702 29390 28983  1155  7091 12939]
epoch=2,  batch_mask=[30271 49571 24612 14737  5986 36184 46472 27293 46907 18804]
epoch=2,  batch_mask=[56562 29788 48143 53596 29112  1821 23653 37665 32615 55295]
epoch=2,  batch_mask=[ 9610  6305 58335 30121 58516 31482  8289 24802 45932 38872]
epoch=2,  batch_mask=[19095 13715  3265 39824 12267  2348 54787 12907 19411 48129]
epoch=2,  batch_mask=[40946 26803 54788   746 26701 32868 40631 56553  9310 45790]
epoc

epoch=3,  batch_mask=[24443 38267  6192 27114  5824 29829 54626 41045 19293 53970]
epoch=3,  batch_mask=[23715 41667 47576 43618  2679 11113 54253 55980 23517 27804]
epoch=3,  batch_mask=[ 7680 52715  5901 23674 40126 40097 43991 54461 25865 23749]
epoch=3,  batch_mask=[ 7998 29348 19518 33535 57444 25121 56783   894 59003   413]
epoch=3,  batch_mask=[31166  4125 32140 19436 56098  9258 24293 58796 48830  4638]
epoch=3,  batch_mask=[52924   837 51433  8308 53472 11903 43841  4142 26887 32316]
epoch=3,  batch_mask=[51145  8843 17866 19534 59101  3901 55439 56321 48852 36110]
epoch=3,  batch_mask=[42770 39791 57666 58613 37597 32473  9027 45766 51151  6971]
epoch=3,  batch_mask=[56652 22918 24922 22881 37395 39511 31214 57659 59208 22893]
epoch=3,  batch_mask=[58600 47153 17823 28375 44468 26278 36762 40034  7053  6509]
epoch=3,  batch_mask=[26530 31402 15199 45704 32270 59451 14369 24911 58683 46218]
epoch=3,  batch_mask=[47186 12239 51042  1788 27269  1387 44469 59739 41149 11549]
epoc

epoch=3,  batch_mask=[15242 16898 35801 35728 10962 17873 18572  6766 46804 19094]
epoch=3,  batch_mask=[  671 11684 43584 56412 20884  5837  1620 39690  6630  2288]
epoch=3,  batch_mask=[57490 18738 47886 31383 16365 29285 44065 27863 42848 55532]
epoch=3,  batch_mask=[49599 21450  9973 28489 14916    57 40590  9588  6015  5530]
epoch=3,  batch_mask=[53807  3718 26504 13169 24867 36926 39727 58658 10017 38489]
epoch=3,  batch_mask=[29312 17919 29176 42132 47640 40650 52840 24265 28364 20647]
epoch=3,  batch_mask=[ 9389 30705 35047 51172 10247 52762 50274 10872 29687  8645]
epoch=3,  batch_mask=[  283 27881 20445 46973  1847 32097 48361 30322 12390 32545]
epoch=3,  batch_mask=[49107 10348   376 43854 46153 42169 32304 24487 19283 17798]
epoch=3,  batch_mask=[44268  7520  2140 11907 27637 32826 38945 47049  7223 21303]
epoch=3,  batch_mask=[23002 51723 56815 35129 46013 40661 56428 42789 51819 28119]
epoch=3,  batch_mask=[21831 55739 34353 46954 53590 25048   317 17579 37694 50745]
epoc

epoch=4,  batch_mask=[13498  1007 36920 26755 34594  2314 44730 21390 59426 48920]
epoch=4,  batch_mask=[41757 27445 29502 13540 23145 39119 57005 50804 41389 52403]
epoch=4,  batch_mask=[22435 24833 14146 26946  8657 35747 29523 46623 41719 32703]
epoch=4,  batch_mask=[52819 18424 31041 16794  4274   892 29323 59507  4330 30712]
epoch=4,  batch_mask=[21007 48295 32102 24302 24543 36373 54464 42285 23112 47663]
epoch=4,  batch_mask=[ 6212 54603 48380 52143 55055 41365 49774 15620 39980 26327]
epoch=4,  batch_mask=[12254 33308 33283 37299 24874 33171 12446 28806 41613 19396]
epoch=4,  batch_mask=[24137 53015  2746 57289 44553 57860 53665 12319 50389  5250]
epoch=4,  batch_mask=[33532 45773 32952 40958  2542 43498 57274 49575 19569 24765]
epoch=4,  batch_mask=[ 4463  3682 38879  1161 33105   760  9525 29869 16697 21135]
epoch=4,  batch_mask=[14980 19313 47144 55423 25270 19091  4081 16152 18390 27985]
epoch=4,  batch_mask=[14348 18653 25561 29514 48392 30769 25000 41893 17137  6570]
epoc

epoch=4,  batch_mask=[22585   646 50498 47602 58600 43629 47593 26398 30316 35948]
epoch=4,  batch_mask=[ 6997 58657 27072 12726  3689 43505 18172 22906  3097 34232]
epoch=4,  batch_mask=[51895 41537 14624  5249 18232 19135 48587   913 12480 47267]
epoch=4,  batch_mask=[27075 24041 24640 14937 48839 16067 17143  8158 35239 13415]
epoch=4,  batch_mask=[45930 45810  2078 41767 34611 10454 16641 51241 53913 41992]
epoch=4,  batch_mask=[ 3925 14300 58641 31850   645 36916 55491 20705 35186  1279]
epoch=4,  batch_mask=[46423   776 39471 19530 31206  5449 12127 49462 39999 36060]
epoch=4,  batch_mask=[41080 37384 11793 38712 56317 30664 20730 19127 33438 33614]
epoch=4,  batch_mask=[13573 54726 59212   582 13706 43154 48533 36186 37169 34070]
epoch=4,  batch_mask=[ 6029 42032 21458 28598 34897 59674 24544 30170 54452 52274]
epoch=4,  batch_mask=[48305 58302 57806 26963 27307 23286 25018  9242 39945 38839]
epoch=4,  batch_mask=[37903 21527 54129 42165 21691 31458 24574 28564 16124 43093]
epoc

epoch=5,  batch_mask=[40175 54551 26637 26362  8889  6649 18340   443 21544 33275]
epoch=5,  batch_mask=[15576 26692 53799 14520 10011 32964 13217  9014 29408 21709]
epoch=5,  batch_mask=[ 6238  7686  5174 28382 29822 46378  4822 13324 52430   418]
epoch=5,  batch_mask=[35868 29575 17526 43540 16347 39373 57448 57700 33710 37509]
epoch=5,  batch_mask=[58551 52597  8676 50579  5906 47800 30227 47219  1681 56185]
epoch=5,  batch_mask=[ 1309 10561 23523 26290 59814 20225 45633 47305  5112 18644]
epoch=5,  batch_mask=[ 6343  3392 32086 33803 24740 45195 54711 59568  9841 55955]
epoch=5,  batch_mask=[40915 59058 59992 13238 41028 27150  3042 58597 48374 36186]
epoch=5,  batch_mask=[31975 24204 52463 20190 11063 29021 53904 30432 23747 30692]
epoch=5,  batch_mask=[30581 59121 33532 22391  4872 49449 57996 54380 19318  5652]
epoch=5,  batch_mask=[29432 18727 13305  6248 30898 51361 42650 16324 26824  2408]
epoch=5,  batch_mask=[12982 48405 20128 35854 37337 55898 26385 53792 40806 46223]
epoc

epoch=5,  batch_mask=[46878 52247  9047 45668 45334 35455 54558 28229  1118 32887]
epoch=5,  batch_mask=[35702 41486 51419 19719 24141 11837 37653 50073  1656  3214]
epoch=5,  batch_mask=[ 7315 20499 45534 44788 44549 23812  6982 37557 23284 48969]
epoch=5,  batch_mask=[10862 54813 35213 26040 19181 56505 22771 56653 49088 17894]
epoch=5,  batch_mask=[59500 31777 33366  6882  3181 58217 20541 49062 18272 29419]
epoch=5,  batch_mask=[16301 36261 12164 32876 28201 16107 30991 38637 35598  8937]
epoch=5,  batch_mask=[ 2616 53783 42011 58036 19788  4329 22159 13857 46819  7792]
epoch=5,  batch_mask=[49636 30753 16385 43839 17513 46181 54351 48674 39731  4713]
epoch=5,  batch_mask=[50363  4211 18786  6602 30969 35203 22713 28957 55209 11616]
epoch=5,  batch_mask=[22830 43366 15525  9404 30586 50859 20444 36427  6604  9495]
epoch=5,  batch_mask=[29869 39917 45728 40961 32058  9776 10673 19310 25479 19690]
epoch=5,  batch_mask=[18282 35305 35911 58059 57146 59316 14049 29760 59558  8883]
epoc

epoch=6,  batch_mask=[29651 21846  2651 14845 54833 57437 17547 21563 10183 27425]
epoch=6,  batch_mask=[28828 21193 23506 17966 56271 15124 46989 48607 34838 21439]
epoch=6,  batch_mask=[47970 22208 27546 47529 27433 39087 12990 52210 35079 30805]
epoch=6,  batch_mask=[13496  6728 39238 42166 57368 25928  7065 56778 43359  9881]
epoch=6,  batch_mask=[37020 22752 48463 36957  5023 11779 46321 48427  3474 44007]
epoch=6,  batch_mask=[ 1048  8663 30954 33915 43277 48814 16975 38418 53593 14070]
epoch=6,  batch_mask=[51930 32879 12899 37403 55796 44111 12430  2389 28217  2042]
epoch=6,  batch_mask=[37583 24362 22701 12881 29905 38771 51085 24604 13430 10675]
epoch=6,  batch_mask=[ 8044   376 16172 31573 19799 25048 30437 55814 46165 26304]
epoch=6,  batch_mask=[ 5429 38212 13091 17913 14325 39405 19540 52247  2964 58037]
epoch=6,  batch_mask=[21859 24090 32912 19077  1025  9944  8642 59774 11103    58]
epoch=6,  batch_mask=[10938 17318 46233 56355  7254  1849  4052  4725 10113 31794]
epoc

epoch=6,  batch_mask=[58566 54250   486 40138 34861 39888 30930  4351 29592 19677]
epoch=6,  batch_mask=[11413   925 37864 40941 49528 52727 49522 32917  7487 37120]
epoch=6,  batch_mask=[45034 56614 18186 29462 52320 59464 20438  6342 18983 18170]
epoch=6,  batch_mask=[33406 28716  5484 48582 46719 51266 51779 17305  4529  8909]
epoch=6,  batch_mask=[25773 36811 39634 41695  3422  3372 17155 27739  6858 48910]
epoch=6,  batch_mask=[10718 20059 49266 15228 18980 19416 42721 32127 18856 23651]
epoch=6,  batch_mask=[13806 49354 23400 52254 13441 28844 31737 39618 45993 55995]
epoch=6,  batch_mask=[29416 18780 59125 17382   322  7320 39467 43190 11088 33549]
epoch=6,  batch_mask=[23061 31603 58528 44878  5938 12665 37471 27393 46062 31663]
epoch=6,  batch_mask=[22641 23534 35179 14556 29726  4360  5588 21225 41198 46488]
epoch=6,  batch_mask=[32209 25698 23254 25969 28580  7154 39461 18018  8609 51337]
epoch=6,  batch_mask=[45226 33439 25415 28366 48130  9171 35999 43104 23231 38385]
epoc

epoch=7,  batch_mask=[21578 43207 12536 21888 35555 49615 17451 49930 33723 43702]
epoch=7,  batch_mask=[55865 28017 46700 40277 36312  1532 20388 15306 58934  1614]
epoch=7,  batch_mask=[49026 39307 46542 57174 13243 20157 11650  4645 20957 47695]
epoch=7,  batch_mask=[25731  4789 34049 16373 12026 29695 58908 40015 14506 30279]
epoch=7,  batch_mask=[41075 52440 47551  4035  1007 21986 30893  1617 42682 10000]
epoch=7,  batch_mask=[29832  9993 30696 16900 18018 49963 25797 34222  5048 50541]
epoch=7,  batch_mask=[16604 16714 12409 12215 48036  1730 16114 29264 44358 28719]
epoch=7,  batch_mask=[28649 40600 58534 31380 30591  5882 29295 43085 45564  8759]
epoch=7,  batch_mask=[26258 23377  3433 19889   542 24792 25889 53406 38924 12569]
epoch=7,  batch_mask=[50271 56243 50045 30158  6441 53941 10710 47757 49118 54085]
epoch=7,  batch_mask=[50631 22537 48410 39421 34406  9890 59337 21155  7749 37587]
epoch=7,  batch_mask=[32536  5590 39773 10301 17443 29735 37665 48601 26313 47133]
epoc

epoch=7,  batch_mask=[ 7755 25924 44913 47036 23807 52923 50195 11312 40477 52963]
epoch=7,  batch_mask=[ 1344 13755 45267 46501 20197 22793 27359 21602 11742 55725]
epoch=7,  batch_mask=[37419 43170 22084 21832  2907 53303 19705 44903 44325 18744]
epoch=7,  batch_mask=[29902 52437   327 55242 46714 52005 30504 32047 31391 12994]
epoch=7,  batch_mask=[20125  9720 58011 33453 18682   670 53159 22342 30600  8139]
epoch=7,  batch_mask=[ 7190 42064 35549  5609  6918 53630 12852 52158 28729 28260]
epoch=7,  batch_mask=[ 5083 35823  9713  8114 42864 54749 27138 22159 16096  7432]
epoch=7,  batch_mask=[28909 15547  1226 20333 39737 28628 33987 16116  1394 19542]
epoch=7,  batch_mask=[38752 11140 34178 56874 20628 40659 40988 59963 37207 34730]
epoch=7,  batch_mask=[31753  4263 29944 38623 18718 48236 52954 53136 38509   600]
epoch=7,  batch_mask=[31545 53405 18570 19763  8761 14924 16014 10299 35503  1268]
epoch=7,  batch_mask=[11946 37233 49696 27576 53847 33381  7849  9943   755 49135]
epoc

epoch=8,  batch_mask=[45305 10690  7910 29010 52523  5684 37596  4026 32965 47553]
epoch=8,  batch_mask=[20579 34769 56246 16436 20312  1704 23434   224 37629 36434]
epoch=8,  batch_mask=[39783 40506 43717 11525 33908 35092 25441 32005 53631 44484]
epoch=8,  batch_mask=[34573 35537 36464 21260 38850 19781 34204  9942 32929 23854]
epoch=8,  batch_mask=[ 8978 37846 17777 39970 39791 25763 28254 55311 20313 44029]
epoch=8,  batch_mask=[12209  1291 25585 55297 41011 53011 47527 55100 11231 20490]
epoch=8,  batch_mask=[12097  7585 21253 54163 13535 49139 23660 31682 29719 15745]
epoch=8,  batch_mask=[34143 33884 27568  9796 30059 57192 46617  6024 45443 50535]
epoch=8,  batch_mask=[53345 47085 47727 29732 32576 20051 17711 23481 39712 27600]
epoch=8,  batch_mask=[59164 54002 16108 29947 10543 13026 58206 32845  5839 43009]
epoch=8,  batch_mask=[25099 38168 47790 10037  5767 59909 41598 10882 52627 55274]
epoch=8,  batch_mask=[17624 47425 33172 41834 28312 17122 47275 57379 39423 37728]
epoc

epoch=8,  batch_mask=[37215 50042  9758 30377 17677 41175 41211 23388 41447  5888]
epoch=8,  batch_mask=[15407 44389   530  1759  2003  2585 13045 52805 40406   422]
epoch=8,  batch_mask=[10706 36562  2556  5853 34549 40201  4935 30672 45991 52515]
epoch=8,  batch_mask=[ 1254 45790 11063 25877  9122 23160 48285  5127 51652 43444]
epoch=8,  batch_mask=[ 7951  8571  3904   943 46159 29759 25627 52232 56840 44572]
epoch=8,  batch_mask=[22459 30222 52719 31363 25843 35510 55949 20678 32216 12458]
epoch=8,  batch_mask=[  611 21483 14546 27388 41959 56466 28525  7427 48474 47673]
epoch=8,  batch_mask=[ 4746 56053 41848 33566 24440 37412 33770 27341   231 24588]
epoch=8,  batch_mask=[  391 16587 42570  1670 21336 59266  1387 55372 12501  8885]
epoch=8,  batch_mask=[51668 20469 11275 50027  6194 24351 54750 55873 54851  1991]
epoch=8,  batch_mask=[55089 28691 33268 41546 26358 25020 26532 22512 43030 17436]
epoch=8,  batch_mask=[22860 47832 52948 50588 11913 35734  3084 23083 35867 34008]
epoc

epoch=9,  batch_mask=[10209  1857  1897  8214 25908 12648 39243  2503 33935 59836]
epoch=9,  batch_mask=[42635 26856 39286 27464 24962 38897 15206 53646 59133 44573]
epoch=9,  batch_mask=[  717 12215 53261  4671 44027 55960 18628 19636 23440    19]
epoch=9,  batch_mask=[11032  5846 29913 53812 51522 35342 29254 41622 36482 21622]
epoch=9,  batch_mask=[ 4791 43569 27635 57610  6597 37298 17632 19609  1275   206]
epoch=9,  batch_mask=[57930 18325 44934 44933 50733 59120 43340 13041 57557 40639]
epoch=9,  batch_mask=[25457 16947 39295 56464 57330 39056 36064 13321 37901 49412]
epoch=9,  batch_mask=[57335 56937  6987  3445  2768  8342 53037 29412 24590 21446]
epoch=9,  batch_mask=[37281  3664 22300 57148 58564   220  6915 29384 31300 19014]
epoch=9,  batch_mask=[11954 33771   752 27604 38322 17066 20167 31647  1495 21359]
epoch=9,  batch_mask=[ 6682 26534 31261 24949 21234 47339 55031 46614 14307 34458]
epoch=9,  batch_mask=[ 4661 39107 24003 57131  8600 35398 51022 39790 37502  7207]
epoc

epoch=9,  batch_mask=[16783 31101 20283 57991 59299 11950 10374 59743 14839 40227]
epoch=9,  batch_mask=[13663 30326 55205 26724  9636 33738 51588 55109 42658 24037]
epoch=9,  batch_mask=[ 8003 26398 18869 51622 57106 48240 36475 45324 41776 57229]
epoch=9,  batch_mask=[55453 43927 52919 28529 49854 54188 26734 43227 13392 42663]
epoch=9,  batch_mask=[27305 52155 13414 46946 16287 59982 59923 44204 21648 59360]
epoch=9,  batch_mask=[15607 26387 53997 50627 34396  5132 53484 15416  5668 43532]
epoch=9,  batch_mask=[58277  8625  9156 25664 43813 53228 56707 25918 11030 45546]
epoch=9,  batch_mask=[16954 57818 35794 49480  8228 50195 32424  2296 58824 50396]
epoch=9,  batch_mask=[36079 19412 53846 16875 42007 15651 45446 20358 56902 11653]
epoch=9,  batch_mask=[32872 40711 25031  8681 19610 48987 15199 36237 57015 40898]
epoch=9,  batch_mask=[38699  1047 15319 19167   397 48935 16255 47037 28510 53320]
epoch=9,  batch_mask=[25000  2474 18543 41040 42173   438 45208 59320 44552 51168]
epoc