Skip to content

Commit

Permalink
Address rxwei@'s comments on Transformer model (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
jekbradbury committed Mar 7, 2019
1 parent e12177a commit da96bdf
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 86 deletions.
4 changes: 2 additions & 2 deletions Transformer/Model.swift
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ struct Embedding: Differentiable {

@differentiable(wrt: self)
func applied(to input: Tensor<Int32>, in context: Context) -> Tensor<Float> {
return weight.gathering(at: input)
return weight.gathering(atIndices: input)
}
}

Expand All @@ -278,7 +278,7 @@ struct TransformerLM {
let positions = (0..<tokens.shape[1]).map {$0 + states[0].key.shape[1]}
let positionsTensor = Tensor<Int32>(shape: [1, tokens.shape[1]], scalars: positions)
var h = embedding.applied(to: tokens, in: context)
h = h + positionalEmbeddings.gathering(at: positionsTensor)
h = h + positionalEmbeddings.gathering(atIndices: positionsTensor)
for i in 0..<layers.count {
h = layers[i].applied(to: h, state: &states[i], in: context)
}
Expand Down
75 changes: 39 additions & 36 deletions Transformer/Operators.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,47 +12,50 @@ func gelu<Scalar: TensorFlowFloatingPoint>(_ x: Tensor<Scalar>) -> Tensor<Scalar
/// Performs batched matrix multiplication of two tensors. The last two axes of each tensor
/// are treated as the matrix axes; all others are treated as batch axes.
@differentiable(
wrt: (left, right),
vjp: _vjpBatchedMatmul
where Scalar : Differentiable & FloatingPoint
wrt: (left, right),
vjp: _vjpBatchedMatmul
where Scalar : Differentiable & FloatingPoint
)
public func batchedMatmul<Scalar : Numeric>(
_ left: Tensor<Scalar>,
_ right: Tensor<Scalar>,
adjointLeft: Bool = false,
adjointRight: Bool = false
_ left: Tensor<Scalar>,
_ right: Tensor<Scalar>,
adjointLeft: Bool = false,
adjointRight: Bool = false
) -> Tensor<Scalar> {
return Raw.batchMatMul(left, right, adjX: adjointLeft, adjY: adjointRight)
return Raw.batchMatMul(left, right, adjX: adjointLeft, adjY: adjointRight)
}

@usableFromInline
func _vjpBatchedMatmul<Scalar : Differentiable & FloatingPoint>(
_ left: Tensor<Scalar>, _ right: Tensor<Scalar>, adjointLeft: Bool, adjointRight: Bool
_ left: Tensor<Scalar>,
_ right: Tensor<Scalar>,
adjointLeft: Bool,
adjointRight: Bool
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> (Tensor<Scalar>, Tensor<Scalar>)) {
let value = batchedMatmul(left, right, adjointLeft: adjointLeft, adjointRight: adjointRight)
return (value, { v in
if !adjointLeft {
if !adjointRight {
return (
batchedMatmul(v, right, adjointLeft: false, adjointRight: true),
batchedMatmul(left, v, adjointLeft: true, adjointRight: false))
} else {
return (
batchedMatmul(v, right, adjointLeft: false, adjointRight: false),
batchedMatmul(v, left, adjointLeft: true, adjointRight: false))
}
} else {
if !adjointRight {
return (
batchedMatmul(right, v, adjointLeft: false, adjointRight: true),
batchedMatmul(left, v, adjointLeft: false, adjointRight: false))
} else {
return (
batchedMatmul(right, v, adjointLeft: true, adjointRight: true),
batchedMatmul(v, left, adjointLeft: true, adjointRight: true))
}
}
})
let value = batchedMatmul(left, right, adjointLeft: adjointLeft, adjointRight: adjointRight)
return (value, { v in
if !adjointLeft {
if !adjointRight {
return (
batchedMatmul(v, right, adjointLeft: false, adjointRight: true),
batchedMatmul(left, v, adjointLeft: true, adjointRight: false))
} else {
return (
batchedMatmul(v, right, adjointLeft: false, adjointRight: false),
batchedMatmul(v, left, adjointLeft: true, adjointRight: false))
}
} else {
if !adjointRight {
return (
batchedMatmul(right, v, adjointLeft: false, adjointRight: true),
batchedMatmul(left, v, adjointLeft: false, adjointRight: false))
} else {
return (
batchedMatmul(right, v, adjointLeft: true, adjointRight: true),
batchedMatmul(v, left, adjointLeft: true, adjointRight: true))
}
}
})
}

public extension Tensor
Expand All @@ -61,12 +64,12 @@ public extension Tensor
/// same size in the first axis as the scalar count of the index tensor, and the same
/// size in subsequent axes as self.
@differentiable(wrt: self, vjp: _vjpGathering)
func gathering(at indices: Tensor<Int32>) -> Tensor {
func gathering(atIndices indices: Tensor<Int32>) -> Tensor {
return Raw.gather(params: self, indices: indices)
}

func _vjpGathering(at indices: Tensor<Int32>) -> (Tensor, (Tensor) -> Tensor) {
let value = gathering(at: indices)
func _vjpGathering(atIndices indices: Tensor<Int32>) -> (Tensor, (Tensor) -> Tensor) {
let value = gathering(atIndices: indices)
return (value, { [wShape = shape] seed in
var valuesShape = wShape
valuesShape[0] = indices.scalarCount
Expand Down
100 changes: 56 additions & 44 deletions Transformer/PythonCheckpointReader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,108 +9,120 @@ struct Config {
}

extension Config {
init(dict: [String: Int]) {
vocabSize = dict["n_vocab"]!
contextSize = dict["n_ctx"]!
embeddingSize = dict["n_embd"]!
headCount = dict["n_head"]!
layerCount = dict["n_layer"]!
init(dictionary: [String: Int]) {
vocabSize = dictionary["n_vocab"]!
contextSize = dictionary["n_ctx"]!
embeddingSize = dictionary["n_embd"]!
headCount = dictionary["n_head"]!
layerCount = dictionary["n_layer"]!
}
}

let config = Config(dict: [
let config = Config(dictionary: [
"n_vocab": 50257,
"n_ctx": 1024,
"n_embd": 768,
"n_head": 12,
"n_layer": 12])
"n_layer": 12
])

func readTensor<Scalar: TensorFlowScalar>(
from path: String, name: String, scalarType: Scalar.Type
fromPath path: String,
name: String,
scalarType: Scalar.Type
) -> Tensor<Scalar> {
// TODO(jekbradbury): support variadic dtype attrs in RawOpsGenerated
return Tensor<Scalar>(handle: #tfop(
return Tensor(handle: #tfop(
"RestoreV2",
StringTensor(path),
StringTensor([name]),
StringTensor([""]),
dtypes$dtype: [Scalar.tensorFlowDataType]))
}

private func checkShapes(_ tensor1: Tensor<Float>, _ tensor2: Tensor<Float>) {
guard tensor1.shape == tensor2.shape else {
print("Shape mismatch: \(tensor1.shape) != \(tensor2.shape)")
fatalError()
}
}

protocol InitializableFromPythonCheckpoint {
init(from path: String, withScope scope: String)
init(contentsOfPythonCheckpointFile path: String, scope: String)
}

extension Dense: InitializableFromPythonCheckpoint {
init(from path: String, withScope scope: String) {
let kernel = readTensor(from: path, name: scope + "/w", scalarType: Scalar.self)
init(contentsOfPythonCheckpointFile path: String, scope: String) {
let kernel = readTensor(fromPath: path, name: scope + "/w", scalarType: Scalar.self)
self.init(
weight: kernel.squeezingShape(at: 0),
bias: readTensor(from: path, name: scope + "/b", scalarType: Scalar.self),
bias: readTensor(fromPath: path, name: scope + "/b", scalarType: Scalar.self),
activation: identity)
}
init(from path: String, withScope scope: String, activation: String) {
let kernel = readTensor(from: path, name: scope + "/w", scalarType: Scalar.self)
init(contentsOfPythonCheckpointFile path: String, scope: String, activation: String) {
let kernel = readTensor(fromPath: path, name: scope + "/w", scalarType: Scalar.self)
self.init(
weight: kernel.squeezingShape(at: 0),
bias: readTensor(from: path, name: scope + "/b", scalarType: Scalar.self),
bias: readTensor(fromPath: path, name: scope + "/b", scalarType: Scalar.self),
activation: gelu)
}
}

extension LayerNorm: InitializableFromPythonCheckpoint {
init(from path: String, withScope scope: String) {
init(contentsOfPythonCheckpointFile path: String, scope: String) {
self.init(
offset: readTensor(from: path, name: scope + "/b", scalarType: Scalar.self),
scale: readTensor(from: path, name: scope + "/g", scalarType: Scalar.self),
offset: readTensor(fromPath: path, name: scope + "/b", scalarType: Scalar.self),
scale: readTensor(fromPath: path, name: scope + "/g", scalarType: Scalar.self),
axis: -1,
epsilon: Tensor(1e-5))
}
}

extension MultiHeadAttention: InitializableFromPythonCheckpoint {
init(from path: String, withScope scope: String) {
attention = Attention(size: config.embeddingSize / config.headCount, causal: true, dropProbability: 0.2)
wqkv = TimeDistributed(Dense<Float>(from: path, withScope: scope + "/c_attn"))
wo = TimeDistributed(Dense<Float>(from: path, withScope: scope + "/c_proj"))
init(contentsOfPythonCheckpointFile path: String, scope: String) {
attention = Attention(
size: config.embeddingSize / config.headCount,
causal: true,
dropProbability: 0.2)
wqkv = TimeDistributed(Dense<Float>(
contentsOfPythonCheckpointFile: path,
scope: scope + "/c_attn"))
wo = TimeDistributed(Dense<Float>(
contentsOfPythonCheckpointFile: path,
scope: scope + "/c_proj"))
headCount = Int32(12)
}
}

extension FeedForward: InitializableFromPythonCheckpoint {
init(from path: String, withScope scope: String) {
dense1 = TimeDistributed(Dense<Float>(from: path, withScope: scope + "/c_fc", activation: "gelu"))
dense2 = TimeDistributed(Dense<Float>(from: path, withScope: scope + "/c_proj"))
init(contentsOfPythonCheckpointFile path: String, scope: String) {
dense1 = TimeDistributed(Dense<Float>(
contentsOfPythonCheckpointFile: path,
scope: scope + "/c_fc", activation: "gelu"))
dense2 = TimeDistributed(Dense<Float>(
contentsOfPythonCheckpointFile: path,
scope: scope + "/c_proj"))
dropout = Dropout(probability: 0.2)
}
}

extension EncoderLayer: InitializableFromPythonCheckpoint {
init(from path: String, withScope scope: String) {
selfAttention = MultiHeadAttention(from: path, withScope: scope + "/attn")
init(contentsOfPythonCheckpointFile path: String, scope: String) {
selfAttention = MultiHeadAttention(
contentsOfPythonCheckpointFile: path,
scope: scope + "/attn")
selfAttentionDropout = Dropout(probability: 0.2)
selfAttentionNorm = LayerNorm(from: path, withScope: scope + "/ln_1")
feedForward = FeedForward(from: path, withScope: scope + "/mlp")
selfAttentionNorm = LayerNorm(contentsOfPythonCheckpointFile: path, scope: scope + "/ln_1")
feedForward = FeedForward(contentsOfPythonCheckpointFile: path, scope: scope + "/mlp")
feedForwardDropout = Dropout(probability: 0.2)
feedForwardNorm = LayerNorm(from: path, withScope: scope + "/ln_2")
feedForwardNorm = LayerNorm(contentsOfPythonCheckpointFile: path, scope: scope + "/ln_2")
}
}

extension TransformerLM: InitializableFromPythonCheckpoint {
init(from path: String, withScope scope: String) {
init(contentsOfPythonCheckpointFile path: String, scope: String) {
embedding = Embedding(
weight: readTensor(from: path, name: scope + "/wte", scalarType: Float.self))
positionalEmbeddings = readTensor(from: path, name: scope + "/wpe", scalarType: Float.self)
weight: readTensor(fromPath: path, name: scope + "/wte", scalarType: Float.self))
positionalEmbeddings = readTensor(
fromPath: path,
name: scope + "/wpe",
scalarType: Float.self)
layers = (0..<config.layerCount).map { i in
EncoderLayer(from: path, withScope: scope + "/h\(i)")
EncoderLayer(contentsOfPythonCheckpointFile: path, scope: scope + "/h\(i)")
}
norm = LayerNorm(from: path, withScope: scope + "/ln_f")
norm = LayerNorm(contentsOfPythonCheckpointFile: path, scope: scope + "/ln_f")
}
}
4 changes: 2 additions & 2 deletions Transformer/README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Transformer

This is an implementation of [OpenAI's GPT-2 Transformer language model](github.com/openai/gpt-2) using [Swift for TensorFlow](github.com/tensorflow/swift).
This is an implementation of [OpenAI's GPT-2 Transformer language model](https://github.com/openai/gpt-2) using [Swift for TensorFlow](https://github.com/tensorflow/swift).

In order to run this code, first download a pre-trained checkpoint from OpenAI
using the included `download_model.sh` script. Then, compile using `swiftc`:

```sh
bash download_model.sh
swiftc -O Operators.swift Model.swift PythonCheckpointReader.swift main.swift
swiftc -O -ltensorflow Operators.swift Model.swift PythonCheckpointReader.swift main.swift
```

You can now sample from the model either unconditionally:
Expand Down
4 changes: 2 additions & 2 deletions Transformer/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ sys.path = sys.path + ["."]
let encoder = Python.import("encoder").get_encoder("117M")

let checkpoint = "models/117M/model.ckpt"
let model = TransformerLM(from: checkpoint, withScope: "model")
let model = TransformerLM(contentsOfPythonCheckpointFile: checkpoint, scope: "model")

let start_token = Int32(encoder.encoder["<|endoftext|>"])!
var tokens = Tensor(shape: [1, 1], scalars: [start_token])
Expand All @@ -28,7 +28,7 @@ let empty = Tensor<Float>(
zeros: [Int32(config.headCount), 0, Int32(config.embeddingSize / config.headCount)])
var states = (0..<config.layerCount).map { _ in AttentionContext(key: empty, value: empty) }

for t in 0..<100 {
for _ in 0..<100 {
let logits = model.applied(to: tokens, states: &states, in: Context(learningPhase: .inference))
let (batchSize, timeSteps, vocabSize) = (logits.shape[0], logits.shape[1], logits.shape[2])
let lastLogit = logits.slice(
Expand Down

0 comments on commit da96bdf

Please sign in to comment.