In [None]:
# %load KNNn.py
from heapq import nsmallest
from collections import Counter
from operator import itemgetter, methodcaller
import csv
import math

class DataPoint:
    def __init__(self, latitude, longitude, ETR, METSTAT) -> None:
        self.location = (latitude, longitude)
        
        # labels
        if ETR < 320:
            self.ETR = '<320'
        elif ETR < 360:
            self.ETR = '320 ~ 360'
        else:
            self.ETR = '>360'

        if METSTAT < 60:
            self.METSTAT = '[0, 60)' 
        elif METSTAT < 70:
            self.METSTAT = '[60, 70)'
        else:
            self.METSTAT = '[70, ∞)'


    def dist(self, other) -> float:
        return math.dist(self.location, other.location)

class KNeighborsClassifier:
    def __init__(self, n_neighbors=3) -> None:
        self.K = n_neighbors
        self.trainingsets = []

    def add_datapoint(self, datapoint) -> None:
        self.trainingsets.append(datapoint)

    def predict(self, label, testingpoint) -> str:
        votes_by_label = Counter()
        closest_neighbors = nsmallest(self.K, self.trainingsets, key=methodcaller('dist', testingpoint))
        for x in closest_neighbors:
            votes_by_label[getattr(x, label)] += 1
        return votes_by_label.most_common(1)[0][0]

def classify(trainingsets, testingsets, k=3, label='ETR'):
    knn = KNeighborsClassifier(3)
    trainingsets, testingsets = readcsvfile('Solar_energy(1).csv')
    for sample in trainingsets:
        knn.add_datapoint(DataPoint(*sample))

    error_cnt = 0
    for sample in testingsets:
        testingpoint = DataPoint(*sample)
        if knn.predict(label, testingpoint) != getattr(testingpoint, label):
            error_cnt += 1

    return error_cnt

def readcsvfile(filepath):
    rows = []
    with open(filepath, newline='') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            latitude = float(row['Latitude'])
            longitude = float(row['Longitude'])
            ETR = float(row['Avg hourly ETR in 2005 (Wh/m^2)'])
            METSTA = float(row['Avg hourly METSTAT in 2005 (Wh/m^2)'])
            rows.append([row['State'], latitude, longitude, ETR, METSTA])

    rows.sort(key=itemgetter(0,3)) # sort data by State, then ETR, to cover most locations
    trainingsets = []
    testingsets = []
    
    for n, row in enumerate(rows):
        n_mod37 = n % 37
        if n_mod37 % 3 < 2 and n_mod37 != 0:
            trainingsets.append(row[1:])
        else:
            testingsets.append(row[1:])

    return trainingsets, testingsets         

def main():
    trainingsets, testingsets = readcsvfile('Solar_energy(1).csv')
    
    error_cnt = classify(trainingsets, testingsets, k=3, label='ETR')
    print(f'Number of errors (ETR): {error_cnt} out of {len(testingsets)} testing points')

    error_cnt = classify(trainingsets, testingsets, k=3, label='METSTAT')
    print(f'Number of errors (METSTAT): {error_cnt} out of {len(testingsets)} testing points')
    
if __name__ == '__main__':
    main()



