# 1. Initialization Code

## Beam

### Install java by opening the git repository in terminal. Then <font color='blue' face="Fixedsys, monospace" size="+2"><br><br>cd /home/jupyter/dataflow<br><br>sudo ./install-java.sh</font>

### Initialize helper functions to run Java inside cells.

In [None]:
# https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/get-started/try-apache-beam-java.ipynb#scrollTo=CgTXBdTsBn1F
# Run and print a shell command.
def run(cmd, progress = True, verbose = False):
  if progress:
      print('>> {}'.format(cmd))
    
  if verbose:
      !{cmd}  # This is magic to run 'cmd' in the shell.
      print('')
  else:
      ! {cmd} > /dev/null 2>&1

import os

# Download the gradle source.
gradle_version = 'gradle-5.0'
gradle_path = f"/opt/{gradle_version}"
if not os.path.exists(gradle_path):
  run(f"wget -q -nc -O gradle.zip https://services.gradle.org/distributions/{gradle_version}-bin.zip")
  run('unzip -q -d /opt gradle.zip')
  run('rm -f gradle.zip')

# We're choosing to use the absolute path instead of adding it to the $PATH environment variable.
def gradle(args):
  run(f"{gradle_path}/bin/gradle --console=plain {args}")

gradle('-v')

! mkdir -p src/main/java/samples/quickstart/
print('Done')
        

### Definition for <font color='blue' face="Fixedsys, monospace" size="+2">%%java</font> Python magic cell function.

In [None]:
from IPython.core.magic import register_line_magic, register_cell_magic, register_line_cell_magic
@register_cell_magic
def java(line, cell):
    """
    Written by Joseph Gagliardo Jr.
    joegagliardo@gmail.com
    2021-12-22
    """
    gradle_text = """
plugins {
  // id 'idea'     // Uncomment for IntelliJ IDE
  // id 'eclipse'  // Uncomment for Eclipse IDE

  // Apply java plugin and make it a runnable application.
  id 'java'
  id 'application'

  // 'shadow' allows us to embed all the dependencies into a fat jar.
  id 'com.github.johnrengelman.shadow' version '4.0.3'
}

// This is the path of the main class, stored within ./src/main/java/
mainClassName = 'samples.quickstart.{class_name}'

// Declare the sources from which to fetch dependencies.
repositories {
  mavenCentral()
}

// Java version compatibility.
sourceCompatibility = 1.8
targetCompatibility = 1.8

// Use the latest Apache Beam major version 2.
// You can also lock into a minor version like '2.9.+'.
ext.apacheBeamVersion = '2.+'

// Declare the dependencies of the project.
dependencies {
  shadow "org.apache.beam:beam-sdks-java-core:$apacheBeamVersion"

  runtime "org.apache.beam:beam-runners-direct-java:$apacheBeamVersion"
  runtime "org.apache.beam:beam-sdks-java-extensions-sql:$apacheBeamVersion"
  runtime "com.google.auto.value:auto-value-annotations:1.6"
  runtime "com.google.code.gson:gson:2.8.8"
  compile "org.apache.beam:beam-sdks-java-extensions-join-library:$apacheBeamVersion"
  runtime "org.slf4j:slf4j-api:1.+"
  runtime "org.slf4j:slf4j-jdk14:1.+"

  annotationProcessor "com.google.auto.value:auto-value:1.6"

  testCompile "junit:junit:4.+"
}

// Configure 'shadowJar' instead of 'jar' to set up the fat jar.
shadowJar {
  zip64 true
  baseName = '{class_name}' // Name of the fat jar file.
  classifier = null       // Set to null, otherwise 'shadow' appends a '-all' to the jar file name.
  manifest {
    attributes('Main-Class': mainClassName)  // Specify where the main class resides.
  }
}
"""   
    start = cell.find('class ')
    end = cell.find(' {')
    class_name = cell[start+6:end]
    progress = 'noprogress' not in line.lower()
    verbose = 'verbose' in line.lower()
    output = 'nooutput' not in line.lower()

        
    # if len(line) == 0:
    #     start = cell.find('class ')
    #     end = cell.find(' {')
    #     class_name = cell[start+6:end]
    # else:
    #     class_name = line
        
    
    run('rm src/main/java/samples/quickstart/*.java')
    run('rm build/libs/*.jar')
    run('rm -rf /tmp/outputs*', progress = progress, verbose = verbose)

    with open('build.gradle', 'w') as f:
        f.write(gradle_text.replace('{class_name}', class_name))

    with open(f'src/main/java/samples/quickstart/{class_name}.java', 'w') as f:
        f.write(cell)
        
    # Build the project.
    run(f"{gradle_path}/bin/gradle --console=plain build", progress = progress, verbose = verbose)
    run('ls -lh build/libs/', progress = progress, verbose = verbose)
    run(f"{gradle_path}/bin/gradle --console=plain runShadow", progress = progress, verbose = verbose)
    # run('head -n 20 /tmp/outputs*')
    if output:
        run('cat /tmp/outputs*', progress = False, verbose = True)

    print('Done')

print('Done')

In [None]:
# additional dependencies sometimes needed
  compile "org.apache.beam:beam-sdks-java-extensions-google-cloud-platform-core:2.22.0"
  compile "org.apache.beam:beam-runners-google-cloud-dataflow-java:2.22.0"
  compile "org.apache.beam:beam-sdks-java-io-google-cloud-platform:2.22.0"



## Spark

### Install a Spark docker using the following commands.

In [None]:
! docker pull bitnami/spark && \
docker network create spark_network && \
docker run -d --name spark --network=spark_network -e SPARK_MODE=master bitnami/spark
! ln -s /opt/conda/lib/libtinfo.so /opt/conda/lib/libtinfor.so.6
print('Done')

### Install pyspark.

In [None]:
import pip

def install(package):
    if hasattr(pip, 'main'):
        pip.main(['install', package])
    else:
        pip._internal.main(['install', package])

install('pyspark')
        
print('Done')

### Initialize the Spark context variables.

In [None]:
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession, SQLContext
from pyspark.sql.types import *

def initspark(appname = "Notebook", servername = "local[*]"):
    print ('initializing pyspark')
    conf = SparkConf().setAppName(appname).setMaster(servername)
    sc = SparkContext(conf=conf)
    spark = SparkSession.builder.appName(appname).enableHiveSupport().getOrCreate()
    sc.setLogLevel("ERROR")
    print ('pyspark initialized')
    return sc, spark, conf

sc, spark, conf = initspark()
print(sc, spark)
print('Done')

#

***

# 2. <font color='blue' face="Fixedsys, monospace" size="+2">Create</font> allows you to upload data into a <font color='green' size="+2">PCollection</font>.

### Get the path of which python we are running

In [None]:
import sys
print(sys.executable)

## <img src="python.png" width=40 height=40 /><font color='cadetblue' size="+2">Python</font>

### Non Beam example of applying a <font color='blue' face="Fixedsys, monospace" size="+2">	map</font> function to a collection. 

In [None]:
x = ['one', 'two', 'three', 'four']
for e in x:
    print(e.upper())
    
print(list(map(str.title, x)))

### Simple transformation, turn the local collection into a <font color='green' size="+2">PCollection</font> and apply a <font color='blue' face="Fixedsys, monospace" size="+2">Map</font> <font color='green' size="+2">PTransform</font> on it.

In [None]:
import apache_beam as beam

p = beam.Pipeline()
lines = p | beam.Create(['one', 'two', 'three', 'four'])
lines2 = lines | beam.Map(str.title)
lines2 | beam.Map(print)
p.run()



### Usually in python we use a <font color='blue' face="Fixedsys, monospace" size="+2">with</font> block instead.

In [None]:
with beam.Pipeline() as p:
    lines = (
        p | beam.Create(['one', 'two', 'three', 'four'])
          | beam.Map(str.title)
          | beam.Map(print)
    )
    #p.run() # implicit in Python when using with block

# lines is a PCollection object
print('lines = ', lines)


### Simple transformation using a user defined function.

In [None]:
import apache_beam as beam

def title(x):
    return x.title() + '*'

with beam.Pipeline() as p:
    lines = (
        p | beam.Create(['one', 'two', 'three', 'four'])
          | beam.Map(title)
    )
    lines | beam.Map(print)


### Simple transformation using a <font color='blue' face="Fixedsys, monospace" size="+2">lambda</font> instead of a built in function.

In [None]:
import apache_beam as beam

with beam.Pipeline() as p:
    lines = (
        p | beam.Create(['one', 'two', 'three', 'four'])
          | beam.Map(lambda x : x.title() + '*')
    )
    lines | beam.Map(print)


### The pipe <font color='blue' face="Fixedsys, monospace" size="+2">|</font> is actually just an operator overload to call the <font color='blue' face="Fixedsys, monospace" size="+2">apply</font> method of the pipeline. You would never do this in Python, but it helps to understand what is going on under the hood.

In [None]:
import apache_beam as beam

with beam.Pipeline() as p:
        lines = ((p | beam.Create(['one', 'two', 'three', 'four']))
             .apply(beam.Map(str.title)) 
        )
        lines.apply(beam.Map(print))


### The Spark equivalent would be to upload a local Python <font color='blue' face="Fixedsys, monospace" size="+2">list</font> into a Spark <font color='green' size="+2">RDD</font> and do a simple transformation.

In [None]:
rdd1 = ( sc.parallelize(['one', 'two', 'three', 'four'])
        
#           .map(str.title)
       )
rdd1.collect()


## <img src="java.png" width=40 height=40 /><font color='indigo' size="+2">Java</font>

### Simple transformation using a <font color='green' size="+2">lambda</font>.


In [None]:
%%java verbose
package samples.quickstart;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.io.TextIO;

import java.util.*;

public class Create1 {
    public static void main(String[] args) {

        String outputsPrefix = "/tmp/outputs";
        Pipeline p = Pipeline.create();
        
        PCollection<String> lines = p.apply(Create.of("one", "two", "three", "four"));
        lines = lines.apply(MapElements.into(TypeDescriptors.strings()).via((String line) -> line.toUpperCase()));
        lines.apply(TextIO.write().to(outputsPrefix));

        p.run().waitUntilFinish();
    }
}


### Simple transformation using <font color='blue' face="Fixedsys, monospace" size="+2">SimpleFunction</font> instead of <font color='green' size="+2">lambda</font>.


In [None]:
%%java verbose
package samples.quickstart;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.io.TextIO;
import java.util.*;

public class Create2 {
    public static void main(String[] args) {

        String outputsPrefix = "/tmp/outputs";
        Pipeline p = Pipeline.create();
        
        PCollection<String> lines = p.apply(Create.of("one", "two", "three", "four"));
        lines = lines.apply(MapElements.via(
            new SimpleFunction<String, String>() {
              @Override
              public String apply(String line) {
                String ret = line.toUpperCase();
                //System.out.println("** " + ret);
                return ret;
              }
            }));

        lines.apply("Write", TextIO.write().to(outputsPrefix));

        p.run().waitUntilFinish();
    }
}


### Java simple transformation using <font color='blue' face="Fixedsys, monospace" size="+2">SimpleFunction</font> to wrap a User Defined Function.


In [None]:
%%java
package samples.quickstart;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.io.TextIO;
import java.util.*;

public class Create3 {
    public static void main(String[] args) {

        String outputsPrefix = "/tmp/outputs";
        Pipeline p = Pipeline.create();
        
        PCollection<String> lines = p.apply(Create.of("one", "two", "three", "four"));
        lines = lines.apply(MapElements.via(
            new SimpleFunction<String, String>() {
              @Override
              public String apply(String line) {
                return upper(line);
              }
            }));

        lines.apply("Write", TextIO.write().to(outputsPrefix));

        p.run().waitUntilFinish();
    }
    
    public static String upper(String line) {
        return line.toUpperCase();
    }
}


#

***

# 3. <font color='blue' face="Fixedsys, monospace" size="+2">ReadFromText</font> allows you to read a text file into a <font color='green' size="+2">PCollection</font>.

## <img src="python.png" width=40 height=40 /><font color='cadetblue' size="+2">Python</font>

### It's a good idea to start naming the steps for debugging and monitoring later. Names must be unique in the pipeline.

In [None]:
! rm /tmp/outputs*

import apache_beam as beam
from apache_beam.io import ReadFromText, WriteToText

def parse_tuple(x):
    regionid, regionname = x.split(',')
    return (int(regionid), regionname.upper())
    
regionsfilename = 'datasets/northwind/CSV/regions/regions.csv'
with beam.Pipeline() as p:
    regions = (
        p | 'Read' >> ReadFromText(regionsfilename)
          # | 'Parse Tuple' >> beam.Map(parse_tuple)
          # | 'Parse' >> beam.Map(lambda x : x.split(','))
          # | 'Transform' >> beam.Map(lambda x : (int(x[0]), x[1].upper()))
    )
    #regions | 'Write' >> WriteToText('/tmp/outputs')
    regions | 'Print' >> beam.Map(print)

! cat /tmp/outputs*

### Read from CSV and use <font color='green' size="+2">ParDo</font>.

In [None]:
import apache_beam as beam
from apache_beam.io import ReadFromText, WriteToText

class RegionParseTuple(beam.DoFn):
    def process(self, element: str):
        regionid, regionname = element.split(',')
        yield (int(regionid), regionname) # Can also use yield instead of returning a list
#        return [(int(regionid), regionname)] # ParDo's need to return a list
#        yield (int(regionid), regionname.upper()) # Include a transformation instead of doing it as a separate step

class RegionParseDict(beam.DoFn):
    def process(self, element: str):
        regionid, regionname = element.split(',')
        yield {'regionid': int(regionid), 'regionname': regionname}

regionsfilename = 'datasets/northwind/CSV/regions/regions.csv'

with beam.Pipeline() as p:
    regions = (
        p | 'Read' >> ReadFromText(regionsfilename)
          | 'Parse' >> beam.ParDo(RegionParseTuple())
          | 'Filter' >> beam.Filter(lambda x : x[0] < 3)
          # | 'Parse' >> beam.ParDo(RegionParseDict())
          # | 'Filter' >> beam.Filter(lambda x : x['regionid'] < 3)
    )
    #region | 'Write' >> WriteToText('regions.out')
    regions | 'Print' >> beam.Map(print)



### Read from CSV and use <font color='green' size="+2">ParDo</font> with <font color='green' size="+2">Row</font> object.

In [None]:
import apache_beam as beam
from apache_beam.io import ReadFromText, WriteToText

class RegionParseRow(beam.DoFn):
    def process(self, element: str):
        regionid, regionname = element.split(',')
        yield beam.Row(regionid = int(regionid), regionname = regionname)


regionsfilename = 'datasets/northwind/CSV/regions/regions.csv'

with beam.Pipeline() as p:
    regions = (
        p | 'Read' >> ReadFromText(regionsfilename)
          | 'Parse' >> beam.ParDo(RegionParseRow())
          | 'Filter' >> beam.Filter(lambda x : x.regionid < 3)
    )
    regions | 'Print' >> beam.Map(print)



## <img src="java.png" width=40 height=40 /><font color='indigo' size="+2">Java</font>

### Read from CSV and use <font color='blue' face="Fixedsys, monospace" size="+2">Map</font> with <font color='green' >lambda</font>.

In [None]:
%%java
package samples.quickstart;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.io.TextIO;

public class ReadRegions1 {
    public static void main(String[] args) {
        Pipeline p = Pipeline.create();

        String regionsInputFileName = "datasets/northwind/CSV/regions/regions.csv";
        String outputsPrefix = "/tmp/outputs";

        PCollection<String> regions = p
            .apply("Read", TextIO.read().from(regionsInputFileName))
            .apply("Parse", MapElements.into(TypeDescriptors.strings()).via((String element) -> element.toUpperCase()));
        
        regions.apply(TextIO.write().to(outputsPrefix));
        p.run().waitUntilFinish();
    }
}


### <font color='green' size="+2">ParDo</font> using a defined class.

In [None]:
%%java
package samples.quickstart;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFn.ProcessContext;

public class ReadRegions3 {
    public static void main(String[] args) {
        Pipeline p = Pipeline.create();

        String regionsInputFileName = "datasets/northwind/CSV/regions/regions.csv";
        String outputsPrefix = "/tmp/outputs";


        PCollection<String> regions = p
            .apply("Read", TextIO.read().from(regionsInputFileName))
            .apply("Parse", ParDo.of(new AddStar()));
        
        regions.apply(TextIO.write().to(outputsPrefix));
        p.run().waitUntilFinish();
    }
    
    static class AddStar extends DoFn<String, String> {
        @ProcessElement
        public void process(@Element String line, OutputReceiver<String> out) {
            out.output(line + "*");
        }
    }
}



### <font color='green' size="+2">ParDo</font> Example using anonymous class inline.

In [None]:
%%java
package samples.quickstart;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFn.ProcessContext;

