In [1]:
import argparse
import json
import os
import random
import time, math
import torch
import torch.nn as nn
# from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
import numpy as np
from warpctc_pytorch import CTCLoss
from collections import OrderedDict
import pandas as pd

from data.data_loader_self import AudioDataLoader, SpectrogramDataset, BucketingSampler
from data.data_loader import get_accents
from decoder import GreedyDecoder
from model import DeepSpeech, supported_rnns, ForgetNet, Encoder, Decoder, DiscimnateNet
from utils import reduce_tensor, check_loss, Decoder_loss

import easydict

In [2]:
args = easydict.EasyDict({
    'train_manifest' : './data/csvs/dataset_with_age_train_95_180.csv',
    'val_manifest' : './data/csvs/dataset_with_age_val_5_180.csv',
    'sample_rate' : 16000,
#     'labels_path' : 'labels.json',
    'window_size' : .02, 'window_stride' : .01, 'window' : 'hamming',
    'hidden_size' : 1024, 'hidden_layers' : 5, 'rnn_type' : 'gru',
    'epochs' : 5, 'batch_size' : 16, 'num_workers' : 4,
    'patience' : 10,
    'cuda' : True,
    'lr' : 0.001, 'momentum' : 0.9, 'max_norm' : 400, 'learning_anneal' : 1.1, 
    'silent' : True,
    'checkpoint' : False, 'checkpoint_per_batch' : 500,
    'visdom' : True, 'tensorboard' : False,
    'log_dir' : './visualize/deepspeech_final', 'log_params' : True,
    'id' : 'Deepspeech training',
    'continue_from' : '/home/Data/etc/Robust_ASR/exp/0106_dawn_hacka/models/ckpt_2_9114.pth', 'finetune' : True,
    'augment' : True,
    'noise_dir' : None, 'noise_prob' : 0.4, 'noise_min' : 0.0, 'noise_max' : 0.5,
    'no_shuffle' : False,
    'no_sorta_grad' : False,
    'bidirectional' : True,
    'spec_augment' : True,
    'dist_url' : 'tcp://127.0.0.1.:1550', 'dist_backend' : 'nccl',
    'world_size' : 1, 
    'rank' : 0,
    'enco_modules' : 2, 'enco_res' : True, 
    'disc_modules' : 1, 'disc_res' : False,
    'forg_modules' : 2, 'forg_res' : True,
#     'gpu_rank' : 0,
    'seed' : 123456,
    'opt_level' : '',
    'keep_batchnorm_fp32' : None,
    'loss_scale' : None,
    'weights' : ' ',
    'update_rule' : 2,
    'train_asr' : False,
    'dummy' : False,
    'num_epochs' : 1,
    'mw_alpha' : 0.1, 'mw_beta' : 0.2, 'mw_gamma' : 0.6 ,
    'exp_name' : './exp/0106_dawn_hacka/'

})
# Set seeds for determinism
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

#Gpu setting
device = torch.device("cuda" if args.cuda else "cpu")
# torch.cuda.set_device(int(args.gpu_rank))

#Where to save the models and training's metadata
save_folder = os.path.join(args.exp_name, 'models')
tbd_logs = os.path.join(args.exp_name, 'tbd_logdir')
loss_save = os.path.join(args.exp_name, 'train.log')
config_save = os.path.join(args.exp_name, 'config.json')
os.makedirs(save_folder, exist_ok=True)  # Ensure save folder exists

# save the experiment configuration.
with open(config_save, 'w') as f:
    json.dump(args.__dict__, f, indent=2)

# Instantiating tensorboard writer.
# writer = SummaryWriter(tbd_logs)

