## 0. 准备及导入数据

#### 1. 需要创建chicago_taxi_trips数据库，然后导入chicago_taxi_trips_2016_12.csv到表chicago_taxi_trips_2016_12

#### 2. 对chicago_taxi_trips数据库安装madlib

##### 也可以选用已经安装madlib的postgres库，下面需要修改库和表；

## 1. 连接及加载数据插件

In [1]:
%load_ext sql

In [2]:
# Greenplum 4.3.10.0
#%sql postgresql://gpdbchina@10.194.10.68:61000/madlib
        
# PostgreSQL local
%sql postgresql://postgres:postgres@localhost:5432/chicago_taxi_trips

'Connected: postgres@chicago_taxi_trips'

In [3]:
%sql select madlib.version();
#%sql select version();

 * postgresql://postgres:***@localhost:5432/chicago_taxi_trips
1 rows affected.


version
"MADlib version: 1.16, git revision: unknown, cmake configuration time: Tue Jul 2 20:42:19 UTC 2019, build type: Release, build system: Linux-4.9.125-linuxkit, C compiler: gcc 7, C++ compiler: g++ 7"


## 2. 准备数据

In [4]:
%%sql
select taxi_id, pickup_latitude, pickup_longitude from chicago_taxi_trips_2016_12 limit 5;

 * postgresql://postgres:***@localhost:5432/chicago_taxi_trips
5 rows affected.


taxi_id,pickup_latitude,pickup_longitude
5400,688.0,206.0
1257,618.0,407.0
5998,64.0,231.0
2538,170.0,351.0
5856,767.0,733.0


In [5]:
%%sql
alter table chicago_taxi_trips_2016_12 alter taxi_id type int USING taxi_id::integer;
ALTER TABLE chicago_taxi_trips_2016_12 ALTER COLUMN pickup_latitude TYPE decimal(10, 2) USING pickup_latitude::numeric(10,2);
ALTER TABLE chicago_taxi_trips_2016_12 ALTER COLUMN pickup_longitude TYPE decimal(10, 2) USING pickup_longitude::numeric(10,2);


 * postgresql://postgres:***@localhost:5432/chicago_taxi_trips
Done.
Done.
Done.


[]

In [6]:
%%sql 
drop table if exists t_source;

select taxi_id, pickup_latitude, pickup_longitude into t_source from chicago_taxi_trips_2016_12;


 * postgresql://postgres:***@localhost:5432/chicago_taxi_trips
Done.
440919 rows affected.


[]

In [7]:
%sql select taxi_id, pickup_latitude, pickup_longitude from t_source limit 5;

 * postgresql://postgres:***@localhost:5432/chicago_taxi_trips
5 rows affected.


taxi_id,pickup_latitude,pickup_longitude
5400,688.0,206.0
1257,618.0,407.0
5998,64.0,231.0
2538,170.0,351.0
5856,767.0,733.0


In [8]:
%%sql
drop table if exists mat;    
create table mat (id integer,    
                  row_vec double precision[] );


 * postgresql://postgres:***@localhost:5432/chicago_taxi_trips
Done.
Done.


[]

In [9]:
%%sql 
select *, (pickup_latitude, pickup_longitude) from t_source order by taxi_id limit 5; 


 * postgresql://postgres:***@localhost:5432/chicago_taxi_trips
5 rows affected.


taxi_id,pickup_latitude,pickup_longitude,row
3,18.0,610.0,"(18.00,610.00)"
3,170.0,351.0,"(170.00,351.00)"
3,18.0,610.0,"(18.00,610.00)"
3,170.0,351.0,"(170.00,351.00)"
3,18.0,610.0,"(18.00,610.00)"


In [10]:
%%sql
drop table if exists t_source_change;

create table t_source_change
(row_id serial,
 taxi_id int,
 pickup_latitude decimal(10, 2),  
 pickup_longitude decimal(10, 2));  

