In [51]:
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification, AutoModel
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import MSELoss

In [52]:
model = AutoModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True)
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")

In [53]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import MSELoss


# sets the smallest (in terms of magnitude) weights to 0
def get_sparse_weights(data, percentile=.25): 
    flat = data.flatten()
    values, _ = flat.abs().sort(descending=False)
    idx_to_get = int(percentile * flat.size(0))
    cutoff_val = values[idx_to_get]
    mask = torch.where(data.abs() > cutoff_val, 1, 0)
    return data * mask


class DeepSub(nn.Module):
    def __init__(self, input_dim, inner_rank, output_dim) -> None:
        super().__init__()
        self.l1 = nn.Linear(in_features=input_dim, out_features=inner_rank)
        self.act = nn.GELU()
        self.l2 = nn.Linear(in_features=inner_rank, out_features=output_dim)

    def forward(self, x):
        x = self.l1(x)
        x = self.act(x)
        x = self.l2(x)
        return x


def truncated_svd(W, l):
    """Compress the weight matrix W of an inner product (fully connected) layer
    using truncated SVD.
    Parameters:
    W: N x M weights matrix
    l: number of singular values to retain
    Returns:
    Ul, L: matrices such that W \approx Ul*L
    """

    U, s, V = torch.linalg.svd(W)

    Ul = U[:, :l]
    sl = s[:l]
    V = V.t()
    Vl = V[:l, :]

    SV = torch.mm(torch.diag(sl), Vl)
    return Ul, SV


def get_svd_ffn(w1, w2, l, bias=False): 
    ul1, sv1 = truncated_svd(w1, l)
    ul2, sv2 = truncated_svd(w2, l)

    w1_ffn_sv = nn.Linear(sv1.size(1), sv1.size(0), bias=bias)
    w1_ffn_sv.weight.data = sv1
    w1_ffn_ul = nn.Linear(ul1.size(1), ul1.size(0), bias=bias)
    w1_ffn_ul.weight.data = ul1

    w2_ffn_sv = nn.Linear(sv2.size(1), sv2.size(0), bias=bias)
    w2_ffn_sv.weight.data = sv2
    w2_ffn_ul = nn.Linear(ul2.size(1), ul2.size(0), bias=bias)
    w2_ffn_ul.weight.data = ul2
    svd_module = nn.Sequential(w1_ffn_sv, w1_ffn_ul, w2_ffn_sv, w2_ffn_ul)
    return svd_module
    
def train_deep_sub(deep_sub,
                   gt_module,
                   training_iter, 
                  input_size,
                  l=200):
    criterion = MSELoss()
    optimizer = Adam(deep_sub.parameters(), lr=0.001)
    for _ in range(training_iter):
        rand_batch = torch.randn((512, input_size))
        optimizer.zero_grad()
        output = deep_sub(rand_batch)
        # true val calc 
        x = gt_module[0](rand_batch)
        true_val = gt_module[1](x, rand_batch)
        loss = criterion(output, true_val)
        loss.backward()
        optimizer.step()
        print(loss.item())
    return deep_sub


def train_deep_sub_svd(deep_sub,
                       svd_module,
                        gt_module,
                        training_iter, 
                        input_size, 
                        batch_size=512,
                        lr=.001):
    criterion = MSELoss()
    optimizer = Adam(deep_sub.parameters(), lr=lr)
    for _ in range(training_iter):
        rand_batch = torch.randn((batch_size, input_size))
        optimizer.zero_grad()
        output = deep_sub(rand_batch)
        svd_output = svd_module(rand_batch)

        # true val calc 
        x = gt_module[0](rand_batch)
        true_val = gt_module[1](x, rand_batch)

        loss = criterion(output+svd_output, true_val)
        loss.backward()
        optimizer.step()
        print(loss.item())
    return deep_sub

In [21]:
test = model.bert.encoder.layer[0].intermediate.dense.weight.data
sparse_test = get_sparse_weights(test)
model.bert.encoder.layer[0].intermediate.dense.weight.data = sparse_test

test = model.bert.encoder.layer[0].output.dense.weight.data
sparse_test = get_sparse_weights(test)
model.bert.encoder.layer[0].output.dense.weight.data = sparse_test 

gt_module = nn.Sequential(model.bert.encoder.layer[0].intermediate, model.bert.encoder.layer[0].output)
svd_module = get_svd_ffn(model.bert.encoder.layer[0].intermediate.dense.weight.data, 
                         model.bert.encoder.layer[0].output.dense.weight.data, l=200)