public class ReadRegions2 {
    public static void main(String[] args) {
        Pipeline p = Pipeline.create();

        String regionsInputFileName = "datasets/northwind/CSV/regions/regions.csv";
        String outputsPrefix = "/tmp/outputs";


        PCollection<String> regions = p
            .apply("Read", TextIO.read().from(regionsInputFileName))
            .apply("Parse", ParDo.of(new DoFn<String, String>() {
                @ProcessElement
                public void process(ProcessContext c) {
                    String element = c.element();
                    // String[] elements = element.split(",");
                    c.output(element + "*");
                }
            }));
        
        regions.apply(TextIO.write().to(outputsPrefix));
        p.run().waitUntilFinish();
    }
}



#

***

# 4. Parse into a model class.


## <img src="python.png" width=40 height=40 /><font color='cadetblue' size="+2">Python</font>

### Create a model based on <font color='blue' face="Fixedsys, monospace" size="+2">typing.NamedTuple</font> so you can use properties instead of keys for <font color='green' size="+2">dict</font> or position for <font color='green' size="+2">tuple</font> and use the <font color='blue' face="Fixedsys, monospace" size="+2">Filter</font> <font color='green' size="+2">PTransform</font> with <font color='blue' face="Fixedsys, monospace" size="+2">lambda</font>.

In [None]:
import apache_beam as beam
from apache_beam.io import ReadFromText, WriteToText
import typing

class Territory(typing.NamedTuple):
    territoryid: int
    territoryname: str
    regionid: int

    # def __str__(self):
    #     return f'territoryid = {self.territoryid}, regionid = {self.regionid}, territoryname = {self.territoryname}'

beam.coders.registry.register_coder(Territory, beam.coders.RowCoder)
        
class TerritoryParseClass(beam.DoFn):
    def process(self, element):
        territoryid, territoryname, regionid = element.split(',')
        yield Territory(int(territoryid), territoryname, int(regionid))

territoriesfilename = 'datasets/northwind/CSV/territories/territories.csv'
with beam.Pipeline() as p:
    regions = (
        p | 'Read' >> ReadFromText(territoriesfilename)
          | 'Parse' >> beam.ParDo(TerritoryParseClass())
          # | 'Filter 1' >> beam.Filter(lambda x : x.regionid % 2 == 0)
          # | 'Filter 2' >> beam.Filter(lambda x : x.territoryname.startswith('S'))
    )
    regions | 'Print' >> beam.Map(print)
#   regions | 'Write' >> WriteToText('regions.out')


### Use <font color='blue' face="Fixedsys, monospace" size="+2">Filter</font> with a UDF.

In [None]:
import apache_beam as beam
from apache_beam.io import ReadFromText, WriteToText
import typing

class Territory(typing.NamedTuple):
    territoryid: int
    territoryname: str
    regionid: int
beam.coders.registry.register_coder(Territory, beam.coders.RowCoder)
        
class TerritoryParseClass(beam.DoFn):
    def process(self, element):
        territoryid, territoryname, regionid = element.split(',')
        yield Territory(int(territoryid), territoryname, int(regionid))

def startsWithS(element):
    return element.territoryname.startswith('S')

territoriesfilename = 'datasets/northwind/CSV/territories/territories.csv'
with beam.Pipeline() as p:
    territories = (
        p | 'Read' >> ReadFromText(territoriesfilename)
          | 'Parse' >> beam.ParDo(TerritoryParseClass())
          | 'Filter' >> beam.Filter(startsWithS)
    )
    territories | 'Print' >> beam.Map(print)
#   territories | 'Write' >> WriteToText('regions.out')


### Use a <font color='green' size="+2">ParDo</font> class to accomplish filtering.

In [None]:
import apache_beam as beam
from apache_beam.io import ReadFromText, WriteToText
import typing

class Territory(typing.NamedTuple):
    territoryid: int
    territoryname: str
    regionid: int
beam.coders.registry.register_coder(Territory, beam.coders.RowCoder)
        
class TerritoryParseClass(beam.DoFn):
    def process(self, element):
        territoryid, territoryname, regionid = element.split(',')
        yield Territory(int(territoryid), territoryname, int(regionid))

class StartsWithSFilter(beam.DoFn):
    def process(self, element):
        if element.territoryname.startswith('S'):
            yield element
            
territoriesfilename = 'datasets/northwind/CSV/territories/territories.csv'
with beam.Pipeline() as p:
    regions = (
        p | 'Read' >> ReadFromText(territoriesfilename)
          | 'Parse' >> beam.ParDo(TerritoryParseClass())
          | 'Filter' >> beam.ParDo(StartsWithSFilter())
    )
    regions | 'Print' >> beam.Map(print)
#   regions | 'Write' >> WriteToText('regions.out')


### Put the parsing and filtering all into one <font color='green' size="+2">ParDo</font>.

In [None]:
import apache_beam as beam
from apache_beam.io import ReadFromText, WriteToText
import typing

class Territory(typing.NamedTuple):
    territoryid: int
    territoryname: str
    regionid: int
beam.coders.registry.register_coder(Territory, beam.coders.RowCoder)
        
class TerritoryParseClass(beam.DoFn):
    def process(self, element):
        territoryid, territoryname, regionid = element.split(',')
        if territoryname.startswith('S'):
            yield Territory(int(territoryid), territoryname, int(regionid))

territoriesfilename = 'datasets/northwind/CSV/territories/territories.csv'
with beam.Pipeline() as p:
    regions = (
        p | 'Read' >> ReadFromText(territoriesfilename)
          | 'Parse' >> beam.ParDo(TerritoryParseClass())
    )
    regions | 'Print' >> beam.Map(print)
#   regions | 'Write' >> WriteToText('regions.out')


## <img src="java.png" width=40 height=40 /><font color='indigo' size="+2">Java</font>

### Parse a CSV into a <font color='green' size="+2">Row</font> object and filter it using a <font color='green' size="+2">Pardo</font>.

In [None]:
%%java --verbose
package samples.quickstart;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Filter;
import org.apache.beam.sdk.transforms.DoFn.ProcessContext;
import org.apache.beam.sdk.coders.DefaultCoder;
import org.apache.beam.sdk.coders.AvroCoder;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.schemas.Schema;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ReadTerritories {


    public static void main(String[] args) {
        Pipeline p = Pipeline.create();

        
        String territoriesInputFileName = "datasets/northwind/CSV/territories/territories.csv";
        String outputsPrefix = "/tmp/outputs";

        PCollection<Row> territories = p
            .apply("Read", TextIO.read().from(territoriesInputFileName))
            .apply("Parse", ParDo.of(new ParseTerritories())).setRowSchema(TerritorySchema.territorySchema);
        ;                   
        
        territories.apply("Print", MapElements.into(TypeDescriptors.strings())
        .via(
            x -> {
              System.out.println(x);
              return "";
            }));
        p.run().waitUntilFinish();
    }
    
    static class TerritorySchema {
            static Schema territorySchema = Schema.of(
            Schema.Field.of("territoryid", Schema.FieldType.INT64),
            Schema.Field.of("territoryname", Schema.FieldType.STRING),
            Schema.Field.of("regionid", Schema.FieldType.INT64)
            );

    }
    
    static class ParseTerritories extends DoFn<String, Row> {
        private static final Logger LOG = LoggerFactory.getLogger(ParseTerritories.class);

        @ProcessElement
        public void process(ProcessContext c) {
            String[] columns = c.element().split(",");
            try {
                Long territoryID = Long.parseLong(columns[0].trim());
                String territoryName = columns[1].trim();
                Long regionID = Long.parseLong(columns[2].trim());
                Row row = Row.withSchema(TerritorySchema.territorySchema).addValue(territoryID).addValue(territoryName).addValue(regionID).build();
                c.output(row);
            } catch (ArrayIndexOutOfBoundsException | NumberFormatException e) {
                LOG.info("ParseTerritories: parse error on '" + c.element() + "': " + e.getMessage());
            }
        }
    }
    
}


### Parse a CSV into a class and filter it using a <font color='green' size="+2">Pardo</font>.

In [None]:
%%java
package samples.quickstart;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Filter;
import org.apache.beam.sdk.transforms.DoFn.ProcessContext;
import org.apache.beam.sdk.coders.DefaultCoder;
import org.apache.beam.sdk.coders.AvroCoder;
import org.apache.beam.sdk.transforms.SerializableFunction;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ReadTerritories {
    public static void main(String[] args) {
        Pipeline p = Pipeline.create();

        String territoriesInputFileName = "datasets/northwind/CSV/territories/territories.csv";
        String outputsPrefix = "/tmp/outputs";

        PCollection<Territory> territories = p
            .apply("Read", TextIO.read().from(territoriesInputFileName))
            .apply("Parse", ParDo.of(new ParseTerritories()))
            .apply("Filter", ParDo.of(new FilterTerritories()))
        ;                   
        
        territories.apply(TextIO.<Territory>writeCustomType().to(outputsPrefix).withFormatFunction(new SerializeTerritory()));
        p.run().waitUntilFinish();
    }
    
    @DefaultCoder(AvroCoder.class)
    static class Territory {
        Long territoryID;
        String territoryName;
        Long regionID;
        
        Territory() {}
        
        Territory(long territoryID, String territoryName, long regionID) {
            this.territoryID = territoryID;
            this.territoryName = territoryName;
            this.regionID = regionID;
        }
        
        @Override
        public String toString() {
            return String.format("(territoryID = %d, territoryName = %s, regionID = %d)", territoryID, territoryName, regionID);
        }

    }
    
    static class SerializeTerritory implements SerializableFunction<Territory, String> {
        @Override
        public String apply(Territory input) {
          return input.toString();
        }
    }

    static class ParseTerritories extends DoFn<String, Territory> {
        private static final Logger LOG = LoggerFactory.getLogger(ParseTerritories.class);

        @ProcessElement
        public void process(ProcessContext c) {
            String[] columns = c.element().split(",");
            try {
                Long territoryID = Long.parseLong(columns[0].trim());
                String territoryName = columns[1].trim();
                Long regionID = Long.parseLong(columns[2].trim());
                c.output(new Territory(territoryID, territoryName, regionID));
            } catch (ArrayIndexOutOfBoundsException | NumberFormatException e) {
                LOG.info("ParseTerritories: parse error on '" + c.element() + "': " + e.getMessage());
            }
        }
    }
    
    static class FilterTerritories extends DoFn<Territory, Territory> {
        private static final Logger LOG = LoggerFactory.getLogger(FilterTerritories.class);

        @ProcessElement
        public void process(@Element Territory t, OutputReceiver<Territory> o) {
            if (t.territoryID % 2 == 0 && t.territoryName.startsWith("S")) {
                o.output(t);
            }
        }
    }
}


### Parse a CSV into a class and filter it using and anonymous class to create the condition.

In [None]:
%%java
package samples.quickstart;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Filter;
import org.apache.beam.sdk.transforms.DoFn.ProcessContext;
import org.apache.beam.sdk.coders.DefaultCoder;
import org.apache.beam.sdk.coders.AvroCoder;
import org.apache.beam.sdk.transforms.SerializableFunction;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ReadTerritories {
    public static void main(String[] args) {
        Pipeline p = Pipeline.create();

        String territoriesInputFileName = "datasets/northwind/CSV/territories/territories.csv";
        String outputsPrefix = "/tmp/outputs";

        PCollection<Territory> territories = p
            .apply("Read", TextIO.read().from(territoriesInputFileName))
            .apply("Parse", ParDo.of(new ParseTerritories()))
            .apply("Filter", Filter.by(new SerializableFunction<Territory, Boolean>() {
                @Override
                public Boolean apply(Territory t) {
                    return t.territoryID % 2 == 0 && t.territoryName.startsWith("S");
                }
            }))
        ;                   
        
        territories.apply(TextIO.<Territory>writeCustomType().to(outputsPrefix).withFormatFunction(new SerializeTerritory()));
        p.run().waitUntilFinish();
    }
    
    @DefaultCoder(AvroCoder.class)
    static class Territory {
        Long territoryID;
        String territoryName;
        Long regionID;
        
        Territory() {}
        
        Territory(long territoryID, String territoryName, long regionID) {
            this.territoryID = territoryID;
            this.territoryName = territoryName;
            this.regionID = regionID;
        }
        
        @Override
        public String toString() {
            return String.format("(territoryID = %d, territoryName = %s, regionID = %d)", territoryID, territoryName, regionID);
        }

    }
    
    static class SerializeTerritory implements SerializableFunction<Territory, String> {
        @Override
        public String apply(Territory input) {
          return input.toString();
        }
    }

    static class ParseTerritories extends DoFn<String, Territory> {
        private static final Logger LOG = LoggerFactory.getLogger(ParseTerritories.class);

        @ProcessElement
        public void process(ProcessContext c) {
            String[] columns = c.element().split(",");
            try {
                Long territoryID = Long.parseLong(columns[0].trim());
                String territoryName = columns[1].trim();
                Long regionID = Long.parseLong(columns[2].trim());
                c.output(new Territory(territoryID, territoryName, regionID));
            } catch (ArrayIndexOutOfBoundsException | NumberFormatException e) {
                LOG.info("ParseTerritories: parse error on '" + c.element() + "': " + e.getMessage());
            }
        }
    }
    static class FilterTerritories extends DoFn<Territory, Territory> {
        private static final Logger LOG = LoggerFactory.getLogger(FilterTerritories.class);

        @ProcessElement
        public void process(@Element Territory t, OutputReceiver<Territory> o) {
            if (t.territoryID % 2 == 0 && t.territoryName.startsWith("S")) {
                o.output(t);
            }
        }
    }
}


### Parse a CSV into a class and filter it in one step.

In [None]:
%%java
package samples.quickstart;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Filter;
import org.apache.beam.sdk.transforms.DoFn.ProcessContext;
import org.apache.beam.sdk.coders.DefaultCoder;
import org.apache.beam.sdk.coders.AvroCoder;
import org.apache.beam.sdk.transforms.SerializableFunction;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ReadTerritories {
    public static void main(String[] args) {
        Pipeline p = Pipeline.create();

        String territoriesInputFileName = "datasets/northwind/CSV/territories/territories.csv";
        String outputsPrefix = "/tmp/outputs";

        PCollection<Territory> territories = p
            .apply("Read", TextIO.read().from(territoriesInputFileName))
            .apply("Parse", ParDo.of(new ParseTerritories()))
        ;                   
        
        territories.apply(TextIO.<Territory>writeCustomType().to(outputsPrefix).withFormatFunction(new SerializeTerritory()));
        p.run().waitUntilFinish();
    }
    
    @DefaultCoder(AvroCoder.class)
    static class Territory {
        Long territoryID;
        String territoryName;
        Long regionID;
        
        Territory() {}
        
        Territory(long territoryID, String territoryName, long regionID) {
            this.territoryID = territoryID;
            this.territoryName = territoryName;
            this.regionID = regionID;
        }
        
        @Override
        public String toString() {
            return String.format("(territoryID = %d, territoryName = %s, regionID = %d)", territoryID, territoryName, regionID);
        }

    }
    
    static class SerializeTerritory implements SerializableFunction<Territory, String> {
        @Override
        public String apply(Territory input) {
          return input.toString();
        }
    }

    static class ParseTerritories extends DoFn<String, Territory> {
        private static final Logger LOG = LoggerFactory.getLogger(ParseTerritories.class);

        @ProcessElement
        public void process(ProcessContext c) {
            String[] columns = c.element().split(",");
            try {
                Long territoryID = Long.parseLong(columns[0].trim());
                String territoryName = columns[1].trim();
                Long regionID = Long.parseLong(columns[2].trim());
                if (territoryName.startsWith("S")) {
                    c.output(new Territory(territoryID, territoryName, regionID));
                }
            } catch (ArrayIndexOutOfBoundsException | NumberFormatException e) {
                LOG.info("ParseTerritories: parse error on '" + c.element() + "': " + e.getMessage());
            }
        }
    }
    static class FilterTerritories extends DoFn<Territory, Territory> {
        private static final Logger LOG = LoggerFactory.getLogger(FilterTerritories.class);

        @ProcessElement
        public void process(@Element Territory t, OutputReceiver<Territory> o) {
            if (t.territoryID % 2 == 0 && t.territoryName.startsWith("S")) {
                o.output(t);
            }
        }
    }
}


### There are special methods like <font color='blue' face="Fixedsys, monospace" size="+2">whereFieldName</font> but they don't do anything differently than just using a regular <font color='green' size="+2">ParDo</font>. This code doesn't actually run, but shows what it would look like.