insert into t_source_change (taxi_id,pickup_latitude,pickup_longitude)   
select taxi_id,   
       pickup_latitude,  
       pickup_longitude
from t_source 


 * postgresql://postgres:***@localhost:5432/chicago_taxi_trips
Done.
Done.
440919 rows affected.


[]

In [11]:
%sql select * from t_source_change limit 5;

 * postgresql://postgres:***@localhost:5432/chicago_taxi_trips
5 rows affected.


row_id,taxi_id,pickup_latitude,pickup_longitude
1,5400,688.0,206.0
2,1257,618.0,407.0
3,5998,64.0,231.0
4,2538,170.0,351.0
5,5856,767.0,733.0


In [12]:
%%sql 
drop table if exists km_sample;    
create table km_sample (id integer,    
                  row_vec double precision[]);

insert into km_sample select row_id, array_cat(array[pickup_latitude], array[pickup_longitude]) from t_source_change;

select * from km_sample limit 10;

 * postgresql://postgres:***@localhost:5432/chicago_taxi_trips
Done.
Done.
440919 rows affected.
10 rows affected.


id,row_vec
1,"[688.0, 206.0]"
2,"[618.0, 407.0]"
3,"[64.0, 231.0]"
4,"[170.0, 351.0]"
5,"[767.0, 733.0]"
6,"[294.0, 113.0]"
7,"[225.0, 6.0]"
8,"[618.0, 407.0]"
9,"[411.0, 545.0]"
10,"[18.0, 610.0]"


In [13]:
%%sql 
DROP TABLE IF EXISTS km_result;

-- Run kmeans algorithm
CREATE TABLE km_result AS
SELECT * FROM madlib.kmeanspp( 'km_sample',   -- Table of source data
                               'row_vec',      -- Column containing point co-ordinates 
                               5,             -- Number of centroids to calculate
                               'madlib.squared_dist_norm2',   -- Distance function
                               'madlib.avg',  -- Aggregate function
                               20,            -- Number of iterations
                               0.001          -- Fraction of centroids reassigned to keep iterating 
                             );

SELECT * FROM km_result;


 * postgresql://postgres:***@localhost:5432/chicago_taxi_trips
Done.
1 rows affected.
1 rows affected.


centroids,cluster_variance,objective_fn,frac_reassigned,num_iterations
"[[422.46189511644, 551.183833085506], [208.003800704858, 118.628999378066], [96.3450231987061, 569.482355996836], [733.996321609409, 527.633057832474], [679.370322387762, 197.439782749514]]","[900624591.639007, 1014645122.92625, 2932201395.50422, 1164587445.87134, 332465291.891639]",6344523847.83245,0.0004105062381072,6


## 3. Calculate the simplified silhouette coefficient:

In [14]:
%%sql
SELECT * FROM madlib.simple_silhouette( 'km_sample',          -- Input points table
                                        'row_vec',             -- Column containing points
                                        (SELECT centroids FROM km_result),  -- Centroids
                                        'madlib.dist_norm2'   -- Distance function
                                      );

 * postgresql://postgres:***@localhost:5432/chicago_taxi_trips
1 rows affected.


simple_silhouette
0.644737662781543


## 4. Find the cluster assignment for each point:

In [15]:
%%sql
SELECT data.*,  (madlib.closest_column(centroids, row_vec)).column_id as cluster_id
FROM km_sample as data, km_result
ORDER BY data.id limit 10;

 * postgresql://postgres:***@localhost:5432/chicago_taxi_trips
10 rows affected.


id,row_vec,cluster_id
1,"[688.0, 206.0]",4
2,"[618.0, 407.0]",3
3,"[64.0, 231.0]",1
4,"[170.0, 351.0]",2
5,"[767.0, 733.0]",3
6,"[294.0, 113.0]",1
7,"[225.0, 6.0]",1
8,"[618.0, 407.0]",3
9,"[411.0, 545.0]",0
10,"[18.0, 610.0]",2
