In [41]:
import pandas as pd
import numpy as np
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

In [3]:
spark = SparkSession.builder.appName("Test").getOrCreate()
spark

In [7]:
raw_data = {'regiment': ['Nighthawks', 'Nighthawks', 'Nighthawks', 'Nighthawks', 'Dragoons', 'Dragoons', 'Dragoons', 'Dragoons', 'Scouts', 'Scouts', 'Scouts', 'Scouts'],
            'company': ['1st', '1st', '2nd', '2nd', '1st', '1st', '2nd', '2nd','1st', '1st', '2nd', '2nd'],
            'deaths': [523, 52, 25, 616, 43, 234, 523, 62, 62, 73, 37, 35],
            'battles': [5, 42, 2, 2, 4, 7, 8, 3, 4, 7, 8, 9],
            'size': [1045, 957, 1099, 1400, 1592, 1006, 987, 849, 973, 1005, 1099, 1523],
            'veterans': [1, 5, 62, 26, 73, 37, 949, 48, 48, 435, 63, 345],
            'readiness': [1, 2, 3, 3, 2, 1, 2, 3, 2, 1, 2, 3],
            'armored': [1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1],
            'deserters': [4, 24, 31, 2, 3, 4, 24, 31, 2, 3, 2, 3],
            'origin': ['Arizona', 'California', 'Texas', 'Florida', 'Maine', 'Iowa', 'Alaska', 'Washington', 'Oregon', 'Wyoming', 'Louisana', 'Georgia']}

data = pd.DataFrame(data=raw_data)
data.to_csv("Army.csv",index=False)

In [9]:
df = spark.read.csv("Army.csv", header=True)
df.show()

+----------+-------+------+-------+----+--------+---------+-------+---------+----------+
|  regiment|company|deaths|battles|size|veterans|readiness|armored|deserters|    origin|
+----------+-------+------+-------+----+--------+---------+-------+---------+----------+
|Nighthawks|    1st|   523|      5|1045|       1|        1|      1|        4|   Arizona|
|Nighthawks|    1st|    52|     42| 957|       5|        2|      0|       24|California|
|Nighthawks|    2nd|    25|      2|1099|      62|        3|      1|       31|     Texas|
|Nighthawks|    2nd|   616|      2|1400|      26|        3|      1|        2|   Florida|
|  Dragoons|    1st|    43|      4|1592|      73|        2|      0|        3|     Maine|
|  Dragoons|    1st|   234|      7|1006|      37|        1|      1|        4|      Iowa|
|  Dragoons|    2nd|   523|      8| 987|     949|        2|      0|       24|    Alaska|
|  Dragoons|    2nd|    62|      3| 849|      48|        3|      1|       31|Washington|
|    Scouts|    1st| 

In [53]:
filtered_df = df.filter((col("deaths") > 500) | (col("deaths") < 50))
filtered_df.show()

+----------+-------+------+-------+----+--------+---------+-------+---------+--------+
|  regiment|company|deaths|battles|size|veterans|readiness|armored|deserters|  origin|
+----------+-------+------+-------+----+--------+---------+-------+---------+--------+
|Nighthawks|    1st|   523|      5|1045|       1|        1|      1|        4| Arizona|
|Nighthawks|    2nd|    25|      2|1099|      62|        3|      1|       31|   Texas|
|Nighthawks|    2nd|   616|      2|1400|      26|        3|      1|        2| Florida|
|  Dragoons|    1st|    43|      4|1592|      73|        2|      0|        3|   Maine|
|  Dragoons|    2nd|   523|      8| 987|     949|        2|      0|       24|  Alaska|
|    Scouts|    2nd|    37|      8|1099|      63|        2|      1|        2|Louisana|
|    Scouts|    2nd|    35|      9|1523|     345|        3|      1|        3| Georgia|
+----------+-------+------+-------+----+--------+---------+-------+---------+--------+



In [55]:
#################
def show_col(df, colm):
    return df.select(colm)

def get_col_names(df):
    return df.columns

def get_above(df, colm, value):
    return df.filter(col(colm)>value)

def get_below(df, colm,value):
    return df.filter(col(colm)<value)

def get_between(df,colm, upper, lower):
    return df.filter((col(colm) > upper) | (col("deaths") < lower))

In [57]:
print("The veterans col is as follows : ")
show_col(df,"veterans").show()

print("The veterans and Deaths columns is as follows : ")
show_col(df, ['veterans','deaths']).show()

col_names = get_col_names(df)
print(f"The column names are : {col_names}")

maine_alaska = show_col(df.filter((df.origin == "Maine") | (df.origin == "Alaska")), ['origin','deaths', 'size', 'deserters'])
maine_alaska.show()

deaths_above_50 = get_above(df,'deaths', 50)
print("The data for deaths above 50 is : ")
deaths_above_50.show()

print("The rows for deaths above 500 and less than 50 are : ")
get_between(df,"deaths", 500, 50).show()

The veterans col is as follows : 
+--------+
|veterans|
+--------+
|       1|
|       5|
|      62|
|      26|
|      73|
|      37|
|     949|
|      48|
|      48|
|     435|
|      63|
|     345|
+--------+

The veterans and Deaths columns is as follows : 
+--------+------+
|veterans|deaths|
+--------+------+
|       1|   523|
|       5|    52|
|      62|    25|
|      26|   616|
|      73|    43|
|      37|   234|
|     949|   523|
|      48|    62|
|      48|    62|
|     435|    73|
|      63|    37|
|     345|    35|
+--------+------+

The column names are : ['regiment', 'company', 'deaths', 'battles', 'size', 'veterans', 'readiness', 'armored', 'deserters', 'origin']
+------+------+----+---------+
|origin|deaths|size|deserters|
+------+------+----+---------+
| Maine|    43|1592|        3|
|Alaska|   523| 987|       24|
+------+------+----+---------+

The data for deaths above 50 is : 
+----------+-------+------+-------+----+--------+---------+-------+---------+----------+
|  re