In [None]:
%%java verbose
package samples.quickstart;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.DoFn;
//import org.apache.beam.sdk.transforms.Filter;
import org.apache.beam.sdk.schemas.transforms.Filter;
import org.apache.beam.sdk.transforms.DoFn.ProcessContext;
import org.apache.beam.sdk.coders.DefaultCoder;
import org.apache.beam.sdk.coders.AvroCoder;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.Row;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ReadTerritories {
    public static void main(String[] args) {
        Pipeline p = Pipeline.create();

        String territoriesInputFileName = "datasets/northwind/CSV/territories/territories.csv";
        String outputsPrefix = "/tmp/outputs";

        PCollection<Territory> territories = p
            .apply("Read", TextIO.read().from(territoriesInputFileName))
            .apply("Parse", ParDo.of(new ParseTerritories()))
            .apply("Filter", Filter.<Territory>create().whereFieldName("regionID", (Long regionID) -> regionID == 1))
        ;                   
        
        territories.apply(TextIO.<Territory>writeCustomType().to(outputsPrefix).withFormatFunction(new SerializeTerritory()));
        p.run().waitUntilFinish();
    }
    
    @DefaultCoder(AvroCoder.class)
    static class Territory {
        Long territoryID;
        String territoryName;
        Long regionID;
        
        Territory() {}
        
        Territory(long territoryID, String territoryName, long regionID) {
            this.territoryID = territoryID;
            this.territoryName = territoryName;
            this.regionID = regionID;
        }
        
        @Override
        public String toString() {
            return String.format("(territoryID = %d, territoryName = %s, regionID = %d)", territoryID, territoryName, regionID);
        }

    }
    
    static class SerializeTerritory implements SerializableFunction<Territory, String> {
        @Override
        public String apply(Territory input) {
          return input.toString();
        }
    }

    static class ParseTerritories extends DoFn<String, Territory> {
        private static final Logger LOG = LoggerFactory.getLogger(ParseTerritories.class);

        @ProcessElement
        public void process(ProcessContext c) {
            String[] columns = c.element().split(",");
            try {
                Long territoryID = Long.parseLong(columns[0].trim());
                String territoryName = columns[1].trim();
                Long regionID = Long.parseLong(columns[2].trim());
                c.output(new Territory(territoryID, territoryName, regionID));
            } catch (ArrayIndexOutOfBoundsException | NumberFormatException e) {
                LOG.info("ParseTerritories: parse error on '" + c.element() + "': " + e.getMessage());
            }
        }
    }
}


#

***

# 5. Create multiple outputs from a single read.

## <img src="python.png" width=40 height=40 /><font color='cadetblue' size="+2">Python</font>

### Send the same data down multiple paths, such as to group it on two different keys with one read from the source. Also show how to read AVRO.

In [None]:
import apache_beam as beam
from apache_beam import pvalue
from apache_beam.io import ReadFromText, WriteToText
import typing

class Territory(typing.NamedTuple):
    territoryid: int
    territoryname: str
    regionid: int
beam.coders.registry.register_coder(Territory, beam.coders.RowCoder)
        
class TerritoryParseClass(beam.DoFn):
    def process(self, element):
        yield Territory(int(element['territoryid']), element['territorydescription'], int(element['regionid']))

territoriesfilename = 'datasets/northwind/AVRO/territories/territories.avro'
with beam.Pipeline() as p:
    territories = (p | 'Read' >> beam.io.ReadFromAvro(territoriesfilename)
                     | 'Parse' >> beam.ParDo(TerritoryParseClass())
                  )

    # Branch 1
    (territories 
         | 'Lowercase' >> beam.Map(lambda x : (x.territoryid, x.territoryname.lower(), x.regionid))
         | 'Write Lower' >> WriteToText('/tmp/territories_lower.out')
    )
    
    # Branch 2
    (territories 
         | 'Uppercase' >> beam.Map(lambda x : (x.territoryid, x.territoryname.upper(), x.regionid))
         | 'Write Upper' >> WriteToText('/tmp/territories_upper.out')
    )

! echo "Lower" && cat /tmp/territories_lower.out* && echo "Upper" && cat /tmp/territories_upper.out*
    

### Branching uses <font color='blue' face="Fixedsys, monospace" size="+2">TaggedOutput</font> in the <font color='green' size="+2">ParDo</font> to split data into two different paths with different data on each. Also show how to read Parquet.

In [None]:
import apache_beam as beam
from apache_beam import pvalue
from apache_beam.pvalue import TaggedOutput
from apache_beam.io import ReadFromText, WriteToText
import typing

class Territory(typing.NamedTuple):
    territoryid: int
    territoryname: str
    regionid: int
beam.coders.registry.register_coder(Territory, beam.coders.RowCoder)
        
class OddEvenTerritoryParseClass(beam.DoFn):
    def process(self, element):
        territoryid, territoryname, regionid = int(element['territoryid']), element['territoryname'], int(element['regionid'])
        if int(regionid) % 2 == 0:
            yield pvalue.TaggedOutput('Even', Territory(int(territoryid), territoryname, int(regionid)))
        else:
            yield TaggedOutput('Odd', Territory(int(territoryid), territoryname, int(regionid)))

territoriesfilename = 'datasets/northwind/PARQUET/territories/territories.parquet'

with beam.Pipeline() as p:
    territories = p | 'Read' >> beam.io.ReadFromParquet(territoriesfilename) 
    # territories would return a tuple of the two tagged outputs
    # unpack the two outputs to two separate variables to process differently
    # x = territories | 'Parse' >> beam.ParDo(OddEvenTerritoryParseClass()).with_outputs("Even", "Odd")
    # print(x, type(x))
    
    evens, odds = territories | 'Parse' >> beam.ParDo(OddEvenTerritoryParseClass()).with_outputs("Even", "Odd")
    
    evens | 'Write Even' >> WriteToText('/tmp/territories_even.out')
    
    odds | 'Write Odd' >> WriteToText('/tmp/territories_odd.out')

! echo "Evens" && cat /tmp/territories_even.out* && echo "Odds" && cat /tmp/territories_odd.out*

## <img src="java.png" width=40 height=40 /><font color='indigo' size="+2">Java</font>

### Send the same output down two different paths.

In [None]:
! rm /tmp/territories*

In [None]:
%%java nooutput
package samples.quickstart;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Filter;
import org.apache.beam.sdk.transforms.DoFn.ProcessContext;
import org.apache.beam.sdk.coders.DefaultCoder;
import org.apache.beam.sdk.coders.AvroCoder;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.TupleTagList;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ReadTerritories {

    public static void main(String[] args) {
        Pipeline p = Pipeline.create();

        String territoriesInputFileName = "datasets/northwind/CSV/territories/territories.csv";
        String outputsPrefix = "/tmp/outputs";

        PCollection<Territory> territories = p
            .apply("Read", TextIO.read().from(territoriesInputFileName))
            .apply("Parse Territory", ParDo.of(new ParseTerritories()))
        ;                   
        
            
        territories
            .apply("Upper", ParDo.of(new DoFn<Territory, Territory>() {
                @ProcessElement
                public void process(ProcessContext c) {
                    Territory t = c.element();
                    c.output(new Territory(t.territoryID, t.territoryName.toUpperCase(), t.regionID));
                }
            }))
             .apply(TextIO.<Territory>writeCustomType().to("/tmp/territories_upper").withFormatFunction(new SerializeTerritory()));

        territories
            .apply("Lower", ParDo.of(new DoFn<Territory, Territory>() {
                @ProcessElement
                public void process(ProcessContext c) {
                    Territory t = c.element();
                    c.output(new Territory(t.territoryID, t.territoryName.toLowerCase(), t.regionID));
                }
            }))
             .apply(TextIO.<Territory>writeCustomType().to("/tmp/territories_lower").withFormatFunction(new SerializeTerritory()));

        
        p.run().waitUntilFinish();
    }
    
    @DefaultCoder(AvroCoder.class)
    static class Territory {
        Long territoryID;
        String territoryName;
        Long regionID;
        
        Territory() {}
        
        Territory(long territoryID, String territoryName, long regionID) {
            this.territoryID = territoryID;
            this.territoryName = territoryName;
            this.regionID = regionID;
        }
        
        @Override
        public String toString() {
            return String.format("(territoryID = %d, territoryName = %s, regionID = %d)", territoryID, territoryName, regionID);
        }

    }
    
    static class SerializeTerritory implements SerializableFunction<Territory, String> {
        @Override
        public String apply(Territory input) {
          return input.toString();
        }
    }

    static class ParseTerritories extends DoFn<String, Territory> {
        private static final Logger LOG = LoggerFactory.getLogger(ParseTerritories.class);

        @ProcessElement
        public void process(ProcessContext c) {
            String[] columns = c.element().split(",");
            try {
                Long territoryID = Long.parseLong(columns[0].trim());
                String territoryName = columns[1].trim();
                Long regionID = Long.parseLong(columns[2].trim());
                c.output(new Territory(territoryID, territoryName, regionID));
            } catch (ArrayIndexOutOfBoundsException | NumberFormatException e) {
                LOG.info("ParseTerritoriesOddEvenSplit: parse error on '" + c.element() + "': " + e.getMessage());
            }
        }
    }

}


In [None]:
! echo "Upper" && cat /tmp/territories_upper* && echo "Lower" && cat /tmp/territories_lower*


### Branching uses <font color='blue' face="Fixedsys, monospace" size="+2">TupleTag</font> to split the output into two separate path.

In [None]:
! rm /tmp/territories*

In [None]:
%%java nooutput
package samples.quickstart;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Filter;
import org.apache.beam.sdk.transforms.DoFn.ProcessContext;
import org.apache.beam.sdk.coders.DefaultCoder;
import org.apache.beam.sdk.coders.AvroCoder;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.TupleTagList;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ReadTerritories {

    final static TupleTag<Territory> evenTag = new TupleTag<Territory>() {};
    final static TupleTag<Territory> oddTag = new TupleTag<Territory>() {};

    public static void main(String[] args) {
        Pipeline p = Pipeline.create();

        String territoriesInputFileName = "territories.csv";
        String outputsPrefix = "/tmp/outputs";

        PCollectionTuple territories = p
            .apply("Read", TextIO.read().from(territoriesInputFileName))
            .apply("OddEvenSplit", ParDo.of(new ParseTerritoriesOddEvenSplit()).withOutputTags(evenTag, TupleTagList.of(oddTag)))
        ;                   
        
        PCollection<Territory> evenTerritories = territories.get(evenTag);
        evenTerritories.apply(TextIO.<Territory>writeCustomType().to(outputsPrefix + "_even").withFormatFunction(new SerializeTerritory()));

        PCollection<Territory> oddTerritories = territories.get(oddTag);
        oddTerritories.apply(TextIO.<Territory>writeCustomType().to(outputsPrefix + "_odd").withFormatFunction(new SerializeTerritory()));
        p.run().waitUntilFinish();
    }
    
    @DefaultCoder(AvroCoder.class)
    static class Territory {
        Long territoryID;
        String territoryName;
        Long regionID;
        
        Territory() {}
        
        Territory(long territoryID, String territoryName, long regionID) {
            this.territoryID = territoryID;
            this.territoryName = territoryName;
            this.regionID = regionID;
        }
        
        @Override
        public String toString() {
            return String.format("(territoryID = %d, territoryName = %s, regionID = %d)", territoryID, territoryName, regionID);
        }

    }
    
    static class SerializeTerritory implements SerializableFunction<Territory, String> {
        @Override
        public String apply(Territory input) {
          return input.toString();
        }
    }

    static class ParseTerritoriesOddEvenSplit extends DoFn<String, Territory> {
        private static final Logger LOG = LoggerFactory.getLogger(ParseTerritoriesOddEvenSplit.class);

        @ProcessElement
        public void process(ProcessContext c) {


            String[] columns = c.element().split(",");
            try {
                Long territoryID = Long.parseLong(columns[0].trim());
                String territoryName = columns[1].trim();
                Long regionID = Long.parseLong(columns[2].trim());
                if (regionID % 2 == 0) {
                    c.output(evenTag, new Territory(territoryID, territoryName, regionID));
                } else {
                    c.output(oddTag, new Territory(territoryID, territoryName, regionID));
                }
            } catch (ArrayIndexOutOfBoundsException | NumberFormatException e) {
                LOG.info("ParseTerritoriesOddEvenSplit: parse error on '" + c.element() + "': " + e.getMessage());
            }
        }
    }

}


In [None]:
! echo "Odd" && cat /tmp/outputs_odd* && echo "Even" && cat /tmp/outputs_even*


#

***

# 6. Group and Join

## <img src="python.png" width=40 height=40 /><font color='cadetblue' size="+2">Python</font>

### <font color='blue' face="Fixedsys, monospace" size="+2">WithKeys</font> will reshape your data first, then <font color='blue' face="Fixedsys, monospace" size="+2">GroupByKey</font> will cluster the elements as a list under each unique key. The data must be in a <font color='green' size="+2">KV</font> tuple pair first. Also not how to read a JSON file.

In [None]:
import apache_beam as beam
from apache_beam.io import ReadFromText
import json
import typing

class Territory(typing.NamedTuple):
    territoryid: int
    territoryname: str
    regionid: int
beam.coders.registry.register_coder(Territory, beam.coders.RowCoder)
        
class TerritoryParseClass(beam.DoFn):
    def process(self, element):
        territoryid, territoryname, regionid = int(element['territoryid']), element['territoryname'], int(element['regionid'])
        yield Territory(int(territoryid), territoryname, int(regionid))

territoriesfilename = 'datasets/northwind/JSON/territories/territories.json'

with beam.Pipeline() as p:
    territories = (
                  p | 'Read Territories' >> ReadFromText(territoriesfilename)
                    # | 'From json' >> beam.Map(json.loads)
                    # | 'Parse Territories' >> beam.ParDo(TerritoryParseClass())
                    # | 'Territories With Keys' >> beam.util.WithKeys(lambda x : x.regionid)
                    # | 'Group Territories' >> beam.GroupByKey() 
                  )
    territories | 'Print Territories' >> beam.Map(print)


### <font color='blue' face="Fixedsys, monospace" size="+2">Combine</font> is equivalent to a SQL <font color='blue' face="Fixedsys, monospace" size="+2">GROUP BY</font> query
### <font color='blue' face="Fixedsys, monospace" size="+2">SELECT key, sum(value) as total FROM source GROUP BY key</font>

In [None]:
import apache_beam as beam

with beam.Pipeline() as p:
    data = (
        p | 'Create' >> beam.Create([('a', 10), ('a', 20), ('b', 30), ('b', 40), ('c', 50), ('a', 60)])
          #| 'Combine' >> beam.CombinePerKey(sum)
        | beam.GroupByKey()
          | 'Print' >> beam.Map(print)
    )


### Custom <font color='blue' face="Fixedsys, monospace" size="+2">CombineFn</font>

In [None]:
import apache_beam as beam

class CustomCombine(beam.CombineFn):
    """
    This custom combiner will calculate the max of the first element, sum of the second element and a count of total elements
    The final step will also return the average of the second element.
    """
    def create_accumulator(self):
        # method defining how to create an empty accumulator
        return dict()

    def add_input(self, accumulator, input):
        # get the input and split it up for easier manipulation
        k, v = input
        # get the values from the accumulator for the input key or initialize it if it's the first time we see this key
        m, s, c = accumulator.get(k, (0, 0, 0))

        # take the max for the first element of the tuple and sum the second element and count for the third
        accumulator[k] = (v[0] if v[0] > m else m, s + v[1], c + 1)
        return accumulator

    def merge_accumulators(self, accumulators):
        # merge the accumulators from the various workers once they have finished accumulating locally
        merged = dict()
        for accum in accumulators:
          for k, v in accum.items():
            m, s, c = merged.get(k, (0, 0, 0))
            merged[k] = (v[0] if v[0] > m else m, s + v[1], c + v[2])
        return merged

    def extract_output(self, accumulator):
        # called when all the works accumulators have been merge to render the final output
        # return the max, the sum, the count and the average for the key
        return {k : (v[0], v[1], v[2], v[1]/v[2]) for k, v in accumulator.items()}

with beam.Pipeline() as p:
    data = (
        p | 'Create' >> beam.Create([('a', (1, 10)), ('a', (2, 20)), 
                                     ('b', (3, 30)), ('c', (5, 50)), 
                                     ('b', (4, 40)), ('a', (6, 60))])
          | 'Combine' >> beam.CombineGlobally(CustomCombine())
          | 'Print' >> beam.Map(print)
    )


### Create a nested repeating output
### First create a dataset. Here is Python code for the equivalent bq command of <font color='blue' face="Fixedsys, monospace" size="+2">bq mk dataflow</font>.
<br>or the SQL commnd<br>
<font color='blue' face="Fixedsys, monospace" size="+2">create table dataflow.region_territory
(regionid int
, regionname string
, territories array<struct<territoryid int, territoryname string>>
)
</font>
<br>Make sure to use the proper project ID.

In [None]:
# same as doing bq mk dataflow

from google.cloud import bigquery

# Construct a BigQuery client object.
client = bigquery.Client()