wer_results = torch.Tensor(args.epochs)
best_wer, best_cer = None, None
d_avg_loss, p_avg_loss, p_d_avg_loss, start_epoch = 0, 0, 0, 0
poor_cer_list = []
eps = 0.0000000001 # epsilon value
start_iter = 0
accent_dict = get_accents(args.train_manifest)
accent = list(accent_dict.values())
labels = ' _가각간갇갈갉갊감갑값갓갔강갖갗같갚갛개객갠갤갬갭갯갰갱갸갹갼걀걋걍걔걘걜거걱건걷걸걺검겁것겄겅겆겉겊겋게겐겔겜겝겟겠겡겨격겪견겯결겸겹겻겼경곁계곈곌곕곗고곡곤곧골곪곬곯곰곱곳공곶과곽관괄괆괌괍괏광괘괜괠괩괬괭괴괵괸괼굄굅굇굉교굔굘굡굣구국군굳굴굵굶굻굼굽굿궁궂궈궉권궐궜궝궤궷귀귁귄귈귐귑귓규균귤그극근귿글긁금급긋긍긔기긱긴긷길긺김깁깃깅깆깊까깍깎깐깔깖깜깝깟깠깡깥깨깩깬깰깸깹깻깼깽꺄꺅꺌꺼꺽꺾껀껄껌껍껏껐껑께껙껜껨껫껭껴껸껼꼇꼈꼍꼐꼬꼭꼰꼲꼴꼼꼽꼿꽁꽂꽃꽈꽉꽐꽜꽝꽤꽥꽹꾀꾄꾈꾐꾑꾕꾜꾸꾹꾼꿀꿇꿈꿉꿋꿍꿎꿔꿜꿨꿩꿰꿱꿴꿸뀀뀁뀄뀌뀐뀔뀜뀝뀨끄끅끈끊끌끎끓끔끕끗끙끝끼끽낀낄낌낍낏낑나낙낚난낟날낡낢남납낫났낭낮낯낱낳내낵낸낼냄냅냇냈냉냐냑냔냘냠냥너넉넋넌널넒넓넘넙넛넜넝넣네넥넨넬넴넵넷넸넹녀녁년녈념녑녔녕녘녜녠노녹논놀놂놈놉놋농높놓놔놘놜놨뇌뇐뇔뇜뇝뇟뇨뇩뇬뇰뇹뇻뇽누눅눈눋눌눔눕눗눙눠눴눼뉘뉜뉠뉨뉩뉴뉵뉼늄늅늉느늑는늘늙늚늠늡늣능늦늪늬늰늴니닉닌닐닒님닙닛닝닢다닥닦단닫달닭닮닯닳담답닷닸당닺닻닿대댁댄댈댐댑댓댔댕댜더덕덖던덛덜덞덟덤덥덧덩덫덮데덱덴델뎀뎁뎃뎄뎅뎌뎐뎔뎠뎡뎨뎬도독돈돋돌돎돐돔돕돗동돛돝돠돤돨돼됐되된될됨됩됫됴두둑둔둘둠둡둣둥둬뒀뒈뒝뒤뒨뒬뒵뒷뒹듀듄듈듐듕드득든듣들듦듬듭듯등듸디딕딘딛딜딤딥딧딨딩딪따딱딴딸땀땁땃땄땅땋때땍땐땔땜땝땟땠땡떠떡떤떨떪떫떰떱떳떴떵떻떼떽뗀뗄뗌뗍뗏뗐뗑뗘뗬또똑똔똘똥똬똴뙈뙤뙨뚜뚝뚠뚤뚫뚬뚱뛔뛰뛴뛸뜀뜁뜅뜨뜩뜬뜯뜰뜸뜹뜻띄띈띌띔띕띠띤띨띰띱띳띵라락란랄람랍랏랐랑랒랖랗래랙랜랠램랩랫랬랭랴략랸럇량러럭런럴럼럽럿렀렁렇레렉렌렐렘렙렛렝려력련렬렴렵렷렸령례롄롑롓로록론롤롬롭롯롱롸롼뢍뢨뢰뢴뢸룀룁룃룅료룐룔룝룟룡루룩룬룰룸룹룻룽뤄뤘뤠뤼뤽륀륄륌륏륑류륙륜률륨륩륫륭르륵른를름릅릇릉릊릍릎리릭린릴림립릿링마막만많맏말맑맒맘맙맛망맞맡맣매맥맨맬맴맵맷맸맹맺먀먁먈먕머먹먼멀멂멈멉멋멍멎멓메멕멘멜멤멥멧멨멩며멱면멸몃몄명몇몌모목몫몬몰몲몸몹못몽뫄뫈뫘뫙뫼묀묄묍묏묑묘묜묠묩묫무묵묶문묻물묽묾뭄뭅뭇뭉뭍뭏뭐뭔뭘뭡뭣뭬뮈뮌뮐뮤뮨뮬뮴뮷므믄믈믐믓미믹민믿밀밂밈밉밋밌밍및밑바박밖밗반받발밝밞밟밤밥밧방밭배백밴밸뱀뱁뱃뱄뱅뱉뱌뱍뱐뱝버벅번벋벌벎범법벗벙벚베벡벤벧벨벰벱벳벴벵벼벽변별볍볏볐병볕볘볜보복볶본볼봄봅봇봉봐봔봤봬뵀뵈뵉뵌뵐뵘뵙뵤뵨부북분붇불붉붊붐붑붓붕붙붚붜붤붰붸뷔뷕뷘뷜뷩뷰뷴뷸븀븃븅브븍븐블븜븝븟비빅빈빌빎빔빕빗빙빚빛빠빡빤빨빪빰빱빳빴빵빻빼빽뺀뺄뺌뺍뺏뺐뺑뺘뺙뺨뻐뻑뻔뻗뻘뻠뻣뻤뻥뻬뼁뼈뼉뼘뼙뼛뼜뼝뽀뽁뽄뽈뽐뽑뽕뾔뾰뿅뿌뿍뿐뿔뿜뿟뿡쀼쁑쁘쁜쁠쁨쁩삐삑삔삘삠삡삣삥사삭삯산삳살삵삶삼삽삿샀상샅새색샌샐샘샙샛샜생샤샥샨샬샴샵샷샹섀섄섈섐섕서석섞섟선섣설섦섧섬섭섯섰성섶세섹센셀셈셉셋셌셍셔셕션셜셤셥셧셨셩셰셴셸솅소속솎손솔솖솜솝솟송솥솨솩솬솰솽쇄쇈쇌쇔쇗쇘쇠쇤쇨쇰쇱쇳쇼쇽숀숄숌숍숏숑수숙순숟술숨숩숫숭숯숱숲숴쉈쉐쉑쉔쉘쉠쉥쉬쉭쉰쉴쉼쉽쉿슁슈슉슐슘슛슝스슥슨슬슭슴습슷승시식신싣실싫심십싯싱싶싸싹싻싼쌀쌈쌉쌌쌍쌓쌔쌕쌘쌜쌤쌥쌨쌩썅써썩썬썰썲썸썹썼썽쎄쎈쎌쏀쏘쏙쏜쏟쏠쏢쏨쏩쏭쏴쏵쏸쐈쐐쐤쐬쐰쐴쐼쐽쑈쑤쑥쑨쑬쑴쑵쑹쒀쒔쒜쒸쒼쓩쓰쓱쓴쓸쓺쓿씀씁씌씐씔씜씨씩씬씰씸씹씻씽아악안앉않알앍앎앓암압앗았앙앝앞애액앤앨앰앱앳앴앵야약얀얄얇얌얍얏양얕얗얘얜얠얩어억언얹얻얼얽얾엄업없엇었엉엊엌엎에엑엔엘엠엡엣엥여역엮연열엶엷염엽엾엿였영옅옆옇예옌옐옘옙옛옜오옥온올옭옮옰옳옴옵옷옹옻와왁완왈왐왑왓왔왕왜왝왠왬왯왱외왹왼욀욈욉욋욍요욕욘욜욤욥욧용우욱운울욹욺움웁웃웅워웍원월웜웝웠웡웨웩웬웰웸웹웽위윅윈윌윔윕윗윙유육윤율윰윱윳융윷으윽은을읊음읍읏응읒읓읔읕읖읗의읩읜읠읨읫이익인일읽읾잃임입잇있잉잊잎자작잔잖잗잘잚잠잡잣잤장잦재잭잰잴잼잽잿쟀쟁쟈쟉쟌쟎쟐쟘쟝쟤쟨쟬저적전절젊점접젓정젖제젝젠젤젬젭젯젱져젼졀졈졉졌졍졔조족존졸졺좀좁좃종좆좇좋좌좍좔좝좟좡좨좼좽죄죈죌죔죕죗죙죠죡죤죵주죽준줄줅줆줌줍줏중줘줬줴쥐쥑쥔쥘쥠쥡쥣쥬쥰쥴쥼즈즉즌즐즘즙즛증지직진짇질짊짐집짓징짖짙짚짜짝짠짢짤짧짬짭짯짰짱째짹짼쨀쨈쨉쨋쨌쨍쨔쨘쨩쩌쩍쩐쩔쩜쩝쩟쩠쩡쩨쩽쪄쪘쪼쪽쫀쫄쫌쫍쫏쫑쫓쫘쫙쫠쫬쫴쬈쬐쬔쬘쬠쬡쭁쭈쭉쭌쭐쭘쭙쭝쭤쭸쭹쮜쮸쯔쯤쯧쯩찌찍찐찔찜찝찡찢찧차착찬찮찰참찹찻찼창찾채책챈챌챔챕챗챘챙챠챤챦챨챰챵처척천철첨첩첫첬청체첵첸첼쳄쳅쳇쳉쳐쳔쳤쳬쳰촁초촉촌촐촘촙촛총촤촨촬촹최쵠쵤쵬쵭쵯쵱쵸춈추축춘출춤춥춧충춰췄췌췐취췬췰췸췹췻췽츄츈츌츔츙츠측츤츨츰츱츳층치칙친칟칠칡침칩칫칭카칵칸칼캄캅캇캉캐캑캔캘캠캡캣캤캥캬캭컁커컥컨컫컬컴컵컷컸컹케켁켄켈켐켑켓켕켜켠켤켬켭켯켰켱켸코콕콘콜콤콥콧콩콰콱콴콸쾀쾅쾌쾡쾨쾰쿄쿠쿡쿤쿨쿰쿱쿳쿵쿼퀀퀄퀑퀘퀭퀴퀵퀸퀼큄큅큇큉큐큔큘큠크큭큰클큼큽킁키킥킨킬킴킵킷킹타탁탄탈탉탐탑탓탔탕태택탠탤탬탭탯탰탱탸턍터턱턴털턺텀텁텃텄텅테텍텐텔템텝텟텡텨텬텼톄톈토톡톤톨톰톱톳통톺톼퇀퇘퇴퇸툇툉툐투툭툰툴툼툽툿퉁퉈퉜퉤튀튁튄튈튐튑튕튜튠튤튬튱트특튼튿틀틂틈틉틋틔틘틜틤틥티틱틴틸팀팁팃팅파팍팎판팔팖팜팝팟팠팡팥패팩팬팰팸팹팻팼팽퍄퍅퍼퍽펀펄펌펍펏펐펑페펙펜펠펨펩펫펭펴편펼폄폅폈평폐폘폡폣포폭폰폴폼폽폿퐁퐈퐝푀푄표푠푤푭푯푸푹푼푿풀풂품풉풋풍풔풩퓌퓐퓔퓜퓟퓨퓬퓰퓸퓻퓽프픈플픔픕픗피픽핀필핌핍핏핑하학한할핥함합핫항해핵핸핼햄햅햇했행햐향허헉헌헐헒험헙헛헝헤헥헨헬헴헵헷헹혀혁현혈혐협혓혔형혜혠혤혭호혹혼홀홅홈홉홋홍홑화확환활홧황홰홱홴횃횅회획횐횔횝횟횡효횬횰횹횻후훅훈훌훑훔훗훙훠훤훨훰훵훼훽휀휄휑휘휙휜휠휨휩휫휭휴휵휸휼흄흇흉흐흑흔흖흗흘흙흠흡흣흥흩희흰흴흼흽힁히힉힌힐힘힙힛힝'

