Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2010.11929] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale [paper-reading] #56

Open
yoheikikuta opened this issue Jan 7, 2021 · 9 comments

Comments

@yoheikikuta
Copy link
Owner

yoheikikuta commented Jan 7, 2021

論文リンク

https://arxiv.org/abs/2010.11929

公開日(yyyy/mm/dd)

2020/10/22

概要

画像分類を (convolution を使わず) transformer 型モデルで実施した。
Visual Transformer (ViT) というモデル名。
入力は計算量を抑えるために pixel 毎に全て attention を張るのではなく、patch に分け、各 patch を flatten した後に線形変換をしてある次元のベクトルにしたものを全 patch 文並べて embedding を作っている。当然 learnable positonal encoding も足してる。
事前学習は class token の出力に MLP を生やして教師あり学習として実施している(教師なしの方法も提案しているが簡素すぎるからか性能は低い)。
CNN based な過去手法よりも計算コストがだいぶ低い(1/4 TPU days くらいになる)が、ImageNet くらいのサイズのデータサイズだと劣後し、10倍, 100倍としていくことで同等以上の性能を発揮するようになる。

GitHub repository は https://github.com/google-research/vision_transformer

@yoheikikuta
Copy link
Owner Author

yoheikikuta commented Jan 7, 2021

vision transformer の存在を知ってたけどちゃんと読んでいなかった。
DALL·E とかも出て vision への transformer 型モデルの適用をどうやっているのか真面目に理解しようと思って読むことにした。
最近 scalability が流行っててその辺にも言及がありそうだったのもモチベーションの一つ。

@yoheikikuta
Copy link
Owner Author

この研究のモチベーション自体は全然面白くなくて、Transformer 型のモデルが NLP でめっちゃ優秀な性能を発揮してるので出来るだけそれをそのまま vision の方に持ってこようというもの。

実験結果とかを考慮しない場合、よくこれが研究のモチベーションになるよなという感じがする(めちゃくちゃ乱暴に言えば他分野で流行ってるモデルをこっちの分野にも適用してみます、という話なので)。

もう少し話を押し進めると、Transformer 型モデルには「なぜか」べき乗則(test loss がモデルサイズもしくはデータサイズのべき分布になっている)が成り立つので、それが vision においても成立するかということを確かめたいというものになっている。

Transformer 型モデルのべき乗則の背後に何があるのか、それを理解することで(現状理解されているべき乗則を超えた)有効なモデルが作れるのか、にはかなり興味がある。これは個人的な興味で、この論文ではそれについて何かを示唆しているわけではない。

@yoheikikuta
Copy link
Owner Author

yoheikikuta commented Jan 7, 2021

あとちょっと分かってないのは Transformer 型モデルの計算効率性というのも推しているところ。
Transformer の出自を考えたときに、RNN 型のモデルと比べて並列計算との親和性が高いので(GPU や TPU を用いた並列計算が主流になった現在において)計算効率が良いというのは分かる。

vision は元々 convolution が主流でこれは例えばチャンネル毎に並列計算可能なので、vision の文脈でいうところの計算効率性というのはどういうものなのだろうか。
ResNet とかだと層の数が多くてこれは前の層を計算しない限り後の層が計算できないので、その意味で並列性が犠牲になっているけど、Transformer 型のモデルなら(層ごとのパラメタ数は多いけど)層の数は多くないものの方が計算効率はいいということかな?

@yoheikikuta
Copy link
Owner Author

yoheikikuta commented Jan 7, 2021

モデルの話をする。
基本的な構造は Transformer の encoder 部分を使うという感じなので難しいところは特にないが、下記二点だけは扱いが異なる。

  • 入力は全 pixel を flatten したものではなく、ある程度の patch に分けてそれを flatten + 線形変換したものを並べて token としている
    • convolution とか 2D 位置関係を把握するための工夫がいるのか?ということを試していてデータが十分にあればいらないと結論づけている(!)
  • pre-training は教師なしではなく、class token から MLP を生やしてクラス分類をして、通常の画像分類の教師あり画像として学習
    • これじゃあデータを大量に集めて scaling を活かした学習ができないじゃんと思ったけど 、ラベル有り学習データが 300M とかあるのでそれと ImageNet 1k クラス, 21k クラスとかを比較して scaling を見たりしてる
    • fine-tuning は class token 以降の MLP をすげかえて downstream タスクで学習する

