Skip to content

Commit

Permalink
Rebase and code review.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Aug 1, 2015
1 parent 72c5d8e commit 8717f35
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.PlatformDependent


abstract class UnsafeRowConcat {
def concat(row1: UnsafeRow, row2: UnsafeRow): UnsafeRow
abstract class UnsafeRowJoiner {
def join(row1: UnsafeRow, row2: UnsafeRow): UnsafeRow
}


Expand All @@ -38,13 +38,13 @@ abstract class UnsafeRowConcat {
* 4. Update the offset position (i.e. the upper 32 bits in the fixed length part) for all
* variable-length data.
*/
object GenerateRowConcat extends CodeGenerator[(StructType, StructType), UnsafeRowConcat] {
object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), UnsafeRowJoiner] {

def dump(word: Long): String = {
Seq.tabulate(64) { i => if ((word >> i) % 2 == 0) "0" else "1" }.reverse.mkString
}

override protected def create(in: (StructType, StructType)): UnsafeRowConcat = {
override protected def create(in: (StructType, StructType)): UnsafeRowJoiner = {
create(in._1, in._2)
}

Expand All @@ -55,7 +55,7 @@ object GenerateRowConcat extends CodeGenerator[(StructType, StructType), UnsafeR
in
}

def create(schema1: StructType, schema2: StructType): UnsafeRowConcat = {
def create(schema1: StructType, schema2: StructType): UnsafeRowJoiner = {
val ctx = newCodeGenContext()
val offset = PlatformDependent.BYTE_ARRAY_OFFSET

Expand Down Expand Up @@ -136,19 +136,19 @@ object GenerateRowConcat extends CodeGenerator[(StructType, StructType), UnsafeR
| buf, $cursor,
| ${schema1.size * 8});
""".stripMargin
cursor += schema1.size * 8

// --------------------- copy fixed length portion from row 2 ----------------------- //
cursor += schema1.size * 8
val copyFixedLengthRow2 = s"""
|// Copy fixed length data for row2
|PlatformDependent.copyMemory(
| obj2, offset2 + ${bitset2Words * 8},
| buf, $cursor,
| ${schema2.size * 8});
""".stripMargin
cursor += schema2.size * 8

// --------------------- copy variable length portion from row 1 ----------------------- //
cursor += schema2.size * 8
val copyVariableLengthRow1 = s"""
|// Copy variable length data for row1
|long numBytesBitsetAndFixedRow1 = ${(bitset1Words + schema1.size) * 8};
Expand Down Expand Up @@ -196,14 +196,14 @@ object GenerateRowConcat extends CodeGenerator[(StructType, StructType), UnsafeR
// ------------------------ Finally, put everything together --------------------------- //
val code = s"""
|public Object generate($exprType[] exprs) {
| return new SpecificRowConat();
| return new SpecificUnsafeRowJoiner();
|}
|
|class SpecificRowConat extends ${classOf[UnsafeRowConcat].getName} {
|class SpecificUnsafeRowJoiner extends ${classOf[UnsafeRowJoiner].getName} {
| private byte[] buf = new byte[64];
| private UnsafeRow out = new UnsafeRow();
|
| public UnsafeRow concat(UnsafeRow row1, UnsafeRow row2) {
| public UnsafeRow join(UnsafeRow row1, UnsafeRow row2) {
| // row1: ${schema1.size} fields, $bitset1Words words in bitset
| // row2: ${schema2.size}, $bitset2Words words in bitset
| // output: ${schema1.size + schema2.size} fields, $outputBitsetWords words in bitset
Expand Down Expand Up @@ -232,10 +232,10 @@ object GenerateRowConcat extends CodeGenerator[(StructType, StructType), UnsafeR
|}
""".stripMargin

logDebug(s"code for GenerateRowConcat($schema1, $schema2):\n${CodeFormatter.format(code)}")
logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}")
// println(CodeFormatter.format(code))

val c = compile(code)
c.generate(Array.empty).asInstanceOf[UnsafeRowConcat]
c.generate(Array.empty).asInstanceOf[UnsafeRowJoiner]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.types._
/**
* A test suite for the bitset portion of the row concatenation.
*/
class GenerateRowConcatBitsetSuite extends SparkFunSuite {
class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite {

test("bitset concat: boundary size 0, 0") {
testBitsets(0, 0)
Expand Down Expand Up @@ -121,8 +121,8 @@ class GenerateRowConcatBitsetSuite extends SparkFunSuite {
}
}

val concater = GenerateRowConcat.create(schema1, schema2)
val output = concater.concat(row1, row2)
val concater = GenerateUnsafeRowJoiner.create(schema1, schema2)
val output = concater.join(row1, row2)

def dumpDebug(): String = {
val set1 = Seq.tabulate(numFields1) { i => if (row1.isNullAt(i)) "1" else "0" }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.types._

/**
* Test suite for [[GenerateRowConcat]].
* Test suite for [[GenerateUnsafeRowJoiner]].
*
* There is also a separate [[GenerateRowConcatBitsetSuite]] that tests specifically concatenation
* for the bitset portion, since that is the hardest one to get right.
* There is also a separate [[GenerateUnsafeRowJoinerBitsetSuite]] that tests specifically
* concatenation for the bitset portion, since that is the hardest one to get right.
*/
class GenerateRowConcatSuite extends SparkFunSuite {
class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {

private val fixed = Seq(IntegerType)
private val variable = Seq(IntegerType, StringType)
Expand Down Expand Up @@ -89,10 +89,10 @@ class GenerateRowConcatSuite extends SparkFunSuite {
val row1 = converter1.apply(internalConverter1.apply(extRow1).asInstanceOf[InternalRow])
val row2 = converter2.apply(internalConverter2.apply(extRow2).asInstanceOf[InternalRow])

// Run the concater.
// Run the joiner.
val mergedSchema = StructType(schema1 ++ schema2)
val concater = GenerateRowConcat.create(schema1, schema2)
val output = concater.concat(row1, row2)
val concater = GenerateUnsafeRowJoiner.create(schema1, schema2)
val output = concater.join(row1, row2)

// Test everything equals ...
for (i <- mergedSchema.indices) {
Expand Down

0 comments on commit 8717f35

Please sign in to comment.