updating accents


# dataloader

In [3]:
audio_conf = dict(sample_rate=16000,
                    window_size=.02,
                    window_stride=.01,
                    window='hamming',
                    noise_dir= None,
                    noise_prob= 0.4,
                    noise_levels=(0.0, 0.5))

In [4]:
train_dataset = SpectrogramDataset(audio_conf = audio_conf, manifest_filepath = args.train_manifest, labels = labels, accent=accent_dict,
                                  normalize=True, speed_volume_perturb = args.augment, spec_augment = args.spec_augment)

disc_train_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.train_manifest, labels=labels, accent=accent_dict,
                                       normalize=True, speed_volume_perturb=args.augment,
                                       spec_augment=args.spec_augment)

test_dataset = SpectrogramDataset(audio_conf = audio_conf, manifest_filepath = args.val_manifest, labels = labels, accent = accent_dict,
                                 normalize=True, speed_volume_perturb=False, spec_augment=False)
# train_sampler = BucketingSampler(train_dataset, batch_size = args.batch_size)

# train_loader = AudioDataLoader(train_dataset, num_workers= args.num_workers, batch_sampler = train_sampler, pin_memory=True)
train_loader = AudioDataLoader(train_dataset, num_workers= args.num_workers, batch_size = args.batch_size , shuffle=True, pin_memory=True)
disc_train_loader = AudioDataLoader(disc_train_dataset, num_workers=args.num_workers, batch_size = args.batch_size, shuffle=True, pin_memory=True)
test_loader = AudioDataLoader(test_dataset, num_workers= args.num_workers, batch_size =  args.batch_size, pin_memory=True, shuffle=True)
disc_ = iter(disc_train_loader)

In [None]:
# package = torch.load(args.continue_from, map_location=(f"cuda" if args.cuda else "cpu"))
# print(f'Load from {args.continue_from} succeed.')
# models = package['models']
# labels = models['predictor'][0].labels
# audio_conf = models['predictor'][0].audio_conf

In [7]:
# rnn_type = args.rnn_type.lower()
# for j in models.values():
#     if j[-1]:
#         for g in j[-1].param_groups:
#             g['lr'] = args.lr
# asr, criterion, asr_optimizer = models['predictor']
# asr = asr.to(device)
# encoder, _, _ = models['encoder']
# encoder = encoder.to(device)
# decoder, dec_loss, ed_optimizer = models['decoder']
# decoder = decoder.to(device)

# fnet = ForgetNet(num_modules = args.forg_modules, residual_bool = args.forg_res, hard_mask_bool = True)
# # fnet = nn.DataParallel(fnet, device_ids=[0,1,2,3]).to(device)
# fnet = fnet.to(device)

# fnet_optimizer = torch.optim.Adam(fnet.parameters(), lr=args.lr,weight_decay=1e-4,amsgrad=True)
# models['forget_net'] = [fnet, None, fnet_optimizer]

# discriminator = DiscimnateNet(classes=len(accent),num_modules=args.disc_modules,residual_bool=args.disc_res)
# # discriminator = nn.DataParallel(discriminator, device_ids=[0,1,2,3]).to(device)
# discriminator = discriminator.to(device)

# discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=args.lr,weight_decay=1e-4,amsgrad=True)
# accent_counts = pd.read_csv(args.train_manifest, header=None).iloc[:,[-1]].apply(pd.value_counts).to_dict()
# disc_loss_weights = torch.zeros(len(accent)) + eps
# for accent_type_f in accent_counts:
#     if isinstance(accent_counts[accent_type_f], dict):
#         for accent_type_in_f in accent_counts[accent_type_f]:
#             if accent_type_in_f in accent_dict:
#                 disc_loss_weights[accent_dict[accent_type_in_f]] += accent_counts[accent_type_f][accent_type_in_f]
# disc_loss_weights = torch.sum(disc_loss_weights) / disc_loss_weights     
# dis_loss = nn.CrossEntropyLoss(weight=disc_loss_weights.to(device))
# models['discrimator'] = [discriminator, dis_loss, discriminator_optimizer] 