モデルの概念図は以下の通り。

Transformer Encoder 部分はどこに Layer Norm を入れるとかで選択肢があるけどここに書いてある情報だけで十分なので割愛。

モデルの input に関して。まず元の画像は $ x ∈ R^{H \times W \times C} $ とする。
ここは色々選択肢がありうる。NLP では token 毎に 768 dim とかの分散表現に embed していた。画像を同じ仕組みに持ってくる場合に、どういう単位を token とするのがいいだろうか?pixel 毎に embedding することも考えられるが、1 pixel を高次元に embedding してもあまり意味がないだろうし、token 数も画像サイズだけ出てくるのでデカすぎて却下だろう。画像全体をまとめて embedding してしまったら self-attention もクソもないのでこれも却下。
ということで元画像をある程度のサイズの patch に分ける: $ x ∈ R^{N \times P \times P \times C} $
P はパッチのサイズで 16 などと選び、これが決まれば $ N = HW / P^2 $ で N は定まり、例えば (224, 224, 3) の画像の時は N = 196 となる。カラー情報は当然そのパッチに付随するものなので、やりたいのは $ P \times P \times C $ を一つの token としてそれを N 個並べるということだ。
この $ P \times P \times C $ に learnable な行列を掛けて hidden dim D (=768) に embed するということで入力 token 列を形成する。
分類のために頭に class token として同じ D 次元の embedding を加え、あとはやはり(simple な)learnable な poisiton embedding を足せば、入力 embedding は完成となる。ここまでできればモデルの話は Transformer とか BERT が分かってれば完全に理解したと言ってよいだろう。

この patch にして embedding する、というのはなかなか良い塩梅なのではないかと感じる。
画像の局所的な関係性は patch で捉えて、大域的な関係性は self-attention で捉えるという形になっている。
CNN と本質的に異なることは明示的な並進不変性などが取り込まれていないことだ。convolution の利点の一つはそこなわけで、これさえも人間による inductive な bias として取り込むのではなく、embedding でデータからよしなに学習してもらえば十分だろう、という考えだろう。

これはなかなか過激だし、面白いし、個人的にはちょっと悲しい。
人間の持つ知識を convlution という形で取り込んでいたわけだが、そういった明示的な規則ではなく多様なデータから自然と学べば十分だろうという態度の表明であるわけなので。

ちなみにこの論文では convolution を使って feature map を作ってからそれに対して embedding をするという hybrid モデルも試している。
具体的には入力に対して ResNet 構造を適用(これらのパラメタも学習対象)して、得られた feature map を embedding して positon embedding を足して同じように Transfomer Encoder に接続するというモデルである。
これはパラメタ数も増えるし計算効率性も悪いし、モデルとしてはイマイチな感じだが、convolution が本質的に重要であるか、という観点で試しているのだと理解している。

ちょっと長くなったので pre-training に関しては別で書くか。

@yoheikikuta
Copy link
Owner Author

yoheikikuta commented Jan 7, 2021

自分が Vision Transformer の話を最初に聞いた時に思ったのは「へ〜 pre-training はどうやってるんだろう?」というものだった。
というのも、BERT で有効性が強烈に示された masked language model のような学習方法を良い感じに vision で実施する方法が思い浮かばなかったからだ。

NLP においては分布仮説を信じて masked language model で学習するというのは bet するのに十分有望な戦略だと思うけど、画像の場合にもそれってやれるんだっけ?とか思っていた。
画像の一部が欠損している状態でそれを類推するとかいうのはそこまで突拍子もない感じはしないけど、画像の patch 情報を全部再現するように pixel-wise で二乗誤差使って学習するのはいかにもダメそうだよな〜とか。

結論から言うと、Vision Transformer では unsupervised pre-training はしてなくて、画像分類ラベルがついている画像データで supervised training をしていた!マジかよ!
ということでさっきのモデルの図には class token の出力に MLP head がついていて分類問題を解くようになっていたというわけだった。pre-training は Imagenet 1k, Imagenet 21k, JFT-300M とかで 100万 ~ 3億 枚の画像データを使っている。データセットによってクラス数が違うので MLP layer はそれに応じて作って pre-training して、downstream タスクを解く時はこの MLP 部分をすげ替えて fine-tuning するという学習方法になっている。