deep_sub = nn.Sequential(DeepSub(768, 100, 100), nn.GELU(), DeepSub(100, 100, 768))
train_deep_sub(svd_module, gt_module, 10000, 768)

2.3404808044433594
1.649192452430725
1.2274082899093628
0.9792707562446594
0.8256660103797913
0.7215384840965271
0.6590554118156433
0.6354835033416748
0.6165772676467896
0.6021355390548706
0.5957430601119995
0.5993853211402893
0.5984617471694946
0.5904088616371155
0.5991478562355042
0.5961518883705139
0.5909873843193054
0.5828625559806824
0.5840962529182434
0.5765641331672668
0.5704401135444641
0.5655143857002258
0.546519935131073
0.5411257147789001
0.5351696610450745
0.5282106995582581
0.5198228359222412
0.521476686000824
0.5156427621841431
0.5152456164360046
0.5135045647621155
0.509833037853241
0.5036778450012207
0.5043264031410217
0.5029941201210022
0.5068193078041077
0.5033275485038757
0.4946349859237671
0.49859336018562317
0.4955359995365143
0.49355193972587585
0.4932592809200287
0.49073925614356995
0.4943779706954956
0.4904993772506714
0.4865988492965698
0.49055835604667664
0.4930294454097748
0.485507994890213
0.4822685718536377
0.48766061663627625
0.48225077986717224
0.481783241

0.350321501493454
0.3496917188167572
0.34987542033195496
0.34694841504096985
0.3513769209384918
0.34780797362327576
0.35175248980522156
0.3498803675174713
0.34928086400032043
0.34991884231567383
0.34875354170799255
0.34752658009529114
0.34995999932289124
0.34552130103111267
0.3482511341571808
0.3483083248138428
0.34465697407722473
0.3441535532474518
0.34822916984558105
0.34907206892967224
0.3445614278316498
0.34781065583229065
0.35003602504730225
0.3446735441684723
0.3438323438167572
0.348148375749588
0.34640952944755554
0.34925273060798645
0.3499580919742584
0.34717798233032227
0.3510461747646332
0.34587332606315613
0.34605535864830017
0.3427976071834564
0.3449205160140991
0.34468162059783936
0.3440041244029999
0.34565678238868713
0.343302845954895
0.34157875180244446
0.34422925114631653
0.34815487265586853
0.34594473242759705
0.3428768813610077
0.3462843894958496
0.34223851561546326
0.343904048204422
0.347127765417099
0.3429809510707855
0.3454647958278656
0.3440908193588257
0.3454361

0.32208314538002014
0.31891900300979614
0.32180216908454895
0.32325413823127747
0.3160083591938019
0.3223205506801605
0.320821076631546
0.3169354200363159
0.3220958411693573
0.32158324122428894
0.31611526012420654
0.3227801024913788
0.3220151364803314
0.3205329477787018
0.322543740272522
0.31747356057167053
0.3186585009098053
0.31800577044487
0.32051265239715576
0.3216288089752197
0.3184365928173065
0.31564047932624817
0.31661123037338257
0.3199828565120697
0.3172135651111603
0.31926169991493225
0.318606972694397
0.32277804613113403
0.31887194514274597
0.3156932592391968
0.3209262788295746
0.3181702792644501
0.31969502568244934
0.3182297945022583
0.3167916238307953
0.3182806074619293
0.3211977183818817
0.3184356689453125
0.31584879755973816
0.32304656505584717
0.31899312138557434
0.3157851994037628
0.3179066479206085
0.3192986845970154
0.31561756134033203
0.321359783411026
0.31697002053260803
0.32131120562553406
0.31962883472442627
0.31718119978904724
0.31722959876060486
0.318067193031

0.3118942677974701
0.31211403012275696
0.3178589642047882
0.30927804112434387
0.3112354576587677
0.3078344166278839
0.31267401576042175
0.3113599121570587
0.31250548362731934
0.3118552267551422
0.3115105926990509
0.3090011179447174
0.31394001841545105
0.31269291043281555
0.31384262442588806
0.31297579407691956
0.31287631392478943
0.3128879964351654
0.31173279881477356
0.3134079575538635
0.3106342554092407
0.311879426240921
0.3112024962902069
0.31101393699645996
0.3133946657180786
0.3114301264286041
0.3124985694885254
0.3132759928703308
0.3105648458003998
0.31140679121017456
0.3141258656978607
0.31004080176353455
0.30916300415992737
0.31093618273735046
0.30845484137535095
0.3094886839389801
0.3100614547729492
0.30681079626083374
0.3108823597431183
0.30915912985801697
0.30851617455482483
0.30993756651878357
0.312276691198349
0.3115374445915222
0.3096108138561249
0.3069398105144501
0.31125593185424805
0.31154653429985046
0.31386879086494446
0.31787386536598206
0.31481316685676575
0.311660