In [5]:
if args.continue_from:
#     package = torch.load(args.continue_from, map_location=(f"cuda:{args.gpu_rank}" if args.cuda else "cpu"))
    package = torch.load(args.continue_from, map_location=(f"cuda" if args.cuda else "cpu"))
    print(f'Load from {args.continue_from} succeed.')
    models = package['models']
#     models = nn.DataParallel(models, device_ids = [0,1]).to(device)
#     models = models.module
    labels = models['predictor'][0].labels
    audio_conf = models['predictor'][0].audio_conf

    rnn_type = args.rnn_type.lower()
    assert rnn_type in supported_rnns, "rnn_type should be either lstm, rnn or gru"
    assert models['predictor'][0].rnn_type == supported_rnns[rnn_type], "rnnt type of checkpoint and argument must match"

    if not args.train_asr: # if adversarial training.
        assert 'discrimator' in models and 'forget_net' in models, "forget_net and discriminator not found in checkpoint loaded"

    if not args.finetune: # If continuing training after the last epoch.
        start_epoch = package['start_epoch'] - 1  # Index start at 0 for training
        if start_iter is None:
            start_epoch += 1  # We saved model after epoch finished, start at the next epoch.
            start_iter = 0
        else:
            start_iter += 1
        start_iter = package['start_iter']
        print(start_iter)
        best_wer = package['best_wer']
        best_cer = package['best_cer']
        poor_cer_list = package['poor_cer_list']
    else:
        for j in models.values():
            if j[-1]:
                for g in j[-1].param_groups:
                    g['lr'] = args.lr

    asr, criterion, asr_optimizer = models['predictor']
#     asr = nn.DataParallel(asr, device_ids=[0,1]).to(device)
    encoder, _, _ = models['encoder']
#     encoder = nn.DataParallel(encoder, device_ids=[0,1]).to(device)
    decoder, dec_loss, ed_optimizer = models['decoder']
#     decoder = nn.DataParallel(decoder, device_ids=[0,1]).to(device)

    if not args.train_asr:
        fnet, _, fnet_optimizer = models['forget_net']
#         fnet = nn.DataParallel(fnet, device_ids=[0,1]).to(device)
    else:
        if 'forget_net' in models:
            del models['forget_net']

    # Discriminator
    if not args.train_asr:
        accent_counts = pd.read_csv(args.train_manifest, header=None).iloc[:,[-1]].apply(pd.value_counts).to_dict()
        disc_loss_weights = torch.zeros(len(accent)) + eps
        for accent_type_f in accent_counts:
            if isinstance(accent_counts[accent_type_f], dict):
                for accent_type_in_f in accent_counts[accent_type_f]:
                    if accent_type_in_f in accent_dict:
                        disc_loss_weights[accent_dict[accent_type_in_f]] += accent_counts[accent_type_f][accent_type_in_f]
        disc_loss_weights = torch.sum(disc_loss_weights) / disc_loss_weights     
        dis_loss = nn.CrossEntropyLoss(weight=disc_loss_weights.to(device))
        discriminator, _, discriminator_optimizer = models['discrimator']
#         discriminator = nn.DataParallel(discriminator, device_ids=[0,1]).to(device)
        models['discrimator'][1] = dis_loss
    else:
        if 'discrimator' in models:
            del models['discrimator']
            

            
else:

    #Creating the configuration apply to the audio
    audio_conf = dict(sample_rate=args.sample_rate,
                        window_size=args.window_size,
                        window_stride=args.window_stride,
                        window=args.window,
                        noise_dir=args.noise_dir,
                        noise_prob=args.noise_prob,
                        noise_levels=(args.noise_min, args.noise_max))

    rnn_type = args.rnn_type.lower()
    assert rnn_type in supported_rnns, "rnn_type should be either lstm, rnn or gru"

    models = {} # All the models with their loss and optimizer are saved in this dict

    # Different modules used with parameters, optimizer and loss 

    # ASR
    asr = DeepSpeech(rnn_hidden_size=args.hidden_size,
                        nb_layers=args.hidden_layers,
                        labels=labels,
                        rnn_type=supported_rnns[rnn_type],
                        audio_conf=audio_conf,
                        bidirectional=args.bidirectional)
    asr = nn.DataParallel(asr, device_ids=[0,1,3]).to(device)
#     asr = asr.to(device)
    asr_optimizer = torch.optim.Adam(asr.parameters(), lr=args.lr,weight_decay=1e-4,amsgrad=True)
    criterion = CTCLoss()
    models['predictor'] = [asr, criterion, asr_optimizer] 

    # Encoder and Decoder
    encoder = Encoder(num_modules = args.enco_modules, residual_bool = args.enco_res)
    encoder = nn.DataParallel(encoder, device_ids=[0,1,3]).to(device)    
#     encoder = encoder.to(device)
    models['encoder'] = [encoder, None, None]
    decoder = Decoder()
    decoder = nn.DataParallel(decoder, device_ids=[0,1,3]).to(device)
    decoder = decoder.to(device)
    dec_loss = Decoder_loss(nn.MSELoss())

    ed_optimizer = torch.optim.Adam(list(encoder.parameters())+list(decoder.parameters()),
                                    lr=args.lr,weight_decay=1e-4,amsgrad=True)
    models['decoder'] = [decoder, dec_loss, ed_optimizer] 

    # Forget Network
    if not args.train_asr:
        fnet = ForgetNet(num_modules = args.forg_modules, residual_bool = args.forg_res, hard_mask_bool = True)
        fnet = nn.DataParallel(fnet, device_ids=[0,1,3]).to(device)
#         fnet = fnet.to(device)
        fnet_optimizer = torch.optim.Adam(fnet.parameters(), lr=args.lr,weight_decay=1e-4,amsgrad=True)
        models['forget_net'] = [fnet, None, fnet_optimizer]

    # Discriminator
    if not args.train_asr:
        discriminator = DiscimnateNet(classes=len(accent),num_modules=args.disc_modules,residual_bool=args.disc_res)
        discriminator = nn.DataParallel(discriminator, device_ids=[0,1,3]).to(device)
