Skip to content

Commit

Permalink
add code
Browse files Browse the repository at this point in the history
  • Loading branch information
xuyanbo03 committed May 28, 2020
1 parent f299b75 commit 5cf0ce1
Show file tree
Hide file tree
Showing 13 changed files with 892 additions and 0 deletions.
94 changes: 94 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Django #
*.log
*.pot
*.pyc
__pycache__
db.sqlite3
media

# Backup files #
*.bak

# If you are using PyCharm #
.idea
*.iws /out/

# Python #
*.py[cod]
*$py.class

# Distribution / packaging
.Python build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
.pytest_cache/
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# pyenv
.python-version

# celery
celerybeat-schedule.*

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# mkdocs documentation
/site

# mypy
.mypy_cache/

# Sublime Text #
*.tmlanguage.cache
*.tmPreferences.cache
*.stTheme.cache
*.sublime-workspace
*.sublime-project

# sftp configuration file
sftp-config.json

# Package control specific files Package
# Control.last-run
# Control.ca-list
# Control.ca-bundle
# Control.system-ca-bundle
# GitHub.sublime-settings
50 changes: 50 additions & 0 deletions DecoderLayer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/usr/bin/python3
# -*- coding:utf-8 -*-
from tensorflow import keras
from MultiHeadAttention import MultiHeadAttention
from FFN import feed_forward_network


# Decoder Layer
class DecoderLayer(keras.layers.Layer):
"""
x -> self attention -> add & normalize & dropout -> out1
out1, encoding_outputs -> attention -> add & normalize & dropout -> out2
out2 -> feed_forward -> add & normalize & dropout -> out3
"""

def __init__(self, d_model, num_heads, dff, rate=0.1):
super(DecoderLayer, self).__init__()
self.mha1 = MultiHeadAttention(d_model, num_heads)
self.mha2 = MultiHeadAttention(d_model, num_heads)

self.ffn = feed_forward_network(d_model, dff)

self.layer_norm1 = keras.layers.LayerNormalization(epsilon=1e-6)
self.layer_norm2 = keras.layers.LayerNormalization(epsilon=1e-6)
self.layer_norm3 = keras.layers.LayerNormalization(epsilon=1e-6)

self.dropout1 = keras.layers.Dropout(rate)
self.dropout2 = keras.layers.Dropout(rate)
self.dropout3 = keras.layers.Dropout(rate)

def call(self, x, encoding_outputs, training, decoder_mask, encoder_decoder_padding_mask):
# decoder_mask是由look_ahead_mask和decoder_padding_mask做与操作合并而来
# x.shape: (batch_size, target_seq_len, d_model)
# encoding_outputs.shape: (batch_size, input_seq_len, d_model)
# attn1,out1.shape: (batch_size, target_seq_len, d_model)
attn1, attn_weights1 = self.mha1(x, x, x, decoder_mask)
attn1 = self.dropout1(attn1, training=training)
out1 = self.layer_norm1(attn1 + x)

# attn2,out2.shape: (batch_size, target_seq_len, d_model)
attn2, attn_weights2 = self.mha2(
out1, encoding_outputs, encoding_outputs, encoder_decoder_padding_mask)
attn2 = self.dropout2(attn2, training=training)
out2 = self.layer_norm2(attn2 + out1)

# ffn_output, out3.shape: (batch_size, target_seq_len, d_model)
ffn_output = self.ffn(out2)
ffn_output = self.dropout3(ffn_output, training=training)
out3 = self.layer_norm3(ffn_output + out2)
return out3, attn_weights1, attn_weights2
47 changes: 47 additions & 0 deletions DecoderModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/usr/bin/python3
# -*- coding:utf-8 -*-
from tensorflow import keras
from utils import *
from DecoderLayer import DecoderLayer