0.30955079197883606
0.3069211542606354
0.30831819772720337
0.3115195035934448
0.3094204366207123
0.3112686574459076
0.3087814152240753
0.30719009041786194
0.3098706901073456
0.30657273530960083
0.3053613603115082
0.30732330679893494
0.31198635697364807
0.3088266849517822
0.3086428940296173
0.30903196334838867
0.30719876289367676
0.31300970911979675
0.3077813684940338
0.30508527159690857
0.3106975257396698
0.3098454475402832
0.30516117811203003
0.3093340992927551
0.30663028359413147
0.3098139464855194
0.31086215376853943
0.31010711193084717
0.3089423179626465
0.30681583285331726
0.30942609906196594
0.3083515167236328
0.31011006236076355
0.3097356855869293
0.30875274538993835
0.3048510253429413
0.30699872970581055
0.30433765053749084
0.3120148181915283
0.309268981218338
0.30729812383651733
0.30715131759643555
0.30879464745521545
0.3097873628139496
0.3083200752735138
0.3124532103538513
0.3048943877220154
0.30886781215667725
0.30797556042671204
0.31282302737236023
0.309712678194046
0.31394

0.3050364553928375
0.30558279156684875
0.306695818901062
0.3053090572357178
0.30868110060691833
0.3080853223800659
0.306901216506958
0.3054593801498413
0.3098112940788269
0.3089512586593628
0.3061074912548065
0.30882909893989563
0.3102692663669586
0.3067730665206909
0.305799275636673
0.30767762660980225
0.31051796674728394
0.3037709891796112
0.30585581064224243
0.30924192070961
0.30882084369659424
0.3035553991794586
0.30345237255096436
0.30526307225227356
0.30576011538505554
0.3042474687099457
0.3055218756198883
0.3083932101726532
0.30388641357421875
0.30664265155792236
0.30547642707824707
0.29978129267692566
0.30308273434638977
0.30595701932907104
0.3042714297771454
0.3085534870624542
0.3067210912704468
0.30502012372016907
0.3052319288253784
0.309395968914032
0.30611947178840637
0.306564062833786
0.30750030279159546
0.30729544162750244
0.310699999332428
0.30863675475120544
0.308470219373703
0.3081340491771698
0.3036077916622162
0.30667808651924133
0.3047387897968292
0.3064962327480316

0.30795636773109436
0.30424556136131287
0.3066618740558624
0.3013041913509369
0.30484917759895325
0.30852749943733215
0.30664440989494324
0.30573153495788574
0.3079487979412079
0.3082331418991089
0.30569395422935486
0.30235186219215393
0.303114652633667
0.3018204867839813
0.3047623932361603
0.3053413927555084
0.3056270182132721
0.3086729645729065
0.3095931112766266
0.3030492961406708
0.30459997057914734
0.3051346242427826
0.3073385953903198
0.3002837598323822
0.30642223358154297
0.303609162569046
0.3079165518283844
0.3036654591560364
0.3073999583721161
0.3046078383922577
0.3063281774520874
0.3065461814403534
0.3046194612979889
0.3061126470565796
0.3083014190196991
0.30419236421585083
0.3073682188987732
0.3026798963546753
0.3084612786769867
0.3046596944332123
0.30595967173576355
0.3032606542110443
0.30321815609931946
0.3045355975627899
0.3070686161518097
0.3042666018009186
0.3058067262172699
0.3010907471179962
0.30002886056900024
0.30852171778678894
0.306382954120636
0.3062776029109955