これは期待外れ(自分が勝手に期待してただけだが)の結果だが、この論文では unsupervised な方法も試した上で最終的に supervised の方がいいので supervised pre-training をしている。
どのようにやったかというと以下。

  • token のうち 50% を corrupt して [mask] embedding とする
    • NLP よりも corruption ratio がだいぶデカいが、言語よりも画像の方が一部の情報だけで類推できるかなかな?(論文には書いてない)
  • 出力 token から対応する入力 token (これは 16 x 16 x 3 のパッチ)の RGB のカラー平均値(空間平均を取る)を 3-bit に量子化して予測する
    • つまり出力は (R, G, B) = (8通り, 8通り, 8通り) なので 512 クラス分類問題として解いてるっぽい。詳細はあまり書いてなくて、オリジナル実装には unsupervised training の情報は含まれてない。
    • これは筋が悪そう。色が違っても情報としては同じようなケースだってたくさんあるはずだし。
    • 他にもさらに 4 x 4 に区切ってそこでのカラー平均値にすることで 16 * 512 クラス分類問題にするとか、patch 画像の pixel 値予測を回帰問題として解いて L2 loss で学習、とかもやっている
  • 結果としては supervised pre-training と比べて downstream タスク(ImageNet 1k top1 accuracy)で 4% も低かった(pre-training なしで scratch から downstream タスクを学習した場合と比べると 2% 向上とは言っている)

ここはもうちょっと工夫しがいがありそうだけど、なんか論文ではこの方向性にそんなに熱心ではない感じがする。
画像データは分類用のデータセットが大量にあるから supervised でデータ量確保できるということなのだろうか?ここを unsupervised でうまくできるようにするとべき乗則を活かして大規模学習を進める、ということがやりやすいので潤沢な計算資源を持つ人々はやりたくなりそうなもんだけど。


ちなみに PyTorch 実装 https://github.com/lucidrains/vit-pytorch では Bootstrap Your Own Latent (BYOL) https://arxiv.org/abs/2006.07733 という手法で unsupervised 学習ができるようになってる。
チラ見したら、ネットワークを複製して、一方はもう一方のパラメタの exponential moving average とかにして出力が一致するように学習(augmentation とかを使いつつ)していくというものらしい。これだけ聞くとそんなにうまいこといくの?という感じがするが、本題とは外れるので機会があったらまた読んでみることにする。

@yoheikikuta
Copy link
Owner Author

実験結果を見る。

その前に JFT-300M という画像データセットを知らないのでちょっと調べておく。
https://arxiv.org/abs/1707.02968 の論文で使われているもので、18291 クラスの 300 million 画像データセットのこと。そういえばこの論文が出たときに 3 億枚の画像云々という話を聞いたな。
これどこでダウンロードできるのかと思ったら、Google 内部でのみ使っているデータなのね... はぁ... データサイズが重要という話なのにその再現性がない(データ公開されてても計算量的にほとんど誰もできないけど)というのは悲しいっすな...

このデータで pre-training して、各種 downstream タスクを解いた結果が以下。
例えば ViT-H/14 の 14 は patch サイズを意味している(ので例えばこれが大きくなると、入力画像サイズが一定の場合に token 数が減るということになる)。
BiT (Big Transfer) や Nosy Student は先行研究で性能が高いモデル。

精度が驚くほど向上したというものではなく、十分デカいモデルを使えば先行研究で性能の高いモデルと同等以上の性能を発揮できる、というものになっている。

注目すべきは学習が終わるまでに必要な計算時間で、ViT-H/14 で先行研究の 1/4 以下になっている(TPUv3-core-days で 2.5k なので普通の人には手が出せるレベルではないが)。これはバッチサイズとか epoch 数とかを合わせて学習しているので、学習の条件としてはモデル以外は同等になっているはず(他にも当然 weight decay とかも影響するのでその辺もできるだけ同じ条件にしてる)。
それで 4 倍速く学習できるにようになってるので確かに効率的に学習できていると思われるが、その違いがどこから来ているのかは結構非自明(というけ明確には分かってない)。

