Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated Transformer model to use callable. #113

Merged
merged 2 commits into from
Apr 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 37 additions & 40 deletions Transformer/Model.swift
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import TensorFlow
import Python

public struct TimeDistributed: Layer {
struct TimeDistributed: Layer {
var dense: Dense<Float>

public init(_ wrapped: Dense<Float>) {
init(_ wrapped: Dense<Float>) {
self.dense = wrapped
}

@differentiable(wrt: (self, input))
public func applied(to input: Tensor<Float>) -> Tensor<Float> {
func call(_ input: Tensor<Float>) -> Tensor<Float> {
let (batchSize, timeSteps, features) = (input.shape[0], input.shape[1], input.shape[2])
let reshaped = input.reshaped(to: [batchSize * timeSteps, features])
let output = dense.applied(to: reshaped)
let output = dense(reshaped)
let outputFeatures = output.shape[1]
return output.reshaped(to: [batchSize, timeSteps, outputFeatures])
}
Expand All @@ -31,7 +31,7 @@ struct FeedForward: Layer {
}

@differentiable(wrt: (self, input))
func applied(to input: Tensor<Float>) -> Tensor<Float> {
func call(_ input: Tensor<Float>) -> Tensor<Float> {
return input.sequenced(through: dense1, dropout, dense2)
}
}
Expand Down Expand Up @@ -95,26 +95,27 @@ struct Attention: Layer {
@noDerivative let dropout: Dropout<Float>
@noDerivative let scale: Tensor<Float>
@noDerivative let causal: Bool

init(size: Int, causal: Bool = false, dropProbability: Double) {
scale = Tensor(sqrt(Float(size)))
dropout = Dropout<Float>(probability: dropProbability)
self.causal = causal
}

@differentiable(wrt: (self, input))
func applied(to input: AttentionInput)
-> Tensor<Float> {
func call(_ input: AttentionInput) -> Tensor<Float> {
var dotProducts = batchedMatmul(input.query, input.key, adjointRight: true)
dotProducts = causallyMasked(dotProducts, enable: causal) / scale
return batchedMatmul(dropout.applied(to: softmax(dotProducts)), input.value)
return batchedMatmul(dropout(softmax(dotProducts)), input.value)
}
func applied(to input: AttentionInput, state: inout AttentionContext)
-> Tensor<Float> {

func call(_ input: AttentionInput, state: inout AttentionContext) -> Tensor<Float> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call functions with state do not have @differentiable. I keep it that way.

state = AttentionContext(
key: state.key.concatenated(with: input.key, alongAxis: 1),
value: state.value.concatenated(with: input.value, alongAxis: 1))
var dotProducts = batchedMatmul(input.query, state.key, adjointRight: true)
dotProducts = causallyMasked(dotProducts, enable: causal) / scale
return batchedMatmul(dropout.applied(to: softmax(dotProducts)), state.value)
return batchedMatmul(dropout(softmax(dotProducts)), state.value)
}
}

Expand Down Expand Up @@ -167,34 +168,34 @@ struct MultiHeadAttention: Layer {
var wqkv: TimeDistributed
var wo: TimeDistributed
@noDerivative let headCount: Int

init(attention: Attention, size: Int, headCount: Int) {
self.attention = attention
wqkv = TimeDistributed(Dense<Float>(
inputSize: size, outputSize: size * 3, activation: identity))
wo = TimeDistributed(Dense<Float>(inputSize: size, outputSize: size, activation: identity))
self.headCount = headCount
}

@differentiable(wrt: (self, input))
func applied(to input: Tensor<Float>) -> Tensor<Float> {
let qkvProjected = wqkv.applied(to: input)
func call(_ input: Tensor<Float>) -> Tensor<Float> {
let qkvProjected = wqkv(input)
let qkvSplit = splitHeads(qkvProjected, headCount: headCount)
let attentionInput = splitQKV(qkvSplit)
let outputs = attention.applied(to: attentionInput)
return wo.applied(to: joinHeads(outputs, headCount: headCount))
let outputs = attention(attentionInput)
return wo(joinHeads(outputs, headCount: headCount))
}
func applied(
to input: Tensor<Float>,
state: inout AttentionContext
) -> Tensor<Float> {
let qkvProjected = wqkv.applied(to: input)

func call(_ input: Tensor<Float>, state: inout AttentionContext) -> Tensor<Float> {
let qkvProjected = wqkv(input)
let qkvSplit = splitQKV(qkvProjected)
let attentionInput = makeAttentionInput(
query: splitHeads(qkvSplit.query, headCount: headCount),
key: splitHeads(qkvSplit.key, headCount: headCount),
value: splitHeads(qkvSplit.value, headCount: headCount)
)
let outputs = attention.applied(to: attentionInput, state: &state)
return wo.applied(to: joinHeads(outputs, headCount: headCount))
let outputs = attention(attentionInput, state: &state)
return wo(joinHeads(outputs, headCount: headCount))
}
}

Expand All @@ -219,21 +220,18 @@ struct EncoderLayer: Layer {
}

@differentiable(wrt: (self, input))
func applied(to input: Tensor<Float>) -> Tensor<Float> {
func call(_ input: Tensor<Float>) -> Tensor<Float> {
let attended = input + input.sequenced(
through: selfAttentionNorm, selfAttention, selfAttentionDropout)
return attended + attended.sequenced(
through: feedForwardNorm, feedForward, feedForwardDropout)
}

func applied(
to input: Tensor<Float>,
state: inout AttentionContext
) -> Tensor<Float> {
func call(_ input: Tensor<Float>, state: inout AttentionContext) -> Tensor<Float> {
var tmp = input
tmp = selfAttentionNorm.applied(to: tmp)
tmp = selfAttention.applied(to: tmp, state: &state)
tmp = selfAttentionDropout.applied(to: tmp)
tmp = selfAttentionNorm(tmp)
tmp = selfAttention(tmp, state: &state)
tmp = selfAttentionDropout(tmp)
let attended = tmp + input
return attended + attended.sequenced(
through: feedForwardNorm, feedForward, feedForwardDropout)
Expand All @@ -242,15 +240,17 @@ struct EncoderLayer: Layer {

struct Embedding: Differentiable {
var weight: Tensor<Float>

init(weight: Tensor<Float>) {
self.weight = weight
}

init(vocabSize: Int, size: Int) {
self.weight = Tensor(randomUniform: [vocabSize, size])
}

@differentiable(wrt: self)
func applied(to input: Tensor<Int32>) -> Tensor<Float> {
func call(_ input: Tensor<Int32>) -> Tensor<Float> {
return weight.gathering(atIndices: input)
}
}
Expand All @@ -261,21 +261,18 @@ struct TransformerLM {
var layers: [EncoderLayer]
var norm: LayerNorm<Float>

func applied(
to tokens: Tensor<Int32>,
states: inout [AttentionContext]
) -> Tensor<Float> {
func call(_ tokens: Tensor<Int32>, states: inout [AttentionContext]) -> Tensor<Float> {
let positions = (0..<tokens.shape[1]).map { Int32($0 + states[0].key.shape[1]) }
let positionsTensor = Tensor<Int32>(shape: [1, tokens.shape[1]], scalars: positions)
var h = embedding.applied(to: tokens)
var h = embedding(tokens)
h = h + positionalEmbeddings.gathering(atIndices: positionsTensor)
for i in 0..<layers.count {
h = layers[i].applied(to: h, state: &states[i])
h = layers[i](h, state: &states[i])
}
h = norm.applied(to: h)
let logits = TimeDistributed(
h = norm(h)
let tmp = TimeDistributed(
Dense(weight: embedding.weight.transposed(), bias: Tensor(0.0), activation: identity))
.applied(to: h) // a somewhat hacky way to share weights
let logits = tmp(h) // a somewhat hacky way to share weights
return logits
}
}
5 changes: 2 additions & 3 deletions Transformer/Operators.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func gelu<Scalar: TensorFlowFloatingPoint>(_ x: Tensor<Scalar>) -> Tensor<Scalar
vjp: _vjpBatchedMatmul
where Scalar : Differentiable & FloatingPoint
)
public func batchedMatmul<Scalar : Numeric>(
func batchedMatmul<Scalar : Numeric>(
_ left: Tensor<Scalar>,
_ right: Tensor<Scalar>,
adjointLeft: Bool = false,
Expand Down Expand Up @@ -58,8 +58,7 @@ func _vjpBatchedMatmul<Scalar : Differentiable & FloatingPoint>(
})
}

public extension Tensor
where Scalar: TensorFlowFloatingPoint {
extension Tensor where Scalar: TensorFlowFloatingPoint {
/// Gathers slices of self at the specified indices along the first axis. The result has the
/// same size in the first axis as the scalar count of the index tensor, and the same
/// size in subsequent axes as self.
Expand Down
5 changes: 2 additions & 3 deletions Transformer/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@ if CommandLine.arguments.count == 3 {
tokens = Tensor(shape: [1, tokarr.count], scalars: tokarr)
}

let empty = Tensor<Float>(
zeros: [config.headCount, 0, config.embeddingSize / config.headCount])
let empty = Tensor<Float>(zeros: [config.headCount, 0, config.embeddingSize / config.headCount])
var states = (0..<config.layerCount).map { _ in AttentionContext(key: empty, value: empty) }

for _ in 0..<100 {
let logits = model.applied(to: tokens, states: &states)
let logits = model(tokens, states: &states)
let (batchSize, timeSteps, vocabSize) = (logits.shape[0], logits.shape[1], logits.shape[2])
let lastLogit = logits.slice(
lowerBounds: [0, timeSteps - 1, 0],
Expand Down