PROJECT_ID = 'qwiklabs-gcp-04-b1b7cded1c4b'
dataset_id = f"{PROJECT_ID}.dataflow" #.format(client.project)

try:
    client.get_dataset(dataset_id)  # Make an API request.
    print("Dataset {} already exists".format(dataset_id))
except:
    print("Dataset {} is not found".format(dataset_id))
    dataset = bigquery.Dataset(dataset_id)
    dataset.location = "US"
    dataset = client.create_dataset(dataset, timeout=30)  # Make an API request.
    print("Created dataset {}.{}".format(client.project, dataset.dataset_id))
    
    

schema = [
    bigquery.SchemaField("regionid", "INTEGER", mode="REQUIRED"),
    bigquery.SchemaField("regionname", "STRING", mode="REQUIRED"),
    bigquery.SchemaField("territories", "RECORD", mode="REPEATED", 
            fields=[
                    bigquery.SchemaField("territoryid", "STRING", mode="REQUIRED"),
                    bigquery.SchemaField("territoryname", "STRING", mode="REQUIRED")
                   ]
                        )
]

# create table dataflow.region_territory
# (regionid NUMERIC
# ,regionname STRING
# ,territories ARRAY<STRUCT<territoryid NUMERIC, territoryname STRING>>)

table_id = f"{PROJECT_ID}.dataflow.region_territory"

try:
    table = client.get_table(table_id)  # Make an API request.
    print("Table {} already exists.".format(table_id))
    print(table)
except:
    table = bigquery.Table(table_id, schema=schema)
    table = client.create_table(table)  # Make an API request.
    print("Table {} created.".format(table_id))



### The code here is tricky: 
### First parse the two tables into <font color='green' size="+2">tuples</font>, <font color='black' face="Fixedsys, monospace" size="+1">(regionid, regionname)</font> & <font color='black' face="Fixedsys, monospace" size="+1">(regionid, {'territoryid':territoryid, 'territoryname':territoryname})</font>
### <font color='blue' face="Fixedsys, monospace" size="+2">CoGroupByKey</font> yields a shape like <font color='black' face="Fixedsys, monospace" size="+1">(regionid, {'regions':['regionname'], 'territories':[{}])</font> so we need to reshape it to <font color='green' size="+2">dicts</font> to write it to BQ


In [None]:
import apache_beam as beam
from apache_beam.io import ReadFromText

class RegionParseTuple(beam.DoFn):
    def process(self, element):
        regionid, regionname = element.split(',')
        yield (int(regionid), regionname) # Can also use yield instead of returning a list

class TerritoryParseTuple(beam.DoFn):
    # split territory into KV pair of (regionid, (territoryid, territoryname))
    def process(self, element):
        territoryid, territoryname, regionid = element.split(',')
        yield(int(regionid), {'territoryid': int(territoryid), 'territoryname':territoryname})

class SortTerritories(beam.DoFn):
    #{'regionid': 1, 'regionname': 'Eastern', 'territories': [{'territoryid': 1730, 'territoryname': 'Bedford'}, 
    def process(self, element):
        territories = element['territories']
        element['territories'] = sorted(territories, key = lambda x : x['territoryid'])
        yield element

regionsfilename = 'datasets/northwind/CSV/regions/regions.csv'
territoriesfilename = 'datasets/northwind/CSV/territories/territories.csv'

PROJECT_ID = 'qwiklabs-gcp-03-66d8bf72e256'

with beam.Pipeline() as p:
    regions = (
              p | 'Read Regions' >> ReadFromText(regionsfilename)
                | 'Parse Regions' >> beam.ParDo(RegionParseTuple())
              )
    regions | 'Print Regions' >> beam.Map(print)
        
    territories = (
                  p | 'Read Territories' >> ReadFromText('territories.csv')
                    | 'Parse Territories' >> beam.ParDo(TerritoryParseTuple())
                  )
    territories | 'Print Territories' >> beam.Map(print)

    nested = ( 
        {'regions':regions, 'territories':territories} 
              | 'Nest territories into regions' >> beam.CoGroupByKey()
              | 'Reshape to dict' >> beam.Map(lambda x : {'regionid': x[0], 'regionname': x[1]['regions'][0], 
                                                         'territories': x[1]['territories']})
              | 'Sort by territoryid' >> beam.ParDo(SortTerritories())
    )
    nested | 'Print' >> beam.Map(print)
    nested | 'Write nested region_territory to BQ' >> beam.io.WriteToBigQuery('region_territory', dataset = 'dataflow'
                                                                             , project = PROJECT_ID
                                                                             , method = 'STREAMING_INSERTS'
                                                                             )
             
#help(beam.io.WriteToBigQuery)    
#(1, {'regions': ['Eastern'], 'territories': [{'territoryid': 1730, 'territoryname': 'Bedford'}, {'territoryid': 1581, 'territoryname': 'Westboro'}, {'territoryid': 1833, 'territoryname': 'Georgetow'}, {'territoryid': 2116, 'territoryname': 'Bosto
#{'regionid': 1, 'regionname':'Eastern', 'territories' : [{'territoryid':1, 'territoryname':'name1'}, {}, {}]}

### Query the table to show it was populated.

In [None]:
from google.cloud import bigquery

# Construct a BigQuery client object.
client = bigquery.Client()

table_id = f"{PROJECT_ID}.dataflow.region_territory"

query_job = client.query(f"""SELECT * FROM {table_id}""")

results = query_job.result()  # Waits for job to complete.
display(list(results))

### Helper functions to make a generic transform to nest children

In [None]:
import apache_beam as beam

class NestJoin(beam.PTransform):
    '''
    This PTransform will take a dictionary to the left of the | which will be the collection of the two
    PCollections you want to join together. Both must be a dictionary. You will then pass in the name of each
    PCollection and the key to join them on.
    It will automatically reshape the two dicts into tuples of (key, dict) where it removes the key from each dict
    It then CoGroups them and reshapes the tuple into a dict ready for insertion to a BQ table
    '''
    def __init__(self, parent_pipeline_name, parent_key, child_pipeline_name, child_key, sort = lambda x : x):
        self.parent_pipeline_name = parent_pipeline_name
        self.parent_key = parent_key
        self.child_pipeline_name = child_pipeline_name
        self.child_key = child_key
        self.sort = sort

    def expand(self, pcols):
        def reshapeToKV(item, key):
            # pipeline object should be a dictionary
            item1 = item.copy()
            del item1[key]
            return (item[key], item1)

        def reshapeCoGroupToDict(item):
            ret = {self.parent_key : item[0]}
            ret.update(item[1][self.parent_pipeline_name][0])
            ret[self.child_pipeline_name] = item[1][self.child_pipeline_name]
            return ret

        return (
                {
                self.parent_pipeline_name : pcols[self.parent_pipeline_name] | f'Convert {self.parent_pipeline_name} to KV' 
                    >> beam.Map(reshapeToKV, self.parent_key)
                ,self.child_pipeline_name : pcols[self.child_pipeline_name] | f'Convert {self.child_pipeline_name} to KV'
                    >> beam.Map(reshapeToKV, self.child_key)
                } | f'CoGroupByKey {self.child_pipeline_name} into {self.parent_pipeline_name}'
                    >> beam.CoGroupByKey()
                  | f'Reshape to dictionary'
                    >> beam.Map(reshapeCoGroupToDict)
                  | f'Sort the nested data' >> beam.Map(self.sort)
            
        )

class RegionParseDict(beam.DoFn):
    def process(self, element):
        regionid, regionname = element.split(',')
        yield {'regionid':int(regionid), 'regionname':regionname.title()}
      
class TerritoryParseDict(beam.DoFn):
    def process(self, element):
        territoryid, territoryname, regionid = element.split(',')
        yield {'territoryid':int(territoryid), 'territoryname' : territoryname, 'regionid':int(regionid)}
    
regionsfilename = 'datasets/northwind/CSV/regions/regions.csv'
territoriesfilename = 'datasets/northwind/CSV/territories/territories.csv'

def sort_territories(element):
    territories = element['territories']
    element['territories'] = list(sorted(territories, key = lambda x : x['territoryid']))
    return element

with beam.Pipeline() as p:
    regions = (
              p | 'Read Regions' >> ReadFromText(regionsfilename)
                | 'Parse Regions' >> beam.ParDo(RegionParseDict())
                #| 'Print Regions' >> beam.Map(print)
              )
        
    territories = (
                  p | 'Read Territories' >> ReadFromText('territories.csv')
                    | 'Parse Territories' >> beam.ParDo(TerritoryParseDict())
                    #| 'Print Territories' >> beam.Map(print)
                  )

    nestjoin = {'regions':regions, 'territories':territories} | NestJoin('regions', 'regionid', 'territories', 'regionid', sort = sort_territories)
    nestjoin | 'Print Nest Join' >> beam.Map(print)
#     nestjoin | 'Write nested region_territory to BQ' >> beam.io.WriteToBigQuery('region_territory', dataset = 'dataflow'
#                                                                              , project = PROJECT_ID
#                                                                              , method = 'STREAMING_INSERTS'
#                                                                              )



### Simulate an Outer Join with <font color='blue' face="Fixedsys, monospace" size="+2">CoGroup</font>.

In [None]:
import apache_beam as beam

class LeftJoin(beam.PTransform):
    '''
    This PTransform will take a dictionary to the left of the | which will be the collection of the two
    PCollections you want to join together. Both must be a dictionary. You will then pass in the name of each
    PCollection and the key to join them on.
    It will automatically reshape the two dicts into tuples of (key, dict) where it removes the key from each dict
    It then CoGroups them and reshapes the tuple into a dict ready for insertion to a BQ table
    '''
    def __init__(self, parent_pipeline_name, parent_key, child_pipeline_name, child_key):
        self.parent_pipeline_name = parent_pipeline_name
        self.parent_key = parent_key
        self.child_pipeline_name = child_pipeline_name
        self.child_key = child_key

    def expand(self, pcols):
        def reshapeToKV(item, key):
            # pipeline object should be a dictionary
            item1 = item.copy()
            del item1[key]
            return (item[key], item1)

        def reshapeCoGroupToFlatDict(item):
            parent = {self.parent_key : item[0]}
            parent.update(item[1][self.parent_pipeline_name][0])
            ret = []
            for row1 in item[1][self.child_pipeline_name]:
                row = parent.copy()
                row.update(row1)
                ret.append(row)
            return ret

        return (
                {
                self.parent_pipeline_name : pcols[self.parent_pipeline_name] | f'Convert {self.parent_pipeline_name} to KV' 
                    >> beam.Map(reshapeToKV, self.parent_key)
                ,self.child_pipeline_name : pcols[self.child_pipeline_name] | f'Convert {self.child_pipeline_name} to KV'
                    >> beam.Map(reshapeToKV, self.child_key)
                } | f'CoGroupByKey {self.child_pipeline_name} into {self.parent_pipeline_name}'
                    >> beam.CoGroupByKey()
                  | f'Reshape to dictionary'
                    >> beam.FlatMap(reshapeCoGroupToFlatDict)
        )

class RegionParseDict(beam.DoFn):
    def process(self, element):
        regionid, regionname = element.split(',')
        yield {'regionid':int(regionid), 'regionname':regionname.title()}

class TerritoryParseDict(beam.DoFn):
    def process(self, element):
        territoryid, territoryname, regionid = element.split(',')
        yield {'territoryid':int(territoryid), 'territoryname' : territoryname, 'regionid':int(regionid)}
    
regionsfilename = 'datasets/northwind/CSV/regions/regions.csv'
territoriesfilename = 'datasets/northwind/CSV/territories/territories.csv'

with beam.Pipeline() as p:
    regions = (
              p | 'Read Regions' >> ReadFromText(regionsfilename)
                | 'Parse Regions' >> beam.ParDo(RegionParseDict())
              )
#    regions  | 'Print Regions' >> beam.Map(print)
        
    territories = (
                  p | 'Read Territories' >> ReadFromText('territories.csv')
                    | 'Parse Territories' >> beam.ParDo(TerritoryParseDict())
                  )
#    territories | 'Print Territories' >> beam.Map(print)

    nestjoin = {'regions':regions, 'territories':territories} | LeftJoin('regions', 'regionid', 'territories', 'regionid')
    nestjoin | 'Print Nest Join' >> beam.Map(print)



## <img src="java.png" width=40 height=40 /><font color='indigo' size="+2">Java</font>

### For Java you don't need to group into KV shape first, instead you could use the <font color='blue' face="Fixedsys, monospace" size="+2">Group</font> and <font color='blue' face="Fixedsys, monospace" size="+2">Select</font> methods.

In [None]:
%%java 
package samples.quickstart;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Filter;
import org.apache.beam.sdk.transforms.DoFn.ProcessContext;
import org.apache.beam.sdk.coders.DefaultCoder;
import org.apache.beam.sdk.coders.AvroCoder;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.schemas.transforms.Group;
import org.apache.beam.sdk.schemas.transforms.Select;
import org.apache.beam.sdk.transforms.*;
import org.apache.beam.sdk.schemas.JavaFieldSchema;
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.schemas.transforms.Convert;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GroupTerritories {
    public static void main(String[] args) {
        Pipeline p = Pipeline.create();

        String territoriesInputFileName = "datasets/northwind/CSV/territories/territories.csv";
        String outputsPrefix = "/tmp/outputs";

        PCollection<Result> territories = p
            .apply("Read", TextIO.read().from(territoriesInputFileName))
            .apply("Parse", ParDo.of(new ParseTerritories()))
            .apply("GroupBy regionID", Group.<Territory>byFieldNames("regionID")
                                            .aggregateField("territoryID", Count.combineFn(), "cnt"))
            .apply("Select", Select.fieldNames("key.regionID", "value.cnt"))
            .apply(Convert.fromRows(Result.class))
                   
        ;                   
        
        territories.apply(TextIO.<Result>writeCustomType().to(outputsPrefix).withFormatFunction(new SerializeResult()));
        p.run().waitUntilFinish();
    }
    
    @DefaultCoder(AvroCoder.class)
    @DefaultSchema(JavaFieldSchema.class)
    static class Territory {
        Long territoryID;
        String territoryName;
        Long regionID;
        
        Territory() {}
        
        Territory(long territoryID, String territoryName, long regionID) {
            this.territoryID = territoryID;
            this.territoryName = territoryName;
            this.regionID = regionID;
        }
        
        @Override
        public String toString() {
            return String.format("(territoryID = %d, territoryName = %s, regionID = %d)", territoryID, territoryName, regionID);
        }

    }
    
    static class SerializeTerritory implements SerializableFunction<Territory, String> {
        @Override
        public String apply(Territory input) {
          return input.toString();
        }
    }

    static class ParseTerritories extends DoFn<String, Territory> {
        private static final Logger LOG = LoggerFactory.getLogger(ParseTerritories.class);

        @ProcessElement
        public void process(ProcessContext c) {
            String[] columns = c.element().split(",");
            try {
                Long territoryID = Long.parseLong(columns[0].trim());
                String territoryName = columns[1].trim();
                Long regionID = Long.parseLong(columns[2].trim());
                c.output(new Territory(territoryID, territoryName, regionID));
            } catch (ArrayIndexOutOfBoundsException | NumberFormatException e) {
                LOG.info("ParseTerritories: parse error on '" + c.element() + "': " + e.getMessage());
            }
        }
    }

    @DefaultCoder(AvroCoder.class)
    @DefaultSchema(JavaFieldSchema.class)
    static class Result {
        Long regionID;
        Long cnt;
        
        Result() {}
        
        Result(Long regionID, Long cnt) {
            this.regionID = regionID;
            this.cnt = cnt;
        }
        
        @Override
        public String toString() {
            return String.format("(regionid = %d, cnt = %d)", regionID, cnt);
        }
    }
    static class SerializeResult implements SerializableFunction<Result, String> {
        @Override
        public String apply(Result input) {
          return input.toString();
        }
    }
}


### But for the <font color='blue' face="Fixedsys, monospace" size="+2">JOIN</font> extension function you still need to shape the data into a KV pair and then unnest it when done.

