# MIE524 - Assignment 1


## Setup

Let's set up Spark on your Colab environment.  Run the cell below!

In [1]:
!pip install pyspark
!pip install -U -q PyDrive
!apt install openjdk-8-jdk-headless -qq
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"

Collecting pyspark
  Downloading pyspark-3.5.3.tar.gz (317.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m317.3/317.3 MB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.5.3-py2.py3-none-any.whl size=317840625 sha256=e6ffb38b649dd0d2187cc0aaee5b927bba1898c0985d854123a2e4d6c8652982
  Stored in directory: /root/.cache/pip/wheels/1b/3a/92/28b93e2fbfdbb07509ca4d6f50c5e407f48dce4ddbda69a4ab
Successfully built pyspark
Installing collected packages: pyspark
Successfully installed pyspark-3.5.3
The following additional packages will be installed:
  libxtst6 openjdk-8-jre-headless
Suggested packages:
  openjdk-8-demo openjdk-8-source libnss-mdns fonts-dejavu-extra fonts-nanum fonts-ipafont-gothic
  fonts-ipafont-mincho fonts-wqy-microhei fonts-wqy-zenhei fonts-indic

Now we authenticate a Google Drive client to download the file we will be processing in our Spark job.

**Make sure to follow the interactive instructions.**

In [2]:
from pyspark.sql import *
from pyspark.sql.functions import *
from pyspark import SparkContext
import pandas as pd

# create the Spark Session
spark = SparkSession.builder.getOrCreate()

# create the Spark Context
sc = spark.sparkContext

Put all your imports, and path constants in the next cells.

In [3]:
# from itertools import combinations
# from operator import add


## Q1 - Message Count in Spark

### Load the dataset

In [4]:
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# Authenticate and create the PyDrive client
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)



In [5]:
# first make sure to upload the werkmap_messages.csv file from the data folder of your starter repository to colab's files
werkmap_message_data = 'werkmap_messages.csv'

If you executed the cells above, you should be able to see the file *plotsummaries.txt* under the "Files" tab on the left panel.

### Write your function in the next cells

In [6]:
def get_month_year(row):
    """
    INPUT:
        row : a row of the input data
    OUTPUT:
        month_year : string
    """

    # YOUR CODE HERE
    eventDateTime = (row.split(';'))[5]
    month_year = (eventDateTime.split(' '))[0]
    month = month_year.split('-')[1]
    year = month_year.split('-')[0]
    return year+'-'+month

# You may have additional functions or modify the provided functions as necessary

## Run your function in the next cells to output required content.

In [7]:
rdd = sc.textFile(werkmap_message_data)
header = rdd.first()
rdd = rdd.filter(lambda line: line!= header)

rdd = rdd.map(get_month_year)
rdd = rdd.map(lambda my: (my,1))
rdd_count = rdd.reduceByKey(lambda x,y:(x+y)).sortByKey()
rdd_count.take(20)

[('2015-07', 2933),
 ('2015-08', 3954),
 ('2015-09', 7079),
 ('2015-10', 8483),
 ('2015-11', 9695),
 ('2015-12', 9572),
 ('2016-01', 12582),
 ('2016-02', 11760)]

In [8]:
columns = ["date","count"]
df_count = rdd_count.toDF(columns)
df_count.show()

+-------+-----+
|   date|count|
+-------+-----+
|2015-07| 2933|
|2015-08| 3954|
|2015-09| 7079|
|2015-10| 8483|
|2015-11| 9695|
|2015-12| 9572|
|2016-01|12582|
|2016-02|11760|
+-------+-----+



In [9]:
#Change to date to order by time and change back to string
df_count = df_count.withColumn("date", to_date(df_count["date"], "yyyy-MM"))

df_count = df_count.orderBy(col("date").asc())

df_count = df_count.withColumn("date", date_format(df_count["date"], "yyyy-MM"))
df_count.show()

+-------+-----+
|   date|count|
+-------+-----+
|2015-07| 2933|
|2015-08| 3954|
|2015-09| 7079|
|2015-10| 8483|
|2015-11| 9695|
|2015-12| 9572|
|2016-01|12582|
|2016-02|11760|
+-------+-----+



In [10]:
df_count_pandas = df_count.toPandas()
df_count_pandas.to_csv("/content/Q1.txt", sep='\t', index=False)

## PART 2 - Oxford Covid-19 Government Response Tracker

### Load the dataset

In [11]:
id='1J_2ido9_-LiasNi8xlzk5-DeHu-2_4Zs'
downloaded = drive.CreateFile({'id': id})
downloaded.GetContentFile('OxCGRT_USA_latest.csv')

### Q2 - Computing Index Score with Spark

In [12]:
def clean_data(df):
    """
    INPUT:
        df: spark dataframe
    OUTPUT:
        cleaned data: spark dataframe

    NOTE: output the given word with characters stripped.
    """
    # YOUR CODE HERE

    #Remove NULL RegionName
    result = df.select(["regionName"]+["Date"] + indicators + indicator_flags)
    result = result.filter(col("RegionName").isNotNull())

    #Use only data in year 2020,2021
    result = result.withColumn("Date", to_date(result["Date"], "yyyyMMdd"))
    result = result.filter(year(result["Date"]).isin(2020, 2021))
    result = result.withColumn("Date", date_format(result["Date"], "yyyyMMdd"))

    #Cast indicators+flags columns to float for easy calculation
    for p in indicators+indicator_flags:
      result = result.withColumn(p, col(p).cast("float"))

    #Consider only Year Month to groupBy Later
    result = result.withColumn("Date", substring(col("Date"), 1, 6))
    return result





def impute_data(df):
    """
    INPUT:
        df: spark dataframe
    OUTPUT:
        imputed data: spark dataframe

    NOTE: output the dataframe with nan values replaced with the minimal value of the given indicator.
    """
    # YOUR CODE HERE

    result = df.na.fill(value=0)
    return result


def group_and_aggregate_data(df):
    """
    INPUT:
        df: spark dataframe
    OUTPUT:
        groupe and aggregated data: spark dataframe

    NOTE: output the dataframe with grouped (by month) and aggregated (based on the algorithm) data.
    """
    # YOUR CODE HERE
    unique_region_date = df.select("regionName","Date").distinct()

    for p in indicators+indicator_flags:
      grouped_df = df.groupBy("regionName","Date",p).agg(count(p).alias("Count " + p))
      #https://stackoverflow.com/questions/36654162/mode-of-grouped-data-in-pyspark
      result_df = grouped_df.groupBy("regionName","Date").agg(max(struct(col("Count " + p), col(p))).alias("Max" + p)).select(col('regionName'),col('Date'),col('Max'+p+'.'+p))
      unique_region_date = unique_region_date.join(result_df, on = ["regionName","Date"], how="inner")
    #unique_region_date will now be a dataframe with mode of each month of each state, considering larger values if ties
    return unique_region_date

