In [None]:
import sys
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from PyQt5.QtWidgets import QApplication, QMainWindow, QFileDialog, QPushButton, QLabel, QVBoxLayout, QWidget

class StockPredictionApp(QMainWindow):
    def __init__(self):
        super().__init__()

        self.file_path = None
        self.model = None

        self.setWindowTitle("Stock Price Prediction")
        
        # Stock Heading Label
        self.heading_label = QLabel(self)
        self.heading_label.setText("Stock Heading: ")

        # File Selection Button
        self.file_button = QPushButton("Select Stock Data", self)
        self.file_button.clicked.connect(self.load_data)

        # Train Model Button
        self.train_button = QPushButton("Train Model", self)
        self.train_button.clicked.connect(self.train_model)
        self.train_button.setEnabled(False)

        # Prediction Button
        self.predict_button = QPushButton("Make Prediction", self)
        self.predict_button.clicked.connect(self.make_prediction)
        self.predict_button.setEnabled(False)

        # Prediction Result Label
        self.result_label = QLabel(self)
        self.result_label.setText("")

        # Matplotlib Figure for Plotting
        self.figure, self.ax = plt.subplots()
        self.canvas = FigureCanvas(self.figure)
        self.ax.set_xlabel('Days')
        self.ax.set_ylabel('Stock Price')
        self.ax.set_title('Stock Price Prediction')

        # Layout Setup
        central_widget = QWidget(self)
        layout = QVBoxLayout(central_widget)
        layout.addWidget(self.heading_label)
        layout.addWidget(self.file_button)
        layout.addWidget(self.train_button)
        layout.addWidget(self.predict_button)
        layout.addWidget(self.result_label)
        layout.addWidget(self.canvas)
        self.setCentralWidget(central_widget)

    def load_data(self):
        self.file_path, _ = QFileDialog.getOpenFileName(self,"Select Stock Data","","CSV files (*.csv)")
        if self.file_path:
            try:
                data = pd.read_csv(self.file_path, nrows=1)  # Read only the first row (header)
                heading = data.columns[0] if not data.empty and len(data.columns) > 0 else "No Heading"
            except pd.errors.EmptyDataError:
                heading = "No Heading"

            self.heading_label.setText(f"Stock Heading: {heading}")
            self.train_button.setEnabled(True)


    def train_model(self):
        data = pd.read_csv(self.file_path)
    
    # Create feature ('Close') and target variable ('Close_shifted')
        data['Close_shifted'] = data['Close'].shift(-1)
    
    # Drop the last row to ensure consistent lengths
        data = data.dropna()
    
        X = data[['Close']].values
        y = data['Close_shifted'].values
    
        X_train, _, y_train, _ = train_test_split(X, y, test_size=0.2, random_state=42)
    
        self.model = LinearRegression()
        self.model.fit(X_train, y_train)
        self.predict_button.setEnabled(True)


    def make_prediction(self):
        if self.model:
            test_data = pd.read_csv(self.file_path)
            X_test = test_data[['Close']].values
            prediction = self.model.predict(X_test[-1].reshape(1, -1))
            self.result_label.setText(f"Predicted Price: {prediction[0]:.2f}")
            self.plot_predictions(test_data['Close'], prediction)
        else:
            self.result_label.setText("Model not trained yet.")

    def plot_predictions(self, actual_prices, predicted_prices):
        self.ax.clear()
        self.ax.plot(actual_prices, label='Actual Prices')
        self.ax.plot(len(actual_prices) - 1, predicted_prices, 'ro', label='Predicted Price')
        self.ax.legend()
        self.canvas.draw()


if __name__ == "__main__":
    app = QApplication(sys.argv)
    window = StockPredictionApp()
    window.resize(800, 600)
    window.show()
    sys.exit(app.exec_())