0.30093351006507874
0.30404582619667053
0.3036467730998993
0.3052516281604767
0.30543938279151917
0.3042318522930145
0.3045300245285034
0.3076654076576233
0.3046133518218994
0.3115554749965668
0.3056958019733429
0.30732157826423645
0.30524906516075134
0.3068418800830841
0.3088814318180084
0.305258184671402
0.3039838969707489
0.3029796779155731
0.3035685122013092
0.3055257797241211
0.3036099970340729
0.3002457320690155
0.30143192410469055
0.3050096333026886
0.3071652948856354
0.3057335913181305
0.30122706294059753
0.30451464653015137
0.30304089188575745
0.30749818682670593
0.3035564124584198
0.29964718222618103
0.3034871518611908
0.306735634803772
0.30566325783729553
0.30732694268226624
0.30336394906044006
0.3072282373905182
0.30569031834602356
0.30664005875587463
0.30621474981307983
0.3054504096508026
0.3076995611190796
0.3005262315273285
0.30397069454193115
0.30572691559791565
0.30799686908721924
0.3031165897846222
0.30159762501716614
0.3075045347213745
0.3071742355823517
0.3042401373

0.3055380880832672
0.3085266053676605
0.305926650762558
0.3061122000217438
0.30656418204307556
0.30532610416412354
0.30559858679771423
0.3033014237880707
0.3020824193954468
0.30657312273979187
0.30576804280281067
0.3068842589855194
0.3053155839443207
0.30716168880462646
0.30154451727867126
0.3064395785331726
0.30520859360694885
0.3050726354122162
0.30977919697761536
0.3053126633167267
0.3050227165222168
0.30658987164497375
0.30881860852241516
0.30368903279304504
0.3053937256336212
0.30480462312698364
0.3022105395793915
0.30525490641593933
0.30432257056236267
0.3071911334991455
0.30788642168045044
0.3064892292022705
0.3047533929347992
0.3058319389820099
0.30465617775917053
0.30304691195487976
0.30911028385162354
0.30033913254737854
0.30473384261131287
0.3052005469799042
0.3024154007434845
0.3033870458602905
0.30694761872291565
0.3069287836551666
0.30851122736930847
0.30684059858322144
0.30390995740890503
0.3050497770309448
0.3049505949020386
0.30517086386680603
0.30689430236816406
0.307

0.30411800742149353
0.30379119515419006
0.3074561357498169
0.308190256357193
0.3054538071155548
0.30183807015419006
0.30487725138664246
0.3051932156085968
0.30414360761642456
0.3032344877719879
0.30574196577072144
0.3034878671169281
0.3052656054496765
0.3057592213153839
0.3042854368686676
0.3060894310474396
0.30180177092552185
0.3025599718093872
0.3025824725627899
0.3064258098602295
0.3025331199169159
0.30462121963500977
0.30727240443229675
0.30475106835365295
0.3095220923423767
0.3046703636646271
0.30387401580810547
0.3043273687362671
0.30709409713745117
0.3013620674610138
0.3036283552646637
0.30146875977516174
0.3027157485485077
0.304330438375473
0.3055386245250702
0.30339792370796204
0.3069184422492981
0.3055126368999481
0.30503806471824646
0.2996222972869873
0.30269697308540344
0.2995803654193878
0.3038876950740814
0.30447062849998474
0.30320999026298523
0.30205756425857544
0.3029981851577759
0.3032018840312958
0.30734989047050476
0.30359959602355957
0.30254697799682617
0.301471590

0.3037968575954437
0.3036439120769501
0.3005921542644501
0.30336347222328186
0.3047146797180176
0.30376186966896057
0.3044221103191376
0.30476054549217224
0.30164095759391785
0.3003555238246918
0.30223938822746277
0.3033798038959503
0.30125370621681213
0.303047239780426
0.3068666160106659
0.3020065426826477
0.3043656647205353
0.3045680522918701
0.2998342216014862
0.3007298409938812
0.30478495359420776
0.3065342605113983
0.30169180035591125
0.30216020345687866
0.3048068583011627
0.3052598237991333
0.3044874966144562
0.3020639419555664
0.30580735206604004
0.3034302890300751
0.3026975095272064
0.3037930428981781
0.3037559986114502
0.3060227632522583
0.3065936863422394
0.30502811074256897
0.3025273084640503
0.3034234344959259
0.3011951744556427
0.3021753430366516
0.3051431477069855
0.3040274977684021
0.30452224612236023
0.3032325804233551
0.30246278643608093
0.3020192086696625
0.305634468793869
0.30442014336586
0.30525389313697815
0.3043877184391022
0.3038351833820343
0.30072981119155884
0