# Decoder Model
class DecoderModel(keras.layers.Layer):
def __init__(self, num_layers, target_vocab_size, max_length,
d_model, num_heads, dff, rate=0.1):
super(DecoderModel, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.max_length = max_length

self.embedding = keras.layers.Embedding(target_vocab_size, self.d_model)
# position_embedding.shape: (1, max_length, d_model)
self.position_embedding = get_position_embedding(self.max_length, self.d_model)
self.dropout = keras.layers.Dropout(rate)
self.decoder_layers = [
DecoderLayer(d_model, num_heads, dff, rate) for _ in range(self.num_layers)]

def call(self, x, encoding_outputs, training, decoder_mask, encoder_decoder_padding_mask):
# x.shape: (batch_size, output_seq_len)
output_seq_len = tf.shape(x)[1]
tf.debugging.assert_less_equal(
output_seq_len, self.max_length,
message='output_seq_len should be less or equal to self.max_length')

attention_weights = {}

# x.shape: (batch_size, output_seq_len, d_model)
x = self.embedding(x)
# 做缩放,范围是0-d_model,目的是在与position_embedding做完加法后,x起的作用更大
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x += self.position_embedding[:, :output_seq_len, :]
x = self.dropout(x, training=training)

# x.shape: (batch_size, output_seq_len, d_model)
for i in range(self.num_layers):
x, att1, att2 = self.decoder_layers[i](
x, encoding_outputs, training, decoder_mask, encoder_decoder_padding_mask)
attention_weights['decoder_layer{}_att1'.format(i + 1)] = att1
attention_weights['decoder_layer{}_att2'.format(i + 1)] = att2

return x, attention_weights
36 changes: 36 additions & 0 deletions EncoderLayer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/usr/bin/python3
# -*- coding:utf-8 -*-
from tensorflow import keras
from MultiHeadAttention import MultiHeadAttention
from FFN import feed_forward_network


# Encoder Layer
class EncoderLayer(keras.layers.Layer):
"""
x -> self attention -> add & normalize & dropout -> feed_forward -> add & normalize & dropout
"""

def __init__(self, d_model, num_heads, dff, rate=0.1):
super(EncoderLayer, self).__init__()
self.mha = MultiHeadAttention(d_model, num_heads)
self.ffn = feed_forward_network(d_model, dff)
self.layer_norm1 = keras.layers.LayerNormalization(epsilon=1e-6)
self.layer_norm2 = keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = keras.layers.Dropout(rate)
self.dropout2 = keras.layers.Dropout(rate)

def call(self, x, training, encoder_padding_mask):
# x.shape: (batch_size, seq_len, dim=d_model)
# attn_output.shape: (batch_size, seq_len, d_model)
# out1.shape: (batch_size, seq_len, d_model)
attn_output, _ = self.mha(x, x, x, encoder_padding_mask)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layer_norm1(x + attn_output)

# ffn_output.shape: (batch_size, seq_len, d_model)
# out2.shape: (batch_size, seq_len, d_model)
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output, training=training)
out2 = self.layer_norm2(out1 + ffn_output)
return out2
42 changes: 42 additions & 0 deletions EncoderModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/python3
# -*- coding:utf-8 -*-
import tensorflow
from tensorflow import keras
from utils import *
from EncoderLayer import EncoderLayer


class EncoderModel(keras.layers.Layer):
def __init__(self, num_layers, input_vocab_size, max_length,
d_model, num_heads, dff, rate=0.1):
super(EncoderModel, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.max_length = max_length

self.embedding = keras.layers.Embedding(input_vocab_size, self.d_model)
# position_embedding.shape: (1, max_length, d_model)
self.position_embedding = get_position_embedding(self.max_length, self.d_model)
self.dropout = keras.layers.Dropout(rate)
self.encoder_layers = [
EncoderLayer(d_model, num_heads, dff, rate) for _ in range(self.num_layers)]

def call(self, x, training, encoder_padding_mask):
# x.shape: (batch_size, input_seq_len)
input_seq_len = tf.shape(x)[1]
tf.debugging.assert_less_equal(
input_seq_len, self.max_length,
message='input_seq_len should be less or equal to self.max_length')

# x.shape: (batch_size, input_seq_len, d_model)
x = self.embedding(x)
# 做缩放,范围是0-d_model,目的是在与position_embedding做完加法后,x起的作用更大
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x += self.position_embedding[:, :input_seq_len, :]
x = self.dropout(x, training=training)

# x.shape: (batch_size, input_seq_len, d_model)
for i in range(self.num_layers):
x = self.encoder_layers[i](x, training, encoder_padding_mask)

return x
12 changes: 12 additions & 0 deletions FFN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/usr/bin/python3
# -*- coding:utf-8 -*-
from tensorflow import keras


# feed_forward
def feed_forward_network(d_model, dff):
# dff: dim of feed forward network.
return keras.Sequential([
keras.layers.Dense(dff, activation='relu'),
keras.layers.Dense(d_model)
])
65 changes: 65 additions & 0 deletions MultiHeadAttention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#!/usr/bin/python3
# -*- coding:utf-8 -*-
import tensorflow as tf
from tensorflow import keras
from utils import *

# 多头注意力
class MultiHeadAttention(keras.layers.Layer):
"""
理论上:
x -> Wq0 -> q0
x -> Wk0 -> k0
x -> Wv0 -> v0
实际上:把x分成q,k,v
q -> Wq0 -> q0
k -> Wk0 -> k0
v -> Wv0 -> v0
实战中的技巧:
q -> Wq -> Q -> split -> q0,q1,q2...
"""

def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert self.d_model % self.num_heads == 0
self.depth = self.d_model // self.num_heads

self.WQ = keras.layers.Dense(self.d_model)
self.WK = keras.layers.Dense(self.d_model)
self.WV = keras.layers.Dense(self.d_model)

self.dense = keras.layers.Dense(self.d_model)

def split_heads(self, x, batch_size):
# x.shape: (batch_size, seq_len, d_model)
# d_model = num_heads * depth
# x -> (batch_size, num_heads, seq_len, depth)
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])