#         discriminator = discriminator.to(device)
        discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=args.lr,weight_decay=1e-4,amsgrad=True)
        accent_counts = pd.read_csv(args.train_manifest, header=None).iloc[:,[-1]].apply(pd.value_counts).to_dict()
        disc_loss_weights = torch.zeros(len(accent)) + eps
        for accent_type_f in accent_counts:
            if isinstance(accent_counts[accent_type_f], dict):
                for accent_type_in_f in accent_counts[accent_type_f]:
                    if accent_type_in_f in accent_dict:
                        disc_loss_weights[accent_dict[accent_type_in_f]] += accent_counts[accent_type_f][accent_type_in_f]
        disc_loss_weights = torch.sum(disc_loss_weights) / disc_loss_weights     
        dis_loss = nn.CrossEntropyLoss(weight=disc_loss_weights.to(device))
        models['discrimator'] = [discriminator, dis_loss, discriminator_optimizer] 
        


Load from /home/Data/etc/Robust_ASR/exp/0106_dawn_hacka/models/ckpt_2_9114.pth succeed.


In [14]:
# iterer = iter(train_loader)

In [7]:
# data = next(iterer)
# inputs, targets, input_percentages, target_sizes, accents = data
# input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
# inputs = inputs.to(device)

In [8]:
# z,updated_lengths = encoder(inputs,input_sizes.type(torch.LongTensor).to(device)) # Encoder network
# decoder_out = decoder(z) # Decoder network
# asr_out, asr_out_sizes = asr(z, updated_lengths) # Predictor network
# # Loss                
# asr_out = asr_out.transpose(0, 1)  # TxNxH
# asr_loss = criterion(asr_out.float(), targets, asr_out_sizes.cpu(), target_sizes).to(device)
# asr_loss = asr_loss / updated_lengths.size(0)  # average the loss by minibatch
# decoder_loss = dec_loss.forward(inputs, decoder_out, input_sizes,device) * alpha

In [6]:
# Printing the parameters of all the different modules 
[print(f"Number of parameters for {i[0]} in Million is: {DeepSpeech.get_param_size(i[1][0])/1000000}") for i in models.items()]
# [print(f"Number of parameters for {i[0]} in Million is: {DeepSpeech.get_param_size(i[1][0])/1000000}") for i in models.module.items()]
accent_list = sorted(accent, key=lambda x:accent[x])
a = f"epoch,wer,cer,acc,"
for accent_type in accent_list:
    a += f"precision_{accent_type},"
for accent_type in accent_list:
    a += f"recall_{accent_type},"
for accent_type in accent_list:
    a += f"f1_{accent_type},"
a += "d_avg_loss,p_avg_loss\n"

Number of parameters for predictor in Million is: 67.165184
Number of parameters for encoder in Million is: 0.949728
Number of parameters for decoder in Million is: 0.251009
Number of parameters for forget_net in Million is: 0.949728
Number of parameters for discrimator in Million is: 6.876227


In [7]:
# To choose the number of times update the discriminator
update_rule = args.update_rule
prob = np.geomspace(1, update_rule*100000, num=update_rule)[::-1]
prob = prob / np.sum(prob)
prob_ = [0 for i in prob]
prob_[-1] = 1
print(f"Initial Probability to udpate to the discrimiantor: {prob}")
diff = np.array([ prob[i] - prob[-1-i] for i in range(len(prob))])
#     diff /= len(train_sampler)*args.num_epochs
diff /= len(train_loader)*(args.num_epochs/2)

#reading weights for different losses
alpha = args.mw_alpha
beta = args.mw_beta
gamma = args.mw_gamma

Initial Probability to udpate to the discrimiantor: [9.999950e-01 4.999975e-06]


In [None]:
 for epoch in range(start_epoch, args.epochs):
    [i[0].train() for i in models.values()] # putting all the models in training state
#     [i[0].train() for i in models.module.values()] # putting all the models in training state
    start_epoch_time = time.time()
    p_counter, d_counter = eps, eps

    for i, (data) in enumerate(train_loader, start=start_iter):
#         if i == len(train_sampler):
        if i == len(train_loader)-1:
            break
        if args.dummy and i%2 == 1: break

        # Data loading
        inputs, targets, input_percentages, target_sizes, accents = data
        input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
        inputs = inputs.to(device)

        if args.checkpoint_per_batch > 0 and i > 0 and (i + 1) % args.checkpoint_per_batch == 0:
            package = {'models': models, 'start_epoch': epoch + 1, 'best_wer': best_wer, 'best_cer': best_cer, 'poor_cer_list': poor_cer_list, 'start_iter': i}
            torch.save(package, os.path.join(save_folder, f"ckpt_{epoch+1}_{i+1}.pth"))
        if args.train_asr: # Only trainig the ASR component

            [m[-1].zero_grad() for m in models.values() if m[-1] is not None] #making graidents zero
#             [m[-1].zero_grad() for m in models.module.values() if m[-1] is not None] #making graidents zero
            p_counter += 1
            # Forward pass
            z,updated_lengths = encoder(inputs,input_sizes.type(torch.LongTensor).to(device)) # Encoder network
            decoder_out = decoder(z) # Decoder network
            asr_out, asr_out_sizes = asr(z, updated_lengths) # Predictor network
            # Loss                
            asr_out = asr_out.transpose(0, 1)  # TxNxH
            asr_loss = criterion(asr_out.float(), targets, asr_out_sizes.cpu(), target_sizes).to(device)
            asr_loss = asr_loss / updated_lengths.size(0)  # average the loss by minibatch
            decoder_loss = dec_loss.forward(inputs, decoder_out, input_sizes,device) * alpha
            loss = asr_loss + decoder_loss
            p_loss = loss.item()
            p_avg_loss += p_loss

            loss.backward()
            ed_optimizer.step()
            asr_optimizer.step()

            # Logging to tensorboard and train.log.
#             writer.add_scalar('Train/Predictor-Per-Iteration-Loss', p_loss, len(train_sampler)*epoch+i+1) # Predictor-loss in the current iteration.
#             writer.add_scalar('Train/Predictor-Avergae-Loss-Cur-Epoch', p_avg_loss/p_counter, len(train_sampler)*epoch+i+1) # Average predictor-loss uptil now in current epoch.
#             print(f"Epoch: [{epoch+1}][{i+1}/{len(train_sampler)}]\t predictor Loss: {round(p_loss,4)} ({round(p_avg_loss/p_counter,4)})") 
            if i % 10 ==0:
                print(f"Epoch: [{epoch+1}][{i+1}/{len(train_loader)}]\t predictor Loss: {round(p_loss,4)} ({round(p_avg_loss/p_counter,4)})") 
#             del(z); del(updated_lengths); del(decoder_out); del(asr_out); del(asr_out_sizes);
#             del(asr_loss); del(decoder_loss); 
#             torch.cuda.empty_cache()
            continue
            
        if args.num_epochs > epoch: 
            prob -= diff
        else: 
            prob = prob_ 
        update_rule = np.random.choice(args.update_rule, 1, p=prob) + 1 # mostly 1
        d_avg_loss_iter = eps

        for k in range(int(update_rule)): #updating the discriminator only  
   
            d_counter += 1
            [m[-1].zero_grad() for m in models.values() if m[-1] is not None] #making graidents zero
#             [m[-1].zero_grad() for m in models.module.values() if m[-1] is not None] #making graidents zero

            # Data loading
            try: inputs_, targets_, input_percentages_, target_sizes_, accents_ = next(disc_)
            except:
#                 disc_train_sampler.shuffle(start_epoch)
                disc_ = iter(disc_train_loader)
                inputs_, targets_, input_percentages_, target_sizes_, accents_ = next(disc_)

            input_sizes_ = input_percentages_.mul_(int(inputs_.size(3))).int()
            inputs_ = inputs_.to(device)
            accents_ = torch.tensor(accents_).to(device)
            # Forward pass
            z,updated_lengths = encoder(inputs_,input_sizes_.type(torch.LongTensor).to(device)) # Encoder network
            m = fnet(inputs_,input_sizes_.type(torch.LongTensor).to(device)) # Forget network
            z_ = z * m # Forget Operation
            discriminator_out = discriminator(z_) # Discriminator network
            # Loss
            discriminator_loss = dis_loss(discriminator_out, accents_)
            d_loss = discriminator_loss.item()
            d_avg_loss += d_loss
            d_avg_loss_iter += d_loss

            discriminator_loss.backward()
            discriminator_optimizer.step()

            if i % 20 == 0:
#             print(f"Epoch: [{epoch+1}][{i+1,k+1}/{len(train_sampler)}]\t\t\t\t\t Discriminator Loss: {round(d_loss,4)} ({round(d_avg_loss/d_counter,4)})")
                print(f"Epoch: [{epoch+1}][{i+1,k+1}/{len(train_loader)}]\t\t\t\t\t Discriminator Loss: {round(d_loss,4)} ({round(d_avg_loss/d_counter,4)})")

            del(z); del(updated_lengths); del(m); del(z_); del(discriminator_out); 
            del(discriminator_loss);  del(d_loss)
            torch.cuda.empty_cache()
    
    
            # Random labels for adversarial learning of the predictor network                
            # Shuffling the elements of a list s.t. elements are not same at the same indices
        dummy = [] 
        for acce in accents:
            while True:
                d = random.randint(0,len(accent)-1)
                if acce != d:
                    dummy.append(d)
                    break
        accents = torch.tensor(dummy).to(device)

        [m[-1].zero_grad() for m in models.values() if m[-1] is not None] #making graidents zero
#         [m[-1].zero_grad() for m in models.module.values() if m[-1] is not None] #making graidents zero
        p_counter += 1

        # Forward pass
        z,updated_lengths = encoder(inputs,input_sizes.type(torch.LongTensor).to(device)) # Encoder network
        decoder_out = decoder(z) # Decoder network
        m = fnet(inputs,input_sizes.type(torch.LongTensor).to(device)) # Forget network
        z_ = z * m # Forget Operation
        discriminator_out = discriminator(z_) # Discriminator network
        asr_out, asr_out_sizes = asr(z_, updated_lengths) # Predictor network
        # Loss
        discriminator_loss = dis_loss(discriminator_out, accents) * beta
        p_d_loss = discriminator_loss.item()
        p_d_avg_loss += p_d_loss
        mask_regulariser_loss = (m * (1-m)).mean() * gamma

        asr_out = asr_out.transpose(0, 1)  # TxNxH
        asr_loss = criterion(asr_out.float(), targets, asr_out_sizes.cpu(), target_sizes).to(device)
        asr_loss = asr_loss / updated_lengths.size(0)  # average the loss by minibatch
        decoder_loss = dec_loss.forward(inputs, decoder_out, input_sizes,device) * alpha
        loss = asr_loss + decoder_loss + mask_regulariser_loss
        p_loss = loss.item()
        p_avg_loss += p_loss

        discriminator_loss.backward(retain_graph=True)
        ed_optimizer.zero_grad()
        loss.backward()
        ed_optimizer.step()
        asr_optimizer.step()
        fnet_optimizer.step()
        
        if i % 20 == 0:
#         print(f"Epoch: [{epoch+1}][{i+1}/{len(train_sampler)}]\t predictor Loss: {round(p_loss,4)} ({round(p_avg_loss/p_counter,4)})\t dummy_discriminator Loss: {round(p_d_loss,4)} ({round(p_d_avg_loss/p_counter,4)})") 
            print(f"Epoch: [{epoch+1}][{i+1}/{len(train_loader)}]\t predictor Loss: {round(p_loss,4)} ({round(p_avg_loss/p_counter,4)})\t dummy_discriminator Loss: {round(p_d_loss,4)} ({round(p_d_avg_loss/p_counter,4)})") 
        