0.30238309502601624
0.30404141545295715
0.3065711259841919
0.30275827646255493
0.3005533814430237
0.3034028112888336
0.30406296253204346
0.30588290095329285
0.30292990803718567
0.30021339654922485
0.3038191497325897
0.304516464471817
0.3026851415634155
0.30434927344322205
0.3028711974620819
0.3041062653064728
0.30497515201568604
0.30406925082206726
0.30164530873298645
0.30587223172187805
0.3063923418521881
0.30421245098114014
0.3025749921798706
0.303029865026474
0.30144044756889343
0.30485427379608154
0.30427736043930054
0.3014928996562958
0.30069541931152344
0.30486294627189636
0.30407026410102844
0.3003677725791931
0.3060655891895294
0.3077623248100281
0.30135539174079895
0.3029704988002777
0.3030455410480499
0.309502512216568
0.3026919960975647
0.30593499541282654
0.30256208777427673
0.3049359619617462
0.30184832215309143
0.30100297927856445
0.3057310879230499
0.30423998832702637
0.3010581433773041
0.30428117513656616
0.3044203519821167
0.3018225133419037
0.30432507395744324
0.30650

0.30342456698417664
0.3041495382785797
0.30006155371665955
0.30714473128318787
0.3023035526275635
0.3059222996234894
0.3033805787563324
0.30007919669151306
0.302760511636734
0.2988903224468231
0.30164480209350586
0.3050288259983063
0.30480432510375977
0.30446457862854004
0.30387046933174133
0.3044661283493042
0.3061789572238922
0.3007860481739044
0.3045044243335724
0.304413765668869
0.3035685122013092
0.30170485377311707
0.30763232707977295
0.30582693219184875
0.30098769068717957
0.29907286167144775
0.3033555746078491
0.3049439489841461
0.2995823323726654
0.3019488453865051
0.29820165038108826
0.3005366921424866
0.30557700991630554
0.30531200766563416
0.3025342524051666
0.30244311690330505
0.3058997392654419
0.30134955048561096
0.29932650923728943
0.30733728408813477
0.3038061857223511
0.30702725052833557
0.30460527539253235
0.3023207485675812
0.30515071749687195
0.2996163070201874
0.30198052525520325
0.30483677983283997
0.3033762574195862
0.3011186420917511
0.3051910102367401
0.302764

0.30179738998413086
0.3020501136779785
0.30241742730140686
0.30057546496391296
0.30280330777168274
0.3002665042877197
0.30583229660987854
0.30678918957710266
0.3042742609977722
0.3075354993343353
0.30633535981178284
0.3059663474559784
0.3011123239994049
0.30423131585121155
0.3028614819049835
0.30544814467430115
0.30309411883354187
0.30259573459625244
0.303043395280838
0.3051382303237915
0.3037464916706085
0.2984599173069
0.3013571798801422
0.30373579263687134
0.3039281964302063
0.3014732897281647
0.30182769894599915
0.30283212661743164
0.30388346314430237
0.3056790828704834
0.3062433898448944
0.3044319152832031
0.3035587966442108
0.3016749322414398
0.30475014448165894
0.30558064579963684
0.3015967607498169
0.3038212060928345
0.3052046000957489
0.29820629954338074
0.3039587438106537
0.3034871816635132
0.30671364068984985
0.3051016330718994
0.3023994266986847
0.3022354543209076
0.3049183785915375
0.30248698592185974
0.30177947878837585
0.3024061620235443
0.30320292711257935
0.30163523554

0.30608102679252625
0.29995495080947876
0.30648764967918396
0.3000235855579376
0.30239516496658325
0.30625858902931213
0.30015724897384644
0.30319008231163025
0.30192288756370544
0.2999323308467865
0.3018503487110138
0.30556514859199524
0.30410468578338623
0.307437539100647
0.30459949374198914
0.3051554262638092
0.3029825985431671
0.30276021361351013
0.3056350350379944
0.3064330518245697
0.3038972020149231
0.3049190938472748
0.3010762333869934
0.3081338405609131
0.3042030930519104
0.30492961406707764
0.30386045575141907
0.3052133023738861
0.30648621916770935
0.30119001865386963
0.2981506884098053
0.3063969016075134
0.3012680113315582
0.30498600006103516
0.30234381556510925
0.30105480551719666
0.3055734932422638
0.3047354221343994
0.30246052145957947
0.3039938509464264
0.30291709303855896
0.3062317967414856
0.3067154884338379
0.3027847111225128
0.30151060223579407
0.30701956152915955
0.3024047613143921
0.301633358001709
0.30457189679145813
0.30512315034866333
0.3012072741985321
0.302913

0.30101144313812256
0.29765644669532776
0.3001026213169098
0.29969584941864014
0.30333879590034485
0.30057817697525024
0.30249840021133423
0.30607447028160095
0.30120614171028137
0.30557993054389954
0.3016386330127716
0.305107980966568
0.3051576316356659
0.3022868037223816
0.3011031448841095
0.3011086881160736
0.30409887433052063
0.3023032248020172
0.298970490694046
0.303038626909256
0.30077746510505676
0.301498144865036
0.30053624510765076
0.3074307143688202
0.3031843304634094
0.29961463809013367
0.3027713894844055
0.30089953541755676
0.302211195230484
0.30294355750083923
0.30058297514915466
0.30011454224586487
0.30552229285240173
0.30148085951805115
0.302699476480484
0.3002242147922516
0.3003470003604889
0.30191171169281006
0.3031102418899536
0.30055147409439087
0.30425989627838135
0.30417877435684204
0.3058822453022003
0.2997850179672241
0.30226895213127136
0.30219516158103943
0.3041298985481262
0.3047182261943817
0.3024924695491791
0.30234646797180176
0.30345866084098816
0.30275264

0.30602702498435974
0.3067704737186432
0.30397701263427734
0.30267098546028137
0.30385062098503113
0.30219101905822754
0.30201461911201477
0.3031119406223297
0.3013870418071747
0.3027261793613434
0.30482813715934753
0.3039734959602356
0.3013588488101959
0.3066151738166809
0.30050745606422424
0.3032921254634857
0.3011375963687897
0.30527010560035706
0.30241450667381287
0.3046700954437256
0.3037513196468353
0.30647996068000793
0.30027705430984497
0.3018352687358856
0.30025193095207214
0.3009707033634186
0.3032834231853485
0.29993703961372375
0.30170387029647827
0.3006817698478699
0.3007507026195526
0.3055121898651123
0.30212101340293884
0.3047168552875519
0.30381524562835693
0.3011346161365509
0.30174028873443604
0.3039512634277344
0.2984258830547333
0.30198726058006287
0.3059837520122528
0.29942095279693604
0.303288072347641
0.30659982562065125
0.30090007185935974
0.30222567915916443
0.3059810400009155
0.30303463339805603
0.30208733677864075
0.3052806854248047
0.3039763271808624
0.30365

0.3010460436344147
0.30617207288742065
0.30172106623649597
0.30586907267570496
0.3075800836086273
0.30578097701072693
0.30480340123176575
0.30211111903190613
0.29892075061798096
0.3027696907520294
0.304402619600296
0.30368947982788086
0.30443719029426575
0.2987178862094879
0.30043789744377136
0.30022868514060974
0.30274397134780884
0.30074822902679443
0.30535992980003357
0.3047730028629303
0.3023768365383148
0.30383387207984924
0.2993490993976593
0.2996397018432617
0.30153775215148926
0.30721786618232727
0.3029869794845581
0.29899588227272034
0.3012109696865082
0.30025357007980347
0.30217310786247253
0.2971784472465515
0.30109062790870667
0.30171307921409607
0.302181601524353
0.30441632866859436
0.30200573801994324
0.30393800139427185
0.3028266131877899
0.30226612091064453
0.3028225898742676
0.3010057210922241
0.3031996786594391
0.3048557937145233
0.3020201623439789
0.30084338784217834
0.3027310371398926
0.30175304412841797
0.30376574397087097
0.2962230145931244
0.29960954189300537
0.2

0.2967964708805084
0.30283552408218384
0.3087123930454254
0.3017314374446869
0.29968908429145813
0.3047638237476349
0.3030667304992676
0.3014824092388153
0.30145105719566345
0.29740315675735474
0.30078601837158203
0.30130282044410706
0.3007033169269562
0.30211594700813293
0.2977635860443115
0.3036635220050812
0.3042864203453064
0.30337122082710266
0.30765488743782043
0.3040757477283478
0.30099794268608093
0.30260804295539856
0.30607926845550537
0.30212071537971497
0.3010975420475006
0.3044254779815674
0.3030618727207184
0.30336514115333557
0.2999952435493469
0.3004852831363678
0.30232295393943787
0.30180832743644714
0.30423107743263245
0.30410948395729065
0.3001764118671417
0.3071165084838867
0.30433422327041626
0.29997918009757996
0.30347439646720886
0.3021332919597626
0.3035639524459839
0.30398228764533997
0.3031297028064728
0.305554062128067
0.3017744719982147
0.30250871181488037
0.30660882592201233
0.3042277991771698
0.3023539185523987
0.3016525208950043
0.30183136463165283
0.30275

0.302186518907547
0.3035227954387665
0.299314022064209
0.30406615138053894
0.3075621724128723
0.3016241490840912
0.3058570325374603
0.3042841851711273
0.3030664324760437
0.3029530346393585
0.30167993903160095
0.29984888434410095
0.300949364900589
0.30440500378608704
0.3032900094985962
0.30081045627593994
0.3031323254108429
0.3003564774990082
0.3009285628795624
0.30054613947868347
0.30475232005119324
0.3013421893119812
0.3003508150577545
0.3015400469303131
0.2978743612766266
0.3021615147590637
0.30030450224876404
0.30388858914375305
0.3038024604320526
0.30041441321372986
0.3034530282020569
0.3025358021259308
0.30466428399086
0.3040246367454529
0.3014668822288513
0.30110111832618713
0.30085790157318115
0.3025873899459839
0.302480012178421
0.30152204632759094
0.30257874727249146
0.30679866671562195
0.3004480302333832
0.30175843834877014
0.30311059951782227
0.3010070323944092
0.3013412058353424
0.30134034156799316
0.3045487701892853
0.3022438585758209
0.30499473214149475
0.302197128534317


0.3000757396221161
0.30268222093582153
0.30238333344459534
0.30200693011283875
0.30471134185791016
0.3034851551055908
0.3048412501811981
0.3041389584541321
0.301057368516922
0.30200448632240295
0.3017372190952301
0.3037859797477722
0.3058081567287445
0.29849010705947876
0.29848217964172363
0.30547836422920227
0.2997942864894867
0.30156829953193665
0.29827603697776794
0.3030487298965454
0.29976925253868103
0.30076125264167786
0.2980159521102905
0.300325870513916
0.29927995800971985
0.29840901494026184
0.307116836309433
0.30281803011894226
0.2981010377407074
0.3042873442173004
0.3027382493019104
0.30430716276168823
0.3024144470691681
0.29951873421669006
0.3050810992717743
0.29914894700050354
0.29770544171333313
0.30344465374946594
0.30142107605934143
0.3038153052330017
0.3037444055080414
0.30334311723709106
0.30302509665489197
0.29794082045555115
0.30530664324760437
0.30647680163383484
0.3055615723133087
0.30578359961509705
0.30232328176498413
0.3024652600288391
0.30227962136268616
0.300

0.3031105101108551
0.30395829677581787
0.3016946017742157
0.3021385967731476
0.30439266562461853
0.30381134152412415
0.3050518035888672
0.3086027503013611
0.3060223460197449
0.3050578534603119
0.3026573956012726
0.30356889963150024
0.3020591139793396
0.3005443513393402
0.3013477027416229
0.3042398691177368
0.307488352060318
0.30307817459106445
0.30142512917518616
0.3015862703323364
0.30140483379364014
0.3043985366821289
0.29974135756492615
0.3018195629119873
0.29660138487815857
0.3046457767486572
0.30022287368774414
0.3025965392589569
0.3014422059059143
0.2997106909751892
0.30103829503059387
0.30703312158584595
0.3001101315021515
0.29783546924591064
0.3000520169734955
0.297739714384079
0.30088719725608826
0.3023480176925659
0.2987845242023468
0.2999311685562134
0.2989113926887512
0.2978399991989136
0.3044806718826294
0.30180004239082336
0.3024096190929413
0.3046112358570099
0.30181536078453064
0.30426573753356934
0.30297744274139404
0.3042435944080353
0.30198630690574646
0.299480319023

0.30351313948631287
0.3007364273071289
0.30435115098953247
0.30313605070114136
0.3059645891189575
0.30406808853149414
0.29634663462638855
0.3006424307823181
0.3022979497909546
0.30255940556526184
0.3011837303638458
0.30110102891921997
0.30089834332466125
0.3031063377857208
0.30330702662467957
0.3043713867664337
0.3021019399166107
0.30123046040534973
0.3031138479709625
0.30364927649497986
0.3066861927509308
0.30092406272888184
0.30311158299446106
0.30352354049682617
0.3002994656562805
0.30099013447761536
0.30106985569000244
0.3025999367237091
0.2993275821208954
0.30192992091178894
0.30423131585121155
0.30070945620536804
0.302255243062973
0.30117619037628174
0.3001951575279236
0.3016818165779114
0.30142447352409363
0.29849526286125183
0.3016097843647003
0.2997908294200897
0.3014223277568817
0.30371180176734924
0.2982262670993805
0.3028915822505951
0.3007502853870392
0.3060818612575531
0.3045160472393036
0.30370229482650757
0.3001985251903534
0.3041841685771942
0.30389711260795593
0.30439