def call(self, q, k, v, mask):
batch_size = tf.shape(q)[0]

q = self.WQ(q) # q.shape: (batch_size, seq_len_q, d_model)
k = self.WK(k) # k.shape: (batch_size, seq_len_k, d_model)
v = self.WV(v) # v.shape: (batch_size, seq_len_v, d_model)

q = self.split_heads(q, batch_size) # q.shape: (batch_size, num_heads, seq_len_q, depth)
k = self.split_heads(k, batch_size) # k.shape: (batch_size, num_heads, seq_len_k, depth)
v = self.split_heads(v, batch_size) # v.shape: (batch_size, num_heads, seq_len_v, depth)

# scaled_attention_outputs.shape: (batch_size, num_heads, seq_len_q, depth)
# attention_weights.shape: (batch_size, num_heads, seq_len_q, seq_len_k)
scaled_attention_outputs, attention_weights = scaled_dot_product_attention(q, k, v, mask)

# scaled_attention_outputs.shape: (batch_size, seq_len_q, num_heads, depth)
scaled_attention_outputs = tf.transpose(scaled_attention_outputs, perm=[0, 2, 1, 3])

# concat_attention.shape: (batch_size, seq_len_q, d_model)
concat_attention = tf.reshape(scaled_attention_outputs, (batch_size, -1, self.d_model))

# output.shape: (batch_size, seq_len_q, d_model)
output = self.dense(concat_attention)
return output, attention_weights
31 changes: 31 additions & 0 deletions Transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/usr/bin/python3
# -*- coding:utf-8 -*-
import tensorflow as tf
from tensorflow import keras
from EncoderModel import EncoderModel
from DecoderModel import DecoderModel


class Transformer(keras.Model):
def __init__(self, num_layers, input_vocab_size, target_vocab_size, max_length,
d_model, num_heads, dff, rate=0.1):
super(Transformer, self).__init__()
self.encoder_model = EncoderModel(
num_layers, input_vocab_size, max_length, d_model, num_heads, dff, rate)
self.decoder_model = DecoderModel(
num_layers, target_vocab_size, max_length, d_model, num_heads, dff, rate)
self.final_layer = keras.layers.Dense(target_vocab_size)

def call(self, inp, tar, training, encoder_padding_mask,
decoder_mask, encoder_decoder_padding_mask):
# encoding_outputs.shape: (batch_size, input_seq_len, d_model)
encoding_outputs = self.encoder_model(inp, training, encoder_padding_mask)

# decoding_outputs.shape: (batch_size, output_seq_len, d_model)
decoding_outputs, attention_weights = self.decoder_model(
tar, encoding_outputs, training, decoder_mask, encoder_decoder_padding_mask)

# decoding_outputs.shape: (batch_size, output_seq_len, target_vocab_size)
predictions = self.final_layer(decoding_outputs)

return predictions, attention_weights
Loading

0 comments on commit 5cf0ce1

Please sign in to comment.