From 1a1c63a08b713e8b491e2dd23f0fa92646f5af38 Mon Sep 17 00:00:00 2001 From: NeuralLink <38289341+kartik4949@users.noreply.github.com> Date: Mon, 14 Dec 2020 09:53:12 +0530 Subject: [PATCH] Gan is real...Look what tiny just generated! (#192) * mode collapse solved * info add * delete unnecessary imports * readme --- README.md | 6 ++++++ docs/mnist_by_tinygrad.jpg | Bin 0 -> 15126 bytes examples/mnist_gan.py | 22 +++++++++++----------- 3 files changed, 17 insertions(+), 11 deletions(-) create mode 100644 docs/mnist_by_tinygrad.jpg diff --git a/README.md b/README.md index 98c3a92107a2..3013441ab9f9 100644 --- a/README.md +++ b/README.md @@ -121,6 +121,12 @@ PROTIP: Set "GPU=1" environment variable if you want this to go faster. PROPROTIP: Set "DEBUG=1" environment variable if you want to see why it's slow. +### Gan is real... Generated with pure tinygrad! + +

+ +

+ ### The promise of small tinygrad will always be below 1000 lines. If it isn't, we will revert commits until tinygrad becomes smaller. diff --git a/docs/mnist_by_tinygrad.jpg b/docs/mnist_by_tinygrad.jpg new file mode 100644 index 0000000000000000000000000000000000000000..90cfde24cdada00fb04fa76f7205560156e71b8f GIT binary patch literal 15126 zcmbvQby!n>_y>%iIbbxB5~C56Qc9YEK^P#QG}0v{-3(Ad#ZiJFF$_vzbR!`#x^sXu zBcyXM#FYo)zsAjaQ(5mmO20e0RV{j0@qnGtp2}QBkuoFw!w`vU78BvO}Rf0^&kEe4_kN=q=gXqIV^v zrKP!rPpgRhZ*Nl8gahva#P36uNa=SX$%VwuU}$7)VqH7rW?*c3mSO0h2)fvkP?1k9dOVNl1Am$QaagAyyuYe3IehO!pJM zm9B4|THrw0e- zv3td5;l}YkxUvqws`>5F!+|Yo`k4Xvei4%1-%qP-NZifU)%R(-Z%rz;5bziJ+mJq$ z=9VLLhQo`b?*~BIWc`JB)M5Mks&rStJD8e$N4OGv(BTSL`t@+n;61OR^nogyP=6BM z*^(J5Wyv^-dggS?wpwi{i=rR!e73Mn4h?y`e~7#SJ`|aowqahPDsoKQUbD#KcO`<+ ziq-l2O>p6z)tJ~s+=AaJ%EE%5z`)7$7JaS(+XJaG{oMPD_(fhZ$hM~5jlJAG{7?aj z__x!!Pem!hosO5k}!R;Gww&zKKIiqX^F*I1WOW09II zd_Fin{f~kz15R&BK{KN(XFhyuqFn6g*}Z2rdJ&rUI&LJdqE2vr;yR&l_TA&OYmeLk z!6!fEeC0X5sEl__b2jmh$L$f8W0!^8q02l3CO~5`^X3Xhv>T^&1sLIvSG=}5%!mub zGA>4P=v4mjciP-Z(lI2jd2u7FO|$x==Frb^0rz|NJLM26?%H#8__)^*^X()39_4;w z)P-fy)^zaKz!H3seMC~T+Ca~r&tSoUtClg8&T*z+2$FFB;ilODW|Ghif*3i`T%7)_RWW8#)8<2GQs%z&=NX zYWTWc<>Bu8|26~lVSHNrp2iWjN(IL)JQJ#a8l&t8e!G7ssxlL%T3=UPDbbgHuPkBE z0Q^mE{q-rHFOUVc+!bf>w+mjZvkZm*5eb-y*}nopioF#c|2g$fQuG*)=0+Z!WYby@ z=sB4Zzqjvn!Y5Q9%U@xOj6=Q57uQLAxV(^C=A~2~Z-TqVuK8OHau&j#mpN;hn_mIm z3mI8gfC;dO!Yk%-Lwa#TzO}vN z3h2u~3smG&-?QBjVgmIe$=B1C+G1@(akwi0T`;(?%`B%7 zbVrW#XQ&Q*3Dx;1aae-1T6GawAd@LGW5y#&s{tuv0xeeRJ^}*c-f7QM8JDqv;cQoc zYoWh{eujrRrti75*a0an45d$Cq=B$rz;({8@mIid<;p(l z!&|C+O9$8{(TCn=(`MQ@@+J*fvL@7k<>XNs7DB|}C0~)ayaEVe!W#RCm{qTWdyxOV zyU5kF-Q?LoS&rEwkEUBnyVC!xKJ@^bx4?x}<;KZ^0nI)ft7C9-Fl^;zen}ODn^je1 zaw2o;uXn7h$FS(vtKas0GhHYm&iIqY+E;wLJB5#~fa}FK!*&2}#p=@?IxZ{a(iQ1`aO#XD)iYe#c=SnQ7DlrbXMtjQY z|GUAoOXHueHCjpK?3~^#?oG=UTrJm|G)DJ8m9t!f|GKW6GI;@NO7u4Ec1M7Z#w2wvB|>5Q7{G$ffk1EC5=4FhugyjiX&1PKBxR^4vVOKlC0UWVKkb3 z@ISsoSsw*oyftLB!gvxJ_n^F<)rwJ*Q7(9yvoNM%pEgq*0PpZr7Hj#_{W`X9p%TVx_kUT zfAy1?mm5^)ku@GOr+2P^_93wMkj@90#raF94eX!ceOEUD`heH%ax|@I6=bo8FBkua z<>aw*fp{4;LF;{EHLe;($GZG=(t^WDA~rYofh9FgTjwB~_WP~X>b^-S1h)@IgTYr4BL5HJeM1IW9>D#s`*!Z=#=5kpirZ>$N6)}7u`oppiM6N?!DjN z!xMuCU<<&I zJM;Jzgv8gjeJlON!v5k67s%V`4nKnts)nEx2~_*@ki)NwN=7<(e!qolOrKj|yTsh} zDoi>+{yEvs0+%K>M+I}+-uofO7(vnLF*u1oKklnG*q_sm|J+*!=2gSUV>bpjOg0Kk zHg0B@zqgfG35b_CMsLv_qw0tI z$J_o&&kXSzGcmvh1?y;D)|^Fa!h%!+xSHSCNP`Miw<#^cBVK^J*{~UxuiH(0=K4$0 zp^Qv+gI}|+du(jz3~yUCjv0tVw?Ca0@W}At?y2Xn&3a}6Q`okuCc!<|80=BB^vfc6AcFTiw%)79^puS3t&??t|OjkK)zY zJkd~4l4<#Lt61Dg$nUNci{A$|e_;l#8^_qbzk*eZK^G5W8^+nx3x zgAImjFKqd$l?x$lYVcCAnrm9LWaPL|`)+*2mMTZQ`iNv?Opj^!v9iSEB>vD4UiNCD zuiMX$#B2XuX!9rIy}8d7Pj`Vwb|kqTRXRt%FiiIHhnL3zQ5C-AGrf3`v#8N^KOfEUWs-P%q-(+#IXyXko}&}| z`Zec|eh9(OQB!n~Hv|i8u~hHDhYA(lbfA=?d^dZ*~m& zGQa01Q1-5k+EY0noch~5OJvv(aFqqa_9!>xT91=VHzrlwI zma>-{ut=1I48Sj6mM(OGgs5rnN+BPw;1%J&+OlUfTz^o-sBf2F<}kt}gx+=?;YEyI ze_iTgJ@ZasZIvw`Fo7mr9i=;C8HxpIJ*KD)IhLA~M0}};NXtMCec@~%IEjDo=0n`# zek3bvV{A~Mfl$n@k8!19PTg?RAfxNt9Zd2Gz83*mxCO2re#jRT+f`7{qB)Ig^XR8{ z)QfpL*A2DucMw{&Aw2UEJTwy$$+qG8@_w7Z6m|Lv(|s9p=3>jSLQFzC2bjO@Q^pF{ z2fzaw)LvVJK(4B;&4)DbM~a%P>8qU*2Xv1=n{OeHKtfzK3Hv*8OBHt<#$nZ!)nBx3 zOX8Z$>^F4M4q^|7ENx0mfb}%>;lpKt>xDC$8lEDiY6dkmV|QXWzPu=V1t$ykpQ?~1!f|Ho}grP3RS zptUfC-iJYR!z`HPm90<9TU*;E^&#m_havTsDg$LvkMvu8`ReGlIi2EphPqn2I+nnM zK>gM6s_74Dl?L#1&y?Q};_{%+$x@#5w8vA}iz*_>=nji?Xt>YfP2RMN{|ldWm%kL& zso4_yvu-56>CRZNR#8b&WjkElQJs7eL$+NW@{LKh&nrG)OO^9?EaRM09CJok=0yxm zY{7QiW1n{2jjt9{WStF(Iuc#-^x;`ZY?bF9a}O>5BffXXSHKgU)yLv89mcSwhZ4Nh z^BYcsIlzM@L9xZHA4%itbna>XM28)z=QDK_Ru#$@D#<6*olo+YXcbh)y%Cp=t7tt~ zC5)6*tZ_zLJf6Dnt!I@zq5BJk=jN^50tZgqPwH?J5dE1CR+i8no9Q1b(K25t#TXsV zD1}Iv?Ke->6d|W9-t4DW<~4mQ?zw1QTC`hYa@G*cawnw1?x(5OqErvBdgH23+pRD={dxZiC1aT znv%KOWsG$(ofVP`7ha2tkx{Kv04T}m&aG{5&xs9_T<1GF_cOfCUN&T!*@DG6D4hQu z-e<7e^u+n33&ZHMN>i)b%pC<6UMiB|)IILk_VNO#&iAe@`y!#sY`%Mi|OGm~erYBsWI(*4nD@65%L zXx);}@rI@kr?VV8h=dOJ-(%3peE)C@JlZG}zL+k$=XnO(yH~tw{pLoZ9Ok!{GW<_z zqkzZa$EVk{4yA~Fjf{7-9gpDBql?B|js54>%moZm_EBwXTtB%UdA$;Y?LYnNDr&gW z>XwIP9`8$-bt(}_OmbhWzFhuWXTh>?_Ud1hsET&;ZV3*53+$dNG5;aqJ~Js|IO8bS zo2F%>L&ur^>*C8))-k)d>};~xP99{hYM?DP$Ll;B+`-PJfydjzTS&ccFyNAu(*3iu6Han zHThQn^y~B!f8I)po1M6q!X$TDKyFE<{NHb5q-Kv$27RUbf;FvkF>*6kz|%o1 z^OyLYDTGutPw1lDiVw=9-Jyeh7(+pJ4WhH4as(=LAUX8|)Dgvr8a+h+WnUNbE*6$g zaIcG#S0l`RIj)y)X0Wdv`Z8x*&s?bsTZKXrg2C+9!(?4`fk>qt8~ zRP@HB6JgM|553zrxJ*=RUKM8nM}=e)7-N-w()rykxmig$?qab$XDZfgCAPbLYW#_w zPStM$Dky*!vj$lcU=?o^sdUGeL%t{cOyO9is+vKAG9gdpohSTp&wf$io#@>(ZMEA8 z2&5d*xaq1yBU@uj+Jhs?0N_L#BwB5#`-|m%SUyLM<=gPT-|mfc;!^n6G(mtZcIf;H zIN#Fm8WZp7U+bdlyO)p=xK!9__sG~!s_Tc<-vDg$qAJw4y2TkIYe{kiqzoHG{rQ(C zfLx;Vm*0pz{dIn3w@-^eJ)y_OmbPB}*kdmF-r;)lKqo-T48jPJS$FApt?F;q z#7Its`p-($&px2?5FOlx!JU6Lqo{YZ(U&^&^I_0%#P}#?%E_duc5(yGIm#=}->Sp`-IPTg^QIId^SaUWx*Kvfp)*h)Ka{&KUCS6r7iCGM zIhr{ip-G*w=M|e^CL3LJZ0s!(Am-ymVj*D^`AbE`CQIj4TJeKi=fo!3rYCox(?Z-O zzX5MQSx>Y05m)*=5Mx;KMSPVxt< z>fw>M$9?{RFAEZetjRAg^TCaZQRJPod(Lz(2Jr$IE?Qb~`70od68ptq(&*mY^UUU> zT`pL%25DzOfpdMJRssFpL#4WB^?0PX4kxW@Hl2FSZ_^h>(^OGzv{E(vCAedm0QVx7 z1i~QKlSDP81w-D;FQ^CzUI47@aUy`lJ2g+&ce^6@;@=|FMU#lm?~=KLYAcY3qcex9>SJMy~i`Q>}M<08%NDw#={hvue?JTcvOrYBF4cc9*^ zOrUiZ_NrgsXDl#ZqAk^U*t*w2%FUNP0X{T$bPN+Qf==YOen5cbHTe_P64=GPio-i6 zQ@5IL3l86f4cz|DY5UEp7+fPS;uy_&h&(+10Ia`KPcgQ2iOq=J)4+9I0fH5c0`R4@ z@)leZ9F29KsAy*&*n?F^X#gLME(l8S^oa(JOF{@X)8B`CR3j_e4}CzFu?D9N(}KLz&q;(|NVCHR8p1aEK0&V-FQ0j zA)<+Q(1Y*}xvWaDR68Cl6;L799s>XS3Z^tMgtJGOQRm6qw4T6{J+t`A#(9TOc6Ad? zqb!UPmuZH6eG*+NjRfYx%OsRH5p90VyOb9?8-nQ2e)D3F=GBui+vd-zMK8agQZ(=3 z%rw^Q+>yx2)HpjTviZK{qEu1BAe9jkE-!yRk||{Fm9)g79DU9J{{$FJpz%6Ku6b@x z&z3I}I*j7Dob+a1#`FRbXgR!yi0-J5^u$rV-9&J=z+q#ZO5cy(Ke2ms@~m9r-i``| z*I_kkLN!yP|E0Rhrik9QFU1@MUl;%2^Y|^%IW}@Mtl|gE;@WDQ{;;?!!1Ql6p>2HJ zQwBRz6L<7z9mZ}r?DL-DPRYtC%vKw20@!~5_Q_D>)XlZsL{CL_)>Bys3K=$t@>~0s zBQzFgy*?}#Rc>oKf6dU{oVgj-BdK*_r(*vi0qtT?Hd0ld#!eh9XL#AaJ->7W|EdmB zq#%rxKU2JIXNHnE{zWI0+NVb9#~ANP>r7g??d;4l+4Pc!@#GW|ds{L;^c7Ftotssyd; zs*h@|;vBr_YY^`oS!Vjf?2IeACLx%-@v-}5x^(!;8Tw13@$h%ZIl9C8;ll<7d z0X3XaH6(anm2r@?LW6YddFsc+n|jkw2X5?tF#dI-l710Q8mw*=K6~k0@jJmPiX5TC z=8|M)z}Bc$9R8{OzIq^ClF|^_c2=hH@vLu7aNUiYUGmQx53_!IisVMPvKSD$HTk`W z7#2!4zl9tr_cnAEV?S9WckY%{rIq|q4eu`6R-iytaX2o-M*oc+p!K~LAb>!c03}}b z()|Wl9b4Y-N^XNZiCWw1Fj&rmMM=Xy6V2jYjUI#xgS?m>hOKr#$zLn3GmVLKZXP&- zPx>O1o7y-7{OK|%cw3c|H9~Gk%hTxgs?7r9pou=zPyz(yU;<~Tyb%tHq1pWTVeiu~ zruQ5GT@J@$MFj?_l({!6nHUkl3p#Dw#8S?%z@+j^SFsQS>vHCPowMBXGF?zgi|6YU+yb`{Ct?E=* za7T+;jtOM1nIOp0CBXyk^Vke}0gV(|HCt6>_WiqbTRi-Fmw9+EW7fFt^k?r|i$^g%IIuS#CdHT=9VDoYHu+?7P=en_nmE z9^{R#mQtldbZ53py17`*gClndOkc`X?Wh>U$(8A};a!Aed&1?o-mS;t3r|Wn|Bin? z`%6hpOUQ($3LR(DasCILNAR~`OKFwDYCP2%c_hVUt_^G9>4sp(QT#1IaQS$#_!`9g-~8$3a2PHWtRTaJJa1R zw!tfjipxa6spbC7*i7>KA0>21IMqF`jxzbo!Gh%XfNj~;Gh30I-M*ofCzxlQ1rETjUX-QLu35;Np5xY841 zAa8cWDNlIm;l^VOcI6%1C*43>1g^%tJ5rnIqciJU-rpLD_kW^kNj^%j*rl-GF36@i ziSO#$t<1WsLqhD%jtjrU$nD=NtB~VT+d$$RI!uWx^yDGdHHFAR^(VzmRCw>wZWyxf z6uREbUC%yftU6?M?MVz6Jj|y3p*{XZNBm~E8&5o`ZD{Iw_|`TiH@*G7Tr3OA(jOtP z-{*#@@FCk!r3r;I#Ye%BWo(=t(ix$Z+>mvvp37*+P|w08q+gZ#QwF+nfuT=*qMK(K z$ng4!{|D05vYr29^m3eY^rzj2l}3AcO97qNbFsx>{mK`THUR2cg__?;ITmi{m4E|3ZnnhOVzeU_k zLQaG)Dqq*CDhfDalqXaImqzdoop10Nw9YB1OUM!j+lfjQdXtN5WMQJR!O^u0N^nnY znA*W)BfL7yEvVLuBu?xgY=d8C<1Bt}sB$XIx=}%x5zTr5gYAKFlaFXt-gJUh z2VVN4Emn_E?7eCuYm+h}4mC7PuMAVbr8CLhyGbSw{$Nx*PWb!>&(U0s2mTwX90Pr- zZH(($Y@E1kWTT|1@{HUioV_!;wce+5SLCTzsV(8S6wW!$cxI~cLH$!vjHO2(g6=Sz zPVT?R9&5ZwOnt?v-Q9l6#_^B+axmrS*X28M%awkF}}H^D$uV@ zOv0tOq-xsO*75l-VU&vQKu7?rYX~B-d{O&z-)EEW+x7BJUEOiF`O|f!Y``mg!?4)>hg6JBxGZ2A-bo7B zJb}KceE<8biv;n#OuDG#z07FD)OGkeS-nW%q35>H^vPatSS=Sa0gE^u++Zcf?+;U+ z;JO-Rjb~9|3exqC+b6=ph&+B|Pf@h#damIo1I;gdoWax&775tGy%`Zwqw=VrKZ`n1MWI_YC*mhvNw*A*r+#UtT@P!#3m z=+?P+LDhBfqLJgzo!E%HnSyNECmnFv=sUBF5zm2;OlKFG-3r+4bxqiv;I*<~vzK^T z-%S8we-}2xy=ZD$YbO4{Z}2fN965lbjQcz&VrlC@HF_c>x#4kGSWf_gUF;eRx~GrP7)Wz<~!D1J>UXlGUpj}{0#AkJvh z3oE!89E;kLPGVQEXQWQI>3dSG!76`XTw|bdD$}lK^`CX*BJxX0fB2v3w)v=xsC-{^=dugx#M-w2PJj{!* zh$_(;wQo$A9C#Te`9vOka6Wn4ruHn(M1CMR?EXG7elq)k?csWXt7EHD+wY$O)QPY* z=yOtepO~W2e;8{k)wZ6%>JXUlThq&x_<_O0K^EiZHc;$2T)5&g!6|oXS+9}7MKOgn zPT$xh=g`7JDgCj=Fqy}}{wQ&O{X2;LG3Rcyl17e^z?CH-0ErP!VZg8I3=i7OIsG1X6Y&>`^yN1oE8e$86Zp%jOn44kC9%7q3S}Gr$L49L+MdGS>*Anf6Ly(fpMOFg#0n8?oY%E`M;92LZc|+vc`Tn@QYF90iPY7}as1@@Sk*)GDOu1#i zkL$;0Lb!I8re<$*nOvI|m3#zvZ7*F?$DI+jCQtR?kv9JDc6M{bdskO4kMb9Db-ogt zD~Puv^;Nu|6XL-(kXn>Y+XvfWG3L9b*v&wi{;a%_uzO;7uUrfp344Ok)SIXqZgBcX z=uegDc{E};J62w5omc{;u%c;z+nQ#vO+ZcTFJ4f3-cxMT>;98Z`A*q`epG0#;e#fC zcJojGW9{h+s3)@*Q*4{wCkf}oM!eDtg znVOT^0;(^5A;RCUAir^AWuOVEWdnl z)CT#dQoqB5HS1nbc26`9c}8f%>u36LjE=#?{)(H3idslrzrCDmWXOhQz>&JDtmj`4 z(*2L(yZxDU{a?Iq8Prx+7Opr;SL;dLsZQwh)f&uuy}=ir8&A`Gi!HRdClyV7Ki(rE zgt!DBDki0{%&RcH?UWQTR>-+O|==iap{f_8{9oPXp{eBUvYuii};J88UpmA1_I+ z6e^e+XdkA;_2E`Fo!Z`fN)xoop;e&Ptop8>j9)$3*P(*3JCwhKcORNTlZMzA@b{?v zeTO118V(2llBUDH9POG{Cx}N;<+Y*^n>+`^i(5CUBUdNO73??Q4+4#<oJg!q+ZY9-! z`{i@mAw}*S`#@FjT zA3r_i_kJ?0x!OF)oOj}!QdMUa)5?lau5CG$>F-!ed1>H_KBa{SAuuQ-NyZ)K-j(?qf{mIMi%pK|pE*X|l{rTr?V_c44qhZfFGDcF_G zO>+e&e(JUpg06+hhrIlFmi|EC;Nc~Fmb1_@a;mq62;p=lFji(ih-3r9g)j z=CT{9LUA-H`faX*+{6*i7>lO{ZUV3r`L2Gda>@ zw2Y$csc_U2R%S&Z**9lIc8{1|AiI&t|bGxsGX4XyjRePmY&{w@2&bHO*lvwK$R@=RK3*55){R6Nq zeY;9}Xu(K2uJZm9Yn{qpJGeMjmE-S%lSCS8#w+vtvBcM^MILn9J+%j17TN#Fv~p`o z1N@!X)OS;v2D#FtL0!_AzF+XA%%>}!HBRV}3ZHJD2j6A|&NZeNZa&dEdFR@E7Y&4R ze+qn>!jjNAgCaW?NWPfjnw`A4bQ#@d*krW8bOj7egz;av)dT!(XGCM3YO8z}BQ(Vp zNz=Cu;zKk>l@$yKFC64IC4YCF6tYYd2mZSr9wB>A;W`P;}^RjR^^;9dRSL0$bVU;=Op z6#Zh7l9|xL!&Td8;xX=%LIR7fxu#<}hOam1Xbkn4^~moQ5-Q=wVv09GuGOuO=WpgJ z*5%HJAaWQ~Wq;h6MuYdoYlD}Wg#7xGLgDp`HMmIM$8BMy<}ct3udGIIRPj{x{9Tk0 zm&XvDhtXFz*y7EJ_Zt@x5`;x*pfG>yfub~&5q;x5fe9y*>1%Hx^+0^Cc~t-0pwcx; zxZRQ~keIei>R7T4G}Z=rKK79mxq0d=V#4HW#$|F9uYgiT4?=46xdGRwd*ACt>05IG z_VF`zQvYqH%`U#M=FLwD;lK>bJpdKd2D=k2rgL%Wb$O^tQ!^&htEJ~yY0Y)l*M5VX z(`q9ZxB{Xe?SEOF2Qp9GBvpyrkVDl{60NlsU0?R(cg3zPr5E&`%pI7oA$DTNr+0wF<=U>RjBT#@!gdlb;l5Y+oJtzdlB_l3fD&b zo`-1v&hBdaUy z<%$(Z#v||DsS?Kd>p|qO%KpVRs4e-HyIS_UMR z_>SmM_@zbkZ$MAsq@o7W_B+H~AAd`QXQB?a7Z(B_)sGvgS5)S;t9IJ8_aE+Nbz529 zKmFxrP~a}YoVGeBw}tTU4A>w9WYaPKPYbNXHxn#j{{|k1)X6NYb1rmpvA)BH=DE1s zirhi+C3hqF3a-CE5-`FfV=h|fP z{};*)AmyCzDff+%aq1r1_(H_KODuu)_Z$ITvrp7oObdG!&$7=(>TWY|1Te_VH))63 z7d^ioW|r83&0{TOA%|C5TffP+UGQcCYC1?qHOYm^$P^7_is;S?j-JgD(;r%w(1QBl zN9Oe94MhVFtsKhBfFoGz+j89585CZ73!FEciW`v-ozb`c^cy~Srf*&SfnH!z&$aa; zuK<>*i?DH*NED*Hj5)&QIC0u*K>Dj9^uzpQ-P;bbNm6cE=mh1XQyLetyXBX;AFoQU zi7dIg&=Um!l61rjOCfXHt&P{=h2G|0$rcw3V1*Db7Hj>6i?fUi1>9C6_c zvQLy8z5^f(e*H;9pKoSvDmd6Nu&mYziNZZL61nzj)eS_Sz) ztXJg66ZftF(DLPK0Y*B(=5n3d$d*JbUYXctbpY=EHUyaXk2jRPuV=!6YV80C|Io!<9dq=7y49wR3zH>qs?=oSC6u(3(Ga$ zdzTjV^*ERl&&(mXvv`Y37@Hmg5@mTLD&w_V#Z& z1ztM&BA-$p1Rf5$NvU#z#wr8Eo)w9RUN6J+-B`8S%DR*=I>o%kJdRB)s@{257@h?Q z_Wci8g(^!9ch`Q@eL#h}L`nWPNEkCoJ`!;y%}jGZ$PA`lrmJMQ%Hu;E!6~CEI`qzR zSjXr%i}(*)-dzUUxxX9aDr&6#m;G=uJx*x&dzoxd8><4;qjb{H(bb{q0y;@W6=^1g zrIyeoR8zlLor?1;9i7=w$6rm|H6lmn?e^#9;NZZ--G;6 z79FYXG;<0C6xDa-mlcUTG@xdOasU&j*4@%;Z*5RyU6TVl&HHO<1fSw~fk*P#r0rcp zeVzWFehjk$uk0URqu5o2E>ulIL; zD<0-eIqtMsAKA5^^?n+x48g3qoA&LC$7l%P2-7%HKxyOajGcCqHf*CG*7TMh zt$=;A^`?bLU;NDNLlrZG0(EE=#_Y$;<6$>56hwRelUGXP8Pd3u?mPk{2U*D`xu-Be zE7NzrHVoNPEuZAmyJ2bv-4oN_*fNppqT^NklHIbkl_j@U<7)9Qi>tOiekWk2HzqLQ zZNfD^zjaK2O4_Q*$Z>y*^lQ@)w?A(Mt9Prad~Jw@uq*Jzp9gPw-mszj zhb(>Y&gGM0U!Lr#Bs*RgouoEz1N^rLp`Srm_BP_5Rum6>TSLdQ!*~s0cXyO=#U=mD zpOWbodyIa3h&?VUwsAO-UevNC?`7~B_m@Ay<76vfKe3yjmN z`oQ7nfEc%Ox!&^w3@Oc~y|(*wjDMnXeW1ABDR{1&oGHLx+aHpFP%bobQs@?9-{``A zlp#bVUJ%qd$xHb~B$VN`3I>bRKMTNaf9txAmck^+l8e3xIA=3>AnfcYJ{$c$3VO<&-&kLgU_E+U^G-4!!=pCcaCk?Te?<7rPdc1FjbuSm| zS$CATzWjZo$r0RST+EC zD&Gz3tQxk^arJ-6XQ=sum2G#+Z1MvlSjLy`f+R35!{eM#2|~%*-gc8;=BsqEHtkJa zO>sdxHd(!D%Suj}7uO!qmYB>h?0r*aBvVgnkos0oVss>ow>EP}+&=UEm*n1zizDX9 z;0OLm4~P`@Gve4H-k^}C$5WFB3-~Z62TLr4FmfNw@8clSph~(sf_P~7>22aQo(B4$0;ki0pA_HignJNG%yp_9^rkWs+40D<)CCfr6$10OC9)S;m^n3x zhtYge!=DUCebh$<0`%wwirHRzhVK40U{GWeNPNJr=w@IuMC2Lev&l6{C>Ylu{HQlX z;f2b0&wBF#bN-$iVc z*q42(I@9D8>(GCdkg0SV=HFfY$OkVP579(NqzSn-UEKC0^nNm{`2u~yY>>W>zF0@! zD@*vzWdTsj$@jm9Iu}Y7uPc2Mv~b}hAt4@Te!tEm_nXt-9ejuK5%u#HGtxvx7btlz zt!fAaoGl}36)vs-;dul`D@$(4ENO>Uc?&s4yUXY}y>)@dPTF}_v%gExws7>+1~$S+ zi5eq4`)7Qk|Mu!RqFDuekBb<7Mhi%U&QZv1!Yw9}`;saXC1%wFlo-5b4u5jg%UcHl zb1IdfxLgMUgP0d5+yl`(D>LSV6AaMy(_j;MJ1EoV6zu=SbSrL_Si3DJE-Q)gh*b3J zY$4P0C$_v)Oq28vd8tcIYq+wGsU!8A_$+GaxY}UY!vd3?yT-2HuSajTtSpcG5?@S; V2mAH|{&U2+GwJ`wllyA={{W!mlRy9f literal 0 HcmV?d00001 diff --git a/examples/mnist_gan.py b/examples/mnist_gan.py index 438ed3cb42fe..67e4b7ba7f6f 100644 --- a/examples/mnist_gan.py +++ b/examples/mnist_gan.py @@ -9,7 +9,7 @@ from tinygrad.tensor import Tensor, Function, register from extra.utils import get_parameters import tinygrad.optim as optim -from test_mnist import X_train, Y_train +from test_mnist import X_train from torchvision.utils import make_grid, save_image import torch GPU = os.getenv("GPU") is not None @@ -52,9 +52,9 @@ def forward(self, x, train=True): if __name__ == "__main__": generator = LinearGen() discriminator = LinearDisc() - batch_size = 128 + batch_size = 512 k = 1 - epochs = 100 + epochs = 300 generator_params = get_parameters(generator) discriminator_params = get_parameters(discriminator) gen_loss = [] @@ -62,13 +62,13 @@ def forward(self, x, train=True): output_folder = "outputs" os.makedirs(output_folder, exist_ok=True) train_data_size = len(X_train) - ds_noise = Tensor(np.random.uniform(size=(64,128)).astype(np.float32), gpu=GPU, requires_grad=False) + ds_noise = Tensor(np.random.randn(64,128).astype(np.float32), gpu=GPU, requires_grad=False) n_steps = int(train_data_size/batch_size) if GPU: [x.cuda_() for x in generator_params+discriminator_params] # optimizers - optim_g = optim.Adam(generator_params, lr=0.001) - optim_d = optim.Adam(discriminator_params, lr=0.001) + optim_g = optim.Adam(generator_params,lr=0.0002, b1=0.5) # 0.0002 for equilibrium! + optim_d = optim.Adam(discriminator_params,lr=0.0002, b1=0.5) def regularization_l2(model, a=1e-4): #TODO: l2 reg loss @@ -88,7 +88,7 @@ def real_label(bs): def fake_label(bs): y = np.zeros((bs,2), np.float32) - y[range(bs), [0]*bs] = -2.0 + y[range(bs), [0]*bs] = -2.0 # Can we do label smoothin? i.e -2.0 changed to -1.98789. fake_labels = Tensor(y, gpu=GPU) return fake_labels @@ -124,18 +124,18 @@ def train_generator(optimizer, data_fake): print(f"Epoch {epoch} of {epochs}") for i in tqdm(range(n_steps)): image = generator_batch() - for step in range(k): - noise = Tensor(np.random.uniform(size=(batch_size,128)), gpu=GPU) + for step in range(k): # Try with k = 5 or 7. + noise = Tensor(np.random.randn(batch_size,128), gpu=GPU) data_fake = generator.forward(noise).detach() data_real = image loss_d_step = train_discriminator(optim_d, data_real, data_fake) loss_d += loss_d_step - noise = Tensor(np.random.uniform(size=(batch_size,128)), gpu=GPU) + noise = Tensor(np.random.randn(batch_size,128), gpu=GPU) data_fake = generator.forward(noise) loss_g_step = train_generator(optim_g, data_fake) loss_g += loss_g_step fake_images = generator.forward(ds_noise).detach().cpu().data - fake_images = (fake_images.reshape(-1,1,28,28)+ 1)/2 + fake_images = (fake_images.reshape(-1, 1, 28, 28)+ 1) / 2 # 0 - 1 range. fake_images = make_grid(torch.tensor(fake_images)) save_image(fake_images, os.path.join(output_folder,f"image_{epoch}.jpg")) epoch_loss_g = loss_g / n_steps