#         del(z); del(updated_lengths); del(m); del(z_); del(discriminator_out); del(decoder_out); del(asr_out); del(asr_out_sizes);
#         del(discriminator_loss); del(decoder_loss); del(asr_loss); del(p_d_loss); del(p_loss)
        torch.cuda.empty_cache()
        
    d_avg_loss /= d_counter
    p_avg_loss /= p_counter
    epoch_time = time.time() - start_epoch_time
    start_iter = 0
    print('Training Summary Epoch: [{0}]\t'
          'Time taken (s): {1}\t'
          'D/P average Loss {2}, {3}\t'.format(epoch + 1, epoch_time, round(d_avg_loss,4),round(p_avg_loss,4)))

    start_ter = 0
    with torch.no_grad():
        total_cer, total_wer, num_tokens, num_chars = eps, eps, eps, eps
        conf_mat = np.ones((len(accent), len(accent)))*eps # ground-truth: dim-0; predicted-truth: dim-1;
        tps, fps, tns, fns = np.ones((len(accent)))*eps, np.ones((len(accent)))*eps, np.ones((len(accent)))*eps, np.ones((len(accent)))*eps # class-wise TP, FP, TN, FN
        acc_weights = np.ones((len(accent)))*eps
        length, num = eps, eps
        #Decoder used for evaluation
        target_decoder = GreedyDecoder(labels)
        for i, (data) in enumerate(test_loader):
            if args.dummy and i%2 == 1: break

            # Data loading
            inputs, targets, input_percentages, target_sizes, accents = data
            input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
            inputs = inputs.to(device)

            # Forward pass
            if not args.train_asr:
                z,updated_lengths = encoder(inputs,input_sizes.type(torch.LongTensor).to(device)) # Encoder network
                m = fnet(inputs,input_sizes.type(torch.LongTensor).to(device)) # Forget network
                z_ = z * m # Forget Operation
                discriminator_out = discriminator(z_) # Discriminator network
                asr_out, asr_out_sizes = asr(z_, updated_lengths) # Predictor network
                
                del(z); del(updated_lengths); del(m); del(z_); 
            else:
                z,updated_lengths = encoder(inputs,input_sizes.type(torch.LongTensor).to(device)) # Encoder network
                decoder_out = decoder(z) # Decoder network
                asr_out, asr_out_sizes = asr(z, updated_lengths) # Predictor network
                del(z); del(updated_lengths); del(decoder_out);
            # Predictor metric
            split_targets = []
            offset = 0
            for size in target_sizes:
                split_targets.append(targets[offset:offset + size])
                offset += size
            decoded_output, _ = target_decoder.decode(asr_out, asr_out_sizes)
            target_strings = target_decoder.convert_to_strings(split_targets)

            for x in range(len(target_strings)):
                transcript, reference = decoded_output[x][0], target_strings[x][0]
                wer_inst = target_decoder.wer(transcript, reference)
                cer_inst = target_decoder.cer(transcript, reference)
                total_wer += wer_inst
                total_cer += cer_inst
                num_tokens += len(reference.split())
                num_chars += len(reference.replace(' ', ''))

            wer = float(total_wer) / num_tokens
            cer = float(total_cer) / num_chars

            if not args.train_asr:
                # Discriminator metrics: fill in the confusion matrix.
                out, predicted = torch.max(discriminator_out, 1)
                for j in range(len(accents)):
                    acc_weights[accents[j]] += 1
                    if accents[j] == predicted[j].item():
                        num = num + 1
                    conf_mat[accents[j], predicted[j].item()] += 1
                length = length + len(accents)
                
        # Discriminator metrics: compute metrics using confustion metrics.
    for acc_type in range(len(accent)):
        tps[acc_type] = conf_mat[acc_type, acc_type]
        fns[acc_type] = np.sum(conf_mat[acc_type, :]) - tps[acc_type]
        fps[acc_type] = np.sum(conf_mat[:, acc_type]) - tps[acc_type]
        tns[acc_type] = np.sum(conf_mat) - tps[acc_type] - fps[acc_type] - fns[acc_type]
    class_wise_precision, class_wise_recall = tps/(tps+fps), tps/(fns+tps)
    class_wise_f1 = 2 * class_wise_precision * class_wise_recall / (class_wise_precision + class_wise_recall)
    macro_precision, macro_recall, macro_accuracy = np.mean(class_wise_precision), np.mean(class_wise_recall), np.mean((tps+tns)/(tps+fps+fns+tns))
    weighted_precision, weighted_recall = ((acc_weights / acc_weights.sum()) * class_wise_precision).sum(), ((acc_weights / acc_weights.sum()) * class_wise_recall).sum()
    weighted_f1 = 2 * weighted_precision * weighted_recall / (weighted_precision + weighted_recall)
    micro_precision, micro_recall, micro_accuracy = tps.sum()/(tps.sum()+fps.sum()), tps.sum()/(fns.sum()+tps.sum()), (tps.sum()+tns.sum())/(tps.sum()+tns.sum()+fns.sum()+fps.sum())
    micro_f1, macro_f1 = 2*micro_precision*micro_recall/(micro_precision+micro_recall), 2*macro_precision*macro_recall/(macro_precision+macro_recall)

    # Logging to tensorboard.
#     writer.add_scalar('Validation/Average-WER', wer, epoch+1)
#     writer.add_scalar('Validation/Average-CER', cer, epoch+1)
#     writer.add_scalar('Validation/Discriminator-Accuracy', num/length *100, epoch+1)
#     writer.add_scalar('Validation/Discriminator-Precision', weighted_precision, epoch+1)
#     writer.add_scalar('Validation/Discriminator-Recall', weighted_recall, epoch+1)
#     writer.add_scalar('Validation/Discriminator-F1', weighted_f1, epoch+1)

    print('Validation Summary Epoch: [{0}]\t'
            'Average WER {wer:.3f}\t'
            'Average CER {cer:.3f}\t'
            'Accuracy {acc_: .3f}\t'
            'Discriminator accuracy (micro) {acc: .3f}\t'
#             'Discriminator precision (micro) {pre: .3f}\t'
#             'Discriminator recall (micro) {rec: .3f}\t'
            'Discriminator F1 (micro) {f1: .3f}\t'.format(epoch + 1, wer=wer, cer=cer, acc_ = num/length *100 , acc=micro_accuracy, pre=weighted_precision, rec=weighted_recall, f1=weighted_f1))


    a += f"{epoch},{wer},{cer},{num/length *100},"

    for idx, accent_type in enumerate(accent_list):
        a += f"{class_wise_precision[idx]},"
    for idx, accent_type in enumerate(accent_list):
        a += f"{class_wise_recall[idx]},"
    for idx, accent_type in enumerate(accent_list):
        a += f"{class_wise_f1[idx]},"
    a += f"{d_avg_loss},{p_avg_loss},{alpha},{beta},{gamma}\n"

    with open(loss_save, "w") as f:
        f.write(a)

    d_avg_loss, p_avg_loss, p_d_avg_loss = 0, 0, 0

    # anneal lr
    for j in models.values():
#     for j in models.module.values():
        if j[-1]:
            for g in j[-1].param_groups:
                g['lr'] = g['lr'] / args.learning_anneal
    print('Learning rate annealed to: {lr:.6f}'.format(lr=g['lr']))

    # Exiting criteria
    terminate_train = False
    if best_cer is None or best_cer > cer:
        best_cer = cer
        poor_cer_list = []
    else:
        poor_cer_list.append(cer)
        if len(poor_cer_list) >= args.patience:
            print("Exiting training loop...")
            terminate_train = True

    if best_wer is None or best_wer > wer:
        best_wer = wer
        print("Updating the final model!")
        package = {'models': models, 'start_epoch': epoch+1, 'best_wer': best_wer, 'best_cer': best_cer, 'poor_cer_list': poor_cer_list, 'start_iter': None}