In [None]:
%%java
package samples.quickstart;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Filter;
import org.apache.beam.sdk.transforms.DoFn.ProcessContext;
import org.apache.beam.sdk.coders.DefaultCoder;
import org.apache.beam.sdk.coders.AvroCoder;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.schemas.transforms.Group;
import org.apache.beam.sdk.schemas.transforms.Select;
import org.apache.beam.sdk.transforms.*;
import org.apache.beam.sdk.schemas.JavaFieldSchema;
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.schemas.transforms.Convert;
import org.apache.beam.sdk.extensions.joinlibrary.Join;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.transforms.WithKeys;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JoinTerritories {
    public static void main(String[] args) {
        Pipeline p = Pipeline.create();
        
        String regionsInputFileName = "datasets/northwind/CSV/regions/regions.csv";
        String territoriesInputFileName = "datasets/northwind/CSV/territories/territories.csv";
        String outputsPrefix = "/tmp/outputs";

        PCollection<KV<Long, Region>> regions = p
            .apply("Read Regions", TextIO.read().from(regionsInputFileName))
            .apply("Parse Regions", ParDo.of(new CSVToRegion()))
            .apply("Regions KV", WithKeys.of(new SerializableFunction<Region, Long>() {
                @Override
                public Long apply(Region r) {
                  return r.regionid;
                }}));
          ;
        
        PCollection<KV<Long, Territory>> territories = p
            .apply("Read Territories", TextIO.read().from(territoriesInputFileName))
            .apply("Parse Territories", ParDo.of(new ParseTerritories()))
            .apply("Territories KV", WithKeys.of(new SerializableFunction<Territory, Long>() {
                @Override
                public Long apply(Territory t) {
                  return t.regionid;
                }}));
          ;
        
        PCollection<KV<Long, KV<Region, Territory>>> result =
            Join.innerJoin(regions, territories);  
        
        PCollection<Result> result2 = result
        
            .apply("Unnest KV", ParDo.of(new DoFn<KV<Long, KV<Region, Territory>>, Result>() {
                @ProcessElement
                public void process(ProcessContext c) {
                    KV<Long, KV<Region, Territory>> e = c.element();
                    Long regionid = e.getKey();
                    KV<Region, Territory> v = e.getValue();
                    Region r = v.getKey();
                    Territory t = v.getValue(); 
                    String regionname = r.regionname;
                    Long territoryid = t.territoryid;
                    String territoryname = t.territoryname;
                    //c.output(new Result(1L, "regionname", 2L, "territoryname"));
                    c.output(new Result(regionid, regionname, territoryid, territoryname));
                }
                
            })
            );

        
        result2.apply(TextIO.<Result>writeCustomType().to(outputsPrefix).withFormatFunction(new SerializeResult()));
        p.run().waitUntilFinish();
    }
    
    @DefaultCoder(AvroCoder.class)
    @DefaultSchema(JavaFieldSchema.class)
    static class Region {
        Long regionid;
        String regionname;
        
        Region() {}
        
        Region(Long regionid, String regionname) {
            this.regionid = regionid;
            this.regionname = regionname;
        }
        
        @Override
        public String toString() {
            return String.format("(regionid = %d, regionname = %s)", regionid, regionname);
        }
    }
    
    static class SerializeRegion implements SerializableFunction<Region, String> {
        @Override
        public String apply(Region input) {
          return input.toString();
        }
    }

    static class CSVToRegion extends DoFn<String, Region> {
        private static final Logger LOG = LoggerFactory.getLogger(CSVToRegion.class);

        @ProcessElement
        public void process(ProcessContext c) {
            String[] columns = c.element().split(",");
            try {
                Long regionid = Long.parseLong(columns[0].trim());
                String regionname = columns[1].trim();
                c.output(new Region(regionid, regionname));
            } catch (ArrayIndexOutOfBoundsException | NumberFormatException e) {
                LOG.info("CSVToRegion: parse error on '" + c.element() + "': " + e.getMessage());
            }
        }
    }

    @DefaultCoder(AvroCoder.class)
    @DefaultSchema(JavaFieldSchema.class)
    static class Territory {
        Long territoryid;
        String territoryname;
        Long regionid;
        
        Territory() {}
        
        Territory(long territoryid, String territoryname, long regionid) {
            this.territoryid = territoryid;
            this.territoryname = territoryname;
            this.regionid = regionid;
        }
        
        @Override
        public String toString() {
            return String.format("(territoryid = %d, territoryname = %s, regionid = %d)", territoryid, territoryname, regionid);
        }

    }
    
    static class SerializeTerritory implements SerializableFunction<Territory, String> {
        @Override
        public String apply(Territory input) {
          return input.toString();
        }
    }

    static class ParseTerritories extends DoFn<String, Territory> {
        private static final Logger LOG = LoggerFactory.getLogger(ParseTerritories.class);

        @ProcessElement
        public void process(ProcessContext c) {
            String[] columns = c.element().split(",");
            try {
                Long territoryid = Long.parseLong(columns[0].trim());
                String territoryname = columns[1].trim();
                Long regionid = Long.parseLong(columns[2].trim());
                c.output(new Territory(territoryid, territoryname, regionid));
            } catch (ArrayIndexOutOfBoundsException | NumberFormatException e) {
                LOG.info("ParseTerritories: parse error on '" + c.element() + "': " + e.getMessage());
            }
        }
    }

    @DefaultCoder(AvroCoder.class)
    @DefaultSchema(JavaFieldSchema.class)
    static class Result {
        Long regionid;
        String regionname;
        Long territoryid;
        String territoryname;
        
        Result() {}
        
        Result(Long regionid, String regionname, Long territoryid, String territoryname) {
            this.regionid = regionid;
            this.regionname = regionname;
            this.territoryid = territoryid;
            this.territoryname = territoryname;
        }
        
        @Override
        public String toString() {
            return String.format("(regionid = %d, regionname = %s, territoryid = %d, territoryname = %s)", regionid, regionname, territoryid, territoryname);
        }
    }
    static class SerializeResult implements SerializableFunction<Result, String> {
        @Override
        public String apply(Result input) {
          return input.toString();
        }
    }
}
                   
                   
// KV{3, KV{(regionid = 3, regionname = Northern), (territoryid = 3801, territoryname = Portsmouth, regionid = 3)}}

#

***

# 7. BeamSQL

## <img src="python.png" width=40 height=40 /><font color='cadetblue' size="+2">Python</font>

### SQL Transform uses <font color='green' size="+2">PCOLLECTION</font> as the name of a single source passed into it.

In [None]:
import apache_beam as beam
from apache_beam.io import ReadFromText
from apache_beam import coders
from apache_beam.transforms.sql import SqlTransform

import typing
import json

class Territory(typing.NamedTuple):
    territoryid: int
    territoryname: str
    regionid: int
    
    def __str__(self):
        return f'territoryid = {self.territoryid} territoryname = {self.territoryname} regionid = {self.regionid}'
coders.registry.register_coder(Territory, coders.RowCoder)
        
@beam.typehints.with_output_types(Territory)
class TerritoryParseClass(beam.DoFn):
    def process(self, element):
        territoryid, territoryname, regionid = element.split(',')
        yield Territory(int(territoryid), territoryname.title(), int(regionid))
    
class RegionCount(typing.NamedTuple):
    regionid: int
    count: int
    
    def __str__(self):
        return f'regionid = {self.regionid} count = {self.count}'
coders.registry.register_coder(RegionCount, coders.RowCoder)
        
        
territoriesfilename = 'territories.csv'
with beam.Pipeline() as p:
    territories = (
                  p | 'Read Territories' >> ReadFromText('territories.csv')
#                    | 'Parse Territories' >> beam.ParDo(TerritoryParseClass()).with_output_types(Territory) # if we didn't have with_output_types decorator
                    | 'Parse Territories' >> beam.ParDo(TerritoryParseClass())
                    | 'SQL Territories' >> SqlTransform("""SELECT regionid, count(*) as `count` FROM PCOLLECTION GROUP BY regionid""")
#                    | 'Map Territories for Print' >> beam.Map(lambda x : f'regionid = {x.regionid}  count = {x.count}')
                    | 'Convert to RegionCount Class' >> beam.Map(lambda x : RegionCount(x.regionid, x.count))
                    )
    territories | 'Print SQL' >> beam.Map(print)
    


### For a SQL query that has more than one source, bundle the sources together in a dictionary, they keys become the table names inside the SQL string.

In [None]:
import apache_beam as beam
from apache_beam.io import ReadFromText
from apache_beam import coders
from apache_beam.transforms.sql import SqlTransform

import typing
import json

with beam.Pipeline() as p:
    parent = (
            p | 'Create Parent' >> beam.Create([(1, 'Vowel'), (2, 'Consonant'), (4, 'Unknown')])
              | 'Map Parent' >> beam.Map(lambda x : beam.Row(parent_id = x[0], parent_name = x[1]))
    )

    child = (
            p | 'Create Child' >> beam.Create([('Alpha', 1), ('Beta', 2), ('Gamma', 2), ('Delta', 2), ('Epsilon', 1), ('Pi', 3)])
              | 'Map Child' >> beam.Map(lambda x : beam.Row(child_name = x[0], parent_id = x[1]))
    )
    
    result = ( {'parent': parent, 'child' : child} 
         | SqlTransform("""
             SELECT p.parent_id, p.parent_name, c.child_name 
             FROM parent as p 
             INNER JOIN child as c ON p.parent_id = c.parent_id
             """)
        | 'Format Output' >> beam.Map(lambda x : f'{x.parent_id}, {x.parent_name}, {x.child_name}')
        )
    result | 'Print Join' >> beam.Map(print)


### Real example

In [None]:
import apache_beam as beam
from apache_beam import pvalue
from apache_beam.io import ReadFromText, WriteToText
import typing

class Region(typing.NamedTuple):
    regionid: int
    regionname: str
beam.coders.registry.register_coder(Region, beam.coders.RowCoder)
        
class RegionParseClass(beam.DoFn):
    def process(self, element):
        yield Region(int(element['regionid']), element['regiondescription'])

class Territory(typing.NamedTuple):
    territoryid: int
    territoryname: str
    regionid: int
beam.coders.registry.register_coder(Territory, beam.coders.RowCoder)
        
class TerritoryParseClass(beam.DoFn):
    def process(self, element):
        yield Territory(int(element['territoryid']), element['territorydescription'], int(element['regionid']))

class Result(typing.NamedTuple):
    regionid: int
    regionname: str
    cnt: int
beam.coders.registry.register_coder(Result, beam.coders.RowCoder)
               
regionsfilename = 'datasets/northwind/AVRO/regions/*.avro'
territoriesfilename = 'datasets/northwind/AVRO/territories/territories.avro'
with beam.Pipeline() as p:
    regions = (p | 'Read Regions' >> beam.io.ReadFromAvro(regionsfilename)
                     | 'Parse Regions' >> beam.ParDo(RegionParseClass())
                  )

    territories = (p | 'Read Territories' >> beam.io.ReadFromAvro(territoriesfilename)
                     | 'Parse Territories' >> beam.ParDo(TerritoryParseClass())
                  )

    result = ( {'regions': regions, 'territories' : territories} 
         | SqlTransform("""
SELECT r.regionid AS regionid, r.regionname AS regionname, SUM(1) AS cnt 
FROM regions AS r 
JOIN territories AS t on t.regionid = r.regionid 
GROUP BY r.regionid, r.regionname
""")
        | 'Convert to Result Class' >> beam.Map(lambda x : Result(x.regionid, x.regionname, x.cnt))
#        | 'Format Output' >> beam.Map(lambda x : f'{x.regionid}, {x.regionname}, {x.cnt}')
             )
    result | 'Print Join' >> beam.Map(print)


## <img src="java.png" width=40 height=40 /><font color='indigo' size="+2">Java</font>

### Beam SQL using Pojo with a simple query

In [None]:
%%java verbose
package samples.quickstart;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Filter;
import org.apache.beam.sdk.transforms.DoFn.ProcessContext;
import org.apache.beam.sdk.coders.DefaultCoder;
import org.apache.beam.sdk.coders.AvroCoder;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.extensions.sql.SqlTransform;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.schemas.AutoValueSchema;
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
import org.apache.beam.sdk.schemas.JavaFieldSchema;
import org.apache.beam.sdk.schemas.annotations.SchemaCreate;
import com.google.auto.value.AutoValue;
import org.apache.beam.sdk.schemas.transforms.Convert;
import com.google.gson.Gson;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


public class ReadTerritories {
    public static void main(String[] args) {
        System.getProperties().put("org.apache.commons.logging.simplelog.defaultlog","fatal");

        Pipeline p = Pipeline.create();
        p.getSchemaRegistry().registerPOJO(Territory.class);
 
        String territoriesInputFileName = "datasets/northwind/JSON/territories/territories.json";
        String outputsPrefix = "/tmp/outputs";

        PCollection<Territory> result = p
            .apply("Read", TextIO.read().from(territoriesInputFileName))
            .apply("Parse", ParDo.of(new JsonToTerritory()))
            .apply(SqlTransform.query("SELECT territoryid, upper(territoryname) as territoryname, regionid FROM PCOLLECTION WHERE regionid = 1"))
            .apply(Convert.fromRows(Territory.class))
        ;

        /*
        result.apply(MapElements.via(
            new SimpleFunction<Territory, Territory>() {
              @Override
              public Territory apply(Territory t) {
                System.out.println("** " + t);
                return t;
              }
            })); 
        */
        
        result.apply(TextIO.<Territory>writeCustomType().to(outputsPrefix).withFormatFunction(new SerializeTerritory()));
        
        p.run().waitUntilFinish();
    }
    

    @DefaultSchema(JavaFieldSchema.class)
    static class Territory {
        Long territoryid;
        String territoryname;
        Long regionid;
        
        Territory() {}
        
        Territory(long territoryid, String territoryname, long regionid) {
            this.territoryid = territoryid;
            this.territoryname = territoryname;
            this.regionid = regionid;
        }
        
        @Override
        public String toString() {
            return String.format("(territoryID = %d, territoryName = %s, regionID = %d)", territoryid, territoryname, regionid);
        }

    }
    
    static class SerializeTerritory implements SerializableFunction<Territory, String> {
        @Override
        public String apply(Territory input) {
          return input.toString();
        }
    }

    static class JsonToTerritory extends DoFn<String, Territory> {
        @ProcessElement
        public void process(@Element String json, OutputReceiver<Territory> r) throws Exception {
            Gson gson = new Gson();
            Territory t = gson.fromJson(json, Territory.class);
            r.output(t);
        }
    }
}


### Beam SQL using multiple sources

In [None]:
%%java
package samples.quickstart;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Filter;
import org.apache.beam.sdk.transforms.DoFn.ProcessContext;
import org.apache.beam.sdk.coders.DefaultCoder;
import org.apache.beam.sdk.coders.AvroCoder;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.extensions.sql.SqlTransform;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.schemas.AutoValueSchema;
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
import org.apache.beam.sdk.schemas.JavaFieldSchema;
import org.apache.beam.sdk.schemas.annotations.SchemaCreate;
import com.google.auto.value.AutoValue;
import org.apache.beam.sdk.schemas.transforms.Convert;
import com.google.gson.Gson;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.TupleTag;
import java.io.Serializable;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


public class ReadTerritories {
    public static void main(String[] args) {
        System.getProperties().put("org.apache.commons.logging.simplelog.defaultlog","fatal");

        Pipeline p = Pipeline.create();
        p.getSchemaRegistry().registerPOJO(Region.class);
        p.getSchemaRegistry().registerPOJO(Territory.class);
        p.getSchemaRegistry().registerPOJO(Result.class);
 
        String regionsInputFileName = "datasets/northwind/CSV/regions/regions.csv";
        String territoriesInputFileName = "datasets/northwind/JSON/territories/territories.json";
        String outputsPrefix = "/tmp/outputs";

        PCollection<Region> regions = p
            .apply("Read Regions", TextIO.read().from(regionsInputFileName))
            .apply("Parse Regions", ParDo.of(new CSVToRegion()));

        PCollection<Territory> territories = p
            .apply("Read Territories", TextIO.read().from(territoriesInputFileName))
            .apply("Parse Territories", ParDo.of(new JsonToTerritory()));
        
         PCollectionTuple joinSources = PCollectionTuple
                                        .of(new TupleTag<>("regions"), regions)
                                        .and(new TupleTag<>("territories"), territories);                                          
                                                    


        PCollection<Result> result = joinSources
            .apply(SqlTransform.query("SELECT r.regionid AS regionid, r.regionname AS regionname, SUM(1) AS cnt FROM regions AS r JOIN territories AS t on t.regionid = r.regionid group by r.regionid, r.regionname"))
            .apply(Convert.fromRows(Result.class))
        ;

        result.apply(TextIO.<Result>writeCustomType().to(outputsPrefix).withFormatFunction(new SerializeResult()));
        
        p.run().waitUntilFinish();
    }
    
    @DefaultCoder(AvroCoder.class)
    @DefaultSchema(JavaFieldSchema.class)
    static class Region {
        Long regionid;
        String regionname;
        
        Region() {}
        
        Region(Long regionid, String regionname) {
            this.regionid = regionid;
            this.regionname = regionname;
        }
        
        @Override
        public String toString() {
            return String.format("(regionid = %d, regionname = %s)", regionid, regionname);
        }
    }
    
    static class SerializeRegion implements SerializableFunction<Region, String> {
        @Override
        public String apply(Region input) {
          return input.toString();
        }
    }

    static class CSVToRegion extends DoFn<String, Region> {
        private static final Logger LOG = LoggerFactory.getLogger(CSVToRegion.class);

        @ProcessElement
        public void process(ProcessContext c) {
            String[] columns = c.element().split(",");
            try {
                Long regionid = Long.parseLong(columns[0].trim());
                String regionname = columns[1].trim();
                c.output(new Region(regionid, regionname));
            } catch (ArrayIndexOutOfBoundsException | NumberFormatException e) {
                LOG.info("CSVToRegion: parse error on '" + c.element() + "': " + e.getMessage());
            }
        }
    }
    
    @DefaultCoder(AvroCoder.class)
    @DefaultSchema(JavaFieldSchema.class)
    static class Territory {
        Long territoryid;
        String territoryname;
        Long regionid;
        
        Territory() {}
        
        Territory(Long territoryid, String territoryname, Long regionid) {
            this.territoryid = territoryid;
            this.territoryname = territoryname;
            this.regionid = regionid;
        }
        
        @Override
        public String toString() {
            return String.format("(territoryid = %d, territoryname = %s, regionID = %d)", territoryid, territoryname, regionid);
        }
        /*
        @Override
        public boolean equals (Object o) {
            if (o == this)
                return true;
            return false;
         }
        */
    }
    
    static class SerializeTerritory implements SerializableFunction<Territory, String> {
        @Override
        public String apply(Territory input) {
          return input.toString();
        }
    }

    static class JsonToTerritory extends DoFn<String, Territory> {
        @ProcessElement
        public void process(@Element String json, OutputReceiver<Territory> r) throws Exception {
            Gson gson = new Gson();
            Territory t = gson.fromJson(json, Territory.class);
            r.output(t);
        }
    }
     
    @DefaultCoder(AvroCoder.class)
    @DefaultSchema(JavaFieldSchema.class)
    static class Result {
        Long regionid;
        String regionname;
        int cnt;
        
        Result() {}
        
        Result(Long regionid, String regionname, int cnt) {
            this.regionid = regionid;
            this.regionname = regionname;
            this.cnt = cnt;
        }
        
        @Override
        public String toString() {
            return String.format("(regionid = %d, regionname = %s, cnt = %d)", regionid, regionname, cnt);
        }
        /*
        @Override
        public boolean equals (Object o) {
            if (o == this)
                return true;
            return false;
         }
        */
    }
    
    static class SerializeResult implements SerializableFunction<Result, String> {
        @Override
        public String apply(Result input) {
          return input.toString();
        }
    }
}


### Example from Beam documentation

In [None]:
%%java verbose
package samples.quickstart;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.extensions.sql.SqlTransform;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;

/**
 * This is a quick example, which uses Beam SQL DSL to create a data pipeline.
 *
 * <p>Run the example from the Beam source root with
 *
 * <pre>
 *   ./gradlew :sdks:java:extensions:sql:runBasicExample
 * </pre>
 *
 * <p>The above command executes the example locally using direct runner. Running the pipeline in
 * other runners require additional setup and are out of scope of the SQL examples. Please consult
 * Beam documentation on how to run pipelines.
 */
class BeamSqlExample {

  public static void main(String[] args) {
    PipelineOptions options = PipelineOptionsFactory.fromArgs(args).create();
    Pipeline p = Pipeline.create(options);

    // define the input row format
    Schema type =
        Schema.builder().addInt32Field("c1").addStringField("c2").addDoubleField("c3").build();

    Row row1 = Row.withSchema(type).addValues(1, "row", 1.0).build();
    Row row2 = Row.withSchema(type).addValues(2, "row", 2.0).build();
    Row row3 = Row.withSchema(type).addValues(3, "row", 3.0).build();

    // create a source PCollection with Create.of();
    PCollection<Row> inputTable =
        PBegin.in(p).apply(Create.of(row1, row2, row3).withRowSchema(type));

    // Case 1. run a simple SQL query over input PCollection with BeamSql.simpleQuery;
    PCollection<Row> outputStream =
        inputTable.apply(SqlTransform.query("select c1, c2, c3 from PCOLLECTION where c1 > 1"));

    // print the output record of case 1;
    outputStream
        .apply(
            "log_result",
            MapElements.via(
                new SimpleFunction<Row, Row>() {
                  @Override
                  public Row apply(Row input) {
                    // expect output:
                    //  PCOLLECTION: [3, row, 3.0]
                    //  PCOLLECTION: [2, row, 2.0]
                    System.out.println("PCOLLECTION: " + input.getValues());
                    return input;
                  }
                }))
        .setRowSchema(type);

    // Case 2. run the query with SqlTransform.query over result PCollection of case 1.
    PCollection<Row> outputStream2 =
        PCollectionTuple.of(new TupleTag<>("CASE1_RESULT"), outputStream)
            .apply(SqlTransform.query("select c2, sum(c3) from CASE1_RESULT group by c2"));

    // print the output record of case 2;
    outputStream2
        .apply(
            "log_result",
            MapElements.via(
                new SimpleFunction<Row, Row>() {
                  @Override
                  public Row apply(Row input) {
                    // expect output:
                    //  CASE1_RESULT: [row, 5.0]
                    System.out.println("CASE1_RESULT: " + input.getValues());
                    return input;
                  }
                }))
        .setRowSchema(
            Schema.builder().addStringField("stringField").addDoubleField("doubleField").build());

    p.run().waitUntilFinish();
  }
}



### Beam SQL using Pojo into a Result Pojo

In [None]:
%%java 
package samples.quickstart;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Filter;
import org.apache.beam.sdk.transforms.DoFn.ProcessContext;
import org.apache.beam.sdk.coders.DefaultCoder;
import org.apache.beam.sdk.coders.AvroCoder;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.extensions.sql.SqlTransform;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.schemas.AutoValueSchema;
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
import org.apache.beam.sdk.schemas.JavaFieldSchema;
import org.apache.beam.sdk.schemas.annotations.SchemaCreate;
import com.google.auto.value.AutoValue;
import org.apache.beam.sdk.schemas.transforms.Convert;
import com.google.gson.Gson;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


public class ReadTerritories {
    public static void main(String[] args) {
        System.getProperties().put("org.apache.commons.logging.simplelog.defaultlog","fatal");

        Pipeline p = Pipeline.create();
        p.getSchemaRegistry().registerPOJO(Result.class);
 
        String territoriesInputFileName = "datasets/northwind/JSON/territories/territories.json";
        String outputsPrefix = "/tmp/outputs";

        // Define the schema to hold the results.
        Schema resultSchema = Schema.of(
            Schema.Field.of("regionid", Schema.FieldType.INT64), 
            Schema.Field.of("cnt", Schema.FieldType.INT64));

        PCollection<Result> result = p
            .apply("Read", TextIO.read().from(territoriesInputFileName))
            .apply("Parse", ParDo.of(new JsonToTerritory()))
            .apply(SqlTransform.query("SELECT regionid, COUNT(*) as cnt FROM PCOLLECTION GROUP BY regionid"))
            .apply(Convert.fromRows(Result.class))
        ;
        
        result.apply(TextIO.<Result>writeCustomType().to(outputsPrefix).withFormatFunction(new SerializeResult()));
        
        p.run().waitUntilFinish();
    }
    

    @DefaultSchema(JavaFieldSchema.class)
    static class Territory {
        Long territoryid;
        String territoryname;
        Long regionid;
        
        Territory() {}
        
        Territory(long territoryid, String territoryname, long regionid) {
            this.territoryid = territoryid;
            this.territoryname = territoryname;
            this.regionid = regionid;
        }
        
        @Override
        public String toString() {
            return String.format("(territoryID = %d, territoryName = %s, regionID = %d)", territoryid, territoryname, regionid);
        }

    }
    
    static class SerializeTerritory implements SerializableFunction<Territory, String> {
        @Override
        public String apply(Territory input) {
          return input.toString();
        }
    }

    static class JsonToTerritory extends DoFn<String, Territory> {
        @ProcessElement
        public void process(@Element String json, OutputReceiver<Territory> r) throws Exception {
            Gson gson = new Gson();
            Territory t = gson.fromJson(json, Territory.class);
            r.output(t);
        }
    }
     
    @DefaultSchema(JavaFieldSchema.class)
    static class Result {
        Long regionid;
        Long cnt;
        
        Result() {}
        
        Result(Long regionid, Long cnt) {
            this.regionid = regionid;
            this.cnt = cnt;
        }
        
        @Override
        public String toString() {
            return String.format("(regionid = %d, cnt = %d)", regionid, cnt);
        }
        @Override
        public boolean equals (Object o) {
            if (o == this)
                return true;
            return false;
         }
    }
    
    static class SerializeResult implements SerializableFunction<Result, String> {
        @Override
        public String apply(Result input) {
          return input.toString();
        }
    }
}


### BeamSQL Java working wrong way with schemas

In [None]:
%%java verbose nooutput
package samples.quickstart;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Filter;
import org.apache.beam.sdk.transforms.DoFn.ProcessContext;
import org.apache.beam.sdk.coders.DefaultCoder;
import org.apache.beam.sdk.coders.AvroCoder;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.extensions.sql.SqlTransform;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.schemas.AutoValueSchema;
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
import org.apache.beam.sdk.schemas.annotations.SchemaCreate;
import com.google.auto.value.AutoValue;
import org.apache.beam.sdk.schemas.transforms.Convert;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ReadTerritories {
    public static void main(String[] args) {
        Pipeline p = Pipeline.create();

        String territoriesInputFileName = "datasets/northwind/CSV/territories/territories.csv";
        String outputsPrefix = "/tmp/outputs";

        PCollection<Territory> territories = p
            .apply("Read", TextIO.read().from(territoriesInputFileName))
            .apply("Parse", ParDo.of(new ParseTerritories()))
        ;                   
        
        // Define the schema for the records.
        Schema territorySchema = Schema
          .builder()
          .addInt64Field("territoryID")
          .addStringField("territoryName")
          .addInt64Field("regionID")
          .build();
        // Define the schema to hold the results.
        
        Schema resultSchema = Schema.of(
            Schema.Field.of("regionID", Schema.FieldType.INT64), 
            Schema.Field.of("cnt", Schema.FieldType.INT64));
        
        // Convert them to Rows with the same schema as defined above via a DoFn.
        PCollection<Row> territories2 = territories
          .apply(
          ParDo.of(new DoFn<Territory, Row>() {
            @ProcessElement
            public void process(ProcessContext c) {
              // Get the current POJO instance
              Territory t = c.element();

              // Create a Row with the appSchema schema
              // and values from the current POJO
              Row territoryRow =
                    Row
                      .withSchema(territorySchema)
                      .addValues(
                        t.territoryID,
                        t.territoryName,
                        t.regionID)
                      .build();

              // Output the Row representing the current POJO
              c.output(territoryRow);
            }
          })).setRowSchema(territorySchema);
        
          PCollection<Row> territories3 = territories2.apply(Convert.toRows()).apply(
             SqlTransform.query("SELECT regionID, COUNT(*) as cnt from PCOLLECTION GROUP BY regionID")).setRowSchema(resultSchema);
        
          territories3.apply(
              "Print", MapElements.via(new SimpleFunction<Row, Row>() {
                  @Override
                  public Row apply(Row input) {
                      System.out.println("SQL Result: " + input.getValues());
                      return input;
                  }
              }
          )).setRowSchema(resultSchema);
//        territories3.apply(TextIO.<Row>writeCustomType().to(outputsPrefix).withFormatFunction(new SerializeTerritory()));
        p.run().waitUntilFinish();
    }

/*    
    @schemultSchema(AutoValueSchema.class)
    @AutoValue
    public static abstract class Territory {
      public abstract Long getTerritoryID();
      public abstract String getTerritoryName();
      public abstract Long getRegionID();

      @SchemaCreate
      public static Territory create(Long territoryID, String territoryName, Long regionID) {
        return new AutoValue_TerritoryClass(territoryID, territoryName, regionID);
      }
*/    
    
    @DefaultCoder(AvroCoder.class)
    static class Territory {
        Long territoryID;
        String territoryName;
        Long regionID;
        
        Territory() {}
        
        Territory(long territoryID, String territoryName, long regionID) {
            this.territoryID = territoryID;
            this.territoryName = territoryName;
            this.regionID = regionID;
        }
        
        @Override
        public String toString() {
            return String.format("(territoryID = %d, territoryName = %s, regionID = %d)", territoryID, territoryName, regionID);
        }
    }
    
    static class SerializeTerritory implements SerializableFunction<Territory, String> {
        @Override
        public String apply(Territory input) {
          return input.toString();
        }
    }

    static class ParseTerritories extends DoFn<String, Territory> {
        private static final Logger LOG = LoggerFactory.getLogger(ParseTerritories.class);

        @ProcessElement
        public void process(ProcessContext c) {
            String[] columns = c.element().split(",");
            try {
                Long territoryID = Long.parseLong(columns[0].trim());
                String territoryName = columns[1].trim();
                Long regionID = Long.parseLong(columns[2].trim());
                c.output(new Territory(territoryID, territoryName, regionID));
            } catch (ArrayIndexOutOfBoundsException | NumberFormatException e) {
                LOG.info("ParseTerritories: parse error on '" + c.element() + "': " + e.getMessage());
            }
        }
    }
    

    
     
/*    
    
    
    @DefaultCoder(AvroCoder.class)
    static class Region {
        Long regionID;
        Long cnt regionName;
        
        Region() {}
        
        Region(long regionID, long cnt) {
            this.regionID = regionID;
            this.cnt = cnt;
        }
        
        @Override
        public String toString() {
            return String.format("(regionID = %d, cnt = %d)", regionID, cnt);
        }

    }
    
    static class SerializeTerritory implements SerializableFunction<Region, String> {
        @Override
        public String apply(Region input) {
          return input.toString();
        }
    }

    
    
private class Transform extends PTransform<pcollectionlist<row>, PCollection<row>> {
 
    @Override
    public PCollection<row> expand(PCollectionList<row> pinput) {
      checkArgument(
          pinput.size() == 1,
          "Wrong number of inputs for %s: %s",
          BeamUncollectRel.class.getSimpleName(),
          pinput);
      PCollection<row> upstream = pinput.get(0);
 
      // Each row of the input contains a single array of things to be emitted; Calcite knows
      // what the row looks like
      Schema outputSchema = CalciteUtils.toSchema(getRowType());
 
      PCollection<row> uncollected =
          upstream.apply(ParDo.of(new UncollectDoFn(outputSchema))).setRowSchema(outputSchema);
 
      return uncollected;
    }
  }    
    static class ParseRegions extends DoFn<Row, Region> {
        private static final Logger LOG = LoggerFactory.getLogger(ParseTerritories.class);

        @ProcessElement
        public void process(ProcessContext c) {
            
            String[] columns = c.element().split(",");
            try {
                Long territoryID = Long.parseLong(columns[0].trim());
                String territoryName = columns[1].trim();
                Long regionID = Long.parseLong(columns[2].trim());
                c.output(new Territory(territoryID, territoryName, regionID));
            } catch (ArrayIndexOutOfBoundsException | NumberFormatException e) {
                LOG.info("ParseTerritories: parse error on '" + c.element() + "': " + e.getMessage());
            }
        }
    }

    
    
*/
    
    
}


#

***

# 8. <font color='green' size="+2">DoFn</font> Lifecycle

## <img src="python.png" width=40 height=40 /><font color='cadetblue' size="+2">Python</font>

### <font color='green' size="+2">DoFn</font> Lifecycle

In [None]:
import apache_beam as beam
from apache_beam.pvalue import AsSingleton, AsDict
from apache_beam.io import ReadFromText

class TerritoryParseTuple(beam.DoFn):
    # split territory into KV pair of (regionid, (territoryid, territoryname))
    def process(self, element):
        territoryid, territoryname, regionid = element.split(',')
        yield(int(territoryid), territoryname, int(regionid))
        
                
class LookupRegion(beam.DoFn):
    def setup(self):
        self.lookup = {1:'North', 2:'South', 3:'East', 4:'West'}
        print('setup')
        
    def start_bundle(self):
        print('start bundle')
        
    def process(self, element, uppercase = 0):
        #lookuptable = {1:'North', 2:'South', 3:'East', 4:'West'}
        territoryid, territoryname, regionid = element
        region = self.lookup.get(regionid, 'No Region')
        if uppercase == 1:
            region = region.upper()
        yield(territoryid, territoryname, regionid, region)
        
    def finish_bundle(self):
        print('finish bundle')

    def teardown(self):
        print('teardown')
        del self.lookup
    

with beam.Pipeline() as p:
    territories =  (
        p | 'Read Territories' >> ReadFromText('territories.csv')
          | 'Parse Territories' >> beam.ParDo(TerritoryParseTuple())
    )
    
    lookup = (
        territories
        | beam.ParDo(LookupRegion(), uppercase = 1 ) 
    )
    lookup | 'Print Loopup' >> beam.Map(print)
        


#

***

# 9. Side Inputs

## <img src="python.png" width=40 height=40 /><font color='cadetblue' size="+2">Python</font>

### Side inputs are about passing extra parameters to a function where the parameters are calculated in the pipeline itself.

In [None]:
import apache_beam as beam
from apache_beam.pvalue import AsSingleton, AsDict
from apache_beam.io import ReadFromText
from apache_beam.transforms.combiners import Sample

class TerritoryParseTuple(beam.DoFn):
    # split territory into KV pair of (regionid, (territoryid, territoryname))
    def process(self, element, uppercase = '0'):
        # It's a bit weird here but what is passed in is a single element array of a string
        territoryid, territoryname, regionid = element.split(',')
        yield(int(territoryid), territoryname if uppercase[0] == '0' else territoryname.upper(), int(regionid))

        
with beam.Pipeline() as p:
    # x == 10
    # uppcase = ['0' if x == 10 else '1']
    sideinput = (
        p | 'Read sideinput.txt' >> ReadFromText('sideinput.txt')
          | Sample.FixedSizeGlobally(1)
    )
    
    territories =  (
        p | 'Read Territories' >> ReadFromText('territories.csv')
#          | 'Parse Territories' >> beam.ParDo(TerritoryParseTuple(), uppercase = ["1"]) # This is not a side input but just passing a fixed parameter
#          | 'Parse Territories' >> beam.ParDo(TerritoryParseTuple(), uppercase = sideinput)  # fails because sideinput is a PCollection not an integer
          | 'Parse Territories' >> beam.ParDo(TerritoryParseTuple(), uppercase = beam.pvalue.AsSingleton(sideinput))  # When the parameter is calculated in the pipeline itself, that makes it a side input
    )
    territories | 'Print Loopup' >> beam.Map(print)

#    maxregion | 'Print Min' >> beam.Map(print)


#

***

# 10. Windows (Not Complete)

## <img src="python.png" width=40 height=40 /><font color='cadetblue' size="+2">Python</font>

### Side inputs are about passing extra parameters to a function where the parameters are calculated in the pipeline itself.

In [None]:
import apache_beam as beam
from apache_beam.transforms.sql import SqlTransform

#dir(beam.io)
#help(SqlTransform)
from datetime import datetime
print(datetime.strptime('1997-03-12', '%Y-%m-%d').date())
    

In [None]:
import apache_beam as beam
from apache_beam.io import ReadFromAvro, ReadFromParquet
from apache_beam.transforms.combiners import Sample
from datetime import datetime
import time
import typing
from apache_beam import coders
from apache_beam.transforms.sql import SqlTransform
from apache_beam import window

class GetTimestampFn(beam.DoFn):
    def process(self, element, window=beam.DoFn.WindowParam):
        #print(element, type(element), dir(element))
        # window_start = window.start.to_utc_datetime().strftime("%Y-%m-%dT%H:%M:%S")
        # window_end = window.end.to_utc_datetime().strftime("%Y-%m-%dT%H:%M:%S")
        window_start = window.start.to_utc_datetime()
        window_end = window.end.to_utc_datetime()
        output = { **(element._asdict()), 'window_start': window_start, 'window_end': window_end}
        yield output

class LeftJoin(beam.PTransform):
    '''
    This PTransform will take a dictionary to the left of the | which will be the collection of the two
    PCollections you want to join together. Both must be a dictionary. You will then pass in the name of each
    PCollection and the key to join them on.
    It will automatically reshape the two dicts into tuples of (key, dict) where it removes the key from each dict
    It then CoGroups them and reshapes the tuple into a dict ready for insertion to a BQ table
    '''
    def __init__(self, parent_pipeline_name, parent_key, child_pipeline_name, child_key):
        self.parent_pipeline_name = parent_pipeline_name
        self.parent_key = parent_key
        self.child_pipeline_name = child_pipeline_name
        self.child_key = child_key

    def expand(self, pcols):
        def reshapeToKV(item, key):
            # pipeline object should be a dictionary
            item1 = item.copy()
            del item1[key]
            return (item[key], item1)

        def reshapeCoGroupToFlatDict(item):
            parent = {self.parent_key : item[0]}
            parent.update(item[1][self.parent_pipeline_name][0])
            ret = []
            for row1 in item[1][self.child_pipeline_name]:
                row = parent.copy()
                row.update(row1)
                ret.append(row)
            return ret
#            yield ret

        return (
                {
                self.parent_pipeline_name : pcols[self.parent_pipeline_name] | f'Convert {self.parent_pipeline_name} to KV' 
                    >> beam.Map(reshapeToKV, self.parent_key)
                ,self.child_pipeline_name : pcols[self.child_pipeline_name] | f'Convert {self.child_pipeline_name} to KV'
                    >> beam.Map(reshapeToKV, self.child_key)
                } | f'CoGroupByKey {self.child_pipeline_name} into {self.parent_pipeline_name}'
                    >> beam.CoGroupByKey()
                  | f'Reshape to dictionary'
                    >> beam.FlatMap(reshapeCoGroupToFlatDict)
        )


ordersfile = 'datasets/northwind/AVRO/orders/orders.avro'
orderdetailsfile = 'datasets/northwind/PARQUET/orderdetails/orderdetails.parquet'
with beam.Pipeline() as p:
    
    
    orders = (p | 'Read Orders' >> beam.io.ReadFromAvro(ordersfile)
#                | 'Parse Orders' >> beam.ParDo(ParseOrder())
                  )

    orderdetails = (p | 'Read OrderDetails' >> beam.io.ReadFromParquet(orderdetailsfile)
#                      | 'Parse OrderDetails' >> beam.ParDo(ParseOrderDetail())
                   )

    leftjoin = (
        {'orders': orders, 'orderdetails': orderdetails} 
        | 'Join' >> LeftJoin('orders', 'orderid', 'orderdetails', 'orderid')
        | 'Select' >> beam.Map(lambda x : {'orderdate': datetime.strptime(x['orderdate'], '%Y-%m-%d').date()
                                           , 'shipcountry': x['shipcountry']
                                           , 'amount': x['unitprice'] * x['quantity']
                                          })
       | 'Timestamp' >> beam.Map(lambda x : beam.window.TimestampedValue(x, time.mktime(x['orderdate'].timetuple())))
       | 'Window' >> beam.WindowInto(window.FixedWindows(60 * 60 * 24))
       | 'Group' >> beam.GroupBy(shipcountry = lambda x: x['shipcountry']).aggregate_field(lambda x: x['amount'], sum, 'totalamount')
       | 'AddWindowTimestamp' >> (beam.ParDo(GetTimestampFn()))
#       | 'Group' >> beam.GroupBy(shipcountry = lambda x: x['shipcountry']).aggregate_field(lambda x: x['amount'], sum, 'totalamount')
      # | 'Group' >> beam.GroupBy(shipcountry = lambda x: x['shipcountry'], orderdate = lambda x: x['orderdate']).aggregate_field(lambda x: x['amount'], sum, 'totalamount')
    )
    leftjoin | beam.Map(print)



In [None]:
import apache_beam as beam
from apache_beam.io import ReadFromText

class LeftJoin(beam.PTransform):
    '''
    This PTransform will take a dictionary to the left of the | which will be the collection of the two
    PCollections you want to join together. Both must be a dictionary. You will then pass in the name of each
    PCollection and the key to join them on.
    It will automatically reshape the two dicts into tuples of (key, dict) where it removes the key from each dict
    It then CoGroups them and reshapes the tuple into a dict ready for insertion to a BQ table
    '''
    def __init__(self, parent_pipeline_name, parent_key, child_pipeline_name, child_key):
        self.parent_pipeline_name = parent_pipeline_name
        self.parent_key = parent_key
        self.child_pipeline_name = child_pipeline_name
        self.child_key = child_key

    def expand(self, pcols):
        def reshapeToKV(item, key):
            # pipeline object should be a dictionary
            item1 = item.copy()
            del item1[key]
            return (item[key], item1)

        def reshapeCoGroupToFlatDict(item):
            parent = {self.parent_key : item[0]}
            parent.update(item[1][self.parent_pipeline_name][0])
            ret = []
            for row1 in item[1][self.child_pipeline_name]:
                row = parent.copy()
                row.update(row1)
                ret.append(row)
            return ret

        return (
                {
                self.parent_pipeline_name : pcols[self.parent_pipeline_name] | f'Convert {self.parent_pipeline_name} to KV' 
                    >> beam.Map(reshapeToKV, self.parent_key)
                ,self.child_pipeline_name : pcols[self.child_pipeline_name] | f'Convert {self.child_pipeline_name} to KV'
                    >> beam.Map(reshapeToKV, self.child_key)
                } | f'CoGroupByKey {self.child_pipeline_name} into {self.parent_pipeline_name}'
                    >> beam.CoGroupByKey()
                  | f'Reshape to dictionary'
                    >> beam.Map(reshapeCoGroupToFlatDict)
        )

class RegionParseDict(beam.DoFn):
    def process(self, element):
        regionid, regionname = element.split(',')
        yield {'regionid':int(regionid), 'regionname':regionname.title()}

class TerritoryParseDict(beam.DoFn):
    def process(self, element):
        territoryid, territoryname, regionid = element.split(',')
        yield {'territoryid':int(territoryid), 'territoryname' : territoryname, 'regionid':int(regionid)}
    
regionsfilename = 'datasets/northwind/CSV/regions/regions.csv'
territoriesfilename = 'datasets/northwind/CSV/territories/territories.csv'

with beam.Pipeline() as p:
    regions = (
              p | 'Read Regions' >> ReadFromText(regionsfilename)
                | 'Parse Regions' >> beam.ParDo(RegionParseDict())
              )
#    regions  | 'Print Regions' >> beam.Map(print)
        
    territories = (
                  p | 'Read Territories' >> ReadFromText('territories.csv')
                    | 'Parse Territories' >> beam.ParDo(TerritoryParseDict())
                  )
#    territories | 'Print Territories' >> beam.Map(print)

    nestjoin = {'regions':regions, 'territories':territories} | LeftJoin('regions', 'regionid', 'territories', 'regionid')
    nestjoin | 'Print Nest Join' >> beam.Map(print)



In [None]:
import apache_beam as beam
from apache_beam.io import ReadFromAvro, ReadFromParquet
from apache_beam.transforms.combiners import Sample
from datetime import datetime
import typing
from apache_beam import coders
from apache_beam.transforms.sql import SqlTransform

class Order(typing.NamedTuple):
    orderid: int
    customerid: str
    employeeid: int
    orderdate: str
    # orderdate: datetime
    # requireddate: datetime
    # shippeddate: datetime
    # shipvia: int
    # freight: float
    # shipname: str
    # shipaddress: str
    # shipcity: str
    # shipregion: str
    # shippostalcode: str
    shipcountry: str
beam.coders.registry.register_coder(Order, coders.RowCoder)

class ParseOrder(beam.DoFn):
    # split territory into KV pair of (regionid, (territoryid, territoryname))
    def process(self, element):
        yield(Order(
            int(element['orderid'])
            ,element['customerid']
            ,int(element['employeeid'])
            ,element['orderdate']
            # ,datetime.strptime(element['orderdate'], '%Y-%m-%d').date()
            # ,datetime.strptime(element['requireddate'], '%Y-%m-%d').date()
            # ,datetime.strptime(element['shippeddate'], '%Y-%m-%d').date()
            # ,int(element['shipvia'])
            # ,float(element['freight'])
            # ,element['shipname']
            # ,element['shipaddress']
            # ,element['shipcity']
            # ,element['shipregion']
            # ,element['shippostalcode']
            ,element['shipcountry']
        ))

class OrderDetail(typing.NamedTuple):
    orderid: int
    productid: int
    unitprice: float
    quantity: int
    discount: float
beam.coders.registry.register_coder(OrderDetail, coders.RowCoder)

class ParseOrderDetail(beam.DoFn):
    # split territory into KV pair of (regionid, (territoryid, territoryname))
    def process(self, element):
        yield(OrderDetail(int(element['orderid']), int(element['productid']), float(element['unitprice']), int(element['quantity']), float(element['discount'])))

ordersfile = 'datasets/northwind/AVRO/orders/orders.avro'
orderdetailsfile = 'datasets/northwind/PARQUET/orderdetails/orderdetails.parquet'
with beam.Pipeline() as p:
    
    
    orders = (p | 'Read Orders' >> beam.io.ReadFromAvro(ordersfile)
                | 'Parse Orders' >> beam.ParDo(ParseOrder())
                  )

    orderdetails = (p | 'Read OrderDetails' >> beam.io.ReadFromParquet(orderdetailsfile)
                      | 'Parse OrderDetails' >> beam.ParDo(ParseOrderDetail())
                   )

    query = ( {'orders': orders, 'orderdetails': orderdetails} 
              | 'SQL Territories' >> SqlTransform(
"""
SELECT o.shipcountry, cast(orderdate as date) as orderdate, od.unitprice * od.quantity as amount
FROM orders as o
JOIN orderdetails as od on o.orderid = od.orderid
""", dialect = 'zetasql')
            )      
    query | beam.Map(print)
# SELECT o.orderdate, o.shipcountry, sum(od.unitprice * od.quantity) as amount
# FROM orders as o
# JOIN orderdetails as od on o.orderid = od.orderid
# GROUP BY o.orderdate, o.shipcountry
    
    
    # TUMBLE(f_timestamp, INTERVAL '1' HOUR)
    # query | beam.Map(print)
    
    
              #     SELECT o.orderdate, sum(od.unitprice * od.quantity) as amount
              # FROM orders as o
              # JOIN orderdetails as od on o.orderid = od.orderid
              # GROUP BY o.orderdate


### Side input that is a lookup list
### More realistic example where the entire lookup table is read in the pipeline then distributed to each worker as a side input

In [None]:
import apache_beam as beam
from apache_beam.pvalue import AsList
from apache_beam.io import ReadFromText

class RegionParseDict(beam.DoFn):
    def process(self, element):
        regionid, regionname = element.split(',')
        yield {'regionid': int(regionid), 'regionname': regionname.title()}

class TerritoryParseTuple(beam.DoFn):
    def process(self, element):
        territoryid, territoryname, regionid = element.split(',')
        yield(int(territoryid), territoryname, int(regionid))
        
                
class LookupRegion(beam.DoFn):
    def process(self, element, lookuptable = [{'regionid':1, 'regionname':'North'}, {'regionid':2, 'regionname':'South'}]):
        # {1:'North', 2:'South'}
        territoryid, territoryname, regionid = element
        # Becase the regions PCollection is a different shape, use the following comprehension to make it easier to do a lookup
        lookup = {e['regionid'] : e['regionname'] for e in lookuptable } # {1:'North', 2:'South'}
        yield(territoryid, territoryname, regionid, lookup.get(regionid, 'No Region'))

with beam.Pipeline() as p:
    regions = (
        p | 'Read Regions' >> ReadFromText('regions.csv')
          | 'Parse Regions' >> beam.ParDo(RegionParseDict())
    )
    # regions | 'Print Regions' >> beam.Map(print)

    territories =  (
        p | 'Read Territories' >> ReadFromText('territories.csv')
          | 'Parse Territories' >> beam.ParDo(TerritoryParseTuple())
    )
    # territories | 'Print Territories' >> beam.Map(print)
    
    lookup = (
        territories
        | beam.ParDo(LookupRegion(), lookuptable = beam.pvalue.AsList(regions))
    )
    lookup | 'Print Loopup' >> beam.Map(print)
        


In [30]:
import apache_beam as beam
from apache_beam import pvalue
from apache_beam.io import ReadFromText, WriteToText
from datetime import datetime
import time
import typing
from apache_beam import coders
from apache_beam.transforms.sql import SqlTransform
from apache_beam import window

class GetTimestampFn(beam.DoFn):
    def process(self, element, window=beam.DoFn.WindowParam):
        #print(element, type(element), dir(element))
        # window_start = window.start.to_utc_datetime().strftime("%Y-%m-%dT%H:%M:%S")
        # window_end = window.end.to_utc_datetime().strftime("%Y-%m-%dT%H:%M:%S")
        window_start = window.start.to_utc_datetime()
        window_end = window.end.to_utc_datetime()
        output = { **(element._asdict()), 'window_start': window_start, 'window_end': window_end}
        yield output