まず、パラメタ数は ViT-Huge で 6.32 億で、BiT は 9.3 億、EfficientNet-L2 は 4.8 億。
EfficientNet-L2 はパラメタ数は少ないけど学習は遅いのは Noisy Student 使ってて teacher と student の 2 つのネットワークを使うことになるから遅いということかな(なので inference が遅いということは特にないはず)。
BiT はパラメタ数多くて演算回数が多いので遅い?のかと思ったけど、学習時に必要になるトータルの演算数はこの論文からは読み取りづらい...
例えば appendix の Table 6 に以下の表があるが、これは scaling の実験をしたもので、上で載せた結果と同じものがない... BiT つまり ResNet152x4 は表にはないが、とはいえ幅を 2 倍にしても演算数が 4 倍程度という感じなので、ViT-H/14 と遜色なさそう。

そしたら縦にたくさん積んでるので層毎の計算の依存が多いので遅いのかな〜と思ったけど、(学習じゃないけど)inference は同じくらいの速度だった。以下の図の左から ViT-H/14 と R152x4 は同等程度。右からはバッチサイズに関しては ViT-H/14 の方が多く持てそうだけど、学習時にはバッチサイズ揃えて計算してるはずなのでこれは効いてないはず。

ということで自分にはなぜ ViT の学習が他と比べてこれだけ速いのかはイマイチわからなかった。
ここは結構重要だと思う(精度はほぼ同等でしかないので)んだけど、論文でもちゃんと書いてない(と思われる)のでう〜むという感じ。誰か詳しい人に教えてもらいたい。

@yoheikikuta
Copy link
Owner Author

データセットサイズに対する scaling について。
ViT では pre-training は supervised なので、ここでは ImageNet 1k, 21k, JFT-300M のそれぞれで pre-training した場合に downstream タスクの性能がどうなったかで検証している。

結果は以下の図の左。
BiT に関しては色々なモデルサイズで実施し、その上限と下限を示している。点にするとごちゃごちゃになっちゃうからっすな。
ImageNet 1k では明らかに BiT より低性能だしデカいモデルだから良いというわけでもないが、データセットを大きくするとモデルがデカい順に並びそして BiT と同等以上の性能を発揮するようになっている。

右の図も学習サンプル数を増やすことで ViT の大きいモデルは特に 30M-100M で大きく性能を伸ばしていることが分かる。

同じような scaling を横軸を学習時の総演算回数でプロットしたものが以下の結果。

これは hybrid (token を作る前に resnet 構造を入れて feature map を抽出するモデル) もプロットされているが、十分に学習する前には convolution を明示的に入れたものが性能に寄与しているが、十分に学習すると convolution を入れても効果は特にないという結果になっている。
これはなかなか impressive な結果。理屈としては十分データがあって学習すれば確かに convolution の構造を明示的に入れなくてもそれに相当するものを学習できてもよさそうというのは分かるが、それがいままさに示されたというのはやはり驚き。

この図は Table 6 のモデルに対するプロットで R152x4 がないというのはやっぱり納得がいかないが、結果としてはデカいが強いで scaling してそうというものになってる。

論文では ViT は性能が saturate してなさそうだからもっと scaling させたいねって言ってるけど、それは BiT も同じだと思うけどな。
色々実験してるけどなんか結果の見せ方が ViT を無駄に良く見せようとしてる感じがして気になる。

@yoheikikuta
Copy link
Owner Author

yoheikikuta commented Jan 8, 2021

positional embedding に関しては

  • 入れないと性能が落ちる
  • 2D で入れるとか relative positional embedding を入れるとかも試したけど、大きく性能は変わらなくてシンプルな 1D が一番良い

という結果。
2D 的に取り扱うとかいかにも有効そうだけど、そんなん気にせず 1D で token 作って並べて突っ込んどけばいいんや!ってことらしい。
これは自分の予想に反してたのでやや驚き。結論としてはシンプルなので分かりやすいけど。

@yoheikikuta
Copy link
Owner Author

yoheikikuta commented Jan 8, 2021

他にも色々やってるけど、自分が一番興味があった部分はチェックできた(理解できてないところはあるけど)ので、だいたいこんなもんかな。

patch 化するところだけちゃんと把握できてれば Transformer 型モデルそのものに近い構造なので、これは確かに text と一緒に使いたくなってくるね。それを実現しておどくべき成果を発揮したのが DALL·E ということだと思うので、次はそこを理解したいね(論文は出てないけど)。

@yoheikikuta yoheikikuta changed the title [2020] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale [2010.11929] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale [paper-reading] Aug 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant