In [None]:
from pathlib import Path
import matplotlib.pyplot as plt

from au2v.analyst import Analyst
from au2v.config import ModelConfig, TrainerConfig
from au2v.dataset_manager import load_dataset_manager
from au2v.trainer import PyTorchTrainer
from au2v.model import load_model

In [None]:
data = []
for negative_sample_size in [5, 10, 15, 20]:
    model_config = ModelConfig(
        negative_sample_size=negative_sample_size,
        lr=0.0003,
    )
    trainer_config = TrainerConfig(
        dataset_name="movielens",
        model_dir="../cache/model",
        dataset_dir="../cache/dataset",
        load_dataset=False,
        save_dataset=False,
        epochs=5,
    )
    dataset_manager = load_dataset_manager(
        dataset_name=trainer_config.dataset_name,
        dataset_dir=trainer_config.dataset_dir,
        data_dir="../data",
        load_dataset=trainer_config.load_dataset,
        save_dataset=trainer_config.save_dataset,
        window_size=model_config.window_size,
    )
    for trial in range(5):
        model = load_model(
            dataset_manager=dataset_manager,
            trainer_config=trainer_config,
            model_config=model_config
        )
        trainer = PyTorchTrainer(
            model=model,
            dataset_manager=dataset_manager,
            trainer_config=trainer_config,
            model_config=model_config,
        )
        loss_dict = trainer.fit(show_fig=False)
        data.append((negative_sample_size, loss_dict))

In [None]:
"""
[(5,
  {'train': [0.20399765526313834,
    0.20369163532984538,
    0.2036696707931467,
    0.20417265983271496,
    0.20573039937957002],
   'train-size=10': [0.21463676405624604,
    0.21319739415612018,
    0.2168227992427181,
    0.215764646379041,
    0.21910748368417712],
   'train-size=20': [0.21643668979826108,
    0.21435097846346843,
    0.21892635797111082,
    0.21740616899980625,
    0.22085922263877492],
   'train-size=30': [0.21383317108725158,
    0.21158876238574445,
    0.21619650658587336,
    0.21498099599086062,
    0.21846593790490862],
   'train-size=40': [0.21459628746543133,
    0.21252311556272105,
    0.21745070836073915,
    0.21569497337643528,
    0.21918972067429987],
   'train-size=50': [0.21964523070295092,
    0.2176059213742404,
    0.2218941053454305,
    0.2209181974471455,
    0.2245150277732124]}),
 (5,
  {'train': [0.20593310709592033,
    0.20628856512147759,
    0.2064108847737359,
    0.20581207827930884,
    0.20612334223375112],
   'train-size=10': [0.2195399626040123,
    0.2169248020984757,
    0.21852724673882337,
    0.21618799319569493,
    0.2184649412061127],
   'train-size=20': [0.22158680457464405,
    0.21888751253275804,
    0.22003586422389662,
    0.21733968043831034,
    0.22004724976042628],
   'train-size=30': [0.21943796121738326,
    0.21684428650728413,
    0.21826428636698655,
    0.21485092018691587,
    0.21804420901855953],
   'train-size=40': [0.2207082755548853,
    0.21794939062125246,
    0.21952717971633856,
    0.21628304056718317,
    0.21877033525789288],
   'train-size=50': [0.2255046096905856,
    0.222358472960096,
    0.2240408417624487,
    0.2210227628828774,
    0.22460764268754232]}),
 (5,
  {'train': [0.20516597993525384,
    0.2052864147060814,
    0.20443622648584017,
    0.20372786789041958,
    0.20381908462000695],
   'train-size=10': [0.21725419520492284,
    0.21566827742146774,
    0.21623214134867763,
    0.2137312949963019,
    0.21665216793476696],
   'train-size=20': [0.21982425031527666,
    0.21743583112535342,
    0.2182177762750169,
    0.21533667999254147,
    0.2183540897889876],
   'train-size=30': [0.21694400948537906,
    0.21512231839374757,
    0.21599716640694042,
    0.21247563626564725,
    0.21656062737317153],
   'train-size=40': [0.21832581263192943,
    0.21581709342943112,
    0.21660491215511107,
    0.21325946945539662,
    0.21727144214468944],
   'train-size=50': [0.2239302871092944,
    0.22048348377288227,
    0.22161600350494115,
    0.2187094468046242,
    0.22165655439168636]}),
 (5,
  {'train': [0.2033362867605843,
    0.2031300005281417,
    0.20332128838769997,
    0.20321414270101698,
    0.2029515619559871],
   'train-size=10': [0.2176606380183932,
    0.214201098596546,
    0.216157434062219,
    0.21423372836180135,
    0.21606979529622575],
   'train-size=20': [0.22038601632689087,
    0.21569308239809223,
    0.21838733490923762,
    0.2155151127929419,
    0.21806632129239364],
   'train-size=30': [0.217477915362573,
    0.21430617738777483,
    0.21568401997358028,
    0.21238086634958295,
    0.21521052424336823],
   'train-size=40': [0.21831193230521512,
    0.21429053662528455,
    0.21685506991097625,
    0.21325984219430197,
    0.2164274384857903],
   'train-size=50': [0.22346923678693636,
    0.21949304178567,
    0.22159048579108548,
    0.21819341896285474,
    0.2211660858191235]}),
 (5,
  {'train': [0.202029954331091,
    0.20179255763128223,
    0.2015308847956273,
    0.20087488851194535,
    0.20098331687249837],
   'train-size=10': [0.21672618263204332,
    0.21277691801668894,
    0.21449044556684896,
    0.2107858651540649,
    0.21761070296797955],
   'train-size=20': [0.21823558861940678,
    0.21376460748659054,
    0.21587297991967538,
    0.21255223243169383,
    0.21874638678322375],
   'train-size=30': [0.21560814032252407,
    0.21185488914939718,
    0.21343768503464444,
    0.20960232762383743,
    0.21706769898743697],
   'train-size=40': [0.21669981303349348,
    0.21309547042342977,
    0.21421032598320866,
    0.21054155851753664,
    0.21795239352004628],
   'train-size=50': [0.22161943014238922,
    0.21698396999231526,
    0.2188509381572965,
    0.21491961957703173,
    0.22235652216723267]}),
 (10,
  {'train': [0.12819304013158336,
    0.12981402695466182,
    0.1307697867204985,
    0.1311016405237297,
    0.13141186342957056],
   'train-size=10': [0.14050132246084615,
    0.13867605362139956,
    0.1403269566280741,
    0.139084791108756,
    0.14126021834746214],
   'train-size=20': [0.1429781867584712,
    0.1410255513980355,
    0.14294887008801313,
    0.14156002740205173,
    0.14357212752523557],
   'train-size=30': [0.1412813630951962,
    0.13930193156423704,
    0.14107690562664624,
    0.13959775749646441,
    0.14178734199262003],
   'train-size=40': [0.14223202087090048,
    0.14007355440670335,
    0.14207218020734652,
    0.14020285209719563,
    0.14262009842295043],
   'train-size=50': [0.14534379720268115,
    0.14314474721609707,
    0.14525670148956943,
    0.14349842092520754,
    0.145910473881473]}),
 (10,
  {'train': [0.13134520863967414,
    0.13189134706815467,
    0.13214678391823934,
    0.13267827872296054,
    0.1336652948901296],
   'train-size=10': [0.14196402044363424,
    0.1404142021922998,
    0.1428037740185227,
    0.14213661020490484,
    0.14407986287080066],
   'train-size=20': [0.14450296350348163,
    0.1426746880924198,
    0.14491069274888912,
    0.14429609414557337,
    0.14665067919962843],
   'train-size=30': [0.1426595384386224,
    0.14080443411645754,
    0.14370505039540815,
    0.14300003637310485,
    0.14474471689949572],
   'train-size=40': [0.14358837650695316,
    0.14184525546053767,
    0.14439012446034122,
    0.14400369338166547,
    0.14609248296055996],
   'train-size=50': [0.1466788068203859,
    0.1450774673844727,
    0.14779055244486097,
    0.14687545356196416,
    0.14930262490057608]}),
 (10,
  {'train': [0.13393036435708852,
    0.13509505218532125,
    0.13569419501733307,
    0.13620887029887427,
    0.13607639179695147],
   'train-size=10': [0.14479772326811938,
    0.14390442027172573,
    0.14558995567576985,
    0.14412907741858927,
    0.14477131715123082],
   'train-size=20': [0.14724682985057294,
    0.14629254080879855,
    0.147953750382007,
    0.14689845854127911,
    0.14708420697232366],
   'train-size=30': [0.14599474456528544,
    0.14477956074644144,
    0.14657418419357757,
    0.14553233687306794,
    0.14620912242943132],
   'train-size=40': [0.14709516488750216,
    0.1457910924940042,
    0.14807803666507693,
    0.14645512066256833,
    0.14684941672103505],
   'train-size=50': [0.15013343809356153,
    0.1489832644731226,
    0.15099487443205337,
    0.14982264310541288,
    0.15039061681485513]}),
 (10,
  {'train': [0.13548425090784114,
    0.13553823509910545,
    0.13527813432528368,
    0.1348343486644775,
    0.13442387152329788],
   'train-size=10': [0.14415647624663905,
    0.1428938469206783,
    0.1435454548664496,
    0.14190853712424426,
    0.14248590381212636],
   'train-size=20': [0.14661559577978833,
    0.14523825532114001,
    0.14623105001281683,
    0.14428693433882486,
    0.14468104539202972],
   'train-size=30': [0.14550518181542277,
    0.14404400763377337,
    0.14482820548222097,
    0.14268022261455027,
    0.14325049802870818],
   'train-size=40': [0.14650945900611476,
    0.14510096179347642,
    0.1457446280499579,
    0.1435160868814294,
    0.14441596874049012],
   'train-size=50': [0.1495294409318709,
    0.14813849875624752,
    0.14923974577809723,
    0.14702636601639466,
    0.14744865495554157]}),
 (10,
  {'train': [0.13367028256337582,
    0.1345956273603011,
    0.13436364253073532,
    0.1332984246883111,
    0.13278111479744686],
   'train-size=10': [0.14341828730744374,
    0.14234568503960757,
    0.14281670350424,
    0.14046119940532764,
    0.1417679349091691],
   'train-size=20': [0.14575063437223434,
    0.14504997075443538,
    0.1455160688976167,
    0.1426904363111711,
    0.14397000164633067],
   'train-size=30': [0.14484309289656894,
    0.14355385513372823,
    0.14411585578616237,
    0.14138166165687668,
    0.14265521899075576],
   'train-size=40': [0.14539304333673397,
    0.14463492556357047,
    0.1449549244323247,
    0.1418124882268234,
    0.1436246627233398],
   'train-size=50': [0.14852106382309552,
    0.14775141895237104,
    0.14813574917719396,
    0.14503671491230039,
    0.14675469539115127]}),
 (15,
  {'train': [0.09993390365416284,
    0.10076230418784557,
    0.10149920234818806,
    0.10186560303535541,
    0.10211093043140763],
   'train-size=10': [0.10791559685283983,
    0.10802201653870058,
    0.10958199265977027,
    0.1088254216691138,
    0.10996022264302617],
   'train-size=20': [0.11018542829953448,
    0.11018170017591665,
    0.11163447828779757,
    0.11091834894368346,
    0.11215748868777718],
   'train-size=30': [0.10893973907534506,
    0.10902072247904791,
    0.1108490068098189,
    0.10978818506422178,
    0.11092902426148804],
   'train-size=40': [0.10970784524376963,
    0.10977055908928454,
    0.11157881248165184,
    0.11056698113679886,
    0.11184194841435258],
   'train-size=50': [0.11214040639534803,
    0.11213768483467505,
    0.11373909699245238,
    0.11274784467589687,
    0.11413273005418374]}),
 (15,
  {'train': [0.10193593536643936,
    0.10226375931469564,
    0.10237692516267509,
    0.10220854983038387,
    0.10197244756780482],
   'train-size=10': [0.11010545691553976,
    0.10933489751228144,
    0.10988947297905532,
    0.10884637196718806,
    0.10924581921016666],
   'train-size=20': [0.11193793581824907,
    0.11131902185963913,
    0.11213289600023081,
    0.11083182188826547,
    0.11154260769696303],
   'train-size=30': [0.11080398704384414,
    0.11021782289928114,
    0.11090258036700773,
    0.10969736469043813,
    0.11040849899741965],
   'train-size=40': [0.11171734238594351,
    0.1111697989030623,
    0.11186818872958841,
    0.11064525830074096,
    0.11120446509038898],
   'train-size=50': [0.11402694891456147,
    0.11318037921274213,
    0.1141570582985878,
    0.11285083877368712,
    0.11337427333207198]}),
 (15,
  {'train': [0.1012504218889177,
    0.10097350619091995,
    0.10088974057071795,
    0.10063512196980474,
    0.10040859523726435],
   'train-size=10': [0.1084661452283322,
    0.10744276118110603,
    0.10812457577443459,
    0.10658117939888591,
    0.10724918932562143],
   'train-size=20': [0.11074284401158212,
    0.1096958790866422,
    0.11006183250689171,
    0.10870366826863356,
    0.10943438417055237],
   'train-size=30': [0.10947347999038831,
    0.10848620712337359,
    0.1090102323763807,
    0.107349360702743,
    0.10813242176049193],
   'train-size=40': [0.11013021173191742,
    0.10904441270190225,
    0.10960954973395441,
    0.10805122648746195,
    0.10869366728084189],
   'train-size=50': [0.1126745031543181,
    0.11129458531947203,
    0.11212683091281166,
    0.11058190611886307,
    0.11122241217485616]}),
 (15,
  {'train': [0.09999238680269526,
    0.09997189026323515,
    0.1000071811078737,
    0.10001850968413807,
    0.09999106158027582],
   'train-size=10': [0.10730317623262674,
    0.106543114907305,
    0.10736191608536412,
    0.1063554068686257,
    0.10738746423116872],
   'train-size=20': [0.10938336700201035,
    0.10852860501954253,
    0.1096072148689082,
    0.10877082612313016,
    0.10974965422925814],
   'train-size=30': [0.1081636885941868,
    0.10703412107598613,
    0.10848630261673055,
    0.10711907542927164,
    0.10831500820710625],
   'train-size=40': [0.10880869301691862,
    0.10774175193108304,
    0.10924404345347848,
    0.10813210421884564,
    0.10895767262284185],
   'train-size=50': [0.11117241563091815,
    0.11027118721058671,
    0.11156997179061594,
    0.11018815458240644,
    0.11149582382239087]}),
 (15,
  {'train': [0.09980147758670765,
    0.099977056457399,
    0.10043924465350519,
    0.10091136007926663,
    0.10136595518774688],
   'train-size=10': [0.10771776158624971,
    0.10640448384301764,
    0.10840359917828735,
    0.10832236182521766,
    0.10936597314938692],
   'train-size=20': [0.10979867065456551,
    0.10859028228991469,
    0.11062122230798425,
    0.11044051193855178,
    0.11166525737080775],
   'train-size=30': [0.10854786991233557,
    0.10737486299074872,
    0.10956252870005621,
    0.10924551709437035,
    0.11070344145868866],
   'train-size=40': [0.10927042511987015,
    0.10789270520630018,
    0.11032057144272496,
    0.109933374108563,
    0.11145533684273841],
   'train-size=50': [0.11151049863284743,
    0.11048796678512869,
    0.11255550667853423,
    0.11230407078081453,
    0.11380023362351135]}),
 (20,
  {'train': [0.08285519728459624,
    0.08315162625013889,
    0.08324779200581449,
    0.08325266544232302,
    0.08322626763649371],
   'train-size=10': [0.08976589972284478,
    0.08924309132804334,
    0.08996168996246767,
    0.08911411558658304,
    0.09021053662602331],
   'train-size=20': [0.09182708259199707,
    0.09121138497557439,
    0.09184576675925456,
    0.09104555131683886,
    0.09214294901196386],
   'train-size=30': [0.09101416863186258,
    0.09028857904420771,
    0.09103119142458473,
    0.09011169134731024,
    0.09147731749944284],
   'train-size=40': [0.09164133835846269,
    0.09115990513647107,
    0.09185875249160848,
    0.09080263521050064,
    0.09215960693611226],
   'train-size=50': [0.09354736057805343,
    0.09292400880178935,
    0.09364295068760993,
    0.09272285781695809,
    0.09400429734041993]}),
 (20,
  {'train': [0.08322138478221458,
    0.08331044892904069,
    0.08340701830456586,
    0.08339046809576361,
    0.08348059086588727],
   'train-size=10': [0.08980532536204432,
    0.08927978911030461,
    0.08988966446527293,
    0.08927745214650329,
    0.08993580370721682],
   'train-size=20': [0.09168764634031645,
    0.0912389916853166,
    0.09174810234509723,
    0.09126786184562764,
    0.09185810680960266],
   'train-size=30': [0.09079397826547354,
    0.09026409801043256,
    0.09106140300421647,
    0.09041797191324369,
    0.09109824360676215],
   'train-size=40': [0.09164532945609428,
    0.09116263309834709,
    0.09171159439523455,
    0.09112117932715887,
    0.09198820926773717],
   'train-size=50': [0.09351181343827449,
    0.09301113085427755,
    0.0935293221977395,
    0.09298157398129853,
    0.09354947119111746]}),
 (20,
  {'train': [0.08332995783335329,
    0.08361821860872545,
    0.08377230440757985,
    0.08380016516920784,
    0.08373726204450994],
   'train-size=10': [0.08987750364861018,
    0.08950252148886802,
    0.08993647362984403,
    0.08940409039947349,
    0.09004639855153124],
   'train-size=20': [0.09189210674712356,
    0.09147700908738123,
    0.09195837997634645,
    0.0914430685446296,
    0.0920434445142746],
   'train-size=30': [0.0909431711049147,
    0.09047750037320902,
    0.09120796222082327,
    0.09046987171324206,
    0.0912176153399575],
   'train-size=40': [0.09185792169940303,
    0.09138673753805564,
    0.09196926905235774,
    0.09133370818806366,
    0.09195070902646428],
   'train-size=50': [0.09367744645602266,
    0.09309369997239449,
    0.09381777950575654,
    0.09323870590035345,
    0.09385767117352553]}),
 (20,
  {'train': [0.08365824070739349,
    0.08375080313903696,
    0.08377711703729887,
    0.08379149651231284,
    0.08373287667682013],
   'train-size=10': [0.09003291499446815,
    0.08932193211266692,
    0.09011352083212892,
    0.08954520869842718,
    0.08987088858241767],
   'train-size=20': [0.09200076055778585,
    0.0914033110922491,
    0.09220956591233401,
    0.09147813622380646,
    0.09188406862003702],
   'train-size=30': [0.09121579624397654,
    0.09052159805113161,
    0.09139460104871804,
    0.09065007596788272,
    0.09103406262649617],
   'train-size=40': [0.09198017030114859,
    0.09123334446003739,
    0.09222027848304158,
    0.09137823159845782,
    0.09192484061063176],
   'train-size=50': [0.09384377269257962,
    0.09332523904216121,
    0.09404225427080208,
    0.09335624353146889,
    0.09363957011783627]}),
 (20,
  {'train': [0.08335896611295185,
    0.08342781096021556,
    0.08344371209139624,
    0.08335729721879755,
    0.08330970160543787],
   'train-size=10': [0.08935887335051954,
    0.08910575456602472,
    0.08970651513254138,
    0.08897287524502043,
    0.08958013324250638],
   'train-size=20': [0.09140640969427538,
    0.09108369734505532,
    0.0915864970482571,
    0.09094643697772227,
    0.09154096602554053],
   'train-size=30': [0.09044483251554865,
    0.09018225705539676,
    0.09073678697918502,
    0.08984587331053237,
    0.09053323491358421],
   'train-size=40': [0.09126988378628878,
    0.09087393968038156,
    0.09151299350278479,
    0.09071335805134034,
    0.09143554650142159],
   'train-size=50': [0.09319573934649078,
    0.09286080332289279,
    0.09341796550532462,
    0.09265519405754519,
    0.09320938723607802]})]
"""

In [None]:
x, y = [], []
loss_data = {}

for sample_size, loss_dict in data:
    min_loss = min(loss_dict["train-size=50"])
    x.append(sample_size)
    y.append(min_loss)
    if sample_size not in loss_data:
        loss_data[sample_size] = []
    loss_data[sample_size].append(min_loss)

In [None]:
x_means, y_means = [], []

for sample_size, losses in loss_data.items():
    x_means.append(sample_size)
    y_means.append(sum(losses) / len(losses))

In [None]:
plt.rcParams['font.family'] = "Osaka"
plt.scatter(x, y, marker="x")
plt.plot(x_means, y_means, c="red", marker="o")
plt.xticks([5, 10, 15, 20])
plt.xlabel("サンプルサイズ")
plt.ylabel("予測誤差")

In [None]:
t = sum(y)
n = len(y)
ct = t * t / n
st = sum(map(lambda x: x ** 2, y)) - ct

In [None]:
sa = 0
for losses in loss_data.values():
    sa += sum(losses) ** 2 / len(losses)
sa -= ct

In [None]:
se = st - sa

In [None]:
st, se, sa

In [None]:
vis = []
for losses in loss_data.values():
    mean = sum(losses) / len(losses)
    vi = 0
    for loss in losses:
        vi += (loss - mean) ** 2
    vi /= len(losses)
    vis.append(vi)

In [None]:
ve = 0
for i in range(4):
    ve += 4 * vis[i] / 16
ve

In [None]:
from math import sqrt
for i in range(4):
    for j in range(4):
        if i >= j:
            continue
        tij = abs((y_means[i] - y_means[j]) / sqrt(ve * (1 / 5 + 1 / 5)))
        print(i, j, tij)