0.30093804001808167
0.30197060108184814
0.3035684823989868
0.3068719208240509
0.3028040826320648
0.3004813492298126
0.30077388882637024
0.3017330765724182
0.30275821685791016
0.305308073759079
0.3025965392589569
0.3024728298187256
0.30103182792663574
0.3004927337169647
0.3026578724384308
0.30278435349464417
0.3030858039855957
0.3026964068412781
0.30107423663139343
0.301187127828598
0.30101171135902405
0.29841285943984985
0.3003176748752594
0.3026357591152191
0.30348601937294006
0.30138149857521057
0.3059660494327545
0.29805848002433777
0.30199727416038513
0.3047737777233124
0.30332401394844055
0.3070487082004547
0.3039435148239136
0.30668336153030396
0.3034970164299011
0.3021187484264374
0.3028944432735443
0.3053946793079376
0.3034191429615021
0.30451515316963196
0.3012203872203827
0.300560861825943
0.2992301881313324
0.30238375067710876
0.30429548025131226
0.30315253138542175
0.299857497215271
0.3023987114429474
0.30142417550086975
0.3015690743923187
0.30755096673965454
0.304580599069

Sequential(
  (0): Linear(in_features=768, out_features=200, bias=False)
  (1): Linear(in_features=200, out_features=3072, bias=False)
  (2): Linear(in_features=3072, out_features=200, bias=False)
  (3): Linear(in_features=200, out_features=768, bias=False)
)

In [54]:
import pandas as pd 
# cola_df = pd.read_csv('glue_data/CoLA/train.tsv', sep='\t', header=1)
cola_df = pd.read_table('glue_data/CoLA/train.tsv', header=None)
text_sentences = list(cola_df.iloc[0:10][3])
text_sentences

["Our friends won't buy this analysis, let alone the next one we propose.",
 "One more pseudo generalization and I'm giving up.",
 "One more pseudo generalization or I'm giving up.",
 'The more we study verbs, the crazier they get.',
 'Day by day the facts are getting murkier.',
 "I'll fix you a drink.",
 'Fred watered the plants flat.',
 'Bill coughed his way out of the restaurant.',
 "We're dancing the night away.",
 'Herman hammered the metal flat.']

In [55]:
out = tokenizer(text_sentences, return_tensors='pt', padding=True)
tokenized_text = out['input_ids']
token_type = out['token_type_ids']
attention_mask = out['attention_mask']

In [56]:
# attention = Tuple of torch.FloatTensor (one for each layer) 
# of shape (batch_size, num_heads, sequence_length, sequence_length).

model_output = model(input_ids=tokenized_text, 
                     attention_mask=attention_mask, 
                     token_type_ids=token_type)

In [58]:
len(model_output.attentions)

12

In [36]:
tokenized_text

[[101,
  2256,
  2814,
  2180,
  1005,
  1056,
  4965,
  2023,
  4106,
  1010,
  2292,
  2894,
  1996,
  2279,
  2028,
  2057,
  16599,
  1012,
  102],
 [101,
  2028,
  2062,
  18404,
  2236,
  3989,
  1998,
  1045,
  1005,
  1049,
  3228,
  2039,
  1012,
  102],
 [101,
  2028,
  2062,
  18404,
  2236,
  3989,
  2030,
  1045,
  1005,
  1049,
  3228,
  2039,
  1012,
  102],
 [101,
  1996,
  2062,
  2057,
  2817,
  16025,
  1010,
  1996,
  13675,
  16103,
  2121,
  2027,
  2131,
  1012,
  102],
 [101, 2154, 2011, 2154, 1996, 8866, 2024, 2893, 14163, 8024, 3771, 1012, 102],
 [101, 1045, 1005, 2222, 8081, 2017, 1037, 4392, 1012, 102],
 [101, 5965, 27129, 1996, 4264, 4257, 1012, 102],
 [101, 3021, 19055, 2010, 2126, 2041, 1997, 1996, 4825, 1012, 102],
 [101, 2057, 1005, 2128, 5613, 1996, 2305, 2185, 1012, 102],
 [101, 11458, 25756, 1996, 3384, 4257, 1012, 102]]