#### You are given a table toll_log which stores every vehicle crossing a toll plaza. Each row represents one crossing of a vehicle through the toll gate. Toll Rules
Motorcycle → Free <br>
Car → 40 <br>
Bus → 70 <br>
Truck → 80 <br>
Return Discount Rule <br>
If the same vehicle returns within 4 hours: <br>
Car → 20 <br>
Bus → 30 <br>
Truck → 40 <br>
Otherwise, charge the full toll again.

### Write another SQL query to calculate the total toll collected per day.

### Ans : 360


In [0]:
from pyspark.sql import functions as fs
from pyspark.sql import Window as wn

data = [
    (1, 'DL01AA1111', 'car',        '2026-01-10 08:00:00'),
    (2, 'DL01AA1111', 'car',        '2026-01-10 11:00:00'),
    (3, 'DL02BB2222', 'truck',      '2026-01-10 09:00:00'),
    (4, 'DL02BB2222', 'truck',      '2026-01-10 16:00:00'),
    (5, 'DL03CC3333', 'bus',        '2026-01-10 10:00:00'),
    (6, 'DL03CC3333', 'bus',        '2026-01-10 12:00:00'),
    (7, 'DL04DD4444', 'motorcycle', '2026-01-10 11:00:00'),
    (8, 'DL05EE5555', 'car',        '2026-01-10 14:00:00')
]

columns = ["txn_id", "vehicle_no", "vehicle_type", "crossing_time"]

df = spark.createDataFrame(data, columns) \
          .withColumn("crossing_time", fs.to_timestamp("crossing_time"))


In [0]:
df.display()

In [0]:
def get_log_history():

    return  (spark.table('toll_log')
 .select("vehicle_no","vehicle_type","crossing_time")
 .withColumn("pre_time", fs.lag('crossing_time').over(wn.partitionBy("vehicle_no").orderBy("crossing_time")))
)
    
def get_discounted_price():
    pre_df = get_log_history()
    return pre_df.withColumn(
    'total_amount',
    fs.when(fs.col('vehicle_type') == 'motorcycle', 0)
    .when(
        fs.col('pre_time').isNotNull() &
        (fs.unix_timestamp('crossing_time') - fs.unix_timestamp('pre_time') <= 4 * 3600),
        fs.when(fs.col('vehicle_type') == 'car', 20)
         .when(fs.col('vehicle_type') == 'bus', 30)
         .when(fs.col('vehicle_type') == 'truck', 40)
    )
    .otherwise(
        fs.when(fs.col('vehicle_type') == 'car', 40)
         .when(fs.col('vehicle_type') == 'bus', 70)
         .when(fs.col('vehicle_type') == 'truck', 80)
    )
)


In [0]:
get_log_history().display()

In [0]:
df = get_log_history()
discounted_df = get_discounted_price()

In [0]:
discounted_df.groupBy('vehicle_no').agg(fs.sum('total_amount').alias('total_price')).display()

In [0]:
discounted_df.groupBy().agg(fs.sum('total_amount').alias('total_price')).display()

###In SQL

In [0]:
%sql
CREATE OR REPLACE TABLE toll_log (
    txn_id INT,
    vehicle_no STRING,
    vehicle_type STRING,
    crossing_time TIMESTAMP
);
INSERT INTO toll_log VALUES
(1, 'DL01AA1111', 'car',        '2026-01-10 08:00:00'),
(2, 'DL01AA1111', 'car',        '2026-01-10 11:00:00'),
(3, 'DL02BB2222', 'truck',      '2026-01-10 09:00:00'),
(4, 'DL02BB2222', 'truck',      '2026-01-10 16:00:00'),
(5, 'DL03CC3333', 'bus',        '2026-01-10 10:00:00'),
(6, 'DL03CC3333', 'bus',        '2026-01-10 12:00:00'),
(7, 'DL04DD4444', 'motorcycle', '2026-01-10 11:00:00'),
(8, 'DL05EE5555', 'car',        '2026-01-10 14:00:00');




In [0]:
%sql
select * from toll_log;

In [0]:
%sql
with pre_data as
(
SELECT vehicle_no,vehicle_type,crossing_time,
LAG(crossing_time) over (partition by vehicle_no order by crossing_time) as pre_time
from toll_log
),
post_data as
(
SELECT vehicle_no,vehicle_type,crossing_time,
CASE 
  WHEN vehicle_type = 'motorcycle' THEN 0
    -- WHEN pre_time is not null and date_add(HOUR,4,crossing_time) >= pre_time THEN
    WHEN pre_time is not null and TIMESTAMPDIFF(HOUR,pre_time,crossing_time) <= 4 THEN
        CASE 
          when vehicle_type = 'car' THEN 20
          when vehicle_type = 'bus' THEN 30
          when vehicle_type = 'truck' THEN 40
        end
    else
    case
      when vehicle_type = 'car' THEN 40
          when vehicle_type = 'bus' THEN 70
          when vehicle_type = 'truck' THEN 80
      end
  end as price
from pre_data

)

SELECT date(crossing_time) as date,SUM(price) as total_price
from post_data
group by date
