from nvflare.apis.dxo import from_shareable, DXO, DataKind
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.executor import Executor
from nvflare.apis.shareable import Shareable
import pandas as pd
import pickle
from sklearn.linear_model import LinearRegression
import numpy as np

class LinearRegressionTrainer(Executor):
    def __init__(self):
        super().__init__()
        self.model = LinearRegression()

    def train_local_model(self, csv_path):
        df = pd.read_csv(csv_path)
        X = df[['x']].values
        y = df['y'].values
        self.model.fit(X, y)
        return self.model.coef_, self.model.intercept_

    def execute(self, task_name, shareable: Shareable, fl_ctx):
        if task_name != "train":
            return make_reply(ReturnCode.TASK_UNKNOWN)

        site_name = fl_ctx.get_prop(FLContextKey.CLIENT_NAME)
        csv_path = f"./{site_name}.csv"  # expects data_part1.csv, data_part2.csv renamed

        coef, intercept = self.train_local_model(csv_path)

        weights = {
            "coef": coef,
            "intercept": np.array([intercept])
        }

        dxo = DXO(data_kind=DataKind.WEIGHTS, data=weights)
        return dxo.to_shareable()