def compute_index_score(df):
    """
    INPUT:
        df: spark dataframe
    OUTPUT:
        list of index scores per region and period: list

    NOTE: output a list of computed scores per region and period based on the algorithm.
    """
    # YOUR CODE HERE
    #Computing I(j,t)
    for i in indicators:
      temp = i.split("_")
      flag = temp[0] + "_Flag"
      if i in ["C8EV_International travel controls","E2_Debt/contract relief","H2_Testing policy","H3_Contact tracing"]:
        F_value = 0
      else:
        F_value = 1

      if flag not in indicator_flags:
        ft = 0
      else:
        ft = col(flag)

      if i in ["C3M_Cancel public events","C5M_Close public transport","C7M_Restrictions on internal movement","E1_Income support","E2_Debt/contract relief","H1_Public information campaigns","H3_Contact tracing"]:
        N = 2
      elif i in ["C1M_School closing","C2M_Workplace closing","C6M_Stay at home requirements","H2_Testing policy","H8M_Protection of elderly people"]:
        N = 3
      elif i in ["C4M_Restrictions on gatherings","C8EV_International travel controls","H6M_Facial Coverings"]:
        N = 4
      else:
        N = 5
      df = df.withColumn(f"I{temp[0]}", (100/N)*(col(i) - 0.5*(F_value-ft)))
    return df

# You may have additional functions


In [13]:
# You may use below lists to get indicators and flags header
indicators = ["C1M_School closing",
"C2M_Workplace closing",
"C3M_Cancel public events",
"C4M_Restrictions on gatherings",
"C5M_Close public transport",
"C6M_Stay at home requirements",
"C7M_Restrictions on internal movement",
"C8EV_International travel controls",
"E1_Income support",
"E2_Debt/contract relief",
"H1_Public information campaigns",
"H2_Testing policy",
"H3_Contact tracing",
"H6M_Facial Coverings",
"H7_Vaccination policy",
"H8M_Protection of elderly people"]

indicator_flags = ["C1M_Flag",
"C2M_Flag",
"C3M_Flag",
"C4M_Flag",
"C5M_Flag",
"C6M_Flag",
"C7M_Flag",
"E1_Flag",
"H1_Flag",
"H6M_Flag",
"H7_Flag",
"H8M_Flag"]

Run your function in the next cells to output required content.

In [14]:
spark.stop()
spark = SparkSession.builder.getOrCreate()

OxCGRT_latest = spark.read.option("header", True).csv("OxCGRT_USA_latest.csv")


In [15]:
OxCGRT_latest.sample(fraction=0.1).show()

+-------------+-----------+----------+----------+------------+--------+------------------+--------+---------------------+--------+------------------------+--------+------------------------------+--------+--------------------------+--------+-----------------------------+--------+-------------------------------------+--------+----------------------------------+-----------------+-------+-----------------------+------------------+------------------------+-------------------------------+-------+-----------------+------------------+-------------------------------------+-------------------------+--------------------+--------+---------------------+-------+--------------------------------+--------+-----------+-----------------------------------+----------------------------------+-------------------------------------------------------------------------------+--------------------------------------------------------------------+--------------------------------------------------+-------------+-----

In [16]:
OxCGRT_latest.select(["RegionName"]+["Date"] + indicators + indicator_flags).where(col("RegionName").isNull()).show()

+----------+--------+------------------+---------------------+------------------------+------------------------------+--------------------------+-----------------------------+-------------------------------------+----------------------------------+-----------------+-----------------------+-------------------------------+-----------------+------------------+--------------------+---------------------+--------------------------------+--------+--------+--------+--------+--------+--------+--------+-------+-------+--------+-------+--------+
|RegionName|    Date|C1M_School closing|C2M_Workplace closing|C3M_Cancel public events|C4M_Restrictions on gatherings|C5M_Close public transport|C6M_Stay at home requirements|C7M_Restrictions on internal movement|C8EV_International travel controls|E1_Income support|E2_Debt/contract relief|H1_Public information campaigns|H2_Testing policy|H3_Contact tracing|H6M_Facial Coverings|H7_Vaccination policy|H8M_Protection of elderly people|C1M_Flag|C2M_Flag|C3M_Fl

In [17]:
OxCGRT_latest.select(["RegionName"]+["Date"] + indicators + indicator_flags).where(col("Date")>="202201").show()

+----------+--------+------------------+---------------------+------------------------+------------------------------+--------------------------+-----------------------------+-------------------------------------+----------------------------------+-----------------+-----------------------+-------------------------------+-----------------+------------------+--------------------+---------------------+--------------------------------+--------+--------+--------+--------+--------+--------+--------+-------+-------+--------+-------+--------+
|RegionName|    Date|C1M_School closing|C2M_Workplace closing|C3M_Cancel public events|C4M_Restrictions on gatherings|C5M_Close public transport|C6M_Stay at home requirements|C7M_Restrictions on internal movement|C8EV_International travel controls|E1_Income support|E2_Debt/contract relief|H1_Public information campaigns|H2_Testing policy|H3_Contact tracing|H6M_Facial Coverings|H7_Vaccination policy|H8M_Protection of elderly people|C1M_Flag|C2M_Flag|C3M_Fl

In [18]:
df = clean_data(OxCGRT_latest)
df.show()

+----------+------+------------------+---------------------+------------------------+------------------------------+--------------------------+-----------------------------+-------------------------------------+----------------------------------+-----------------+-----------------------+-------------------------------+-----------------+------------------+--------------------+---------------------+--------------------------------+--------+--------+--------+--------+--------+--------+--------+-------+-------+--------+-------+--------+
|regionName|  Date|C1M_School closing|C2M_Workplace closing|C3M_Cancel public events|C4M_Restrictions on gatherings|C5M_Close public transport|C6M_Stay at home requirements|C7M_Restrictions on internal movement|C8EV_International travel controls|E1_Income support|E2_Debt/contract relief|H1_Public information campaigns|H2_Testing policy|H3_Contact tracing|H6M_Facial Coverings|H7_Vaccination policy|H8M_Protection of elderly people|C1M_Flag|C2M_Flag|C3M_Flag|C

In [19]:
df.select(["RegionName"]+["Date"] + indicators + indicator_flags).where(col("Date")>"202112").show()

+----------+----+------------------+---------------------+------------------------+------------------------------+--------------------------+-----------------------------+-------------------------------------+----------------------------------+-----------------+-----------------------+-------------------------------+-----------------+------------------+--------------------+---------------------+--------------------------------+--------+--------+--------+--------+--------+--------+--------+-------+-------+--------+-------+--------+
|RegionName|Date|C1M_School closing|C2M_Workplace closing|C3M_Cancel public events|C4M_Restrictions on gatherings|C5M_Close public transport|C6M_Stay at home requirements|C7M_Restrictions on internal movement|C8EV_International travel controls|E1_Income support|E2_Debt/contract relief|H1_Public information campaigns|H2_Testing policy|H3_Contact tracing|H6M_Facial Coverings|H7_Vaccination policy|H8M_Protection of elderly people|C1M_Flag|C2M_Flag|C3M_Flag|C4M_F

In [20]:
df = group_and_aggregate_data(df)

df.show()



+----------+------+------------------+---------------------+------------------------+------------------------------+--------------------------+-----------------------------+-------------------------------------+----------------------------------+-----------------+-----------------------+-------------------------------+-----------------+------------------+--------------------+---------------------+--------------------------------+--------+--------+--------+--------+--------+--------+--------+-------+-------+--------+-------+--------+
|regionName|  Date|C1M_School closing|C2M_Workplace closing|C3M_Cancel public events|C4M_Restrictions on gatherings|C5M_Close public transport|C6M_Stay at home requirements|C7M_Restrictions on internal movement|C8EV_International travel controls|E1_Income support|E2_Debt/contract relief|H1_Public information campaigns|H2_Testing policy|H3_Contact tracing|H6M_Facial Coverings|H7_Vaccination policy|H8M_Protection of elderly people|C1M_Flag|C2M_Flag|C3M_Flag|C