class LeftJoin(beam.PTransform):
    '''
    This PTransform will take a dictionary to the left of the | which will be the collection of the two
    PCollections you want to join together. Both must be a dictionary. You will then pass in the name of each
    PCollection and the key to join them on.
    It will automatically reshape the two dicts into tuples of (key, dict) where it removes the key from each dict
    It then CoGroups them and reshapes the tuple into a dict ready for insertion to a BQ table
    '''
    def __init__(self, parent_pipeline_name, parent_key, child_pipeline_name, child_key):
        self.parent_pipeline_name = parent_pipeline_name
        self.parent_key = parent_key
        self.child_pipeline_name = child_pipeline_name
        self.child_key = child_key

    def expand(self, pcols):
        def reshapeToKV(item, key):
            # pipeline object should be a dictionary
            item1 = item.copy()
            del item1[key]
            return (item[key], item1)

        def reshapeCoGroupToFlatDict(item):
            parent = {self.parent_key : item[0]}
            parent.update(item[1][self.parent_pipeline_name][0])
            ret = []
            for row1 in item[1][self.child_pipeline_name]:
                row = parent.copy()
                row.update(row1)
                ret.append(row)
            return ret

        return (
                {
                self.parent_pipeline_name : pcols[self.parent_pipeline_name] | f'Convert {self.parent_pipeline_name} to KV' 
                    >> beam.Map(reshapeToKV, self.parent_key)
                ,self.child_pipeline_name : pcols[self.child_pipeline_name] | f'Convert {self.child_pipeline_name} to KV'
                    >> beam.Map(reshapeToKV, self.child_key)
                } | f'CoGroupByKey {self.child_pipeline_name} into {self.parent_pipeline_name}'
                    >> beam.CoGroupByKey()
                  | f'Reshape to dictionary'
                    >> beam.FlatMap(reshapeCoGroupToFlatDict)
        )

ordersfilename = 'datasets/northwind/AVRO/orders/orders.avro'
orderdetailsfilename = 'datasets/northwind/AVRO/orderdetails/orderdetails.avro'

with beam.Pipeline() as p:
    orders = (p | 'Read Orders' >> beam.io.ReadFromAvro(ordersfilename)
             )
    orderdetails = (p | 'Read OrderDetails' >> beam.io.ReadFromAvro(orderdetailsfilename)
                   )
    ordersjoin = ( {'orders': orders, 'orderdetails': orderdetails} 
                    | 'Join' >> LeftJoin('orders', 'orderid', 'orderdetails', 'orderid')
                    | 'Select' >> beam.Map(lambda x : dict( customerid = x['customerid']
                                                               , orderdate = datetime.strptime(x['orderdate'], '%Y-%m-%d').date()
                                                               , country = x['shipcountry']
                                                               , amount = x['unitprice'] * x['quantity']))
                 )
    
    
    groupby = (ordersjoin
               | 'Timestamp' >> beam.Map(lambda x : beam.window.TimestampedValue(x, time.mktime(x['orderdate'].timetuple())))
               | 'Window' >> beam.WindowInto(beam.window.FixedWindows(60 * 60 * 24))
               | 'Group' >> beam.GroupBy(country = lambda x: x['country']).aggregate_field(lambda x: x['amount'], sum, 'totalamount')
               | 'AddWindowTimestamp' >> (beam.ParDo(GetTimestampFn()))
              )

    groupby | beam.Map(print)

{'country': 'France', 'totalamount': 440.0, 'window_start': datetime.datetime(1996, 7, 4, 0, 0), 'window_end': datetime.datetime(1996, 7, 5, 0, 0)}
{'country': 'France', 'totalamount': 670.8, 'window_start': datetime.datetime(1996, 7, 8, 0, 0), 'window_end': datetime.datetime(1996, 7, 9, 0, 0)}
{'country': 'France', 'totalamount': 1176.0, 'window_start': datetime.datetime(1996, 7, 25, 0, 0), 'window_end': datetime.datetime(1996, 7, 26, 0, 0)}
{'country': 'France', 'totalamount': 538.6, 'window_start': datetime.datetime(1996, 8, 6, 0, 0), 'window_end': datetime.datetime(1996, 8, 7, 0, 0)}
{'country': 'France', 'totalamount': 121.6, 'window_start': datetime.datetime(1996, 9, 2, 0, 0), 'window_end': datetime.datetime(1996, 9, 3, 0, 0)}
{'country': 'France', 'totalamount': 1420.0, 'window_start': datetime.datetime(1996, 9, 4, 0, 0), 'window_end': datetime.datetime(1996, 9, 5, 0, 0)}
{'country': 'France', 'totalamount': 268.79999999999995, 'window_start': datetime.datetime(1996, 9, 20, 0, 0

In [44]:
import apache_beam as beam
from apache_beam import pvalue
from apache_beam.io import ReadFromText, WriteToText
from datetime import datetime
import time
import typing
from apache_beam import coders
from apache_beam.transforms.sql import SqlTransform
from apache_beam import window

class Order(typing.NamedTuple):
    orderid: int
    orderdate: datetime.date
    country: str
beam.coders.registry.register_coder(Order, coders.RowCoder)

class OrderDetail(typing.NamedTuple):
    orderid: int
    amount: float
beam.coders.registry.register_coder(OrderDetail, coders.RowCoder)

class GetTimestampFn(beam.DoFn):
    def process(self, element, window=beam.DoFn.WindowParam):
        #print(element, type(element), dir(element))
        # window_start = window.start.to_utc_datetime().strftime("%Y-%m-%dT%H:%M:%S")
        # window_end = window.end.to_utc_datetime().strftime("%Y-%m-%dT%H:%M:%S")
        window_start = window.start.to_utc_datetime()
        window_end = window.end.to_utc_datetime()
        output = { **(element._asdict()), 'window_start': window_start, 'window_end': window_end}
        yield output

class LeftJoin(beam.PTransform):
    '''
    This PTransform will take a dictionary to the left of the | which will be the collection of the two
    PCollections you want to join together. Both must be a dictionary. You will then pass in the name of each
    PCollection and the key to join them on.
    It will automatically reshape the two dicts into tuples of (key, dict) where it removes the key from each dict
    It then CoGroups them and reshapes the tuple into a dict ready for insertion to a BQ table
    '''
    def __init__(self, parent_pipeline_name, parent_key, child_pipeline_name, child_key):
        self.parent_pipeline_name = parent_pipeline_name
        self.parent_key = parent_key
        self.child_pipeline_name = child_pipeline_name
        self.child_key = child_key

    def expand(self, pcols):
        def reshapeToKV(item, key):
            # pipeline object should be a dictionary
            item1 = item.copy()
            del item1[key]
            return (item[key], item1)

        def reshapeCoGroupToFlatDict(item):
            parent = {self.parent_key : item[0]}
            parent.update(item[1][self.parent_pipeline_name][0])
            ret = []
            for row1 in item[1][self.child_pipeline_name]:
                row = parent.copy()
                row.update(row1)
                ret.append(row)
            return ret

        return (
                {
                self.parent_pipeline_name : pcols[self.parent_pipeline_name] | f'Convert {self.parent_pipeline_name} to KV' 
                    >> beam.Map(reshapeToKV, self.parent_key)
                ,self.child_pipeline_name : pcols[self.child_pipeline_name] | f'Convert {self.child_pipeline_name} to KV'
                    >> beam.Map(reshapeToKV, self.child_key)
                } | f'CoGroupByKey {self.child_pipeline_name} into {self.parent_pipeline_name}'
                    >> beam.CoGroupByKey()
                  | f'Reshape to dictionary'
                    >> beam.FlatMap(reshapeCoGroupToFlatDict)
        )

ordersfilename = 'datasets/northwind/AVRO/orders/orders.avro'
orderdetailsfilename = 'datasets/northwind/AVRO/orderdetails/orderdetails.avro'

with beam.Pipeline() as p:
    orders = (p | 'Read Orders' >> beam.io.ReadFromAvro(ordersfilename)
                | 'Parse Date' >> beam.Map(lambda x : dict(orderid = x['orderid']
                                                           , country = x['shipcountry']
                                                           , orderdate = datetime.strptime(x['orderdate'], '%Y-%m-%d').date()))
#                | 'Order to Row' >> beam.Map(lambda x : beam.Row(**x))
                 | 'Order to Class' >> beam.Map(lambda x : Order(**x)).with_output_types(Order) 
             )
    orderdetails = (p | 'Read OrderDetails' >> beam.io.ReadFromAvro(orderdetailsfilename)
                      | 'Parse Amount' >> beam.Map(lambda x : dict(orderid = x['orderid']                        
                                                               , amount = x['unitprice'] * x['quantity']))
                                                               
#                      | 'OrderDetail to Row' >> beam.Map(lambda x : beam.Row(**x))
                 | 'OrderDetail to Class' >> beam.Map(lambda x : OrderDetail(**x)).with_output_types(OrderDetail) 
                   )
    ordersjoin = ( {'orders': orders, 'orderdetails': orderdetails} 
                    | 'SQL' >> SqlTransform('''
                    SELECT country
                        , sum(amount) as total_amount
                    FROM orders as o
                    JOIN orderdetails as od on o.orderid = od.orderid
                    GROUP BY country
                    '''
                 ))
    # ordersjoin = ( {'orders': orders, 'orderdetails': orderdetails} 
    #                 | 'SQL' >> SqlTransform('''
    #                 SELECT country, TUMBLE_START(orderdate, 60*60*24) as start_date
    #                     , sum(amount) as total_amount
    #                 FROM orders as o
    #                 JOIN orderdetails as od on o.orderid = od.orderid
    #                 GROUP BY COUNTRY, TUMBLE(orderdate, 60*60*24)
    #                 '''
    #              ))
    
    
    ordersjoin | beam.Map(print)
#     groupby = (ordersjoin
#                | 'Timestamp' >> beam.Map(lambda x : beam.window.TimestampedValue(x, time.mktime(x['orderdate'].timetuple())))
#                | 'Window' >> beam.WindowInto(beam.window.FixedWindows(60 * 60 * 24))
#                | 'Group' >> beam.GroupBy(country = lambda x: x['country']).aggregate_field(lambda x: x['amount'], sum, 'totalamount')
#                | 'AddWindowTimestamp' >> (beam.ParDo(GetTimestampFn()))
#               )

#     groupby | beam.Map(print)

RuntimeError: org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.UncheckedExecutionException: org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.UncheckedExecutionException: java.lang.IllegalArgumentException: Failed to decode Schema due to an error decoding Field proto:

name: "orderdate"
type {
  nullable: true
  logical_type {
    urn: "beam:logical:pythonsdk_any:v1"
  }
}

	at org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LocalCache$Segment.get(LocalCache.java:2050)
	at org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LocalCache.get(LocalCache.java:3952)
	at org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LocalCache.getOrLoad(LocalCache.java:3974)
	at org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LocalCache$LocalLoadingCache.get(LocalCache.java:4958)
	at org.apache.beam.runners.core.construction.RehydratedComponents.getPCollection(RehydratedComponents.java:139)
	at org.apache.beam.sdk.expansion.service.ExpansionService.lambda$expand$0(ExpansionService.java:497)
	at java.util.stream.Collectors.lambda$toMap$58(Collectors.java:1321)
	at java.util.stream.ReduceOps$3ReducingSink.accept(ReduceOps.java:169)
	at java.util.Collections$UnmodifiableMap$UnmodifiableEntrySet.lambda$entryConsumer$0(Collections.java:1575)
	at java.util.Iterator.forEachRemaining(Iterator.java:116)
	at java.util.Spliterators$IteratorSpliterator.forEachRemaining(Spliterators.java:1801)
	at java.util.Collections$UnmodifiableMap$UnmodifiableEntrySet$UnmodifiableEntrySetSpliterator.forEachRemaining(Collections.java:1600)
	at java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:481)
	at java.util.stream.AbstractPipeline.wrapAndCopyInto(AbstractPipeline.java:471)
	at java.util.stream.ReduceOps$ReduceOp.evaluateSequential(ReduceOps.java:708)
	at java.util.stream.AbstractPipeline.evaluate(AbstractPipeline.java:234)
	at java.util.stream.ReferencePipeline.collect(ReferencePipeline.java:499)
	at org.apache.beam.sdk.expansion.service.ExpansionService.expand(ExpansionService.java:492)
	at org.apache.beam.sdk.expansion.service.ExpansionService.expand(ExpansionService.java:606)
	at org.apache.beam.model.expansion.v1.ExpansionServiceGrpc$MethodHandlers.invoke(ExpansionServiceGrpc.java:305)
	at org.apache.beam.vendor.grpc.v1p48p1.io.grpc.stub.ServerCalls$UnaryServerCallHandler$UnaryServerCallListener.onHalfClose(ServerCalls.java:182)
	at org.apache.beam.vendor.grpc.v1p48p1.io.grpc.internal.ServerCallImpl$ServerStreamListenerImpl.halfClosed(ServerCallImpl.java:354)
	at org.apache.beam.vendor.grpc.v1p48p1.io.grpc.internal.ServerImpl$JumpToApplicationThreadServerStreamListener$1HalfClosed.runInContext(ServerImpl.java:866)
	at org.apache.beam.vendor.grpc.v1p48p1.io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37)
	at org.apache.beam.vendor.grpc.v1p48p1.io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
	at java.lang.Thread.run(Thread.java:745)
Caused by: org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.UncheckedExecutionException: java.lang.IllegalArgumentException: Failed to decode Schema due to an error decoding Field proto:

name: "orderdate"
type {
  nullable: true
  logical_type {
    urn: "beam:logical:pythonsdk_any:v1"
  }
}

	at org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LocalCache$Segment.get(LocalCache.java:2050)
	at org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LocalCache.get(LocalCache.java:3952)
	at org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LocalCache.getOrLoad(LocalCache.java:3974)
	at org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LocalCache$LocalLoadingCache.get(LocalCache.java:4958)
	at org.apache.beam.runners.core.construction.RehydratedComponents.getCoder(RehydratedComponents.java:168)
	at org.apache.beam.runners.core.construction.PCollectionTranslation.fromProto(PCollectionTranslation.java:51)
	at org.apache.beam.runners.core.construction.RehydratedComponents$3.load(RehydratedComponents.java:108)
	at org.apache.beam.runners.core.construction.RehydratedComponents$3.load(RehydratedComponents.java:98)
	at org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LocalCache$LoadingValueReference.loadFuture(LocalCache.java:3528)
	at org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LocalCache$Segment.loadSync(LocalCache.java:2277)
	at org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LocalCache$Segment.lockedGetOrLoad(LocalCache.java:2154)
	at org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LocalCache$Segment.get(LocalCache.java:2044)
	... 27 more
Caused by: java.lang.IllegalArgumentException: Failed to decode Schema due to an error decoding Field proto:

name: "orderdate"
type {
  nullable: true
  logical_type {
    urn: "beam:logical:pythonsdk_any:v1"
  }
}

	at org.apache.beam.sdk.schemas.SchemaTranslation.schemaFromProto(SchemaTranslation.java:271)
	at org.apache.beam.runners.core.construction.CoderTranslators$8.fromComponents(CoderTranslators.java:171)
	at org.apache.beam.runners.core.construction.CoderTranslators$8.fromComponents(CoderTranslators.java:153)
	at org.apache.beam.runners.core.construction.CoderTranslation.fromKnownCoder(CoderTranslation.java:170)
	at org.apache.beam.runners.core.construction.CoderTranslation.fromProto(CoderTranslation.java:145)
	at org.apache.beam.runners.core.construction.RehydratedComponents$2.load(RehydratedComponents.java:87)
	at org.apache.beam.runners.core.construction.RehydratedComponents$2.load(RehydratedComponents.java:82)
	at org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LocalCache$LoadingValueReference.loadFuture(LocalCache.java:3528)
	at org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LocalCache$Segment.loadSync(LocalCache.java:2277)
	at org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LocalCache$Segment.lockedGetOrLoad(LocalCache.java:2154)
	at org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LocalCache$Segment.get(LocalCache.java:2044)
	... 38 more
Caused by: java.lang.IllegalArgumentException: Unexpected type_info: TYPEINFO_NOT_SET
	at org.apache.beam.sdk.schemas.SchemaTranslation.fieldTypeFromProtoWithoutNullable(SchemaTranslation.java:462)
	at org.apache.beam.sdk.schemas.SchemaTranslation.fieldTypeFromProto(SchemaTranslation.java:306)
	at org.apache.beam.sdk.schemas.SchemaTranslation.fieldTypeFromProtoWithoutNullable(SchemaTranslation.java:459)
	at org.apache.beam.sdk.schemas.SchemaTranslation.fieldTypeFromProto(SchemaTranslation.java:306)
	at org.apache.beam.sdk.schemas.SchemaTranslation.fieldFromProto(SchemaTranslation.java:299)
	at org.apache.beam.sdk.schemas.SchemaTranslation.schemaFromProto(SchemaTranslation.java:269)
	... 48 more


#

#     orders = (p | 'Read Orders' >> beam.io.ReadFromAvro(ordersfilename)
# #                | 'Orders To Row' >> beam.Map(lambda x : beam.Row(**x))
# #                | 'Parse' >> beam.ParDo(OrdersParseRow())
#              )
# #    orders | 'Print Orders' >> beam.Map(print)

#     orderdetails = (p | 'Read OrderDetails' >> beam.io.ReadFromAvro(orderdetailsfilename)
# #                | 'OrderDetailsTo Row' >> beam.Map(lambda x : beam.Row(**x))
#             )
# #    orderdetails | 'Print OrderDetailss' >> beam.Map(print)