#         package = {'models': models.module, 'start_epoch': epoch+1, 'best_wer': best_wer, 'best_cer': best_cer, 'poor_cer_list': poor_cer_list, 'start_iter': None}
        torch.save(package, os.path.join(save_folder, f"ckpt_final.pth"))

    if args.checkpoint:
        package = {'models': models, 'start_epoch': epoch+1, 'best_wer': best_wer, 'best_cer': best_cer, 'poor_cer_list': poor_cer_list, 'start_iter': None}
#         package = {'models': models.module, 'start_epoch': epoch+1, 'best_wer': best_wer, 'best_cer': best_cer, 'poor_cer_list': poor_cer_list, 'start_iter': None}
        torch.save(package, os.path.join(save_folder, f"ckpt_{epoch+1}.pth"))

    if terminate_train:
        break
    if args.train_asr:
        del(asr_out); del(asr_out_sizes);
    else:
        del(asr_out); del(asr_out_sizes);del(discriminator_out);     
    torch.cuda.empty_cache()
    if not args.no_shuffle:
        print("Shuffling batches...")
        train_sampler.shuffle(epoch)

Epoch: [1][(1, 1)/11375]					 Discriminator Loss: 0.7733 (0.7733)
Epoch: [1][1/11375]	 predictor Loss: 60.5113 (60.5113)	 dummy_discriminator Loss: 0.2833 (0.2833)
Epoch: [1][(21, 1)/11375]					 Discriminator Loss: 0.8016 (0.7386)
Epoch: [1][21/11375]	 predictor Loss: 41.4233 (55.0981)	 dummy_discriminator Loss: 0.2763 (0.2891)
Epoch: [1][(41, 1)/11375]					 Discriminator Loss: 0.8334 (0.7283)
Epoch: [1][41/11375]	 predictor Loss: 45.1988 (51.7612)	 dummy_discriminator Loss: 0.2905 (0.2897)
Epoch: [1][(61, 1)/11375]					 Discriminator Loss: 0.5743 (0.7227)
Epoch: [1][61/11375]	 predictor Loss: 46.8441 (49.086)	 dummy_discriminator Loss: 0.2985 (0.2917)


In [16]:
package = {'models': models, 'start_epoch': epoch + 1, 'best_wer': best_wer, 'best_cer': best_cer, 'poor_cer_list': poor_cer_list, 'start_iter': i}
torch.save(package, os.path.join(save_folder, f"ckpt_{epoch+1}_{i+1}.pth"))

In [None]:
# python train.py --train-manifest data/csvs/train_sorted_EN_US.csv --val-manifest data/csvs/dev_sorted_EN_US.csv --cuda --rnn-type gru --hidden-layers 5 --hidden-size 1024 --epochs 50 --lr 0.001 --batch-size 32 --gpu-rank 0 --update-rule 1 --exp-name ./exp/1224/ --mw-alpha 0.1 --mw-beta 0.2 --mw-gamma 0.6 --enco-modules 2 --enco-res --forg-modules 2 --forg-res --num-epochs 1 --checkpoint-per-batch 5000 

In [None]:
with torch.no_grad():
    total_cer, total_wer, num_tokens, num_chars = eps, eps, eps, eps
    conf_mat = np.ones((len(accent), len(accent)))*eps # ground-truth: dim-0; predicted-truth: dim-1;
    tps, fps, tns, fns = np.ones((len(accent)))*eps, np.ones((len(accent)))*eps, np.ones((len(accent)))*eps, np.ones((len(accent)))*eps # class-wise TP, FP, TN, FN
    acc_weights = np.ones((len(accent)))*eps
    length, num = eps, eps
    #Decoder used for evaluation
    target_decoder = GreedyDecoder(labels)
    for i, (data) in enumerate(test_loader):
        if args.dummy and i%2 == 1: break

        # Data loading
        inputs, targets, input_percentages, target_sizes, accents = data
        input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
        inputs = inputs.to(device)

        # Forward pass
        if not args.train_asr:
            z,updated_lengths = encoder(inputs,input_sizes.type(torch.LongTensor).to(device)) # Encoder network
            m = fnet(inputs,input_sizes.type(torch.LongTensor).to(device)) # Forget network
            z_ = z * m # Forget Operation
            discriminator_out = discriminator(z_) # Discriminator network
            asr_out, asr_out_sizes = asr(z_, updated_lengths) # Predictor network

            del(z); del(updated_lengths); del(m); del(z_); 
        else:
            z,updated_lengths = encoder(inputs,input_sizes.type(torch.LongTensor).to(device)) # Encoder network
            decoder_out = decoder(z) # Decoder network
            asr_out, asr_out_sizes = asr(z, updated_lengths) # Predictor network
            del(z); del(updated_lengths); del(decoder_out);
        # Predictor metric
        split_targets = []
        offset = 0
        for size in target_sizes:
            split_targets.append(targets[offset:offset + size])
            offset += size
        decoded_output, _ = target_decoder.decode(asr_out, asr_out_sizes)
        target_strings = target_decoder.convert_to_strings(split_targets)

In [None]:
_ = [m[0].eval() for m in models.values() if m[-1] is not None]

In [None]:
target_decoder = GreedyDecoder(labels)

In [None]:
iterer = iter(test_loader)

In [None]:
data = next(iterer)

In [None]:
inputs, targets, input_percentages, target_sizes, accents = data
input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
inputs = inputs.to(device)

# Forward pass
if not args.train_asr:
    z,updated_lengths = encoder(inputs,input_sizes.type(torch.LongTensor).to(device)) # Encoder network
    m = fnet(inputs,input_sizes.type(torch.LongTensor).to(device)) # Forget network
    z_ = z * m # Forget Operation
    discriminator_out = discriminator(z_) # Discriminator network
    asr_out, asr_out_sizes = asr(z_, updated_lengths) # Predictor network

    del(z); del(updated_lengths); del(m); del(z_); 
else:
    z,updated_lengths = encoder(inputs,input_sizes.type(torch.LongTensor).to(device)) # Encoder network
    decoder_out = decoder(z) # Decoder network
    asr_out, asr_out_sizes = asr(z, updated_lengths) # Predictor network
    del(z); del(updated_lengths); del(decoder_out);
# Predictor metric
split_targets = []
offset = 0
for size in target_sizes:
    split_targets.append(targets[offset:offset + size])
    offset += size
decoded_output, _ = target_decoder.decode(asr_out, asr_out_sizes)
target_strings = target_decoder.convert_to_strings(split_targets)

In [None]:
print(target_strings)
print(decoded_output)