In [21]:
num_columns = len(df.columns)
print(num_columns)

30


In [22]:
#Rows with NULL values are Rows with no data for the entire period; because group_and_aggregate_data will always count NULL as 0; so if there exists data, NULL will not be the mode
df = impute_data(df) #Set NULL to 0
df.show()

+----------+------+------------------+---------------------+------------------------+------------------------------+--------------------------+-----------------------------+-------------------------------------+----------------------------------+-----------------+-----------------------+-------------------------------+-----------------+------------------+--------------------+---------------------+--------------------------------+--------+--------+--------+--------+--------+--------+--------+-------+-------+--------+-------+--------+
|regionName|  Date|C1M_School closing|C2M_Workplace closing|C3M_Cancel public events|C4M_Restrictions on gatherings|C5M_Close public transport|C6M_Stay at home requirements|C7M_Restrictions on internal movement|C8EV_International travel controls|E1_Income support|E2_Debt/contract relief|H1_Public information campaigns|H2_Testing policy|H3_Contact tracing|H6M_Facial Coverings|H7_Vaccination policy|H8M_Protection of elderly people|C1M_Flag|C2M_Flag|C3M_Flag|C

In [23]:
df = compute_index_score(df)
df.show()

+----------+------+------------------+---------------------+------------------------+------------------------------+--------------------------+-----------------------------+-------------------------------------+----------------------------------+-----------------+-----------------------+-------------------------------+-----------------+------------------+--------------------+---------------------+--------------------------------+--------+--------+--------+--------+--------+--------+--------+-------+-------+--------+-------+--------+-------------------+-------------------+-----+-----+-----+-------------------+-----+-----+-----+----+-----+------------------+-----+-----+-----+-------------------+
|regionName|  Date|C1M_School closing|C2M_Workplace closing|C3M_Cancel public events|C4M_Restrictions on gatherings|C5M_Close public transport|C6M_Stay at home requirements|C7M_Restrictions on internal movement|C8EV_International travel controls|E1_Income support|E2_Debt/contract relief|H1_Public

In [24]:
for i in indicators:
  temp = i.split("_")
  Ind = f"I{temp[0]}"
  df = df.withColumn(Ind, when(col(i) == 0, 0).otherwise(col(Ind))) #when the policy value is 0, index should be 0
df.show()


+----------+------+------------------+---------------------+------------------------+------------------------------+--------------------------+-----------------------------+-------------------------------------+----------------------------------+-----------------+-----------------------+-------------------------------+-----------------+------------------+--------------------+---------------------+--------------------------------+--------+--------+--------+--------+--------+--------+--------+-------+-------+--------+-------+--------+------------------+------------------+-----+-----+----+------------------+----+-----+----+----+-----+------------------+-----+----+-----+------------------+
|regionName|  Date|C1M_School closing|C2M_Workplace closing|C3M_Cancel public events|C4M_Restrictions on gatherings|C5M_Close public transport|C6M_Stay at home requirements|C7M_Restrictions on internal movement|C8EV_International travel controls|E1_Income support|E2_Debt/contract relief|H1_Public informa

In [25]:
I_columns = []
for i in indicators:
  temp = i.split("_")
  Ind = f"I{temp[0]}"
  I_columns.append(Ind)

df_with_average = df.withColumn("GovernmentResponseIndex", expr('+'.join(I_columns)) / len(I_columns))
df_with_average.show()


+----------+------+------------------+---------------------+------------------------+------------------------------+--------------------------+-----------------------------+-------------------------------------+----------------------------------+-----------------+-----------------------+-------------------------------+-----------------+------------------+--------------------+---------------------+--------------------------------+--------+--------+--------+--------+--------+--------+--------+-------+-------+--------+-------+--------+------------------+------------------+-----+-----+----+------------------+----+-----+----+----+-----+------------------+-----+----+-----+------------------+-----------------------+
|regionName|  Date|C1M_School closing|C2M_Workplace closing|C3M_Cancel public events|C4M_Restrictions on gatherings|C5M_Close public transport|C6M_Stay at home requirements|C7M_Restrictions on internal movement|C8EV_International travel controls|E1_Income support|E2_Debt/contract 

In [26]:
df_to_show = df_with_average.select("regionName","Date","GovernmentResponseIndex")
df_to_show = df_to_show.withColumn("Date", to_date(df_to_show["Date"], "yyyyMM"))

df_to_show = df_to_show.orderBy(col("regionName").asc(), col("Date").asc())

df_to_show = df_to_show.withColumn("Date", date_format(df_to_show["Date"], "MM-yyyy"))
df_to_show.show()

+----------+-------+-----------------------+
|regionName|   Date|GovernmentResponseIndex|
+----------+-------+-----------------------+
|   Alabama|01-2020|                    0.0|
|   Alabama|02-2020|      9.895833333333334|
|   Alabama|03-2020|      40.10416666666667|
|   Alabama|04-2020|      67.70833333333334|
|   Alabama|05-2020|     60.677083333333336|
|   Alabama|06-2020|      59.63541666666667|
|   Alabama|07-2020|                  56.25|
|   Alabama|08-2020|      46.35416666666667|
|   Alabama|09-2020|      52.60416666666667|
|   Alabama|10-2020|     47.395833333333336|
|   Alabama|11-2020|     47.395833333333336|
|   Alabama|12-2020|     47.395833333333336|
|   Alabama|01-2021|     49.895833333333336|
|   Alabama|02-2021|                47.8125|
|   Alabama|03-2021|                47.8125|
|   Alabama|04-2021|      48.17708333333333|
|   Alabama|05-2021|               46.09375|
|   Alabama|06-2021|     42.708333333333336|
|   Alabama|07-2021|      40.62500000000001|
|   Alabam

In [27]:
df_pandas = df_to_show.toPandas()
df_pandas.to_csv("/content/Q2.txt", sep='\t', index=False)

### Q3 - Association Rules

In [28]:
policies =  list(indicators)

In [29]:
def getpolicies(row):
  result = []
  for p in policies:
    if row[p] == None:
      continue
    if row[p] > 0: #access value of policy p, and check if it is implemented
      result.append(p)
  return result
def transform_to_items(df):
    """
      INPUT:
          df: spark dataframe
      OUTPUT:
          list itemsets: list

      NOTE: output a list itemsets from given dataframe.
      """
      # YOUR CODE HERE
    df = df.filter(substring(col("Date"),7,2) == "01") #First day of each month
    for p in policies:
      df = df.withColumn(p, col(p).cast("float")) #Change the policies value to float to be able to compare in getpolicies
    df = df.rdd
    df = df.map(getpolicies)#policies in each region in the first day of each month
    return df

In [30]:
def apriori(items, min_sup, itemset_size):
    """
    INPUT:
        items: list
        min_sup: the min support
    OUTPUT:
        list of frequent itemsets: list

    NOTE: output a list of frequent itemsets.
    """
    # YOUR CODE HERE
    c = {frozenset((subset,)) for subset in policies}
    for k in range(1,itemset_size+1):
      rdd = items.flatMap(lambda basket: count_occurences(basket,c))
      count = rdd.reduceByKey(lambda x,y:(x+y))
      L = count.filter(lambda x: x[1]>= min_sup)
      if k==itemset_size:
        return L
      c = create_combinations(L,k+1)

# You may have additional functions
def count_occurences(basket, c):
  #c is the candidate set
  result = []
  for candidate in c:
    if (candidate).issubset(set(basket)):
      result.append((candidate,1)) #Setting available candidates in each basket to (candidate,1) in order to reduceByKey
  return result
def create_combinations(items,size):
  temp = items.collect()
  result = []
  for i in range(len(temp)):
    for j in range(i+1,len(temp)):
      new = temp[i][0] | temp[j][0]
      if len(new) == size:
        result.append(new)
  combinations = {frozenset(subset) for subset in result}
  return set(combinations)

Run your function in the next cells to output required content.

In [31]:
spark.stop()
spark = SparkSession.builder.getOrCreate()

OxCGRT_latest = spark.read.option("header", True).csv("OxCGRT_USA_latest.csv")

In [32]:
OxCGRT_latest.show()

+-------------+-----------+----------+----------+------------+--------+------------------+--------+---------------------+--------+------------------------+--------+------------------------------+--------+--------------------------+--------+-----------------------------+--------+-------------------------------------+--------+----------------------------------+-----------------+-------+-----------------------+------------------+------------------------+-------------------------------+-------+-----------------+------------------+-------------------------------------+-------------------------+--------------------+--------+---------------------+-------+--------------------------------+--------+-----------+-----------------------------------+----------------------------------+-------------------------------------------------------------------------------+--------------------------------------------------------------------+--------------------------------------------------+-------------+-----

In [33]:
OxCGRT_latest = OxCGRT_latest.filter(col("RegionName").isNotNull()) #Remove NULL Region

#Use only data in year 2020,2021
OxCGRT_latest = OxCGRT_latest.withColumn("Date", to_date(OxCGRT_latest["Date"], "yyyyMMdd"))
OxCGRT_latest = OxCGRT_latest.filter(year(OxCGRT_latest["Date"]).isin(2020, 2021)) #Filter only year 2020,2021
OxCGRT_latest = OxCGRT_latest.withColumn("Date", date_format(OxCGRT_latest["Date"], "yyyyMMdd"))
OxCGRT_latest.show()

+-------------+-----------+----------+----------+------------+--------+------------------+--------+---------------------+--------+------------------------+--------+------------------------------+--------+--------------------------+--------+-----------------------------+--------+-------------------------------------+--------+----------------------------------+-----------------+-------+-----------------------+------------------+------------------------+-------------------------------+-------+-----------------+------------------+-------------------------------------+-------------------------+--------------------+--------+---------------------+-------+--------------------------------+--------+-----------+-----------------------------------+----------------------------------+-------------------------------------------------------------------------------+--------------------------------------------------------------------+--------------------------------------------------+-------------+-----

Checking if all the unique month in the data starts at day 01

In [34]:
unique_month_year = OxCGRT_latest.select("Date","RegionCode")
unique_month_year = unique_month_year.withColumn(
    "DateRegionCode",
    concat(substring(OxCGRT_latest["Date"], 1, 6),lit('01'), lit("-"), OxCGRT_latest["RegionCode"])
)
unique_month_year = unique_month_year.select("DateRegionCode").distinct()


all_date = OxCGRT_latest.select("Date","RegionCode")
all_date = all_date.withColumn(
    "DateRegionCode",
    concat(OxCGRT_latest["Date"], lit("-"), OxCGRT_latest["RegionCode"])
)



In [35]:
unique_month_year.take(20)


[Row(DateRegionCode='20200201-US_CA'),
 Row(DateRegionCode='20201201-US_IA'),
 Row(DateRegionCode='20200601-US_LA'),
 Row(DateRegionCode='20200401-US_MI'),
 Row(DateRegionCode='20210301-US_NC'),
 Row(DateRegionCode='20210501-US_ND'),
 Row(DateRegionCode='20210701-US_NE'),
 Row(DateRegionCode='20210601-US_NH'),
 Row(DateRegionCode='20201101-US_IL'),
 Row(DateRegionCode='20200301-US_MD'),
 Row(DateRegionCode='20211101-US_AK'),
 Row(DateRegionCode='20210501-US_MD'),
 Row(DateRegionCode='20200201-US_CO'),
 Row(DateRegionCode='20210601-US_CO'),
 Row(DateRegionCode='20201101-US_NC'),
 Row(DateRegionCode='20201001-US_NE'),
 Row(DateRegionCode='20210501-US_FL'),
 Row(DateRegionCode='20200901-US_IN'),
 Row(DateRegionCode='20200501-US_CO'),
 Row(DateRegionCode='20200201-US_HI')]

In [36]:
all_date.take(20)

[Row(Date='20200101', RegionCode='US_AK', DateRegionCode='20200101-US_AK'),
 Row(Date='20200102', RegionCode='US_AK', DateRegionCode='20200102-US_AK'),
 Row(Date='20200103', RegionCode='US_AK', DateRegionCode='20200103-US_AK'),
 Row(Date='20200104', RegionCode='US_AK', DateRegionCode='20200104-US_AK'),
 Row(Date='20200105', RegionCode='US_AK', DateRegionCode='20200105-US_AK'),
 Row(Date='20200106', RegionCode='US_AK', DateRegionCode='20200106-US_AK'),
 Row(Date='20200107', RegionCode='US_AK', DateRegionCode='20200107-US_AK'),
 Row(Date='20200108', RegionCode='US_AK', DateRegionCode='20200108-US_AK'),
 Row(Date='20200109', RegionCode='US_AK', DateRegionCode='20200109-US_AK'),
 Row(Date='20200110', RegionCode='US_AK', DateRegionCode='20200110-US_AK'),
 Row(Date='20200111', RegionCode='US_AK', DateRegionCode='20200111-US_AK'),
 Row(Date='20200112', RegionCode='US_AK', DateRegionCode='20200112-US_AK'),
 Row(Date='20200113', RegionCode='US_AK', DateRegionCode='20200113-US_AK'),
 Row(Date='2

In [37]:
missing_values = unique_month_year.join(all_date, on="DateRegionCode", how="left_anti")
if missing_values.count() == 0:
  print("All unique month start at day 01")
else:
  print("Error")
missing_values.show()


All unique month start at day 01
+--------------+
|DateRegionCode|
+--------------+
+--------------+



Basket = stores policies in each region on the first day of each month

In [38]:
print(type(OxCGRT_latest))
df = transform_to_items(OxCGRT_latest)
print(type(df))
df.take(10)


<class 'pyspark.sql.dataframe.DataFrame'>
<class 'pyspark.rdd.PipelinedRDD'>


[[],
 ['C8EV_International travel controls',
  'H1_Public information campaigns',
  'H2_Testing policy',
  'H3_Contact tracing'],
 ['C8EV_International travel controls',
  'H1_Public information campaigns',
  'H2_Testing policy',
  'H3_Contact tracing'],
 ['C1M_School closing',
  'C2M_Workplace closing',
  'C3M_Cancel public events',
  'C4M_Restrictions on gatherings',
  'C5M_Close public transport',
  'C6M_Stay at home requirements',
  'C7M_Restrictions on internal movement',
  'C8EV_International travel controls',
  'E1_Income support',
  'E2_Debt/contract relief',
  'H1_Public information campaigns',
  'H2_Testing policy',
  'H3_Contact tracing',
  'H8M_Protection of elderly people'],
 ['C1M_School closing',
  'C2M_Workplace closing',
  'C3M_Cancel public events',
  'C4M_Restrictions on gatherings',
  'C5M_Close public transport',
  'C6M_Stay at home requirements',
  'C7M_Restrictions on internal movement',
  'C8EV_International travel controls',
  'E1_Income support',
  'E2_Debt/co

Pairs

In [39]:
pairs = apriori(df, 100, 2) #Number in each pair, represents the count of that pair
pairs.take(20)

[(frozenset({'C8EV_International travel controls',
             'H1_Public information campaigns'}),
  1128),
 (frozenset({'C8EV_International travel controls', 'H2_Testing policy'}),
  1173),
 (frozenset({'H1_Public information campaigns', 'H3_Contact tracing'}), 1128),
 (frozenset({'H2_Testing policy', 'H3_Contact tracing'}), 1173),
 (frozenset({'H2_Testing policy', 'H8M_Protection of elderly people'}), 1075),
 (frozenset({'C2M_Workplace closing', 'H1_Public information campaigns'}),
  817),
 (frozenset({'E2_Debt/contract relief', 'H1_Public information campaigns'}),
  990),
 (frozenset({'C7M_Restrictions on internal movement',
             'H1_Public information campaigns'}),
  914),
 (frozenset({'C1M_School closing', 'H1_Public information campaigns'}), 1017),
 (frozenset({'E1_Income support', 'H2_Testing policy'}), 900),
 (frozenset({'C4M_Restrictions on gatherings',
             'H1_Public information campaigns'}),
  657),
 (frozenset({'C3M_Cancel public events', 'H1_Public infor

In [40]:
test = pairs.collect()
tuple(test[0][0])[0] #Accessing the frequent item

'C8EV_International travel controls'

In [41]:
single = apriori(df,100,1)
single = single.flatMap(lambda s: [(tuple(s[0])[0],s[1])])
single.take(10)

[('C8EV_International travel controls', 1173),
 ('H3_Contact tracing', 1173),
 ('E1_Income support', 900),
 ('C3M_Cancel public events', 900),
 ('C4M_Restrictions on gatherings', 657),
 ('C7M_Restrictions on internal movement', 914),
 ('E2_Debt/contract relief', 990),
 ('C6M_Stay at home requirements', 724),
 ('C1M_School closing', 1017),
 ('H8M_Protection of elderly people', 1075)]

X=>Y

In [42]:
pairs_with_conf = pairs.flatMap(lambda pair: [(tuple(pair[0])[0],(tuple(pair[0])[1],pair[1]))])
pairs_with_conf.take(5)

[('H1_Public information campaigns',
  ('C8EV_International travel controls', 1128)),
 ('H2_Testing policy', ('C8EV_International travel controls', 1173)),
 ('H3_Contact tracing', ('H1_Public information campaigns', 1128)),
 ('H3_Contact tracing', ('H2_Testing policy', 1173)),
 ('H2_Testing policy', ('H8M_Protection of elderly people', 1075))]

In [43]:
pairs_with_confidence = pairs_with_conf.join(single) #Join with X to get the count of X from single
pairs_with_confidence.take(5)

[('C6M_Stay at home requirements',
  (('C4M_Restrictions on gatherings', 614), 724)),
 ('C6M_Stay at home requirements',
  (('C7M_Restrictions on internal movement', 715), 724)),
 ('C6M_Stay at home requirements',
  (('C8EV_International travel controls', 724), 724)),
 ('C6M_Stay at home requirements', (('E1_Income support', 659), 724)),
 ('C6M_Stay at home requirements', (('C2M_Workplace closing', 701), 724))]

In [44]:
pairs_with_confidence = pairs_with_confidence.flatMap(lambda pair: [(pair[0],pair[1][0][0],pair[1][0][1]/pair[1][1])])
pairs_with_confidence.take(10)

[('C6M_Stay at home requirements',
  'C4M_Restrictions on gatherings',
  0.8480662983425414),
 ('C6M_Stay at home requirements',
  'C7M_Restrictions on internal movement',
  0.9875690607734806),
 ('C6M_Stay at home requirements', 'C8EV_International travel controls', 1.0),
 ('C6M_Stay at home requirements', 'E1_Income support', 0.9102209944751382),
 ('C6M_Stay at home requirements',
  'C2M_Workplace closing',
  0.9682320441988951),
 ('C6M_Stay at home requirements', 'C1M_School closing', 0.9875690607734806),
 ('C6M_Stay at home requirements', 'H8M_Protection of elderly people', 1.0),
 ('C6M_Stay at home requirements',
  'H7_Vaccination policy',
  0.3729281767955801),
 ('H7_Vaccination policy', 'E1_Income support', 0.8169934640522876),
 ('H7_Vaccination policy',
  'C4M_Restrictions on gatherings',
  0.3954248366013072)]

In [45]:
pairs_to_show = pairs_with_confidence.collect()
columns = ["X", "Y","Confidence"]
sorted_pairs = spark.createDataFrame(pairs_to_show, columns)
sorted_pairs = sorted_pairs.orderBy(col("Confidence").desc(), col("X").asc(), col("Y").asc())
sorted_pairs.show(100)

+--------------------+--------------------+------------------+
|                   X|                   Y|        Confidence|
+--------------------+--------------------+------------------+
|C3M_Cancel public...|C8EV_Internationa...|               1.0|
|C3M_Cancel public...|H1_Public informa...|               1.0|
|C3M_Cancel public...|   H2_Testing policy|               1.0|
|C3M_Cancel public...|H8M_Protection of...|               1.0|
|C5M_Close public ...|C8EV_Internationa...|               1.0|
|C5M_Close public ...|H1_Public informa...|               1.0|
|C5M_Close public ...|   H2_Testing policy|               1.0|
|C6M_Stay at home ...|C8EV_Internationa...|               1.0|
|C6M_Stay at home ...|H8M_Protection of...|               1.0|
|   E1_Income support|C8EV_Internationa...|               1.0|
|   E1_Income support|H8M_Protection of...|               1.0|
|E2_Debt/contract ...|C8EV_Internationa...|               1.0|
|E2_Debt/contract ...|H1_Public informa...|            

In [46]:
top5 = sorted_pairs.limit(5)
top5.show()

+--------------------+--------------------+----------+
|                   X|                   Y|Confidence|
+--------------------+--------------------+----------+
|C3M_Cancel public...|C8EV_Internationa...|       1.0|
|C3M_Cancel public...|H1_Public informa...|       1.0|
|C3M_Cancel public...|   H2_Testing policy|       1.0|
|C3M_Cancel public...|H8M_Protection of...|       1.0|
|C5M_Close public ...|C8EV_Internationa...|       1.0|
+--------------------+--------------------+----------+



Y=>X

In [47]:
pairs_with_conf = pairs.flatMap(lambda pair: [(tuple(pair[0])[1],(tuple(pair[0])[0],pair[1]))])
pairs_with_confidence = pairs_with_conf.join(single)#Join with Y to get the count of Y from single
pairs_with_confidence = pairs_with_confidence.flatMap(lambda pair: [(pair[0],pair[1][0][0],pair[1][0][1]/pair[1][1])])


In [48]:
pairs_to_show = pairs_with_confidence.collect()
columns = ["Y", "X","Confidence"]
sorted_pairs = spark.createDataFrame(pairs_to_show, columns)
sorted_pairs = sorted_pairs.orderBy(col("Confidence").desc(), col("Y").asc(), col("X").asc())
sorted_pairs.show(100)

+--------------------+--------------------+------------------+
|                   Y|                   X|        Confidence|
+--------------------+--------------------+------------------+
|  C1M_School closing|C8EV_Internationa...|               1.0|
|  C1M_School closing|H1_Public informa...|               1.0|
|  C1M_School closing|   H2_Testing policy|               1.0|
|  C1M_School closing|  H3_Contact tracing|               1.0|
|  C1M_School closing|H8M_Protection of...|               1.0|
|C2M_Workplace clo...|C8EV_Internationa...|               1.0|
|C2M_Workplace clo...|H1_Public informa...|               1.0|
|C2M_Workplace clo...|   H2_Testing policy|               1.0|
|C2M_Workplace clo...|  H3_Contact tracing|               1.0|
|C2M_Workplace clo...|H8M_Protection of...|               1.0|
|C3M_Cancel public...|  H3_Contact tracing|               1.0|
|C4M_Restrictions ...|  C1M_School closing|               1.0|
|C4M_Restrictions ...|C3M_Cancel public...|            

Formatting dataframe to output

In [49]:
break_point_row = spark.createDataFrame([Row(X='Y', Y='X',Confidence="Confidence")])
top5 = top5.union(break_point_row)
top5.show()

+--------------------+--------------------+----------+
|                   X|                   Y|Confidence|
+--------------------+--------------------+----------+
|C3M_Cancel public...|C8EV_Internationa...|       1.0|
|C3M_Cancel public...|H1_Public informa...|       1.0|
|C3M_Cancel public...|   H2_Testing policy|       1.0|
|C3M_Cancel public...|H8M_Protection of...|       1.0|
|C5M_Close public ...|C8EV_Internationa...|       1.0|
|                   Y|                   X|Confidence|
+--------------------+--------------------+----------+



In [50]:
top5_YX = sorted_pairs.limit(5)
top5_YX.show()

+------------------+--------------------+----------+
|                 Y|                   X|Confidence|
+------------------+--------------------+----------+
|C1M_School closing|C8EV_Internationa...|       1.0|
|C1M_School closing|H1_Public informa...|       1.0|
|C1M_School closing|   H2_Testing policy|       1.0|
|C1M_School closing|  H3_Contact tracing|       1.0|
|C1M_School closing|H8M_Protection of...|       1.0|
+------------------+--------------------+----------+



In [51]:
top5_YX_renamed = top5_YX.select(
    top5_YX.Y.alias("X"),
    top5_YX.X.alias("Y"),
    top5_YX.Confidence
)
top5_YX_renamed.show()

+------------------+--------------------+----------+
|                 X|                   Y|Confidence|
+------------------+--------------------+----------+
|C1M_School closing|C8EV_Internationa...|       1.0|
|C1M_School closing|H1_Public informa...|       1.0|
|C1M_School closing|   H2_Testing policy|       1.0|
|C1M_School closing|  H3_Contact tracing|       1.0|
|C1M_School closing|H8M_Protection of...|       1.0|
+------------------+--------------------+----------+



In [52]:
top5 = top5.union(top5_YX_renamed)
top5.show()

+--------------------+--------------------+----------+
|                   X|                   Y|Confidence|
+--------------------+--------------------+----------+
|C3M_Cancel public...|C8EV_Internationa...|       1.0|
|C3M_Cancel public...|H1_Public informa...|       1.0|
|C3M_Cancel public...|   H2_Testing policy|       1.0|
|C3M_Cancel public...|H8M_Protection of...|       1.0|
|C5M_Close public ...|C8EV_Internationa...|       1.0|
|                   Y|                   X|Confidence|
|  C1M_School closing|C8EV_Internationa...|       1.0|
|  C1M_School closing|H1_Public informa...|       1.0|
|  C1M_School closing|   H2_Testing policy|       1.0|
|  C1M_School closing|  H3_Contact tracing|       1.0|
|  C1M_School closing|H8M_Protection of...|       1.0|
+--------------------+--------------------+----------+



In [53]:
df_pandas = top5.toPandas()
df_pandas.to_csv("/content/Q3_b.txt", sep='\t', index=False)

Triplets

In [54]:
triplets = apriori(df, 100, 3)
triplets.take(10)

[(frozenset({'C8EV_International travel controls',
             'H1_Public information campaigns',
             'H3_Contact tracing'}),
  1128),
 (frozenset({'C8EV_International travel controls',
             'H2_Testing policy',
             'H3_Contact tracing'}),
  1173),
 (frozenset({'H1_Public information campaigns',
             'H3_Contact tracing',
             'H8M_Protection of elderly people'}),
  1074),
 (frozenset({'C3M_Cancel public events',
             'C4M_Restrictions on gatherings',
             'H1_Public information campaigns'}),
  657),
 (frozenset({'C7M_Restrictions on internal movement',
             'H1_Public information campaigns',
             'H3_Contact tracing'}),
  914),
 (frozenset({'C1M_School closing',
             'C2M_Workplace closing',
             'H1_Public information campaigns'}),
  805),
 (frozenset({'C1M_School closing', 'H2_Testing policy', 'H3_Contact tracing'}),
  1017),
 (frozenset({'C4M_Restrictions on gatherings',
             'C6M_Sta

(X,Y)=>Z

In [55]:
triplets_with_conf = triplets.flatMap(lambda triplet: [(frozenset((tuple(triplet[0])[0],tuple(triplet[0])[1])),(tuple(triplet[0])[2],triplet[1]))])
triplets_with_conf.take(5)

[(frozenset({'H1_Public information campaigns', 'H3_Contact tracing'}),
  ('C8EV_International travel controls', 1128)),
 (frozenset({'H2_Testing policy', 'H3_Contact tracing'}),
  ('C8EV_International travel controls', 1173)),
 (frozenset({'H1_Public information campaigns', 'H3_Contact tracing'}),
  ('H8M_Protection of elderly people', 1074)),
 (frozenset({'C3M_Cancel public events', 'H1_Public information campaigns'}),
  ('C4M_Restrictions on gatherings', 657)),
 (frozenset({'H1_Public information campaigns', 'H3_Contact tracing'}),
  ('C7M_Restrictions on internal movement', 914))]

In [56]:
triplets_with_confidence = triplets_with_conf.join(pairs) #Join with X,Y to get the count of X,Y from pairs
triplets_with_confidence.take(5)

[(frozenset({'H1_Public information campaigns', 'H3_Contact tracing'}),
  (('C8EV_International travel controls', 1128), 1128)),
 (frozenset({'H1_Public information campaigns', 'H3_Contact tracing'}),
  (('H8M_Protection of elderly people', 1074), 1128)),
 (frozenset({'H1_Public information campaigns', 'H3_Contact tracing'}),
  (('C7M_Restrictions on internal movement', 914), 1128)),
 (frozenset({'H1_Public information campaigns', 'H3_Contact tracing'}),
  (('C1M_School closing', 1017), 1128)),
 (frozenset({'H1_Public information campaigns', 'H3_Contact tracing'}),
  (('C6M_Stay at home requirements', 724), 1128))]

In [57]:
triplets_with_confidence = triplets_with_confidence.flatMap(lambda triplet: [(tuple(triplet[0])[0],tuple(triplet[0])[1],triplet[1][0][0],triplet[1][0][1]/triplet[1][1])])
triplets_with_confidence.take(5)

[('H3_Contact tracing',
  'H1_Public information campaigns',
  'C8EV_International travel controls',
  1.0),
 ('H3_Contact tracing',
  'H1_Public information campaigns',
  'H8M_Protection of elderly people',
  0.9521276595744681),
 ('H3_Contact tracing',
  'H1_Public information campaigns',
  'C7M_Restrictions on internal movement',
  0.8102836879432624),
 ('H3_Contact tracing',
  'H1_Public information campaigns',
  'C1M_School closing',
  0.901595744680851),
 ('H3_Contact tracing',
  'H1_Public information campaigns',
  'C6M_Stay at home requirements',
  0.6418439716312057)]

In [58]:
triplets_to_show = triplets_with_confidence.collect()
columns = ["X", "Y","Z","Confidence"]
sorted_triplets = spark.createDataFrame(triplets_to_show, columns)
sorted_triplets = sorted_triplets.orderBy(col("Confidence").desc(), col("X").asc(), col("Y").asc(), col("Z").asc())
sorted_triplets.show(50)

+--------------------+--------------------+--------------------+----------+
|                   X|                   Y|                   Z|Confidence|
+--------------------+--------------------+--------------------+----------+
|C3M_Cancel public...|C6M_Stay at home ...|C8EV_Internationa...|       1.0|
|C3M_Cancel public...|C6M_Stay at home ...|H8M_Protection of...|       1.0|
|C3M_Cancel public...|   E1_Income support|C8EV_Internationa...|       1.0|
|C3M_Cancel public...|   E1_Income support|H8M_Protection of...|       1.0|
|C3M_Cancel public...|E2_Debt/contract ...|C8EV_Internationa...|       1.0|
|C3M_Cancel public...|E2_Debt/contract ...|H1_Public informa...|       1.0|
|C3M_Cancel public...|E2_Debt/contract ...|   H2_Testing policy|       1.0|
|C3M_Cancel public...|E2_Debt/contract ...|H8M_Protection of...|       1.0|
|C3M_Cancel public...|H1_Public informa...|C8EV_Internationa...|       1.0|
|C3M_Cancel public...|H1_Public informa...|H8M_Protection of...|       1.0|
|C3M_Cancel 

In [59]:
top5_triplets = sorted_triplets.limit(5)
top5_triplets.show()

+--------------------+--------------------+--------------------+----------+
|                   X|                   Y|                   Z|Confidence|
+--------------------+--------------------+--------------------+----------+
|C3M_Cancel public...|C6M_Stay at home ...|C8EV_Internationa...|       1.0|
|C3M_Cancel public...|C6M_Stay at home ...|H8M_Protection of...|       1.0|
|C3M_Cancel public...|   E1_Income support|C8EV_Internationa...|       1.0|
|C3M_Cancel public...|   E1_Income support|H8M_Protection of...|       1.0|
|C3M_Cancel public...|E2_Debt/contract ...|C8EV_Internationa...|       1.0|
+--------------------+--------------------+--------------------+----------+



(X,Z) => Y

In [60]:
triplets_with_conf = triplets.flatMap(lambda triplet: [(frozenset((tuple(triplet[0])[0],tuple(triplet[0])[2])),(tuple(triplet[0])[1],triplet[1]))])
triplets_with_confidence = triplets_with_conf.join(pairs)
triplets_with_confidence = triplets_with_confidence.flatMap(lambda triplet: [(tuple(triplet[0])[0],tuple(triplet[0])[1],triplet[1][0][0],triplet[1][0][1]/triplet[1][1])])


In [61]:
triplets_to_show = triplets_with_confidence.collect()
columns = ["X", "Z","Y","Confidence"]
sorted_triplets = spark.createDataFrame(triplets_to_show, columns)
sorted_triplets = sorted_triplets.orderBy(col("Confidence").desc(), col("X").asc(), col("Z").asc(), col("Y").asc())
sorted_triplets.show(50)

+--------------------+--------------------+--------------------+----------+
|                   X|                   Z|                   Y|Confidence|
+--------------------+--------------------+--------------------+----------+
|  C1M_School closing|C7M_Restrictions ...|H1_Public informa...|       1.0|
|  C1M_School closing|C7M_Restrictions ...|   H2_Testing policy|       1.0|
|  C1M_School closing|C7M_Restrictions ...|H8M_Protection of...|       1.0|
|  C1M_School closing|C8EV_Internationa...|H8M_Protection of...|       1.0|
|C3M_Cancel public...|  C1M_School closing|C8EV_Internationa...|       1.0|
|C3M_Cancel public...|  C1M_School closing|H1_Public informa...|       1.0|
|C3M_Cancel public...|  C1M_School closing|   H2_Testing policy|       1.0|
|C3M_Cancel public...|  C1M_School closing|H8M_Protection of...|       1.0|
|C3M_Cancel public...|C2M_Workplace clo...|C8EV_Internationa...|       1.0|
|C3M_Cancel public...|C2M_Workplace clo...|H1_Public informa...|       1.0|
|C3M_Cancel 

Formatting Dataframe to output

In [62]:
break_point_row = spark.createDataFrame([Row(X='X', Y='Z',Z='Y',Confidence="Confidence")])
top5_triplets = top5_triplets.union(break_point_row)
top5_triplets.show()

+--------------------+--------------------+--------------------+----------+
|                   X|                   Y|                   Z|Confidence|
+--------------------+--------------------+--------------------+----------+
|C3M_Cancel public...|C6M_Stay at home ...|C8EV_Internationa...|       1.0|
|C3M_Cancel public...|C6M_Stay at home ...|H8M_Protection of...|       1.0|
|C3M_Cancel public...|   E1_Income support|C8EV_Internationa...|       1.0|
|C3M_Cancel public...|   E1_Income support|H8M_Protection of...|       1.0|
|C3M_Cancel public...|E2_Debt/contract ...|C8EV_Internationa...|       1.0|
|                   X|                   Z|                   Y|Confidence|
+--------------------+--------------------+--------------------+----------+



In [63]:
top5_XZY = sorted_triplets.limit(5)
top5_XZY.show()

+--------------------+--------------------+--------------------+----------+
|                   X|                   Z|                   Y|Confidence|
+--------------------+--------------------+--------------------+----------+
|  C1M_School closing|C7M_Restrictions ...|H1_Public informa...|       1.0|
|  C1M_School closing|C7M_Restrictions ...|   H2_Testing policy|       1.0|
|  C1M_School closing|C7M_Restrictions ...|H8M_Protection of...|       1.0|
|  C1M_School closing|C8EV_Internationa...|H8M_Protection of...|       1.0|
|C3M_Cancel public...|  C1M_School closing|C8EV_Internationa...|       1.0|
+--------------------+--------------------+--------------------+----------+



In [64]:
top5_XZY_renamed = top5_XZY.select(
    top5_XZY.X,
    top5_XZY.Z.alias("Y"),
    top5_XZY.Y.alias("Z"),
    top5_XZY.Confidence
)
top5_XZY_renamed.show()

+--------------------+--------------------+--------------------+----------+
|                   X|                   Y|                   Z|Confidence|
+--------------------+--------------------+--------------------+----------+
|  C1M_School closing|C7M_Restrictions ...|H1_Public informa...|       1.0|
|  C1M_School closing|C7M_Restrictions ...|   H2_Testing policy|       1.0|
|  C1M_School closing|C7M_Restrictions ...|H8M_Protection of...|       1.0|
|  C1M_School closing|C8EV_Internationa...|H8M_Protection of...|       1.0|
|C3M_Cancel public...|  C1M_School closing|C8EV_Internationa...|       1.0|
+--------------------+--------------------+--------------------+----------+



In [65]:
top5_triplets = top5_triplets.union(top5_XZY_renamed)
top5_triplets.show()

+--------------------+--------------------+--------------------+----------+
|                   X|                   Y|                   Z|Confidence|
+--------------------+--------------------+--------------------+----------+
|C3M_Cancel public...|C6M_Stay at home ...|C8EV_Internationa...|       1.0|
|C3M_Cancel public...|C6M_Stay at home ...|H8M_Protection of...|       1.0|
|C3M_Cancel public...|   E1_Income support|C8EV_Internationa...|       1.0|
|C3M_Cancel public...|   E1_Income support|H8M_Protection of...|       1.0|
|C3M_Cancel public...|E2_Debt/contract ...|C8EV_Internationa...|       1.0|
|                   X|                   Z|                   Y|Confidence|
|  C1M_School closing|C7M_Restrictions ...|H1_Public informa...|       1.0|
|  C1M_School closing|C7M_Restrictions ...|   H2_Testing policy|       1.0|
|  C1M_School closing|C7M_Restrictions ...|H8M_Protection of...|       1.0|
|  C1M_School closing|C8EV_Internationa...|H8M_Protection of...|       1.0|
|C3M_Cancel 

(Y,Z) => X

In [66]:

triplets_with_conf = triplets.flatMap(lambda triplet: [(frozenset((tuple(triplet[0])[1],tuple(triplet[0])[2])),(tuple(triplet[0])[0],triplet[1]))])
triplets_with_confidence = triplets_with_conf.join(pairs)
triplets_with_confidence = triplets_with_confidence.flatMap(lambda triplet: [(tuple(triplet[0])[0],tuple(triplet[0])[1],triplet[1][0][0],triplet[1][0][1]/triplet[1][1])])


In [67]:
triplets_to_show = triplets_with_confidence.collect()
columns = ["Y", "Z","X","Confidence"]
sorted_triplets = spark.createDataFrame(triplets_to_show, columns)
sorted_triplets = sorted_triplets.orderBy(col("Confidence").desc(), col("Y").asc(), col("Z").asc(), col("X").asc())
sorted_triplets.show(50)

+--------------------+--------------------+--------------------+----------+
|                   Y|                   Z|                   X|Confidence|
+--------------------+--------------------+--------------------+----------+
|  C1M_School closing|C2M_Workplace clo...|C8EV_Internationa...|       1.0|
|  C1M_School closing|C2M_Workplace clo...|H1_Public informa...|       1.0|
|  C1M_School closing|C2M_Workplace clo...|   H2_Testing policy|       1.0|
|  C1M_School closing|C2M_Workplace clo...|  H3_Contact tracing|       1.0|
|  C1M_School closing|C2M_Workplace clo...|H8M_Protection of...|       1.0|
|  C1M_School closing|C8EV_Internationa...|H1_Public informa...|       1.0|
|  C1M_School closing|C8EV_Internationa...|   H2_Testing policy|       1.0|
|  C1M_School closing|C8EV_Internationa...|  H3_Contact tracing|       1.0|
|C3M_Cancel public...|  C1M_School closing|  H3_Contact tracing|       1.0|
|C3M_Cancel public...|C2M_Workplace clo...|  H3_Contact tracing|       1.0|
|C3M_Cancel 

Formatting Dataframe to output

In [68]:
break_point_row = spark.createDataFrame([Row(X='Y', Y='Z',Z='X',Confidence="Confidence")])
top5_triplets = top5_triplets.union(break_point_row)
top5_triplets.show()

+--------------------+--------------------+--------------------+----------+
|                   X|                   Y|                   Z|Confidence|
+--------------------+--------------------+--------------------+----------+
|C3M_Cancel public...|C6M_Stay at home ...|C8EV_Internationa...|       1.0|
|C3M_Cancel public...|C6M_Stay at home ...|H8M_Protection of...|       1.0|
|C3M_Cancel public...|   E1_Income support|C8EV_Internationa...|       1.0|
|C3M_Cancel public...|   E1_Income support|H8M_Protection of...|       1.0|
|C3M_Cancel public...|E2_Debt/contract ...|C8EV_Internationa...|       1.0|
|                   X|                   Z|                   Y|Confidence|
|  C1M_School closing|C7M_Restrictions ...|H1_Public informa...|       1.0|
|  C1M_School closing|C7M_Restrictions ...|   H2_Testing policy|       1.0|
|  C1M_School closing|C7M_Restrictions ...|H8M_Protection of...|       1.0|
|  C1M_School closing|C8EV_Internationa...|H8M_Protection of...|       1.0|
|C3M_Cancel 

In [69]:
top5_YZX = sorted_triplets.limit(5)
top5_YZX.show()

+------------------+--------------------+--------------------+----------+
|                 Y|                   Z|                   X|Confidence|
+------------------+--------------------+--------------------+----------+
|C1M_School closing|C2M_Workplace clo...|C8EV_Internationa...|       1.0|
|C1M_School closing|C2M_Workplace clo...|H1_Public informa...|       1.0|
|C1M_School closing|C2M_Workplace clo...|   H2_Testing policy|       1.0|
|C1M_School closing|C2M_Workplace clo...|  H3_Contact tracing|       1.0|
|C1M_School closing|C2M_Workplace clo...|H8M_Protection of...|       1.0|
+------------------+--------------------+--------------------+----------+



In [70]:
top5_YZX_renamed = top5_YZX.select(
    top5_YZX.Y.alias("X"),
    top5_YZX.Z.alias("Y"),
    top5_YZX.X.alias("Z"),
    top5_YZX.Confidence
)
top5_YZX_renamed.show()

+------------------+--------------------+--------------------+----------+
|                 X|                   Y|                   Z|Confidence|
+------------------+--------------------+--------------------+----------+
|C1M_School closing|C2M_Workplace clo...|C8EV_Internationa...|       1.0|
|C1M_School closing|C2M_Workplace clo...|H1_Public informa...|       1.0|
|C1M_School closing|C2M_Workplace clo...|   H2_Testing policy|       1.0|
|C1M_School closing|C2M_Workplace clo...|  H3_Contact tracing|       1.0|
|C1M_School closing|C2M_Workplace clo...|H8M_Protection of...|       1.0|
+------------------+--------------------+--------------------+----------+



In [71]:
top5_triplets = top5_triplets.union(top5_YZX_renamed)
top5_triplets.show()

+--------------------+--------------------+--------------------+----------+
|                   X|                   Y|                   Z|Confidence|
+--------------------+--------------------+--------------------+----------+
|C3M_Cancel public...|C6M_Stay at home ...|C8EV_Internationa...|       1.0|
|C3M_Cancel public...|C6M_Stay at home ...|H8M_Protection of...|       1.0|
|C3M_Cancel public...|   E1_Income support|C8EV_Internationa...|       1.0|
|C3M_Cancel public...|   E1_Income support|H8M_Protection of...|       1.0|
|C3M_Cancel public...|E2_Debt/contract ...|C8EV_Internationa...|       1.0|
|                   X|                   Z|                   Y|Confidence|
|  C1M_School closing|C7M_Restrictions ...|H1_Public informa...|       1.0|
|  C1M_School closing|C7M_Restrictions ...|   H2_Testing policy|       1.0|
|  C1M_School closing|C7M_Restrictions ...|H8M_Protection of...|       1.0|
|  C1M_School closing|C8EV_Internationa...|H8M_Protection of...|       1.0|
|C3M_Cancel 

In [72]:
df_pandas_triplets = top5_triplets.toPandas()
df_pandas_triplets.to_csv("/content/Q3_c.txt", sep='\t', index=False)