diff --git a/.github/workflows/ross-test.yml b/.github/workflows/ross-test.yml new file mode 100644 index 0000000..d548a1d --- /dev/null +++ b/.github/workflows/ross-test.yml @@ -0,0 +1,31 @@ +name: ROSS Test + +on: + push: + branches-ignore: + - master + pull_request: + branches: + - v2 + +jobs: + build-linux: + runs-on: ubuntu-latest + strategy: + max-parallel: 5 + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.9 + uses: actions/setup-python@v3 + with: + python-version: '3.9' + - name: Add conda to system path + run: | + # $CONDA is an environment variable pointing to the root of the miniconda directory + echo $CONDA/bin >> $GITHUB_PATH + - name: Install dependencies + run: | + conda env update --file environment.yml --name base + - name: Test with pytest + run: | + python -m pytest -q -rf --tb=short --cov-report term-missing --cov=./ diff --git a/.gitignore b/.gitignore index 8ac4db9..0570184 100644 --- a/.gitignore +++ b/.gitignore @@ -407,3 +407,7 @@ compile_commands.json data /ross_backend/data.db /ross_ui/ross_data/ +ross_backend/ross_data/ +ross_ui/.tmp +*.mat +*.pickle \ No newline at end of file diff --git a/README.md b/README.md index 8dc8058..b351040 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,28 @@ # ROSS v2 + ![image](./images/Ross_Color.png) -ROSS v2 (beta) is the Python version of offline spike sorting software implemented based on the methods described in the paper entitled [An automatic spike sorting algorithm based on adaptive spike detection and a mixture of skew-t distributions](https://www.nature.com/articles/s41598-021-93088-w). (Official Python Implementation) +![ROSS Test](https://github.com/ramintoosi/ROSS/actions/workflows/ross-test.yml/badge.svg) + +ROSS v2 (alpha) is the Python version of offline spike sorting software implemented based on the methods described in the +paper +entitled [An automatic spike sorting algorithm based on adaptive spike detection and a mixture of skew-t distributions](https://www.nature.com/articles/s41598-021-93088-w). ( +Official Python Implementation) ### Important Note on ROSS v2 -ROSS v2 is implemented based on the client-server architecture. In the beta version, the GUI and processing units are completely separated and their connection is based on Restful APIs. However, at this moment, it only works on one machine and we try to find a good way to optimize the data transfer between the client and the server. In our final release, you would be able to run the light GUI on a simple machine while the data and algorithms would be executed on a separate server in your lab. + +ROSS v2 is implemented based on the client-server architecture. In the alpha version, the GUI and processing units are +completely separated and their connection is based on Restful APIs. Now, you are able to run the light GUI on a simple machine while the data and algorithms would be executed on a +separate server in your lab. Please carefully read the docs and check our tutorial videos. ## Requirements -- All the requirement packages are listed at enviroments.yml file in root path + +- All the requirement packages are listed at environment.yml file in root path ## How to install + 1. Git Clone this repository to your local path ```git clone https://github.com/ramintoosi/ROSS``` -2. Checkout to v2 ```git checkout v2``` +2. Checkout to v2 ```git checkout v2``` 3. Create a conda enviroment by command : ```conda env create -f environment.yml``` 4. Activate conda environment ```conda activate ross``` @@ -19,54 +30,78 @@ ROSS v2 is implemented based on the client-server architecture. In the beta vers 1. Run the backend by typing ```python ./ross_backend/app.py``` in the terminal. 2. Run the UI by typing ```python ./ross_ui/main.py``` in the terminal. + +**Note:** If you have a separate server, run ```step 1``` in your server and ```step 2``` in your personal computer. 3. The first time you want to use the software, you must define a user as follows: - - In opened window, click on ```Options``` ---> ```Sign In/Up``` , enter the desired username and password, click on ```Sign Up```. +- In opened window, click on ```Options``` ---> ```Sign In/Up``` , enter the desired username and password, click + on ```Sign Up```. -4. The next time you want to use the software, just click on ```Options``` ---> ```Sign In/Up``` and enter your username and password, click on ```Sign In``` . +4. The next time you want to use the software, just click on ```Options``` ---> ```Sign In/Up``` and enter your username + and password, click on ```Sign In``` . -5. Import your "Raw Data" as follows : +5. Import your "Raw Data" as follows : - - In opened window, click on ```File``` ---> ```Import``` ---> ```Raw Data``` , select the data file from your system, then, select the variable containing raw-data ```(Spike)``` and click on ```OK```. +- In opened window, click on ```File``` ---> ```Import``` ---> ```Raw Data``` , select the data file from your system, + then, select the variable containing raw-data ```(Spike)``` and click on ```OK```. -6. Now you can Go on and enjoy the Software. +6. Enjoy the Software! For more instructions and samples please visit ROSS documentation at (link), or demo video at (link). + ## Usage -ROSS v2, same as v1, provides useful tools for spike detection, automatic and manual sorting. +ROSS v2, same as v1, provides useful tools for spike detection, automatic and manual sorting. - Detection - You can load raw extracellular data and adjust the provided settings for filtering and thresholding. Then by pushing **Start Detection** button the detection results appear in a PCA plot: - + You can load raw extracellular data and adjust the provided settings for filtering and thresholding. Then by pushing * + *Start Detection** button the detection results appear in a PCA plot: + ![image](./images/detection.png) - Automatic Sorting - Automatic sorting is implemented based on five different methods: skew-t and t distributions, GMM, k-means and template matching. Several options are provided for configurations in the algorithm. Automatic sorting results will appear in PCA and waveform plots: - - ![image](./images/sort.png) + Automatic sorting is implemented based on five different methods: skew-t and t distributions, GMM, k-means and + template matching. Several options are provided for configurations in the algorithm. Automatic sorting results will + appear in PCA and waveform plots: + +![image](./images/sort.png) - Manual Sorting - Manual sorting tool is used for manual modifications on automatic results by the researcher. These tools include: Merge, Delete, Resort and Manual grouping or deleting samples in PCA domain: - + Manual sorting tool is used for manual modifications on automatic results by the researcher. These tools include: + Merge, Delete, Resort and Manual grouping or deleting samples in PCA domain: + ![image](./images/sort2.png) - Visualization - - - Several visualization tools are provided such as: 3D plot - + + - Several visualization tools are provided such as: 3D plot + ![image](./images/vis1.png) - - - Also, inter spike interval, neuron live time and Cluster Waveforms - - ![image](./images/vis2.png) - + - Also, inter spike interval, neuron live time and Cluster Waveforms + ![image](./images/vis2.png) +# Citation +If ROSS helps your research, please cite our paper in your publications. + +``` +@article{Toosi_2021, + doi = {10.1038/s41598-021-93088-w}, + url = {https://doi.org/10.1038%2Fs41598-021-93088-w}, + year = 2021, + month = {jul}, + publisher = {Springer Science and Business Media {LLC}}, + volume = {11}, + number = {1}, + author = {Ramin Toosi and Mohammad Ali Akhaee and Mohammad-Reza A. Dehaqani}, + title = {An automatic spike sorting algorithm based on adaptive spike detection and a mixture of skew-t distributions}, + journal = {Scientific Reports} +} +``` diff --git a/environment.yml b/environment.yml index 7f25da1..f9136b0 100644 --- a/environment.yml +++ b/environment.yml @@ -46,42 +46,58 @@ dependencies: - xz=5.2.5=h7b6447c_0 - zlib=1.2.11=h7f8727e_4 - pip: - - aniso8601==9.0.1 - - cffi==1.15.1 - - charset-normalizer==2.0.10 - - colour==0.1.5 - - cryptography==39.0.0 - - cycler==0.11.0 - - flask-jwt==0.3.2 - - flask-jwt-extended==3.0.0 - - flask-restful==0.3.8 - - flask-sqlalchemy==2.5.1 - - greenlet==1.1.2 - - h5py==3.6.0 - - idna==3.3 - - joblib==1.1.0 - - kiwisolver==1.3.2 - - matplotlib==3.4.3 - - nptdms==1.4.0 - - numpy==1.22.1 - - opencv-python==4.6.0.66 - - pillow==9.0.0 - - pycparser==2.21 - - pyjwt==1.4.2 - - pyopengl==3.1.5 - - pyopenssl==23.0.0 - - pyparsing==3.0.6 - - pyqt5==5.15.6 - - pyqt5-qt5==5.15.2 - - pyqt5-sip==12.9.0 - - pyqtgraph==0.13.1 - - pytz==2021.3 - - pywavelets==1.2.0 - - pyyawt==0.1.1 - - requests==2.26.0 - - scikit-learn==1.0.1 - - scipy==1.7.1 - - sip==6.5.0 - - sqlalchemy==1.4.29 - - threadpoolctl==3.0.0 - - urllib3==1.26.8 \ No newline at end of file + - aniso8601==9.0.1 + - astroid==2.15.3 + - cffi==1.15.1 + - charset-normalizer==2.0.10 + - colour==0.1.5 + - coverage==7.2.7 + - cryptography==39.0.0 + - cycler==0.11.0 + - dill==0.3.6 + - exceptiongroup==1.1.1 + - flask-jwt==0.3.2 + - flask-jwt-extended==3.0.0 + - flask-restful==0.3.8 + - flask-sqlalchemy==2.5.1 + - greenlet==1.1.2 + - h5py==3.6.0 + - idna==3.3 + - iniconfig==2.0.0 + - isort==5.12.0 + - joblib==1.1.0 + - kiwisolver==1.3.2 + - lazy-object-proxy==1.9.0 + - matplotlib==3.4.3 + - mccabe==0.7.0 + - nptdms==1.4.0 + - numpy==1.22.1 + - opencv-python==4.6.0.66 + - pillow==9.0.0 + - platformdirs==3.2.0 + - pluggy==1.0.0 + - pycparser==2.21 + - pyjwt==1.4.2 + - pyopengl==3.1.5 + - pyopenssl==23.0.0 + - pyparsing==3.0.6 + - pyqt5==5.15.6 + - pyqt5-qt5==5.15.2 + - pyqt5-sip==12.9.0 + - pyqtgraph==0.13.1 + - pytest==7.3.2 + - pytest-cov==4.1.0 + - pytz==2021.3 + - pywavelets==1.2.0 + - pyyawt==0.1.1 + - requests==2.26.0 + - scikit-learn==1.0.1 + - scipy==1.7.1 + - sip==6.5.0 + - sqlalchemy==1.4.29 + - threadpoolctl==3.0.0 + - tomli==2.0.1 + - tomlkit==0.11.7 + - typing-extensions==4.5.0 + - urllib3==1.26.8 + - wrapt==1.15.0 diff --git a/ross_backend/app.py b/ross_backend/app.py index f057d59..4fd45f7 100644 --- a/ross_backend/app.py +++ b/ross_backend/app.py @@ -7,13 +7,14 @@ from blacklist import BLACKLIST from flask import Flask, jsonify from flask_jwt_extended import JWTManager -from resources.sort import SortDefault, Sort -from resources.project import Project, Projects -from resources.data import RawData, RawDataDefault -from resources.detect import Detect, DetectDefault +from resources.sort import SortDefault +from resources.project import Projects +from resources.data import RawDataDefault +from resources.detect import DetectDefault from resources.sorting_result import SortingResultDefault -from resources.detection_result import DetectionResultDefault +from resources.detection_result import DetectionResult, DetectionResultSpikeMat from resources.user import UserRegister, UserLogin, User, TokenRefresh, UserLogout +from resources.browse import Browse app = Flask(__name__) app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///data.db' # 'sqlite:///:memory:' # 'sqlite:///data.db' @@ -87,19 +88,19 @@ def revoked_token_callback(): api.add_resource(UserLogout, '/logout') api.add_resource(User, '/user/') -# api.add_resource(RawData, '/raw/') api.add_resource(RawDataDefault, '/raw') api.add_resource(DetectDefault, '/detect') -# api.add_resource(Detect, '/detect/') api.add_resource(SortDefault, '/sort') -# api.add_resource(Sort, '/sort/') -api.add_resource(DetectionResultDefault, '/detection_result') +api.add_resource(DetectionResult, '/detection_result') +api.add_resource(DetectionResultSpikeMat, '/detection_result_waveform') api.add_resource(SortingResultDefault, '/sorting_result') api.add_resource(Projects, '/projects') # api.add_resource(Project, '/project/') +api.add_resource(Browse, '/browse') + if __name__ == '__main__': db.init_app(app) - app.run(port=5000, debug=False) + app.run(host='0.0.0.0', port=5000, debug=False) diff --git a/ross_backend/models/data.py b/ross_backend/models/data.py index 4dc5acf..c868799 100644 --- a/ross_backend/models/data.py +++ b/ross_backend/models/data.py @@ -1,7 +1,6 @@ -from db import db - +from __future__ import annotations -# from sqlalchemy.dialects import postgresql +from db import db class RawModel(db.Model): @@ -9,14 +8,16 @@ class RawModel(db.Model): id = db.Column(db.Integer, primary_key=True) data = db.Column(db.String) + mode = db.Column(db.Integer, default=0) # 0: inplace, 1: server user_id = db.Column(db.Integer, db.ForeignKey('users.id')) project_id = db.Column(db.Integer, db.ForeignKey('projects.id', ondelete="CASCADE")) # user = db.relationship('UserModel') # project = db.relationship('ProjectModel', backref="raw", lazy=True) - def __init__(self, user_id, data, project_id): + def __init__(self, user_id, data, project_id, mode): self.data = data + self.mode = mode self.user_id = user_id self.project_id = project_id @@ -33,11 +34,11 @@ def get(cls): return cls.query.first() @classmethod - def find_by_user_id(cls, _id): + def find_by_user_id(cls, _id) -> RawModel: return cls.query.filter_by(user_id=_id, project_id=0).first() @classmethod - def find_by_project_id(cls, project_id): + def find_by_project_id(cls, project_id) -> RawModel: return cls.query.filter_by(project_id=project_id).first() def delete_from_db(self): @@ -53,13 +54,16 @@ class DetectResultModel(db.Model): user_id = db.Column(db.Integer, db.ForeignKey('users.id')) project_id = db.Column(db.Integer, db.ForeignKey('projects.id', ondelete="CASCADE")) + mode = db.Column(db.Integer, default=0) # 0: inplace, 1: server + # user = db.relationship('UserModel') # project = db.relationship('ProjectModel', backref="raw", lazy=True) - def __init__(self, user_id, data, project_id): + def __init__(self, user_id, data, project_id, mode=0): self.data = data self.user_id = user_id self.project_id = project_id + self.mode = mode def json(self): return {'data': self.data} @@ -78,7 +82,7 @@ def get(cls): # return cls.query.filter_by(user_id=_id, project_id=0).first() @classmethod - def find_by_project_id(cls, project_id): + def find_by_project_id(cls, project_id) -> DetectResultModel: return cls.query.filter_by(project_id=project_id).first() def delete_from_db(self): @@ -93,13 +97,16 @@ class SortResultModel(db.Model): user_id = db.Column(db.Integer, db.ForeignKey('users.id')) project_id = db.Column(db.Integer, db.ForeignKey('projects.id', ondelete="CASCADE")) + mode = db.Column(db.Integer, default=0) # 0: inplace, 1: server + # user = db.relationship('UserModel') # project = db.relationship('ProjectModel', backref="raw", lazy=True) - def __init__(self, user_id, data, project_id): + def __init__(self, user_id, data, project_id, mode=0): self.data = data self.user_id = user_id self.project_id = project_id + self.mode = mode def json(self): return {'data': self.data} @@ -118,7 +125,7 @@ def get(cls): # return cls.query.filter_by(user_id=_id, project_id=0).first() @classmethod - def find_by_project_id(cls, project_id): + def find_by_project_id(cls, project_id) -> SortResultModel: return cls.query.filter_by(project_id=project_id).first() def delete_from_db(self): diff --git a/ross_backend/models/project.py b/ross_backend/models/project.py index 0227ab4..c75942a 100644 --- a/ross_backend/models/project.py +++ b/ross_backend/models/project.py @@ -7,15 +7,16 @@ class ProjectModel(db.Model): id = db.Column(db.Integer, primary_key=True) name = db.Column(db.String) user_id = db.Column(db.Integer, db.ForeignKey('users.id', ondelete="CASCADE")) - config_detect = db.relationship('ConfigDetectionModel', backref='project', uselist=False, - cascade="all,delete,delete-orphan") - config_sort = db.relationship('ConfigSortModel', backref='project', uselist=False, - cascade="all,delete,delete-orphan") - raw = db.relationship('RawModel', backref='project', uselist=False, cascade="all,delete,delete-orphan") - detection_result = db.relationship('DetectResultModel', backref='project', uselist=False, - cascade="all,delete,delete-orphan") - sorting_result = db.relationship('SortResultModel', backref='project', uselist=False, - cascade="all,delete,delete-orphan") + + # config_detect = db.relationship('ConfigDetectionModel', backref='project', uselist=False, + # cascade="all,delete,delete-orphan") + # config_sort = db.relationship('ConfigSortModel', backref='project', uselist=False, + # cascade="all,delete,delete-orphan") + # raw = db.relationship('RawModel', backref='project', uselist=False, cascade="all,delete,delete-orphan") + # detection_result = db.relationship('DetectResultModel', backref='project', uselist=False, + # cascade="all,delete,delete-orphan") + # sorting_result = db.relationship('SortResultModel', backref='project', uselist=False, + # cascade="all,delete,delete-orphan") # user = db.relationship('UserModel', backref="projects", lazy=True) # raw = db.relationship('RawModel', back_populates="project") diff --git a/ross_backend/resources/browse.py b/ross_backend/resources/browse.py new file mode 100644 index 0000000..7d33cbd --- /dev/null +++ b/ross_backend/resources/browse.py @@ -0,0 +1,24 @@ +import os +from pathlib import Path + +from flask_jwt_extended import ( + jwt_required, +) +from flask_restful import Resource, reqparse + +Raw_data_path = os.path.join(Path(__file__).parent, '../ross_data/Raw_Data') +Path(Raw_data_path).mkdir(parents=True, exist_ok=True) + + +class Browse(Resource): + parser = reqparse.RequestParser(bundle_errors=True) + parser.add_argument('root', type=str, required=False, default=str(Path.home())) + + @jwt_required + def get(self): + root = Browse.parser.parse_args()['root'] + list_of_folders = [x for x in os.listdir(root) if + os.path.isdir(os.path.join(root, x)) and not x.startswith('.')] + list_of_files = [x for x in os.listdir(root) if os.path.isfile(os.path.join(root, x)) and not x.startswith('.') + and x.endswith(('.mat', '.pkl'))] + return {'folders': list_of_folders, 'files': list_of_files, 'root': root} diff --git a/ross_backend/resources/data.py b/ross_backend/resources/data.py index 14c9dd6..accfb92 100644 --- a/ross_backend/resources/data.py +++ b/ross_backend/resources/data.py @@ -1,166 +1,93 @@ +import io +import pickle + import flask +import numpy as np from flask import request from flask_jwt_extended import jwt_required, get_jwt_identity -from flask_restful import Resource, reqparse -from models.data import RawModel -from models.project import ProjectModel - - -class RawData(Resource): - parser = reqparse.RequestParser() - parser.add_argument('raw', type=str, required=True, help="This field cannot be left blank!") - - @jwt_required - def get(self, name): - user_id = get_jwt_identity() - proj = ProjectModel.find_by_project_name(user_id, name) - if not proj: - return {'message': 'Project does not exist'}, 404 - raw = proj.raw - - # raw = RawModel.find_by_user_id(user_id) - if raw: - # b = io.BytesIO() - # b.write(raw.raw) - # b.seek(0) - - # d = np.load(b, allow_pickle=True) - # print(d['raw'].shape) - # b.close() - # print(user_id, raw.project_id) - return {'message': "Raw Data Exists."}, 201 +from flask_restful import Resource - return {'message': 'Raw Data does not exist'}, 404 +from models.data import RawModel, DetectResultModel, SortResultModel +from rutils.io import read_file_in_server - @jwt_required - def post(self, name): - user_id = get_jwt_identity() - proj = ProjectModel.find_by_project_name(user_id, name) - if not proj: - return {'message': 'Project does not exist'}, 404 - raw = proj.raw - - if raw: - return {'message': "Raw Data already exists."}, 400 - filestr = request.data - # data = RawData.parser.parse_args() - - # print(eval(data['raw']).shape) - raw = RawModel(user_id=user_id, project_id=proj.id, data=filestr) # data['raw']) - - try: - raw.save_to_db() - except: - return {"message": "An error occurred inserting raw data."}, 500 - - return "Success", 201 - - @jwt_required - def delete(self, name): - user_id = get_jwt_identity() - proj = ProjectModel.find_by_project_name(user_id, name) - if not proj: - return {'message': 'Project does not exist'}, 404 - raw = proj.raw - if raw: - raw.delete_from_db() - return {'message': 'Raw Data deleted.'} - return {'message': 'Raw Data does not exist.'}, 404 - - @jwt_required - def put(self, name): - user_id = get_jwt_identity() - proj = ProjectModel.find_by_project_name(user_id, name) - if not proj: - return {'message': 'Project does not exist'}, 404 - raw = proj.raw - filestr = request.data - if raw: - print('here') - raw.data = filestr - try: - raw.save_to_db() - except: - return {"message": "An error occurred inserting raw data."}, 500 - return "Success", 201 - - else: - raw = RawModel(user_id, data=filestr, project_id=proj.id) - try: - print('now here') - raw.save_to_db() - except: - return {"message": "An error occurred inserting raw data."}, 500 - - return "Success", 201 +SESSION = dict() class RawDataDefault(Resource): - parser = reqparse.RequestParser() - parser.add_argument('raw', type=str, required=True, help="This field cannot be left blank!") @jwt_required def get(self): - user_id = get_jwt_identity() - # user = UserModel.find_by_id(user_id) - project_id = request.form['project_id'] + project_id = request.json['project_id'] raw = RawModel.find_by_project_id(project_id) if raw: - response = flask.make_response(raw.data) - response.headers.set('Content-Type', 'application/octet-stream') - return response + if raw.mode == 0: + response = flask.make_response(raw.data) + response.headers.set('Content-Type', 'application/octet-stream') + response.status_code = 210 + return response + else: + if request.json['start'] is None: + return {'message': 'SERVER MODE'}, 212 + if project_id in SESSION: + raw_data = SESSION[project_id] + else: + with open(raw.data, 'rb') as f: + raw_data = pickle.load(f) + start = request.json['start'] + stop = request.json['stop'] + limit = request.json['limit'] + stop = min(len(raw_data), stop) + + ds = int((stop - start) / limit) + 1 + visible = raw_data[start:stop:ds] + + buffer = io.BytesIO() + np.savez_compressed(buffer, visible=visible, stop=stop, ds=ds) + buffer.seek(0) + data = buffer.read() + buffer.close() + response = flask.make_response(data) + response.headers.set('Content-Type', 'application/octet-stream') + response.status_code = 211 + return response return {'message': 'Raw Data does not exist'}, 404 @jwt_required def post(self): - filestr = request.data - user_id = get_jwt_identity() - if RawModel.find_by_user_id(user_id): - return {'message': "Raw Data already exists."}, 400 - - # data = RawData.parser.parse_args() - - # print(eval(data['raw']).shape) - raw = RawModel(user_id=user_id, data=filestr) # data['raw'] - - try: - raw.save_to_db() - except: - return {"message": "An error occurred inserting raw data."}, 500 - - return "Success", 201 - - @jwt_required - def delete(self): + raw_data = request.json['raw_data'] user_id = get_jwt_identity() - raw = RawModel.find_by_user_id(user_id) - if raw: - raw.delete_from_db() - return {'message': 'Raw Data deleted.'} - return {'message': 'Raw Data does not exist.'}, 404 - - @jwt_required - def put(self): - filestr = request.form['raw_bytes'] - user_id = get_jwt_identity() - project_id = request.form['project_id'] + project_id = request.json['project_id'] + mode = request.json['mode'] raw = RawModel.find_by_project_id(project_id) - if raw: - raw.data = filestr - print("raw data in if ", raw) + if mode == 0: + raw_data_path = raw_data + elif mode == 1: try: - print("raw data in try ", raw) - raw.save_to_db() - except: - return {"message": "An error occurred inserting raw data."}, 500 - return "Success", 201 + raw_data_path, SESSION[project_id] = read_file_in_server(request.json) + except TypeError as e: + return {"message": str(e)}, 400 + except ValueError as e: + return {"message": str(e)}, 400 + except KeyError: + return {"message": 'Provided variable name is incorrect'}, 400 + else: + return {"message": f"Mode {mode} not supported"}, 400 + if raw: + raw.data = raw_data_path + raw.mode = mode else: - raw = RawModel(user_id, data=filestr, project_id=project_id) + raw = RawModel(user_id, data=raw_data_path, project_id=project_id, mode=mode) try: raw.save_to_db() - except: - return {"message": "An error occurred inserting raw data."}, 500 + detect_result = DetectResultModel.find_by_project_id(project_id) + if detect_result: + detect_result.delete_from_db() + sort_result = SortResultModel.find_by_project_id(project_id) + if sort_result: + sort_result.delete_from_db() + except Exception as e: + return {"message": str(e)}, 500 return "Success", 201 diff --git a/ross_backend/resources/detect.py b/ross_backend/resources/detect.py index 75cdc90..5af8d36 100644 --- a/ross_backend/resources/detect.py +++ b/ross_backend/resources/detect.py @@ -1,111 +1,18 @@ import pickle import traceback -from uuid import uuid4 from pathlib import Path +from uuid import uuid4 from flask_jwt_extended import jwt_required, get_jwt_identity from flask_restful import Resource, reqparse, request + from models.config import ConfigDetectionModel from models.data import DetectResultModel -from models.project import ProjectModel from models.data import RawModel +from resources.detection_result import SESSION from resources.funcs.detection import startDetection -class Detect(Resource): - parser = reqparse.RequestParser(bundle_errors=True) - parser.add_argument('filter_type', type=str, required=True, choices=('butter')) - parser.add_argument('filter_order', type=int, required=True) - parser.add_argument('pass_freq', type=int, required=True) - parser.add_argument('stop_freq', type=int, required=True) - parser.add_argument('sampling_rate', type=int, required=True) - parser.add_argument('thr_method', type=str, required=True, choices=('median', 'wavelet', 'plexon')) - parser.add_argument('side_thr', type=str, required=True, choices=('positive', 'negative', 'two')) - parser.add_argument('pre_thr', type=int, required=True) - parser.add_argument('post_thr', type=int, required=True) - parser.add_argument('dead_time', type=int, required=True) - parser.add_argument('run_detection', type=bool, default=False) - - @jwt_required - def get(self, name): - user_id = get_jwt_identity() - proj = ProjectModel.find_by_project_name(user_id, name) - if not proj: - return {'message': 'Project does not exist'}, 404 - config = proj.config - if config: - return config.json() - return {'message': 'Detection config does not exist'}, 404 - - @jwt_required - def post(self, name): - user_id = get_jwt_identity() - proj = ProjectModel.find_by_project_name(user_id, name) - if not proj: - return {'message': 'Project does not exist'}, 404 - config = proj.config - if config: - return {'message': "Config Detection already exists."}, 400 - - data = Detect.parser.parse_args() - - config = ConfigDetectionModel(user_id, **data, project_id=proj.id) - - try: - config.save_to_db() - except: - return {"message": "An error occurred inserting detection config."}, 500 - - # if data['run_detection']: - # try: - # print('starting Detection ...') - # startDetection() - # except: - # return {"message": "An error occurred in detection."}, 500 - - return config.json(), 201 - - @jwt_required - def delete(self, name): - user_id = get_jwt_identity() - proj = ProjectModel.find_by_project_name(user_id, name) - if not proj: - return {'message': 'Project does not exist'}, 404 - config = proj.config - if config: - config.delete_from_db() - return {'message': 'Detection config deleted.'} - return {'message': 'Detection config does not exist.'}, 404 - - @jwt_required - def put(self, name): - data = Detect.parser.parse_args() - user_id = get_jwt_identity() - proj = ProjectModel.find_by_project_name(user_id, name) - if not proj: - return {'message': 'Project does not exist'}, 404 - config = proj.config - if config: - for key in data: - config.key = data[key] - try: - config.save_to_db() - except: - return {"message": "An error occurred inserting detection config."}, 500 - - return config.json(), 201 - - else: - config = ConfigDetectionModel(user_id, **data, project_id=proj.id) - - try: - config.save_to_db() - except: - return {"message": "An error occurred inserting detection config."}, 500 - - return config.json(), 201 - - class DetectDefault(Resource): parser = reqparse.RequestParser(bundle_errors=True) parser.add_argument('filter_type', type=str, required=True, choices='butterworth') @@ -119,13 +26,10 @@ class DetectDefault(Resource): parser.add_argument('post_thr', type=int, required=True) parser.add_argument('dead_time', type=int, required=True) parser.add_argument('run_detection', type=bool, default=False) - # -------------------------------------------- parser.add_argument('project_id', type=int, default=False) @jwt_required def get(self): - user_id = get_jwt_identity() - # user = UserModel.find_by_id(user_id) project_id = request.form['project_id'] config = ConfigDetectionModel.find_by_project_id(project_id) if config: @@ -134,38 +38,6 @@ def get(self): @jwt_required def post(self): - user_id = get_jwt_identity() - if ConfigDetectionModel.find_by_user_id(user_id): - return {'message': "Detection config already exists."}, 400 - - data = Detect.parser.parse_args() - config = ConfigDetectionModel(user_id, **data) - - try: - config.save_to_db() - except: - return {"message": "An error occurred inserting detection config."}, 500 - - if data['run_detection']: - try: - print('starting Detection ...') - startDetection(user_id) - except: - return {"message": "An error occurred in detection."}, 500 - - return config.json(), 201 - - @jwt_required - def delete(self): - user_id = get_jwt_identity() - config = ConfigDetectionModel.find_by_user_id(user_id) - if config: - config.delete_from_db() - return {'message': 'Detection config deleted.'} - return {'message': 'Detection config does not exist.'}, 404 - - @jwt_required - def put(self): data = DetectDefault.parser.parse_args() project_id = data['project_id'] user_id = get_jwt_identity() @@ -195,34 +67,52 @@ def put(self): except: print(traceback.format_exc()) return {"message": "An error occurred inserting detection config."}, 500 - if data['run_detection']: - try: - spikeMat, spikeTime = startDetection(project_id) - data_file = {'spikeMat': spikeMat, 'spikeTime': spikeTime} - # ------------------------------------------------------- - print("inserting detection result to database") + # if data['run_detection']: + try: - detection_result_path = str(Path(RawModel.find_by_project_id(project_id).data).parent / \ - (str(uuid4()) + '.pkl')) + raw = RawModel.find_by_project_id(project_id) + + if not raw: + return {'message': 'No raw file'}, 404 + config = ConfigDetectionModel.find_by_project_id(project_id) + if not config: + return {'message': 'Detection config does not exist'}, 404 + if not raw.data: + return {'message': 'raw file has no data'}, 404 + + with open(raw.data, 'rb') as f: + data = pickle.load(f) + + spikeMat, spikeTime, pca_spikes, inds = startDetection(data, config) - with open(detection_result_path, 'wb') as f: - pickle.dump(data_file, f) + data_file = {'spikeMat': spikeMat, 'spikeTime': spikeTime, + 'config': config.json(), 'pca_spikes': pca_spikes, + 'inds': inds} + # ------------------------------------------------------- + print("inserting detection result to database") - detectResult = DetectResultModel.find_by_project_id(project_id) + detection_result_path = str(Path(RawModel.find_by_project_id(project_id).data).parent / + (str(uuid4()) + '.pkl')) + + with open(detection_result_path, 'wb') as f: + pickle.dump(data_file, f) - if detectResult: - detectResult.data = detection_result_path - else: - detectResult = DetectResultModel(user_id, detection_result_path, project_id) + detectResult = DetectResultModel.find_by_project_id(project_id) - try: - detectResult.save_to_db() - except: - print(traceback.format_exc()) - return {"message": "An error occurred inserting detection result."}, 500 + if detectResult: + detectResult.data = detection_result_path + else: + detectResult = DetectResultModel(user_id, detection_result_path, project_id) + try: + detectResult.save_to_db() + SESSION[project_id] = data_file except: print(traceback.format_exc()) - return {"message": "An error occurred in detection."}, 500 + return {"message": "An error occurred inserting detection result."}, 500 + + except: + print(traceback.format_exc()) + return {"message": "An error occurred in detection."}, 500 - return config.json(), 201 + return "Success", 201 diff --git a/ross_backend/resources/detection_result.py b/ross_backend/resources/detection_result.py index 099eab5..f782fc8 100644 --- a/ross_backend/resources/detection_result.py +++ b/ross_backend/resources/detection_result.py @@ -4,28 +4,39 @@ import flask import numpy as np from flask import request -from flask_jwt_extended import jwt_required, get_jwt_identity -from flask_restful import Resource, reqparse +from flask_jwt_extended import jwt_required +from flask_restful import Resource + from models.data import DetectResultModel +SESSION = dict() +DATA_NUM_TO_SEND = 1000 + -class DetectionResultDefault(Resource): - parser = reqparse.RequestParser() - parser.add_argument('raw', type=str, required=True, help="This field cannot be left blank!") +class DetectionResult(Resource): @jwt_required def get(self): - user_id = get_jwt_identity() - # user = UserModel.find_by_id(user_id) - project_id = request.form["project_id"] - detect_result_model = DetectResultModel.find_by_project_id(project_id) - - if detect_result_model: - with open(detect_result_model.data, 'rb') as f: - detect_result = pickle.load(f) - + project_id = int(request.form["project_id"]) + detect_result = None + if project_id in SESSION: + detect_result = SESSION[project_id] + else: + detect_result_model = DetectResultModel.find_by_project_id(project_id) + if detect_result_model: + with open(detect_result_model.data, 'rb') as f: + detect_result = pickle.load(f) + SESSION[project_id] = detect_result + + if detect_result is not None: + inds = detect_result['inds'] buffer = io.BytesIO() - np.savez_compressed(buffer, spike_mat=detect_result['spikeMat'], spike_time=detect_result['spikeTime']) + np.savez_compressed(buffer, + spike_mat=detect_result['spikeMat'][inds[:DATA_NUM_TO_SEND], :], + spike_time=detect_result['spikeTime'], + config=detect_result['config'], + pca_spikes=detect_result['pca_spikes'], + inds=inds[:DATA_NUM_TO_SEND]) buffer.seek(0) raw_bytes = buffer.read() buffer.close() @@ -36,52 +47,31 @@ def get(self): return {'message': 'Detection Result Data does not exist'}, 404 - @jwt_required - def post(self): - filestr = request.data - user_id = get_jwt_identity() - if DetectResultModel.find_by_user_id(user_id): - return {'message': "Detection Result already exists."}, 400 - # data = RawData.parser.parse_args() - # print(eval(data['raw']).shape) - data = DetectResultModel(user_id=user_id, data=filestr) # data['raw']) - try: - data.save_to_db() - except: - return {"message": "An error occurred inserting sort result data."}, 500 - - return "Success", 201 - - @jwt_required - def delete(self): - user_id = get_jwt_identity() - data = DetectResultModel.find_by_user_id(user_id) - if data: - data.delete_from_db() - return {'message': 'Detection Result Data deleted.'} - return {'message': 'Detection Result Data does not exist.'}, 404 +class DetectionResultSpikeMat(Resource): @jwt_required - def put(self): - filestr = request.data - user_id = get_jwt_identity() - data = DetectResultModel.find_by_user_id(user_id) - if data: - data.data = filestr - try: - data.save_to_db() - - except: - return {"message": "An error occurred inserting sort result data."}, 500 - return "Success", 201 - + def get(self): + project_id = int(request.json["project_id"]) + detect_result = None + if project_id in SESSION: + detect_result = SESSION[project_id] else: - data = DetectResultModel(user_id, data=filestr) + detect_result_model = DetectResultModel.find_by_project_id(project_id) + if detect_result_model: + with open(detect_result_model.data, 'rb') as f: + detect_result = pickle.load(f) + SESSION[project_id] = detect_result + + if detect_result is not None: + buffer = io.BytesIO() + np.savez_compressed(buffer, spike_mat=detect_result['spikeMat']) + buffer.seek(0) + raw_bytes = buffer.read() + buffer.close() - try: - data.save_to_db() - except: - return {"message": "An error occurred inserting sort result data."}, 500 + response = flask.make_response(raw_bytes) + response.headers.set('Content-Type', 'application/octet-stream') + return response - return "Success", 201 + return {'message': 'Detection Result Data does not exist'}, 404 diff --git a/ross_backend/resources/funcs/detection.py b/ross_backend/resources/funcs/detection.py index 1c80a05..32ee5ed 100644 --- a/ross_backend/resources/funcs/detection.py +++ b/ross_backend/resources/funcs/detection.py @@ -1,36 +1,12 @@ -import pickle +import random import numpy as np import pyyawt import scipy.signal -from models.config import ConfigDetectionModel -from models.data import RawModel +import sklearn.decomposition as decom -def startDetection(project_id): - raw = RawModel.find_by_project_id(project_id) - - if not raw: - raise Exception - config = ConfigDetectionModel.find_by_project_id(project_id) - if not config: - raise Exception - if not raw.data: - raise Exception - - # print("raw.data", raw.data) - # b = io.BytesIO() - # b.write(raw.data) - # b.seek(0) - # d = np.load(b, allow_pickle=True) - # data_address = d['raw'] - - with open(raw.data, 'rb') as f: - new_data = pickle.load(f) - data = new_data - - # b.close() - +def startDetection(data, config): thr_method = config.thr_method fRp = config.pass_freq fRs = config.stop_freq @@ -51,7 +27,13 @@ def startDetection(project_id): # spike detection SpikeMat, SpikeTime = spike_detector(data_filtered, thr, pre_thr, post_thr, dead_time, thr_side) # print("SpikeMat shape and SpikeTime shape : ", SpikeMat.shape, SpikeTime.shape) - return SpikeMat, SpikeTime + + pca = decom.PCA(n_components=3) + pca_spikes = pca.fit_transform(SpikeMat) + inds = list(range(pca_spikes.shape[0])) + random.shuffle(inds) + + return SpikeMat, SpikeTime, pca_spikes, tuple(inds) # Threshold @@ -74,11 +56,8 @@ def threshold_calculator(method, thr_factor, data): # Filtering def filtering(data, forder, fRp, fRs, sr): - # print('inside filtering!') b, a = scipy.signal.butter(forder, (fRp / (sr / 2), fRs / (sr / 2)), btype='bandpass') - # print('after b, a') data_filtered = scipy.signal.filtfilt(b, a, data) - # print('after filtfilt!') return data_filtered @@ -121,7 +100,7 @@ def spike_detector(data, thr, pre_thresh, post_thresh, dead_time, side): indx_spikes = np.nonzero(spike_detected)[0] SpikeMat = np.zeros((len(indx_spikes), n_points_per_spike)) - SpikeTime = np.zeros((len(indx_spikes), 1)) + SpikeTime = np.zeros((len(indx_spikes),), dtype=np.uint32) # assigning SpikeMat and SpikeTime matrices for i, curr_indx in enumerate(indx_spikes): diff --git a/ross_backend/resources/funcs/gmm.py b/ross_backend/resources/funcs/gmm.py index f1fae28..2d5c9be 100644 --- a/ross_backend/resources/funcs/gmm.py +++ b/ross_backend/resources/funcs/gmm.py @@ -3,27 +3,23 @@ from sklearn.mixture import GaussianMixture as GMM -def gmm_sorter(alignedSpikeMat, ss): - out = dict() - print('GMM') - +def gmm_sorter(aligned_spikemat, ss): g_max = ss.g_max g_min = ss.g_min max_iter = ss.max_iter n_cluster_range = np.arange(g_min + 1, g_max + 1) - scores = [] error = ss.error + scores = [] + for n_cluster in n_cluster_range: - clusterer = GMM(n_components=n_cluster, random_state=5, tol=error, max_iter=max_iter) - cluster_labels = clusterer.fit_predict(alignedSpikeMat) - silhouette_avg = silhouette_score(alignedSpikeMat, cluster_labels) - print('For n_cluster={0} ,score={1}'.format(n_cluster, silhouette_avg)) + cluster = GMM(n_components=n_cluster, random_state=5, tol=error, max_iter=max_iter) + cluster_labels = cluster.fit_predict(aligned_spikemat) + silhouette_avg = silhouette_score(aligned_spikemat, cluster_labels) scores.append(silhouette_avg) k = n_cluster_range[np.argmax(scores)] - clusterer = GMM(n_components=k, random_state=5, tol=error, max_iter=max_iter) - out['cluster_index'] = clusterer.fit_predict(alignedSpikeMat) - print('clusters : ', out['cluster_index']) + cluster = GMM(n_components=k, random_state=5, tol=error, max_iter=max_iter) + cluster_index = cluster.fit_predict(aligned_spikemat) - return out['cluster_index'] + return cluster_index diff --git a/ross_backend/resources/funcs/skew_t_sorter.py b/ross_backend/resources/funcs/skew_t_sorter.py index 07381ee..1fd049a 100644 --- a/ross_backend/resources/funcs/skew_t_sorter.py +++ b/ross_backend/resources/funcs/skew_t_sorter.py @@ -1,15 +1,16 @@ import numpy as np import sklearn.decomposition as decomp +from scipy import optimize +from scipy import stats +from scipy.spatial.distance import cdist +from scipy.special import gamma + from resources.funcs.fcm import FCM from resources.funcs.sort_utils import ( matrix_sqrt, dmvt_ls, d_mixedmvST, ) -from scipy import optimize -from scipy import stats -from scipy.spatial.distance import cdist -from scipy.special import gamma def skew_t_sorter(alignedSpikeMat, sdd, REM=np.array([]), INPCA=False): diff --git a/ross_backend/resources/funcs/sort_utils.py b/ross_backend/resources/funcs/sort_utils.py index 0a6da0d..080bdac 100644 --- a/ross_backend/resources/funcs/sort_utils.py +++ b/ross_backend/resources/funcs/sort_utils.py @@ -36,7 +36,7 @@ def dmvt_ls(y, mu, Sigma, landa, nu): t = scipy.stats.t(nu + p) dens = 2 * (denst) * t.cdf(np.sqrt((p + nu) / (mahalanobis_d + nu)) * np.expand_dims(np.sum( np.tile(np.expand_dims(np.linalg.lstsq(matrix_sqrt(Sigma).T, landa.T, rcond=None)[0], axis=0), (n, 1)) * ( - y - mu), axis=1), axis=1)) + y - mu), axis=1), axis=1)) return dens @@ -141,7 +141,7 @@ def spike_alignment(spike_mat, spike_time, ss): newSpikeMat = spike_mat[:, max_shift:-max_shift] # shifting time with the amount of "max_shift" because first "max_shift" samples are ignored. - time_shift = np.zeros((n_spike, 1)) + max_shift + time_shift = np.zeros((n_spike,)) + max_shift # indices of neg group ind_neg_find = np.nonzero(ind_neg)[0] diff --git a/ross_backend/resources/funcs/sorting.py b/ross_backend/resources/funcs/sorting.py index 0829d6d..de93412 100644 --- a/ross_backend/resources/funcs/sorting.py +++ b/ross_backend/resources/funcs/sorting.py @@ -1,6 +1,5 @@ import pickle -import numpy as np from models.config import ConfigSortModel from models.data import DetectResultModel from resources.funcs.gmm import * @@ -10,6 +9,15 @@ from resources.funcs.t_sorting import * +def create_cluster_time_vec(spike_time: np.ndarray, clusters, config: dict): + cluster_time_vec = np.zeros(spike_time[-1] + config['post_thr'], dtype=np.int8) + for i, t in enumerate(spike_time): + cluster_time_vec[t - config['pre_thr']: t + config['post_thr']] = clusters[i] + 1 + return cluster_time_vec + + +# TODO: combine two sorting functions into one + def startSorting(project_id): detect_result = DetectResultModel.find_by_project_id(project_id) if not detect_result: @@ -27,6 +35,7 @@ def startSorting(project_id): spike_mat = d['spikeMat'] spike_time = d['spikeTime'] + config_det = d['config'] if config.alignment: spike_mat, spike_time = spike_alignment(spike_mat, spike_time, config) @@ -46,8 +55,10 @@ def startSorting(project_id): optimal_set = kmeans(spike_mat, config) elif config.sorting_type == 'GMM': optimal_set = gmm_sorter(spike_mat, config) + else: + raise NotImplementedError(f'{config.sorting_type} not implemented') - return optimal_set + return optimal_set, create_cluster_time_vec(spike_time=d['spikeTime'], clusters=optimal_set, config=config_det) def startReSorting(project_id, clusters, selected_clusters): @@ -67,11 +78,12 @@ def startReSorting(project_id, clusters, selected_clusters): spike_mat = d['spikeMat'] spike_time = d['spikeTime'] + config_det = d['config'] + # TODO : CHECK AND CORRECT # if config.alignment: # spike_mat, spike_time = spike_alignment(spike_mat, spike_time, config) - # TODO : CHECK AND CORRECT if config.filtering: pass # REM = spikeFiltering(spike_mat, config) @@ -88,7 +100,9 @@ def startReSorting(project_id, clusters, selected_clusters): optimal_set = kmeans(spike_mat, config) elif config.sorting_type == 'GMM': optimal_set = gmm_sorter(spike_mat, config) + else: + raise NotImplementedError(f'{config.sorting_type} not implemented') clusters = np.array(clusters) clusters[np.isin(clusters, selected_clusters)] = optimal_set + np.max(clusters) + 1 - return clusters.tolist() + return clusters.tolist(), create_cluster_time_vec(spike_time=d['spikeTime'], clusters=clusters, config=config_det) diff --git a/ross_backend/resources/funcs/t_sorting.py b/ross_backend/resources/funcs/t_sorting.py index fae81fc..287ae00 100644 --- a/ross_backend/resources/funcs/t_sorting.py +++ b/ross_backend/resources/funcs/t_sorting.py @@ -1,13 +1,13 @@ import numpy as np import scipy.linalg import sklearn.decomposition as decomp -from resources.funcs.fcm import FCM from scipy.special import gamma +from resources.funcs.fcm import FCM + def t_dist_sorter(alignedSpikeMat, sdd): out = dict() - print('t-sorting started!') g_max = sdd.g_max g_min = sdd.g_min @@ -20,8 +20,6 @@ def t_dist_sorter(alignedSpikeMat, sdd): u_limit = sdd.u_lim # Initialization - print('-*-' * 20) - print('Initialization Started...') nrow = lambda x: x.shape[0] ncol = lambda x: x.shape[1] n_feat = ncol(SpikeMat) @@ -38,18 +36,12 @@ def t_dist_sorter(alignedSpikeMat, sdd): delta_L = 100 delta_v = 100 max_iter = sdd.max_iter - print('...Initialization Done!') # FCM - print('-*-' * 20) - print('FCM started...') mu, U, _ = FCM(SpikeMat, g, [2, 20, 1, 0]) - print('...FCM Done!') # Estimate starting point for Sigma and Pi from simple clustering method # performed before - print('-*-' * 20) - print('Estimating Sigma and Pi...') rep = np.reshape(np.tile(np.expand_dims(np.arange(g), 1), (1, n_spike)), (g * n_spike, 1)) rep = np.squeeze(rep) rep_data = np.tile(SpikeMat, (g, 1)) @@ -68,21 +60,11 @@ def t_dist_sorter(alignedSpikeMat, sdd): v_old = v Ltt = [] - print('...Estimation Done!') - # Running clustering algorithm for g in [g_max, ...,g_min] - print('-*-' * 20) - print('Clustering Started...') while g >= g_min: itr = 0 # EM - print('#' * 10) - print('EM Alg. Started...') - while ((delta_L > delta_L_limit) or (delta_v > delta_v_limit)) and itr < max_iter: - print('iteration number = ', itr) - # print('g = ', g) - # print('Pi = ', Pi) - # print('#' * 5) + while ((delta_L > delta_L_limit) or (delta_v > delta_v_limit)) and itr < max_iter and len(Pi) > 1: n_sigma = Sigma.shape[2] detSigma = np.zeros((1, n_sigma)) rep = np.reshape(np.tile(np.expand_dims(np.arange(g), 1), (1, n_spike)), (g * n_spike, 1)) @@ -176,7 +158,6 @@ def t_dist_sorter(alignedSpikeMat, sdd): if n_sigma > 1: P = c * np.exp(-(v + n_feat) * np.log(1 + M / v) / 2) @ np.diag(np.squeeze(detSigma)) else: - print("we are here") P = c * np.exp(-(v + n_feat) * np.log(1 + M / v) / 2) @ np.diag(detSigma[0]) # P = c * np.exp(-(v + n_feat) * np.log(1 + M / v) / 2) @ np.diag(np.squeeze(detSigma)) @@ -202,12 +183,14 @@ def t_dist_sorter(alignedSpikeMat, sdd): indx_remove = np.squeeze((Pi == 0)) mu = mu[np.logical_not(indx_remove)] + if len(mu.shape) == 3: + mu = mu[0] Sigma = Sigma[:, :, indx_remove == False] if len(Sigma.shape) > 3: Sigma = Sigma[:, :, :, 0] Pi = Pi[indx_remove == False] - if len(Pi.shape) > 2: - Pi = Pi[:, :, 0] + if len(Pi.shape) > 1: + Pi = Pi[:, 0] # Pi = np.array([d for (d, remove) in zip(Pi, indx_remove) if not remove]) z = z[:, indx_remove == False] if len(z.shape) > 2: @@ -235,13 +218,9 @@ def t_dist_sorter(alignedSpikeMat, sdd): # print('P shape : ', P.shape) # print('delta distance shape : ', delta_distance.shape) - print('...Em Done!') - if L > L_max: L_max = L out['cluster_index'] = np.argmax(z, axis=1) - - out['set.u'] = u else: break diff --git a/ross_backend/resources/project.py b/ross_backend/resources/project.py index a5c0cfa..28eeddd 100644 --- a/ross_backend/resources/project.py +++ b/ross_backend/resources/project.py @@ -1,5 +1,6 @@ from flask_jwt_extended import jwt_required, get_jwt_identity from flask_restful import Resource + from models.config import ConfigDetectionModel, ConfigSortModel from models.data import RawModel from models.project import ProjectModel diff --git a/ross_backend/resources/sort.py b/ross_backend/resources/sort.py index d885be3..7a782ab 100644 --- a/ross_backend/resources/sort.py +++ b/ross_backend/resources/sort.py @@ -1,132 +1,15 @@ import pickle import traceback -from uuid import uuid4 from pathlib import Path +from uuid import uuid4 from flask_jwt_extended import jwt_required, get_jwt_identity from flask_restful import Resource, reqparse, request + from models.config import ConfigSortModel from models.data import SortResultModel, RawModel -from models.project import ProjectModel -from models.user import UserModel from resources.funcs.sorting import startSorting, startReSorting - - -class Sort(Resource): - parser = reqparse.RequestParser(bundle_errors=True) - - # alignment settings - parser.add_argument('max_shift', type=int, required=True) - parser.add_argument('histogram_bins', type=int, required=True) - parser.add_argument('num_peaks', type=int, required=True) - parser.add_argument('compare_mode', type=str, required=True, choices=('magnitude', 'index')) - - # filtering settings - parser.add_argument('max_std', type=float, required=True) - parser.add_argument('max_mean', type=float, required=True) - parser.add_argument('max_outliers', type=float, required=True) - - # sort settings - parser.add_argument('nu', type=float, required=True) - parser.add_argument('max_iter', type=int, required=True) - parser.add_argument('PCA_num', type=int, required=True) - parser.add_argument('g_max', type=int, required=True) - parser.add_argument('g_min', type=int, required=True) - parser.add_argument('u_lim', type=float, required=True) - parser.add_argument('error', type=float, required=True) - parser.add_argument('tol', type=float, required=True) - parser.add_argument('N', type=int, required=True) - parser.add_argument('matching_mode', type=str, required=True, choices=('Euclidean', 'Chi_squared', 'Correlation')) - parser.add_argument('alpha', type=float, required=True) - parser.add_argument('combination', type=bool, required=True) - parser.add_argument('custom_template', type=bool, required=True) - parser.add_argument('sorting_type', type=str, - choices=('t dist', 'skew-t dist', 'GMM', 'K-means', 'template matching'), required=True) - parser.add_argument('max_iter', type=int, required=True) - parser.add_argument('alignment', type=bool, required=True) - parser.add_argument('filtering', type=bool, required=True) - parser.add_argument('run_sorting', type=bool, default=False) - - @jwt_required - def get(self, name): - user_id = get_jwt_identity() - proj = ProjectModel.find_by_project_name(user_id, name) - if not proj: - return {'message': 'Project does not exist'}, 404 - config = proj.config - if config: - return config.json() - return {'message': 'Detection config does not exist'}, 404 - - @jwt_required - def post(self, name): - user_id = get_jwt_identity() - proj = ProjectModel.find_by_project_name(user_id, name) - if not proj: - return {'message': 'Project does not exist'}, 404 - config = proj.config - if config: - return {'message': "Config Detection already exists."}, 400 - - data = Sort.parser.parse_args() - - config = ConfigSortModel(user_id, **data, project_id=proj.id) - - try: - config.save_to_db() - except: - return {"message": "An error occurred inserting detection config."}, 500 - - # if data['run_detection']: - # try: - # print('starting Detection ...') - # startDetection() - # except: - # return {"message": "An error occurred in detection."}, 500 - - return config.json(), 201 - - @jwt_required - def delete(self, name): - user_id = get_jwt_identity() - proj = ProjectModel.find_by_project_name(user_id, name) - if not proj: - return {'message': 'Project does not exist'}, 404 - config = proj.config - if config: - config.delete_from_db() - return {'message': 'Detection config deleted.'} - return {'message': 'Detection config does not exist.'}, 404 - - @jwt_required - def put(self, name): - data = Sort.parser.parse_args() - user_id = get_jwt_identity() - proj = ProjectModel.find_by_project_name(user_id, name) - if not proj: - return {'message': 'Project does not exist'}, 404 - config = proj.config - if config: - for key in data: - config.key = data[key] - try: - config.save_to_db() - except: - print(traceback.format_exc()) - return {"message": "An error occurred inserting detection config."}, 500 - - return config.json(), 201 - - else: - config = ConfigSortModel(user_id, **data, project_id=proj.id) - - try: - config.save_to_db() - except: - print(traceback.format_exc()) - return {"message": "An error occurred inserting detection config."}, 500 - - return config.json(), 201 +from resources.sorting_result import SESSION class SortDefault(Resource): @@ -172,23 +55,12 @@ class SortDefault(Resource): @jwt_required def get(self): - user_id = get_jwt_identity() project_id = request.form['project_id'] - user = UserModel.find_by_id(user_id) config = ConfigSortModel.find_by_project_id(project_id) if config: return config.json() return {'message': 'Sort config does not exist'}, 404 - @jwt_required - def delete(self): - user_id = get_jwt_identity() - config = ConfigSortModel.find_by_user_id(user_id) - if config: - config.delete_from_db() - return {'message': 'Sorting config deleted.'} - return {'message': 'Sorting config does not exist.'}, 404 - @jwt_required def put(self): data = SortDefault.parser.parse_args() @@ -248,12 +120,12 @@ def put(self): print('Starting Sorting ...') if clusters is not None: - clusters_index = startReSorting(project_id, clusters, selected_clusters) + clusters_index, cluster_time_vec = startReSorting(project_id, clusters, selected_clusters) return {'clusters': clusters_index}, 201 else: - clusters_index = startSorting(project_id) + clusters_index, cluster_time_vec = startSorting(project_id) - data = {"clusters": clusters_index} + data = {"clusters": clusters_index, "cluster_time_vec": cluster_time_vec} sort_result_path = str(Path(RawModel.find_by_project_id(project_id).data).parent / (str(uuid4()) + '.pkl')) @@ -270,6 +142,7 @@ def put(self): try: sortResult.save_to_db() + SESSION[project_id] = data except: print(traceback.format_exc()) return {"message": "An error occurred inserting sorting result."}, 500 diff --git a/ross_backend/resources/sorting_result.py b/ross_backend/resources/sorting_result.py index 7ab744f..19c9b78 100644 --- a/ross_backend/resources/sorting_result.py +++ b/ross_backend/resources/sorting_result.py @@ -1,6 +1,7 @@ import io import os.path import pickle +from pathlib import Path from uuid import uuid4 import flask @@ -8,7 +9,12 @@ from flask import request from flask_jwt_extended import jwt_required, get_jwt_identity from flask_restful import Resource, reqparse -from models.data import SortResultModel + +from models.data import SortResultModel, RawModel, DetectResultModel +from resources.detection_result import SESSION as DET_SESSION +from resources.funcs.sorting import create_cluster_time_vec + +SESSION = dict() class SortingResultDefault(Resource): @@ -17,16 +23,20 @@ class SortingResultDefault(Resource): @jwt_required def get(self): - user_id = get_jwt_identity() project_id = request.form["project_id"] - sort_result = SortResultModel.find_by_project_id(project_id) + sort_dict = None + if project_id in SESSION: + sort_dict = SESSION[project_id] + else: - if sort_result: - with open(sort_result.data, 'rb') as f: - sort_dict = pickle.load(f) + sort_result = SortResultModel.find_by_project_id(project_id) + + if sort_result: + with open(sort_result.data, 'rb') as f: + sort_dict = pickle.load(f) + if sort_dict is not None: buffer = io.BytesIO() - print("sort_dict['clusters']", sort_dict['clusters']) - np.savez_compressed(buffer, clusters=sort_dict['clusters']) + np.savez_compressed(buffer, clusters=sort_dict['clusters'], cluster_time_vec=sort_dict["cluster_time_vec"]) buffer.seek(0) raw_bytes = buffer.read() buffer.close() @@ -37,36 +47,6 @@ def get(self): return {'message': 'Sort Result Data does not exist'}, 404 - @jwt_required - def post(self): - filestr = request.data - user_id = get_jwt_identity() - project_id = request.form["project_id"] - if SortResultModel.find_by_project_id(project_id): - return {'message': "Detection Result already exists."}, 400 - - # data = RawData.parser.parse_args() - - # print(eval(data['raw']).shape) - data = SortResultModel(user_id=user_id, data=filestr, project_id=project_id) # data['raw']) - - try: - data.save_to_db() - except: - return {"message": "An error occurred inserting sort result data."}, 500 - - return "Success", 201 - - @jwt_required - def delete(self): - user_id = get_jwt_identity() - project_id = request.form["project_id"] - data = SortResultModel.find_by_project_id(project_id) - if data: - data.delete_from_db() - return {'message': 'Sort Result Data deleted.'} - return {'message': 'Sort Result Data does not exist.'}, 404 - @jwt_required def put(self): # filestr = request.data @@ -77,13 +57,29 @@ def put(self): b.seek(0) d = np.load(b, allow_pickle=True) - project_id = d["project_id"] + project_id = int(d["project_id"]) clusters = d["clusters"] b.close() - save_sort_result_path = '../ross_data/Sort_Result/' + str(uuid4()) + '.pkl' + detect_result = None + if project_id in DET_SESSION: + detect_result = DET_SESSION[project_id] + else: + detect_result_model = DetectResultModel.find_by_project_id(project_id) + if detect_result_model: + with open(detect_result_model.data, 'rb') as f: + detect_result = pickle.load(f) + DET_SESSION[project_id] = detect_result + if detect_result is None: + return {"message": "No Detection"}, 400 + + save_sort_result_path = str(Path(RawModel.find_by_project_id(project_id).data).parent / + (str(uuid4()) + '.pkl')) + + cluster_time_vec = create_cluster_time_vec(detect_result['spikeTime'], clusters, detect_result['config']) + with open(save_sort_result_path, 'wb') as f: - pickle.dump({"clusters": clusters}, f) + pickle.dump({"clusters": clusters, "cluster_time_vec": cluster_time_vec}, f) data = SortResultModel.find_by_project_id(int(project_id)) if data: if os.path.isfile(data.data): diff --git a/ross_backend/resources/user.py b/ross_backend/resources/user.py index 04f8934..deff117 100644 --- a/ross_backend/resources/user.py +++ b/ross_backend/resources/user.py @@ -1,4 +1,3 @@ -from blacklist import BLACKLIST from flask_jwt_extended import ( create_access_token, create_refresh_token, @@ -8,10 +7,12 @@ jwt_required ) from flask_restful import Resource, reqparse +from werkzeug.security import safe_str_cmp + +from blacklist import BLACKLIST from models.config import ConfigDetectionModel, ConfigSortModel from models.project import ProjectModel from models.user import UserModel -from werkzeug.security import safe_str_cmp _user_parser = reqparse.RequestParser() _user_parser.add_argument('username', @@ -40,8 +41,7 @@ def post(self): proj.save_to_db() user.project_default = proj.id user.save_to_db() - config_detect = ConfigDetectionModel(user_id, - project_id=proj.id) # create a default detection config for the default project + config_detect = ConfigDetectionModel(user_id, project_id=proj.id) config_detect.save_to_db() config_sort = ConfigSortModel(user_id, project_id=proj.id) config_sort.save_to_db() diff --git a/ross_backend/rutils/io.py b/ross_backend/rutils/io.py new file mode 100644 index 0000000..ede3fc0 --- /dev/null +++ b/ross_backend/rutils/io.py @@ -0,0 +1,62 @@ +import os +import pickle +from pathlib import Path +from uuid import uuid4 + +from scipy.io import loadmat +import numpy as np + +Raw_data_path = os.path.join(Path(__file__).parent, '../ross_data/Raw_Data') +Path(Raw_data_path).mkdir(parents=True, exist_ok=True) + + +def read_file_in_server(request_data: dict): + print(request_data) + if 'raw_data' in request_data and 'project_id' in request_data: + filename = request_data['raw_data'] + file_extension = os.path.splitext(filename)[-1] + if file_extension == '.mat': + file_raw = loadmat(filename) + variables = list(file_raw.keys()) + if '__version__' in variables: variables.remove('__version__') + if '__header__' in variables: variables.remove('__header__') + if '__globals__' in variables: variables.remove('__globals__') + + if len(variables) > 1: + if 'varname' in request_data: + variable = request_data['varname'] + else: + raise ValueError("More than one variable exists ") + else: + variable = variables[0] + + temp = file_raw[variable].flatten() + + elif file_extension == '.pkl': + with open(filename, 'rb') as f: + file_raw = pickle.load(f) + variables = list(file_raw.keys()) + + if len(variables) > 1: + if 'varname' in request_data: + variable = request_data['varname'] + else: + raise ValueError("More than one variable exists") + else: + variable = variables[0] + + temp = np.array(file_raw[variable]).flatten() + + else: + raise TypeError("Type not supported") + + # ------------------ save raw data as pkl file in data_set folder ----------------------- + address = os.path.join(Raw_data_path, str(uuid4()) + '.pkl') + + with open(address, 'wb') as f: + pickle.dump(temp, f) + # ---------------------------------------------------------------------------------------- + return address, temp + + else: + raise ValueError("request data is incorrect") diff --git a/ross_ui/controller/exportResults.py b/ross_ui/controller/exportResults.py new file mode 100644 index 0000000..a4fc4ae --- /dev/null +++ b/ross_ui/controller/exportResults.py @@ -0,0 +1,34 @@ +from PyQt5 import QtWidgets + +from model.api import API +from view.exportResults import ExportResults + + +class ExportResultsApp(ExportResults): + def __init__(self, api: API): + super().__init__() + self.api = api + self.data_dict = {} + self.type = 'mat' + + self.pushExport.clicked.connect(self.pushExportClicked) + self.pushClose.clicked.connect(self.reject) + + def pushExportClicked(self): + if self.checkSpikeMat.isChecked(): + self.labelDownload.setText('Download spike waveforms from server ...') + QtWidgets.QApplication.processEvents() + res = self.api.get_spike_mat() + if res['stat']: + self.labelDownload.setText('Download Done!') + QtWidgets.QApplication.processEvents() + self.data_dict['SpikeWaveform'] = res['spike_mat'] + else: + self.labelDownload.setText('Download Error!') + QtWidgets.QMessageBox.critical(self, 'Error in Download', res['message']) + if self.radioMat.isChecked(): + self.type = 'mat' + elif self.radioPickle.isChecked(): + self.type = 'pickle' + + self.accept() diff --git a/ross_ui/controller/hdf5.py b/ross_ui/controller/hdf5.py index e78ac69..7ae3277 100644 --- a/ross_ui/controller/hdf5.py +++ b/ross_ui/controller/hdf5.py @@ -1,149 +1,76 @@ -# -*- coding: utf-8 -*- -""" -In this example we create a subclass of PlotCurveItem for displaying a very large -data set from an HDF5 file that does not fit in memory. - -The basic approach is to override PlotCurveItem.viewRangeChanged such that it -reads only the portion of the HDF5 data that is necessary to display the visible -portion of the data. This is further downsampled to reduce the number of samples -being displayed. - -A more clever implementation of this class would employ some kind of caching -to avoid re-reading the entire visible waveform at every update. -""" - -# import initExample ## Add path to library (just for examples; you do not need this) - import numpy as np import pyqtgraph as pg - -# pg.mkQApp() +from model.api import API -# plt = pg.plot() -# plt.setWindowTitle('pyqtgraph example: HDF5 big data') -# plt.enableAutoRange(False, False) -# plt.setXRange(0, 500) - class HDF5Plot(pg.PlotCurveItem): + res = None + SS = None + def __init__(self, *args, **kwds): + self.cluster = None + # self.pen = None + self.api = None self.hdf5 = None - self.limit = 10000 # maximum number of samples to be plotted + self.limit = 10000 pg.PlotCurveItem.__init__(self, *args, **kwds) def setHDF5(self, data, pen=None): self.hdf5 = data self.pen = pen - self.updateHDF5Plot() + if self.pen is not None: + self.setPen(self.pen) + # self.updateHDF5Plot() + + def setAPI(self, api: API): + self.api = api def viewRangeChanged(self): self.updateHDF5Plot() + def setCluster(self, cluster): + self.cluster = cluster + def updateHDF5Plot(self): - if self.hdf5 is None: - self.setData([]) - return vb = self.getViewBox() if vb is None: - return # no ViewBox yet + return - # Determine what data range must be read from HDF5 xrange = vb.viewRange()[0] - start = max(0, int(xrange[0]) - 1) - stop = min(len(self.hdf5), int(xrange[1] + 2)) + if xrange[1] - xrange[0] < 10: + return - # Decide by how much we should downsample - ds = int((stop - start) / self.limit) + 1 + start = max(0, int(xrange[0]) - 1) - if ds == 1: - # Small enough to display with no intervention. - visible = self.hdf5[start:stop] - scale = 1 + if self.hdf5 is None: + stop = int(xrange[1] + 2) + if (HDF5Plot.SS is None) or ([start, stop] != HDF5Plot.SS): + res = self.api.get_raw_data(start, stop, self.limit) + HDF5Plot.SS = [start, stop] + HDF5Plot.res = res + else: + res = HDF5Plot.res + if not res['stat']: + self.setData([]) + return + stop = res['stop'] + visible = res['visible'].copy() + ds = res['ds'] else: - # Here convert data into a down-sampled array suitable for visualizing. - # Must do this piecewise to limit memory usage. - samples = 1 + ((stop - start) // ds) - visible = np.zeros(samples * 2, dtype=self.hdf5.dtype) - sourcePtr = start - targetPtr = 0 - - # read data in chunks of ~1M samples - chunkSize = (1000000 // ds) * ds - while sourcePtr < stop - 1: - chunk = self.hdf5[sourcePtr:min(stop, sourcePtr + chunkSize)] - sourcePtr += len(chunk) - - # reshape chunk to be integral multiple of ds - chunk = chunk[:(len(chunk) // ds) * ds].reshape(len(chunk) // ds, ds) - - # compute max and min - chunkMax = chunk.max(axis=1) - chunkMin = chunk.min(axis=1) - - # interleave min and max into plot data to preserve envelope shape - visible[targetPtr:targetPtr + chunk.shape[0] * 2:2] = chunkMin - visible[1 + targetPtr:1 + targetPtr + chunk.shape[0] * 2:2] = chunkMax - targetPtr += chunk.shape[0] * 2 - - visible = visible[:targetPtr] - scale = ds * 0.5 - - self.setData(visible) # update the plot - self.setPos(start, 0) # shift to match starting index - if self.pen is not None: - self.setPen(self.pen) + stop = min(len(self.hdf5), int(xrange[1] + 2)) + ds = int((stop - start) / self.limit) + 1 + visible = self.hdf5[start:stop:ds].copy() + + x = np.arange(start, stop, ds) + + if self.cluster is not None: + # visible = visible[self.cluster[x]] + # x = x[self.cluster[x]] + visible[np.logical_not(self.cluster[x])] = np.nan + # x[not self.cluster[x]] = np.nan + self.setData(x, visible, connect='finite') + # if self.pen is not None: + # self.setPen(self.pen) self.resetTransform() - # self.scale(scale, 1) # scale to match downsampling - -# def createFile(finalSize=2000000000): -# """Create a large HDF5 data file for testing. -# Data consists of 1M random samples tiled through the end of the array. -# """ - -# chunk = np.random.normal(size=1000000).astype(np.float32) - -# f = h5py.File('test.hdf5', 'w') -# f.create_dataset('data', data=chunk, chunks=True, maxshape=(None,)) -# data = f['data'] - -# nChunks = finalSize // (chunk.size * chunk.itemsize) -# with pg.ProgressDialog("Generating test.hdf5...", 0, nChunks) as dlg: -# for i in range(nChunks): -# newshape = [data.shape[0] + chunk.shape[0]] -# data.resize(newshape) -# data[-chunk.shape[0]:] = chunk -# dlg += 1 -# if dlg.wasCanceled(): -# f.close() -# os.remove('test.hdf5') -# sys.exit() -# dlg += 1 -# f.close() - -# if len(sys.argv) > 1: -# fileName = sys.argv[1] -# else: -# fileName = 'test.hdf5' -# if not os.path.isfile(fileName): -# size, ok = QtGui.QInputDialog.getDouble(None, "Create HDF5 Dataset?", "This demo requires a large HDF5 array. To generate a file, enter the array size (in GB) and press OK.", 2.0) -# if not ok: -# sys.exit(0) -# else: -# createFile(int(size*1e9)) -# #raise Exception("No suitable HDF5 file found. Use createFile() to generate an example file.") - -# f = h5py.File(fileName, 'r') -# curve = HDF5Plot() -# curve.setHDF5(f['data']) -# plt.addItem(curve) - - -# ## Start Qt event loop unless running in interactive mode or using pyside. -# if __name__ == '__main__': - - -# import sys -# if (sys.flags.interactive != 1) or not hasattr(QtCore, 'PYQT_VERSION'): -# QtGui.QApplication.instance().exec_() diff --git a/ross_ui/controller/mainWindow.py b/ross_ui/controller/mainWindow.py index 5c441f9..67fbbbc 100644 --- a/ross_ui/controller/mainWindow.py +++ b/ross_ui/controller/mainWindow.py @@ -1,48 +1,41 @@ import os +import pathlib import pickle import random -import time import traceback from uuid import uuid4 -import enum -import pathlib import matplotlib.pyplot as plt import numpy as np +import pandas as pd import pyqtgraph import pyqtgraph.exporters import pyqtgraph.opengl as gl import scipy.io as sio import scipy.stats as stats -import sklearn.decomposition as decom from PyQt5 import QtCore, QtWidgets, QtGui from PyQt5.QtGui import QPixmap, QTransform, QColor, QIcon from colour import Color from nptdms import TdmsFile from shapely.geometry import Point, Polygon from sklearn.neighbors import NearestNeighbors -import pandas as pd from controller.detectedMatSelect import DetectedMatSelectApp as detected_mat_form from controller.detectedTimeSelect import DetectedTimeSelectApp as detected_time_form +from controller.exportResults import ExportResultsApp as export_results from controller.hdf5 import HDF5Plot from controller.matplot_figures import MatPlotFigures -from controller.multicolor_curve import MultiColoredCurve from controller.projectSelect import projectSelectApp as project_form from controller.rawSelect import RawSelectApp as raw_form from controller.saveAs import SaveAsApp as save_as_form -from controller.segmented_time import SegmentedTime from controller.serverAddress import ServerApp as server_form +from controller.serverFileDialog import ServerFileDialogApp as sever_dialog from controller.signin import SigninApp as signin_form from view.mainWindow import MainWindow icon_path = './view/icons/' - -class softwareMode(enum.Enum): - SAME_PALACE = 0 - CLIENT_SIDE = 1 - SERVER_SIDE = 2 +os.makedirs('.tmp', exist_ok=True) class MainApp(MainWindow): @@ -52,7 +45,9 @@ def __init__(self): self.setWindowIcon(QtGui.QIcon('view/icons/ross.png')) # initial values for software options - self.plotHistFlag = False + # self.plotHistFlag = False + self.plotFlagHist = False + self.plotFlagClusterBased = False self.pca_manual = None self.image = None self.pca_spikes = None @@ -60,13 +55,15 @@ def __init__(self): self.clusters_tmp = None self.clusters_init = None self.clusters = None - self.colors = None + self.inds = None + self.colors = self.distinctColors(127) - self.url = 'http://127.0.0.1:5000' + self.url = 'http://localhost:5000' self.raw = None self.spike_mat = None self.spike_time = None + self.cluster_time_vec = None self.Raw_data_path = os.path.join(pathlib.Path(__file__).parent, '../ross_data/Raw_Data') self.pca_path = os.path.join(pathlib.Path(__file__).parent, '../ross_data/pca_images') @@ -81,27 +78,77 @@ def __init__(self): self.plotManualFlag = False self.resetManualFlag = False self.undoManualFlag = False - self.mode = softwareMode.SAME_PALACE self.tempList = [] + self.processEvents = QtWidgets.QApplication.processEvents + self.startDetection.pressed.connect(self.onDetect) self.startSorting.pressed.connect(self.onSort) self.plotButton.pressed.connect(self.Plot3d) self.actManual.pressed.connect(self.onActManualSorting) - self.plotManual.pressed.connect(self.onPlotManualSorting) + self.plotManual.pressed.connect(self.updateFigures) self.undoManual.pressed.connect(self.onUndoManualSorting) self.resetManual.pressed.connect(self.onResetManualSorting) self.saveManual.pressed.connect(self.onSaveManualSorting) + self.saveManual.setEnabled(False) self.closeButton3d.pressed.connect(self.close3D) self.closeButton3dDet.pressed.connect(self.closeDetect3D) self.assign_close_button.pressed.connect(self.closeAssign) self.assign_button.pressed.connect(self.onAssignManualSorting) + self.exportAct.triggered.connect(self.open_export_dialog) # PCA MANUAL self.resetBottonPCAManual.clicked.connect(self.PCAManualResetButton) self.closeBottonPCAManual.clicked.connect(self.PCAManualCloseButton) self.doneBottonPCAManual.clicked.connect(self.PCAManualDoneButton) + def resetOnSignOutVars(self): + + self.raw = None + self.spike_mat = None + self.spike_time = None + self.user_name = None + self.user = None + self.image = None + self.pca_spikes = None + self.number_of_clusters = None + self.clusters_tmp = None + self.clusters_init = None + self.clusters = None + self.cluster_time_vec = None + + self.user = None + self.user_name = None + self.current_project = None + + self.saveManualFlag = False + self.plotManualFlag = False + self.resetManualFlag = False + self.undoManualFlag = False + + self.tempList = [] + + def resetOnImportVars(self): + + # initial values for software options + self.pca_manual = None + self.image = None + self.pca_spikes = None + self.number_of_clusters = None + self.clusters_tmp = None + self.clusters_init = None + self.clusters = None + + self.spike_mat = None + self.spike_time = None + self.cluster_time_vec = None + + self.saveManualFlag = False + self.plotManualFlag = False + self.resetManualFlag = False + self.undoManualFlag = False + self.tempList = [] + def onUserAccount(self): if self.user is None: self.open_signin_dialog() @@ -110,17 +157,17 @@ def onImportRaw(self): filename, filetype = QtWidgets.QFileDialog.getOpenFileName(self, self.tr("Open file"), os.getcwd(), - self.tr("Raw Files(*.mat *.csv *.tdms)") + self.tr("Raw Files(*.mat *.pkl)") ) if not filename: - return FileNotFoundError('you should select a file') + raise FileNotFoundError('you should select a file') if not os.path.isfile(filename): raise FileNotFoundError(filename) self.statusBar().showMessage(self.tr("Loading...")) - self.wait() + self.processEvents() file_extension = os.path.splitext(filename)[-1] if file_extension == '.mat': file_raw = sio.loadmat(filename) @@ -137,145 +184,169 @@ def onImportRaw(self): else: variable = variables[0] + # nd.array with shape (N,) temp = file_raw[variable].flatten() - self.raw = temp - - # ------------------ save raw data as pkl file in data_set folder --------------------------------------- - address = os.path.join(self.Raw_data_path, str(uuid4()) + '.pkl') - - with open(address, 'wb') as f: - pickle.dump(temp, f) - # ----------------------------------------------------------------------------------------------------- - - elif file_extension == '.csv': - df = pd.read_csv(filename, skiprows=1) - temp = df.to_numpy() - address = os.path.join(self.Raw_data_path, str(uuid4()) + '.pkl') - self.raw = temp - with open(address, 'wb') as f: - pickle.dump(temp, f) - elif file_extension == '.tdms': - tdms_file = TdmsFile.read(filename) - i = 0 - for group in tdms_file.groups(): - df = tdms_file.object(group).as_dataframe() - variables = list(df.keys()) - i = i + 1 + elif file_extension == '.pkl': + with open(filename, 'rb') as f: + file_raw = pickle.load(f) + + variables = list(file_raw.keys()) if len(variables) > 1: variable = self.open_raw_dialog(variables) if not variable: - self.statusBar().showMessage('') + self.statusBar().showMessage(self.tr("Nothing selected")) return else: variable = variables[0] - # group = tdms_file['group name'] - # channel = group['channel name'] - # channel_data = channel[:] - # channel_properties = channel.properties - temp = np.array(df[variable]).flatten() - self.raw = temp - address = os.path.join(self.Raw_data_path, os.path.split(filename)[-1][:-5] + '.pkl') - with open(address, 'wb') as f: - pickle.dump(temp, f) + # nd.array with shape (N,) + temp = np.array(file_raw[variable]).flatten() + + # elif file_extension == '.csv': + # df = pd.read_csv(filename, skiprows=1) + # temp = df.to_numpy() + # address = os.path.join(self.Raw_data_path, str(uuid4()) + '.pkl') + # self.raw = temp + # with open(address, 'wb') as f: + # pickle.dump(temp, f) + # elif file_extension == '.tdms': + # tdms_file = TdmsFile.read(filename) + # i = 0 + # for group in tdms_file.groups(): + # df = tdms_file.object(group).as_dataframe() + # variables = list(df.keys()) + # i = i + 1 + # + # if len(variables) > 1: + # variable = self.open_raw_dialog(variables) + # if not variable: + # self.statusBar().showMessage('') + # return + # else: + # variable = variables[0] + # # group = tdms_file['group name'] + # # channel = group['channel name'] + # # channel_data = channel[:] + # # channel_properties = channel.properties + # temp = np.array(df[variable]).flatten() + # self.raw = temp + # + # address = os.path.join(self.Raw_data_path, os.path.split(filename)[-1][:-5] + '.pkl') + # with open(address, 'wb') as f: + # pickle.dump(temp, f) else: - raise TypeError(f'File type {file_extension} is not supported!') + QtWidgets.QMessageBox.critical(self, 'Error', 'Type is not supported') + return + + # check tmp + if temp.ndim != 1: + QtWidgets.QMessageBox.critical(self, 'Error', 'Variable must be a vector') + return + + self.raw = temp + address = os.path.join(self.Raw_data_path, str(uuid4()) + '.pkl') + + with open(address, 'wb') as f: + pickle.dump(temp, f) self.refreshAct.setEnabled(True) self.statusBar().showMessage(self.tr("Successfully loaded file"), 2500) - self.wait() + self.processEvents() self.statusBar().showMessage(self.tr("Plotting..."), 2500) - self.wait() + self.processEvents() self.plotRaw() + self.resetOnImportVars() + self.plot_histogram_pca.clear() self.plot_clusters_pca.clear() self.widget_waveform.clear() if self.user: self.statusBar().showMessage(self.tr("Uploading to server...")) - self.wait() + self.processEvents() res = self.user.post_raw_data(address) if res['stat']: self.statusBar().showMessage(self.tr("Uploaded"), 2500) - self.wait() + self.processEvents() else: - self.wait() + self.processEvents() def onImportDetected(self): - filename, filetype = QtWidgets.QFileDialog.getOpenFileName(self, self.tr("Open file"), os.getcwd(), - self.tr("Detected Spikes Files(*.mat *.csv *.tdms)")) - - if not filename: - return FileNotFoundError('you should select a file') - - if not os.path.isfile(filename): - raise FileNotFoundError(filename) - - self.statusBar().showMessage(self.tr("Loading...")) - self.wait() - file_extension = os.path.splitext(filename)[-1] - if file_extension == '.mat': - file_raw = sio.loadmat(filename) - variables = list(file_raw.keys()) - if '__version__' in variables: - variables.remove('__version__') - if '__header__' in variables: - variables.remove('__header__') - if '__globals__' in variables: - variables.remove('__globals__') - - if len(variables) > 1: - variable1 = self.open_detected_mat_dialog(variables) - # self.wait() - if not variable1: - self.statusBar().showMessage(self.tr(" ")) - return - variable2 = self.open_detected_time_dialog(variables) - if not variable2: - self.statusBar().showMessage(self.tr(" ")) - return - else: - return - - temp = file_raw[variable1].flatten() - self.spike_mat = temp - - temp = file_raw[variable2].flatten() - self.spike_time = temp - - elif file_extension == '.csv': - pass - - else: - pass - - self.refreshAct.setEnabled(True) - self.statusBar().showMessage(self.tr("Successfully loaded file"), 2500) - self.wait() - - self.statusBar().showMessage(self.tr("Plotting..."), 2500) - self.wait() - self.plotWaveForms() - self.plotDetectionResult() - self.plotPcaResult() - - if self.user: - self.statusBar().showMessage(self.tr("Uploading to server...")) - self.wait() - - res = self.user.post_detected_data(self.spike_mat, self.spike_time) - - if res['stat']: - self.statusBar().showMessage(self.tr("Uploaded"), 2500) - self.wait() - else: - self.wait() + pass + # filename, filetype = QtWidgets.QFileDialog.getOpenFileName(self, self.tr("Open file"), os.getcwd(), + # self.tr("Detected Spikes Files(*.mat *.csv *.tdms)")) + # + # if not filename: + # return FileNotFoundError('you should select a file') + # + # if not os.path.isfile(filename): + # raise FileNotFoundError(filename) + # + # self.statusBar().showMessage(self.tr("Loading...")) + # self.wait() + # file_extension = os.path.splitext(filename)[-1] + # if file_extension == '.mat': + # file_raw = sio.loadmat(filename) + # variables = list(file_raw.keys()) + # if '__version__' in variables: + # variables.remove('__version__') + # if '__header__' in variables: + # variables.remove('__header__') + # if '__globals__' in variables: + # variables.remove('__globals__') + # + # if len(variables) > 1: + # variable1 = self.open_detected_mat_dialog(variables) + # # self.wait() + # if not variable1: + # self.statusBar().showMessage(self.tr(" ")) + # return + # variable2 = self.open_detected_time_dialog(variables) + # if not variable2: + # self.statusBar().showMessage(self.tr(" ")) + # return + # else: + # return + # + # temp = file_raw[variable1].flatten() + # self.spike_mat = temp + # + # temp = file_raw[variable2].flatten() + # self.spike_time = temp + # + # elif file_extension == '.csv': + # pass + # + # else: + # pass + # + # self.refreshAct.setEnabled(True) + # self.statusBar().showMessage(self.tr("Successfully loaded file"), 2500) + # self.wait() + # + # self.statusBar().showMessage(self.tr("Plotting..."), 2500) + # self.wait() + # self.plotWaveForms() + # self.plotDetectionResult() + # self.plotPcaResult() + # + # if self.user: + # self.statusBar().showMessage(self.tr("Uploading to server...")) + # self.wait() + # + # res = self.user.post_detected_data(self.spike_mat, self.spike_time) + # + # if res['stat']: + # self.statusBar().showMessage(self.tr("Uploaded"), 2500) + # self.wait() + # else: + # self.wait() def onImportSorted(self): pass @@ -299,14 +370,57 @@ def open_signin_dialog(self): self.accountButton.setStatusTip("Signed In") self.logInAct.setEnabled(False) self.logOutAct.setEnabled(True) - self.saveAsAct.setEnabled(True) + # self.saveAct.setEnabled(True) + # self.saveAsAct.setEnabled(True) + self.importMenu.setEnabled(True) self.openAct.setEnabled(True) + self.exportAct.setEnabled(True) + self.runMenu.setEnabled(True) + self.visMenu.setEnabled(True) self.statusBar().showMessage(self.tr("Loading Data...")) # self.wait() self.loadDefaultProject() self.statusBar().showMessage(self.tr("Loaded."), 2500) - self.saveAct.setEnabled(True) + + @QtCore.pyqtSlot() + def open_file_dialog_server(self): + dialog = sever_dialog(self.user) + if dialog.exec_() == QtWidgets.QDialog.Accepted: + self.refreshAct.setEnabled(True) + self.statusBar().showMessage(self.tr("Successfully loaded file"), 2500) + self.processEvents() + + self.statusBar().showMessage(self.tr("Plotting..."), 2500) + self.processEvents() + self.plotRaw() + + self.resetOnImportVars() + + self.plot_histogram_pca.clear() + self.plot_clusters_pca.clear() + self.widget_waveform.clear() + + @QtCore.pyqtSlot() + def open_export_dialog(self): + dialog = export_results(self.user) + if dialog.exec_() == QtWidgets.QDialog.Accepted: + data_dict = dialog.data_dict + if dialog.checkSpikeTime: + data_dict['SpikeTime'] = self.spike_time + if dialog.checkClusters: + data_dict['Cluster'] = self.clusters + + filename = QtWidgets.QFileDialog.getSaveFileName(self, "Select Directory", f'./ross_result.{dialog.type}')[ + 0] + if dialog.type == 'mat': + sio.savemat(filename, data_dict) + elif dialog.type == 'pickle': + with open(filename, 'wb') as f: + pickle.dump(data_dict, f) + else: + QtWidgets.QMessageBox.critical(self, 'Error', 'Type not supported!') + QtWidgets.QMessageBox.information(self, 'Save', 'Saved Successfully!') def open_server_dialog(self): dialog = server_form(server_text=self.url) @@ -336,7 +450,7 @@ def open_project_dialog(self, projects): if dialog.exec_() == QtWidgets.QDialog.Accepted: self.current_project = dialog.comboBox.currentText() self.statusBar().showMessage(self.tr("Loading project...")) - self.wait() + self.processEvents() self.user.load_project(self.current_project) self.loadDefaultProject() self.statusBar().showMessage(self.tr("Loaded."), 2500) @@ -357,19 +471,20 @@ def open_save_as_dialog(self): def onSignOut(self): res = self.user.sign_out() if res['stat']: - self.raw = None - self.spike_mat = None - self.spike_time = None - self.user_name = None + self.resetOnSignOutVars() + self.accountButton.setIcon(QtGui.QIcon(icon_path + "unverified.png")) self.accountButton.setMenu(None) self.accountButton.setText('') self.accountButton.setStatusTip("Sign In/Up") self.logInAct.setEnabled(True) self.logOutAct.setEnabled(False) - self.user = None - self.saveAct.setEnabled(False) - self.saveAsAct.setEnabled(False) + # self.saveAct.setEnabled(False) + # self.saveAsAct.setEnabled(False) + self.importMenu.setEnabled(False) + self.exportAct.setEnabled(False) + self.runMenu.setEnabled(False) + self.visMenu.setEnabled(False) self.widget_raw.clear() self.plot_histogram_pca.clear() self.plot_clusters_pca.clear() @@ -387,7 +502,7 @@ def onSaveAs(self): def onSave(self): self.statusBar().showMessage(self.tr("Saving...")) - self.wait() + self.processEvents() res = self.user.save_project(self.current_project) if res['stat']: self.statusBar().showMessage(self.tr("Saved."), 2500) @@ -401,7 +516,7 @@ def onDetect(self): return config_detect = self.read_config_detect() self.statusBar().showMessage(self.tr("Detection Started...")) - self.wait() + self.processEvents() res = self.user.start_detection(config_detect) if res['stat']: @@ -410,30 +525,23 @@ def onDetect(self): if res['stat']: self.spike_mat = res['spike_mat'] self.spike_time = res['spike_time'] + self.pca_spikes = res['pca_spikes'] + self.inds = res['inds'] self.statusBar().showMessage(self.tr("Plotting..."), 2500) - self.wait() - self.plotDetectionResult() - self.plotPcaResult() + self.processEvents() + self.plotHistogramPCA() + self.plotPCA2D() self.plotWaveForms() def manualResort(self, selected_clusters): config_sort = self.read_config_sort() self.statusBar().showMessage(self.tr("Manual ReSorting Started...")) - self.wait() + self.processEvents() res = self.user.start_Resorting(config_sort, self.clusters_tmp, selected_clusters) if res['stat']: self.statusBar().showMessage(self.tr("Manual ReSorting Done."), 2500) self.clusters_tmp = np.array(res['clusters']) - - self.UpdatedClusterIndex() - self.statusBar().showMessage(self.tr("Clusters Waveforms Updated..."), 2500) - self.wait() - self.updateplotWaveForms(self.clusters_tmp) - self.wait() - self.update_plotRaw(self.clusters_tmp) - self.wait() - self.updateManualClusterList(self.clusters_tmp) else: self.statusBar().showMessage(self.tr("Manual ReSorting got error")) @@ -443,30 +551,32 @@ def onSort(self): return config_sort = self.read_config_sort() self.statusBar().showMessage(self.tr("Sorting Started...")) - self.wait() + self.processEvents() res = self.user.start_sorting(config_sort) if res['stat']: self.statusBar().showMessage(self.tr("Sorting Done."), 2500) res = self.user.get_sorting_result() if res['stat']: self.clusters = res['clusters'] + self.cluster_time_vec = res["cluster_time_vec"] self.clusters_init = self.clusters.copy() self.clusters_tmp = self.clusters.copy() self.statusBar().showMessage(self.tr("Clusters Waveforms..."), 2500) - self.wait() + self.processEvents() self.updateplotWaveForms(self.clusters) - self.wait() - self.update_plotRaw(self.clusters) - self.wait() + self.processEvents() + self.updatePlotRaw() + self.processEvents() self.updateManualClusterList(self.clusters_tmp) - self.plotDetectionResult() - self.plotPcaResult() + self.plotHistogramPCA() + self.plotPCA2D() else: self.statusBar().showMessage(self.tr("Sorting got error")) def plotRaw(self): curve = HDF5Plot() + curve.setAPI(self.user) curve.setHDF5(self.raw) self.widget_raw.clear() self.widget_raw.addItem(curve) @@ -474,27 +584,27 @@ def plotRaw(self): self.widget_raw.showGrid(x=True, y=True) self.widget_raw.setMouseEnabled(y=False) - def update_plotRaw(self, clusters): - data = self.raw - colors = self.colors - num_of_clusters = self.number_of_clusters - spike_time = self.spike_time - time = SegmentedTime(spike_time, clusters) - multi_curve = MultiColoredCurve(data, time, num_of_clusters, len(data)) + def updatePlotRaw(self): + curve = HDF5Plot() + curve.setAPI(self.user) + curve.setHDF5(self.raw) self.widget_raw.clear() - for i in range(num_of_clusters + 1): - curve = HDF5Plot() - if i == num_of_clusters: - color = (255, 255, 255) - pen = pyqtgraph.mkPen(color=color) - else: - color = colors[i] + self.widget_raw.showGrid(x=True, y=True) + self.widget_raw.setMouseEnabled(y=False) + self.widget_raw.addItem(curve) + + if self.cluster_time_vec is not None: + for i, i_cluster in enumerate(np.unique(self.cluster_time_vec)): + if i_cluster == 0: + continue + color = self.colors[i - 1] pen = pyqtgraph.mkPen(color=color) - curve.setHDF5(multi_curve.curves[str(i)], pen) - self.widget_raw.addItem(curve) - self.widget_raw.setXRange(0, 10000) - self.widget_raw.showGrid(x=True, y=True) - self.widget_raw.setMouseEnabled(y=False) + curve = HDF5Plot() + curve.setAPI(self.user) + curve.setHDF5(self.raw, pen) + curve.setCluster(self.cluster_time_vec == i_cluster) + self.widget_raw.addItem(curve) + self.widget_raw.setXRange(0, 10000) def onVisPlot(self): self.widget_visualizations() @@ -537,7 +647,7 @@ def hsv_to_rgb(self, h, s, v): if i == 4: return (t, p, v) if i == 5: return (v, p, q) - def distin_color(self, number_of_colors): + def distinctColors(self, number_of_colors): colors = [] golden_ratio_conjugate = 0.618033988749895 # h = np.random.rand(1)[0] @@ -547,26 +657,25 @@ def distin_color(self, number_of_colors): h = h % 1 colors.append(self.hsv_to_rgb(h, 0.99, 0.99)) - return np.array(colors) + return np.array(colors, dtype=int) def updateplotWaveForms(self, clusters_=None): if clusters_ is None: clusters_ = self.clusters_tmp.copy() if self.clusters_tmp is not None: - spike_mat = self.spike_mat[self.clusters_tmp != -1, :] + spike_mat = self.spike_mat[self.clusters_tmp[self.inds,] != -1, :] else: spike_mat = self.spike_mat - clusters = clusters_[clusters_ != -1] + clusters = clusters_[self.inds][clusters_[self.inds] != -1] un = np.unique(clusters) self.number_of_clusters = len(un[un >= 0]) - self.colors = self.distin_color(self.number_of_clusters) spike_clustered = dict() for i in range(self.number_of_clusters): - spike_clustered[i] = spike_mat[clusters == i] + spike_clustered[i] = spike_mat[clusters == i, :] self.widget_waveform.clear() self.widget_waveform.showGrid(x=True, y=True) @@ -575,25 +684,15 @@ def updateplotWaveForms(self, clusters_=None): for i in range(self.number_of_clusters): avg = np.average(spike_clustered[i], axis=0) color = self.colors[i] - # selected_spike = spike_clustered[i][np.sum(np.power(spike_clustered[i] - avg, 2), axis=1) < - # np.sum(np.power(avg, 2)) * 0.2] - # if self.saveManualFlag or self.plotManualFlag: if len(spike_clustered[i]) > 100: ind = np.arange(spike_clustered[i].shape[0]) np.random.shuffle(ind) spike = spike_clustered[i][ind[:100], :] else: spike = spike_clustered[i] - - # else: - # if len(selected_spike) > 100: - # ind = np.arange(selected_spike.shape[0]) - # np.random.shuffle(ind) - # spike = selected_spike[ind[:100], :] - # else: - # spike = selected_spike - + if spike.shape[0] == 0: + continue x = np.empty(np.shape(spike)) x[:] = np.arange(np.shape(spike)[1])[np.newaxis] @@ -607,55 +706,38 @@ def updateplotWaveForms(self, clusters_=None): self.widget_waveform.autoRange() - def create_sub_base(self): - self.number_of_clusters = np.shape(np.unique(self.clusters))[0] - # self.number_of_clusters = len(np.unique(self.clusters)) - if self.number_of_clusters % 3 != 0: - nrow = int(self.number_of_clusters / 3) + 1 - else: - nrow = int(self.number_of_clusters / 3) - - self.sub_base = str(nrow) + str(3) - def onPlotClusterWave(self): try: - self.create_sub_base() + # self.create_sub_base() number_of_clusters = self.number_of_clusters colors = self.colors spike_clustered = dict() for i in range(number_of_clusters): - spike_clustered[i] = self.spike_mat[self.clusters_tmp == i] + spike_clustered[i] = self.spike_mat[self.clusters_tmp[self.inds] == i] - figure = MatPlotFigures('Clusters Waveform', number_of_clusters, width=10, height=6, dpi=100, - subplot_base=self.sub_base) + figure = MatPlotFigures('Clusters Waveform', number_of_clusters, width=10, height=6, dpi=100) for i, ax in enumerate(figure.axes): + for spike in spike_clustered[i]: + ax.plot(spike, color=tuple(colors[i] / 255), linewidth=1, alpha=0.25) avg = np.average(spike_clustered[i], axis=0) - selected_spike = spike_clustered[i][np.sum(np.power(spike_clustered[i] - avg, 2), axis=1) < - np.sum(np.power(avg, 2)) * 0.2] - - # selected_spike = spike_clustered[i][:5] - ax.plot(avg, color='red', linewidth=3) - for spike in selected_spike[:100]: - ax.plot(spike, color=tuple(colors[i] / 255), linewidth=1, alpha=0.125) + ax.plot(avg, color='red', linewidth=2) ax.set_title('Cluster {}'.format(i + 1)) plt.tight_layout() plt.show() except: - pass + print(traceback.format_exc()) def onPlotLiveTime(self): try: - self.create_sub_base() number_of_clusters = self.number_of_clusters colors = self.colors spike_clustered_time = dict() for i in range(number_of_clusters): spike_clustered_time[i] = self.spike_time[self.clusters == i] - figure = MatPlotFigures('LiveTime', number_of_clusters, width=10, height=6, dpi=100, - subplot_base=self.sub_base) + figure = MatPlotFigures('LiveTime', number_of_clusters, width=10, height=6, dpi=100) for i, ax in enumerate(figure.axes): ax.hist(spike_clustered_time[i], bins=100, color=tuple(colors[i] / 255)) ax.set_title('Cluster {}'.format(i + 1)) @@ -663,11 +745,10 @@ def onPlotLiveTime(self): plt.show() except: - pass + print(traceback.format_exc()) def onPlotIsi(self): try: - self.create_sub_base() number_of_clusters = self.number_of_clusters colors = self.colors spike_clustered_time = dict() @@ -678,7 +759,7 @@ def onPlotIsi(self): tmp1 = spike_clustered_time[i][1:].copy() spike_clustered_delta[i] = tmp1 - tmp2 - figure = MatPlotFigures('ISI', number_of_clusters, width=10, height=6, dpi=100, subplot_base=self.sub_base) + figure = MatPlotFigures('ISI', number_of_clusters, width=10, height=6, dpi=100) for i, ax in enumerate(figure.axes): gamma = stats.gamma @@ -704,84 +785,48 @@ def onPlot3d(self): def Plot3d(self): try: - self.plot_3d.setCameraPosition(distance=30) axis1 = self.axis1ComboBox.currentIndex() axis2 = self.axis2ComboBox.currentIndex() axis3 = self.axis3ComboBox.currentIndex() number_of_clusters = len(np.unique(self.clusters)) # Prepration - pca = decom.PCA(n_components=3) - pca_spikes = pca.fit_transform(self.spike_mat) - pca1 = pca_spikes[:, 0] - pca2 = pca_spikes[:, 1] - pca3 = pca_spikes[:, 2] - spike_time = np.squeeze(self.spike_time / 100000) + + pca_spikes = self.pca_spikes + pca1 = pca_spikes[self.inds, 0] + pca2 = pca_spikes[self.inds, 1] + pca3 = pca_spikes[self.inds, 2] + spike_time = np.squeeze(self.spike_time[self.inds]) p2p = np.squeeze(np.abs(np.amax(self.spike_mat, axis=1) - np.amin(self.spike_mat, axis=1))) duty = np.squeeze(np.abs(np.argmax(self.spike_mat, axis=1) - np.argmin(self.spike_mat, axis=1)) / 5) - gx = gl.GLGridItem() - gx.rotate(90, 0, 1, 0) - gx.setSize(15, 15) - gx.translate(-7.5, 0, 0) - self.plot_3d.addItem(gx) - gy = gl.GLGridItem() - gy.rotate(90, 1, 0, 0) - gy.setSize(15, 15) - gy.translate(0, -7.5, 0) - self.plot_3d.addItem(gy) - gz = gl.GLGridItem() - gz.setSize(15, 15) - gz.translate(0, 0, -7.5) - self.plot_3d.addItem(gz) - - mode_flag = False mode_list = [pca1, pca2, pca3, spike_time, p2p, duty] - if (axis1 != axis2) and (axis1 != axis3) and (axis2 != axis3): - pos = np.array((mode_list[axis1], mode_list[axis2], mode_list[axis3])).T - mode_flag = True - - if mode_flag: - colors = self.colors - items = self.plot_3d.items.copy() - for it in items: - if type(it) == pyqtgraph.opengl.items.GLScatterPlotItem.GLScatterPlotItem: - self.plot_3d.removeItem(it) - - for i in range(number_of_clusters): - pos_cluster = pos[self.clusters == i, :] - avg = np.average(pos_cluster, axis=0) - selected_pos = pos_cluster[np.sum(np.power((pos[self.clusters == i, :] - avg), 2), axis=1) < - 0.05 * np.amax(np.sum(np.power((pos[self.clusters == i, :] - avg), 2), - axis=1)), :] - ind = np.arange(selected_pos.shape[0]) - np.random.shuffle(ind) - scattered_pos = selected_pos[ind[:300], :] - color = np.zeros([np.shape(scattered_pos)[0], 4]) - color[:, 0] = colors[i][0] / 255 - color[:, 1] = colors[i][1] / 255 - color[:, 2] = colors[i][2] / 255 - color[:, 3] = 1 - self.plot_3d.addItem(gl.GLScatterPlotItem(pos=scattered_pos, size=3, color=color)) - else: - print('Same Axis Error') + a1, a2, a3 = mode_list[axis1], mode_list[axis2], mode_list[axis3] + + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + for i in range(number_of_clusters): + ax.scatter(a1[self.clusters_tmp[self.inds] == i], a2[self.clusters_tmp[self.inds] == i], + a3[self.clusters_tmp[self.inds] == i], c='#%02x%02x%02x' % tuple(self.colors[i])) + ax.set_xlabel(self.axis1ComboBox.currentText()) + ax.set_ylabel(self.axis2ComboBox.currentText()) + ax.set_zlabel(self.axis3ComboBox.currentText()) + plt.show() except: - pass + print(traceback.format_exc()) def close3D(self): self.subwindow_3d.setVisible(False) - def plotDetectionResult(self): + def plotHistogramPCA(self): try: - if self.clusters_tmp is not None: - spike_mat = self.spike_mat[self.clusters_tmp != -1, :] + self.plot_histogram_pca.clear() + if self.clusters_tmp is None: + pca_spikes = self.pca_spikes else: - spike_mat = self.spike_mat + pca_spikes = self.pca_spikes[self.clusters_tmp != -1] - self.plot_histogram_pca.clear() - pca = decom.PCA(n_components=2) - pca_spikes = pca.fit_transform(spike_mat) hist, xedges, yedges = np.histogram2d(pca_spikes[:, 0], pca_spikes[:, 1], bins=512) x_range = xedges[-1] - xedges[0] @@ -812,22 +857,19 @@ def plotDetectionResult(self): self.plot_histogram_pca.show() except: - pass + print(traceback.format_exc()) - def plotPcaResult(self): + def plotPCA2D(self): self.plot_clusters_pca.clear() if self.clusters_tmp is not None: - spike_mat = self.spike_mat[self.clusters_tmp != -1, :] clusters_ = self.clusters_tmp[self.clusters_tmp != -1] + x_data = self.pca_spikes[self.clusters_tmp != -1, 0] + y_data = self.pca_spikes[self.clusters_tmp != -1, 1] else: - spike_mat = self.spike_mat clusters_ = None - - pca = decom.PCA(n_components=2) - pca_spikes = pca.fit_transform(spike_mat) - x_data = pca_spikes[:, 0] - y_data = pca_spikes[:, 1] + x_data = self.pca_spikes[:, 0] + y_data = self.pca_spikes[:, 1] scatter = pyqtgraph.ScatterPlotItem() @@ -841,8 +883,7 @@ def plotPcaResult(self): else: un = np.unique(clusters_) n_clusters = len(un[un >= 0]) - print("clusters", n_clusters) - new_colors = self.distin_color(n_clusters) + new_colors = self.distinctColors(n_clusters) for i in range(n_clusters): xx = x_data[clusters_ == i] yy = y_data[clusters_ == i] @@ -876,8 +917,7 @@ def onDetect3D(self): self.subwindow_detect3d.setVisible(True) try: - pca = decom.PCA(n_components=2) - pca_spikes = pca.fit_transform(self.spike_mat) + pca_spikes = self.pca_spikes hist, xedges, yedges = np.histogram2d(pca_spikes[:, 0], pca_spikes[:, 1], bins=512) gx = gl.GLGridItem() @@ -910,14 +950,13 @@ def onDetect3D(self): self.plot_detect3d.addItem(bg) except: - pass + print(traceback.format_exc()) def closeDetect3D(self): self.subwindow_detect3d.setVisible(False) def onDetect3D1(self): - pca = decom.PCA(n_components=2) - pca_spikes = pca.fit_transform(self.spike_mat) + pca_spikes = self.pca_spikes hist, xedges, yedges = np.histogram2d(pca_spikes[:, 0], pca_spikes[:, 1], bins=512) xpos, ypos = np.meshgrid(xedges[:-1] + xedges[1:], yedges[:-1] + yedges[1:]) - (xedges[1] - xedges[0]) xpos = xpos.flatten() * 1. / 2 @@ -978,10 +1017,6 @@ def onWatchdogEvent(self): """Perform checks in regular intervals.""" self.mdiArea.checkTimestamps() - def wait(self, duration=2.0): - QtWidgets.QApplication.processEvents() - time.sleep(duration) - def UpdatedClusterIndex(self): cl_ind = np.unique(self.clusters_tmp) cnt = 0 @@ -989,9 +1024,10 @@ def UpdatedClusterIndex(self): if not ind == -1: self.clusters_tmp[self.clusters_tmp == ind] = cnt cnt += 1 + self.onSaveManualSorting() def manualPreparingSorting(self, temp): - if len(self.tempList) == 5: + if len(self.tempList) == 10: self.tempList.pop(0) self.tempList.append(temp) self.plotManualFlag = True @@ -1002,7 +1038,7 @@ def updateManualClusterList(self, cluster_temp): self.listWidget.clear() # n_clusters = np.shape(np.unique(temp))[0] n_clusters = len(np.unique(cluster_temp)[np.unique(cluster_temp) >= 0]) - colors = self.distin_color(n_clusters) + colors = self.distinctColors(n_clusters) for i in range(n_clusters): item = QtWidgets.QListWidgetItem("Cluster {} ({:4.2f} %)".format(i + 1, (cluster_temp == i).mean() * 100)) pixmap = QPixmap(50, 50) @@ -1016,81 +1052,79 @@ def onActManualSorting(self): act = self.manualActWidget.currentItem() clusters = self.listWidget.selectedIndexes() selected_clusters = [cl.row() for cl in clusters] + self.manualPreparingSorting(self.clusters_tmp.copy()) if act.text() == 'Merge': try: self.mergeManual(selected_clusters) - self.manualPreparingSorting(self.clusters_tmp.copy()) - self.updateManualClusterList(self.clusters_tmp.copy()) + self.plotFlagClusterBased = True except: - print("an error accrued in manual Merge") print(traceback.format_exc()) elif act.text() == 'Remove': try: self.removeManual(selected_clusters) - self.manualPreparingSorting(self.clusters_tmp.copy()) - self.updateManualClusterList(self.clusters_tmp.copy()) - self.plotHistFlag = True + self.plotFlagHist = True + self.plotFlagClusterBased = True except: print("an error accrued in manual Remove") print(traceback.format_exc()) + elif act.text() == "Resort": + try: + self.manualResort(selected_clusters) + self.plotFlagHist = True + self.plotFlagClusterBased = True + except: + print("an error accrued in manual resort") + print(traceback.format_exc()) + elif act.text() == 'Assign to nearest': try: self.assignManual() - self.manualPreparingSorting(self.clusters_tmp.copy()) - except: print("an error accrued in manual Assign to nearest") print(traceback.format_exc()) elif act.text() == "PCA Remove": try: self.pca_manual = "Remove" - self.OnPcaRemove() - self.manualPreparingSorting(self.clusters_tmp.copy()) - self.plotHistFlag = True - + self.OnPcaManualAct() except: - print("an error accrued in manual pcaRemove") print(traceback.format_exc()) - pass elif act.text() == "PCA Group": try: self.pca_manual = "Group" - self.OnPcaRemove() - self.manualPreparingSorting(self.clusters_tmp.copy()) - + self.OnPcaManualAct() except: - print("an error accrued in manual pcaGroup") print(traceback.format_exc()) - pass - elif act.text() == "Resort": - try: - self.manualResort(selected_clusters) - self.manualPreparingSorting(self.clusters_tmp.copy()) - self.updateManualClusterList(self.clusters_tmp.copy()) - except: - print("an error accrued in manual resort") - print(traceback.format_exc()) + self.UpdatedClusterIndex() + self.updateManualClusterList(self.clusters_tmp) if self.autoPlotManualCheck.isChecked(): - self.onPlotManualSorting() + self.updateFigures() except: print(traceback.format_exc()) self.statusBar().showMessage(self.tr("an error accrued in manual act !"), 2000) - def onPlotManualSorting(self): - # update 2d pca plot - self.plotPcaResult() - if self.plotHistFlag: - self.plotDetectionResult() - self.plotHistFlag = False - - if self.plotManualFlag: - self.updateplotWaveForms(self.clusters_tmp.copy()) - self.statusBar().showMessage(self.tr("Updating Spikes Waveforms...")) - self.wait() - self.update_plotRaw(self.clusters_tmp.copy()) - self.statusBar().showMessage(self.tr("Updating Raw Data Waveforms..."), 2000) - self.plotManualFlag = False + # def onPlotManualSorting(self): + # # update 2d pca plot + # self.plotPCA2D() + # if self.plotHistFlag: + # self.plotHistogramPCA() + # self.plotHistFlag = False + # + # if self.plotManualFlag: + # self.updateplotWaveForms(self.clusters_tmp.copy()) + # self.statusBar().showMessage(self.tr("Updating Spikes Waveforms...")) + # self.wait() + # self.update_plotRaw() + # self.statusBar().showMessage(self.tr("Updating Raw Data Waveforms..."), 2000) + # self.plotManualFlag = False + + def updateFigures(self): + if self.plotFlagHist: + self.plotHistogramPCA() + if self.plotFlagClusterBased: + self.plotPCA2D() + self.updatePlotRaw() + self.updateplotWaveForms() def onResetManualSorting(self): @@ -1098,16 +1132,16 @@ def onResetManualSorting(self): self.tempList = [] # update 2d pca plot - self.plotDetectionResult() - self.plotPcaResult() + self.plotHistogramPCA() + self.plotPCA2D() # update cluster list - self.updateManualClusterList(self.clusters_tmp.copy()) + self.updateManualClusterList(self.clusters_tmp) self.updateplotWaveForms(self.clusters_init.copy()) self.statusBar().showMessage(self.tr("Resetting Spikes Waveforms...")) - self.wait() + self.processEvents() - self.update_plotRaw(self.clusters_init.copy()) + self.updatePlotRaw() self.statusBar().showMessage(self.tr("Resetting Raw Data Waveforms..."), 2000) self.resetManualFlag = False @@ -1115,52 +1149,37 @@ def onResetManualSorting(self): self.saveManualFlag = False def onSaveManualSorting(self): - if self.saveManualFlag: - self.clusters = self.clusters_tmp.copy() - # self.number_of_clusters = np.shape(np.unique(self.clusters))[0] - self.statusBar().showMessage(self.tr("Save Clustering Results...")) - self.wait() - - res = self.user.save_sort_results(self.clusters) - if res['stat']: - self.wait() - self.statusBar().showMessage(self.tr("Updating Spikes Waveforms...")) - self.updateplotWaveForms(self.clusters) - self.wait() - self.statusBar().showMessage(self.tr("Updating Raw Data Waveforms..."), 2000) - self.update_plotRaw(self.clusters) - self.wait() - self.statusBar().showMessage(self.tr("Saving Done.")) - self.updateManualClusterList(self.clusters_tmp) - # self.updateManualSortingView() - else: - self.statusBar().showMessage(self.tr("An error occurred in saving!..."), 2000) + self.clusters = self.clusters_tmp.copy() + self.statusBar().showMessage(self.tr("Save Clustering Results...")) + self.processEvents() - self.saveManualFlag = False + res = self.user.save_sort_results(self.clusters) + if not res['stat']: + self.statusBar().showMessage(self.tr("An error occurred in saving!..."), 2000) def onUndoManualSorting(self): try: - self.tempList.pop() - if len(self.tempList) != 0: - self.clusters_tmp = self.tempList[-1] + + if len(self.tempList) > 0: + self.clusters_tmp = self.tempList.pop() else: - self.clusters_tmp = self.clusters_init.copy() + QtWidgets.QMessageBox.information(self, "ROSS", "No More Undo :(") # update cluster list self.updateManualClusterList(self.clusters_tmp) # update 2d pca plot if self.autoPlotManualCheck.isChecked(): - self.plotDetectionResult() - self.plotPcaResult() + self.plotHistogramPCA() + self.plotPCA2D() self.updateplotWaveForms(self.clusters_tmp.copy()) self.statusBar().showMessage(self.tr("Undoing Spikes Waveforms..."), 2000) - self.wait() + self.processEvents() self.statusBar().showMessage(self.tr("Undoing Raw Data Waveforms..."), 2000) - self.update_plotRaw(self.clusters_tmp.copy()) + self.updatePlotRaw(self.clusters_tmp.copy()) self.statusBar().showMessage(self.tr("Undoing Done!"), 2000) - self.wait() + self.processEvents() except: self.statusBar().showMessage(self.tr("There is no manual act for undoing!"), 2000) @@ -1168,58 +1187,42 @@ def onUndoManualSorting(self): def mergeManual(self, selected_clusters): if len(selected_clusters) >= 2: self.statusBar().showMessage(self.tr("Merging...")) - self.wait() + self.processEvents() sel_cl = selected_clusters[0] for ind in selected_clusters: self.clusters_tmp[self.clusters_tmp == ind] = sel_cl - self.UpdatedClusterIndex() - self.statusBar().showMessage(self.tr("...Merging Done!"), 2000) - self.wait() + self.processEvents() else: self.statusBar().showMessage(self.tr("For Merging you should select at least two clusters..."), 2000) def removeManual(self, selected_clusters): if len(selected_clusters) != 0: self.statusBar().showMessage(self.tr("Removing...")) - self.wait() + self.processEvents() for sel_cl in selected_clusters: self.clusters_tmp[self.clusters_tmp == sel_cl] = - 1 - self.UpdatedClusterIndex() self.statusBar().showMessage(self.tr("...Removing Done!"), 2000) else: self.statusBar().showMessage(self.tr("For Removing you should select at least one clusters..."), 2000) def assignManual(self): self.subwindow_assign.setVisible(True) - try: - n_clusters = self.number_of_clusters - except: - n_clusters = 0 - # try: - # self.listSourceWidget.close() - # self.listTargetsWidget.close() - # self.assign_button.close() - # self.assign_close_button.close() - # except: - # pass + n_clusters = len(np.unique(self.clusters_tmp)) + try: self.listSourceWidget.clear() - for i in range(n_clusters): - item = QtWidgets.QListWidgetItem("Cluster %i" % (i + 1)) - self.listSourceWidget.addItem(item) - self.listSourceWidget.setCurrentItem(item) - self.listTargetsWidget.clear() for i in range(n_clusters): - item_target = QtWidgets.QListWidgetItem("Cluster %i" % (i + 1)) - self.listTargetsWidget.addItem(item_target) - self.listTargetsWidget.setCurrentItem(item_target) + item1 = QtWidgets.QListWidgetItem("Cluster %i" % (i + 1)) + item2 = QtWidgets.QListWidgetItem("Cluster %i" % (i + 1)) + self.listSourceWidget.addItem(item1) + self.listTargetsWidget.addItem(item2) except: - pass + print(traceback.format_exc()) def onAssignManualSorting(self): source = self.listSourceWidget.currentItem() @@ -1232,54 +1235,48 @@ def onAssignManualSorting(self): try: if len(target_clusters) >= 1: self.statusBar().showMessage(self.tr("Assigning Source Cluster to Targets...")) - self.wait() + self.processEvents() - source_spikes = self.spike_mat[self.clusters_tmp == source_cluster] + source_spikes = self.pca_spikes[self.clusters_tmp == source_cluster] source_ind = np.nonzero(self.clusters_tmp == source_cluster) target_avg = np.zeros((len(target_clusters), source_spikes.shape[1])) for it, target in enumerate(target_clusters): - target_avg[it, :] = np.average(self.spike_mat[self.clusters_tmp == target], axis=0) + target_avg[it, :] = np.average(self.pca_spikes[self.clusters_tmp == target], axis=0) # TODO: check different nearest_neighbors algorithms nbrs = NearestNeighbors(n_neighbors=1).fit(target_avg) indices = nbrs.kneighbors(source_spikes, return_distance=False) self.clusters_tmp[source_ind] = np.array(target_clusters)[indices.squeeze()] - self.UpdatedClusterIndex() - self.statusBar().showMessage(self.tr("...Assigning to Nearest Clusters Done!"), 2000) - self.wait() + self.processEvents() self.listSourceWidget.clear() self.listTargetsWidget.clear() - - for i in range(len(np.unique(self.clusters_tmp))): - item = QtWidgets.QListWidgetItem("Cluster %i" % (i + 1)) - self.listSourceWidget.addItem(item) - - for i in range(len(np.unique(self.clusters_tmp))): - item_target = QtWidgets.QListWidgetItem("Cluster %i" % (i + 1)) - self.listTargetsWidget.addItem(item_target) - - self.manualPreparingSorting(self.clusters_tmp.copy()) + # TODO: check for all bugs like this + # for i in range(len(np.unique(self.clusters_tmp))): + cu = np.unique(self.clusters_tmp) + for i in range(len(cu[cu != -1])): + item1 = QtWidgets.QListWidgetItem("Cluster %i" % (i + 1)) + item2 = QtWidgets.QListWidgetItem("Cluster %i" % (i + 1)) + self.listSourceWidget.addItem(item1) + self.listTargetsWidget.addItem(item2) + self.UpdatedClusterIndex() + self.updateManualClusterList(self.clusters_tmp) + if self.autoPlotManualCheck.isChecked(): + self.plotFlagClusterBased = True + self.updateFigures() else: self.statusBar().showMessage(self.tr("You Should Choose One Source and at least One Target..."), 2000) - - self.updateManualClusterList(self.clusters_tmp) except: - pass + print(traceback.format_exc()) def closeAssign(self): self.subwindow_assign.setVisible(False) - def OnPcaRemove(self): - if self.clusters_tmp is not None: - spike_mat = self.spike_mat[self.clusters_tmp != -1, :] - else: - spike_mat = self.spike_mat + def OnPcaManualAct(self): - pca = decom.PCA(n_components=2) - self.pca_spikes = pca.fit_transform(spike_mat) + self.pca_spikes = self.pca_spikes hist, xedges, yedges = np.histogram2d(self.pca_spikes[:, 0], self.pca_spikes[:, 1], bins=512) x_range = xedges[-1] - xedges[0] @@ -1296,7 +1293,7 @@ def OnPcaRemove(self): if image._renderRequired: image.render() image.qimage = image.qimage.transformed(QTransform().scale(1, -1)) - image.save('test.png') + image.save('.tmp/test.png') self.PCAManualResetButton() self.subwindow_pca_manual.setVisible(True) @@ -1305,13 +1302,8 @@ def OnPcaRemove(self): def PCAManualDoneButton(self): points = self.subwindow_pca_manual.widget().points.copy() - if self.clusters_tmp is not None: - spike_mat = self.spike_mat[self.clusters_tmp != -1, :] - else: - spike_mat = self.spike_mat - pca = decom.PCA(n_components=2) - self.pca_spikes = pca.fit_transform(spike_mat) + self.pca_spikes = self.pca_spikes hist, xedges, yedges = np.histogram2d(self.pca_spikes[:, 0], self.pca_spikes[:, 1], bins=512) x_range = xedges[-1] - xedges[0] y_range = yedges[-1] - yedges[0] @@ -1328,12 +1320,15 @@ def PCAManualDoneButton(self): poly = Polygon(norm_points) if self.pca_manual == "Remove": + self.plotFlagClusterBased = True + self.plotFlagHist = True for i in range(len(self.pca_spikes)): p1 = Point(self.pca_spikes[i]) if not (p1.within(poly)): self.clusters_tmp[i] = -1 elif self.pca_manual == "Group": + self.plotFlagClusterBased = True clusters = self.clusters_tmp.copy() un = np.unique(clusters[clusters != -1]) num_of_clusters = len(un[un >= 0]) @@ -1347,17 +1342,17 @@ def PCAManualDoneButton(self): else: print("pca manual flag is not define") - self.UpdatedClusterIndex() self.subwindow_pca_manual.widget().reset() self.subwindow_pca_manual.setVisible(False) - self.updateplotWaveForms(self.clusters_tmp.copy()) - self.updateManualClusterList(self.clusters_tmp.copy()) - self.plotDetectionResult() - self.plotPcaResult() + + self.UpdatedClusterIndex() + self.updateManualClusterList(self.clusters_tmp) + if self.autoPlotManualCheck.isChecked(): + self.updateFigures() def PCAManualResetButton(self): self.subwindow_pca_manual.widget().reset() - self.image = QPixmap('test.png') + self.image = QPixmap('.tmp/test.png') self.label_pca_manual.setMaximumWidth(self.image.width()) self.label_pca_manual.setMaximumHeight(self.image.height()) self.label_pca_manual.resize(self.image.width(), self.image.height()) @@ -1389,40 +1384,42 @@ def loadDefaultProject(self): res = self.user.get_raw_data() if res['stat']: - with open(res['raw'], 'rb') as f: - new_data = pickle.load(f) + if 'raw' in res: + with open(res['raw'], 'rb') as f: + new_data = pickle.load(f) - self.raw = new_data - self.statusBar().showMessage(self.tr("Plotting..."), 2500) - self.wait() + self.raw = new_data + self.statusBar().showMessage(self.tr("Plotting..."), 2500) + self.processEvents() flag_raw = True res = self.user.get_detection_result() if res['stat']: self.spike_mat = res['spike_mat'] self.spike_time = res['spike_time'] - self.plotWaveForms() - self.plotDetectionResult() - self.plotPcaResult() + self.pca_spikes = res['pca_spikes'] + self.inds = res['inds'] + self.plotFlagHist = True # done res = self.user.get_sorting_result() if res['stat']: self.clusters_init = res['clusters'] + self.cluster_time_vec = res['cluster_time_vec'] self.clusters = self.clusters_init.copy() self.clusters_tmp = self.clusters_init.copy() - self.updateplotWaveForms(self.clusters_init.copy()) - self.wait() flag_update_raw = True self.updateManualClusterList(self.clusters_tmp) - self.plotPcaResult() + self.plotFlagClusterBased = True if flag_raw: if flag_update_raw: - self.update_plotRaw(self.clusters) + self.updatePlotRaw() else: self.plotRaw() + self.updateFigures() + def read_config_detect(self): config = dict() config['filter_type'] = self.filterType.currentText() diff --git a/ross_ui/controller/matplot_figures.py b/ross_ui/controller/matplot_figures.py index b9d8c02..68a56f4 100644 --- a/ross_ui/controller/matplot_figures.py +++ b/ross_ui/controller/matplot_figures.py @@ -1,17 +1,13 @@ +import math + import matplotlib.pyplot as plt class MatPlotFigures: - def __init__(self, fig_title, number_of_clusters, width=10, height=6, dpi=100, subplot_base='13'): + def __init__(self, fig_title, number_of_clusters, width=10, height=6, dpi=100): fig = plt.figure(figsize=(width, height), dpi=dpi) - fig.canvas.set_window_title(fig_title) + fig.canvas.manager.set_window_title(fig_title) self.axes = [] - for i in range(int(subplot_base) // 10): - if i == (int(subplot_base) // 10) - 1: - for j in range(number_of_clusters % 3): - sub = int(subplot_base + str((i * 3) + j + 1)) - self.axes.append(fig.add_subplot(sub)) - else: - for j in range(3): - sub = int(subplot_base + str((i * 3) + j + 1)) - self.axes.append(fig.add_subplot(sub)) + nrows = math.ceil(number_of_clusters / 3) + for i in range(number_of_clusters): + self.axes.append(fig.add_subplot(nrows, 3, i + 1)) diff --git a/ross_ui/controller/multicolor_curve.py b/ross_ui/controller/multicolor_curve.py deleted file mode 100644 index 95ef681..0000000 --- a/ross_ui/controller/multicolor_curve.py +++ /dev/null @@ -1,25 +0,0 @@ -import numpy as np - - -class MultiColoredCurve: - def __init__(self, data, segmented_time, num_of_clusters, raw_len): - self.data = data - self.segmented = segmented_time - self.curves = dict() - - for i in range(num_of_clusters + 1): - self.curves['{}'.format(i)] = np.zeros(raw_len) - - for seg in self.segmented.segmented: - if seg.spike: - start = int(seg.start) - end = int(seg.end) - if int(seg.color) >= 0: - self.curves[str(int(seg.color))][start: end] = self.data[start: end].copy() - else: - self.curves[str(num_of_clusters)][start: end] = self.data[start: end].copy() - - else: - start = int(seg.start) - end = int(seg.end) - self.curves[str(num_of_clusters)][start: end] = self.data[start: end].copy() diff --git a/ross_ui/controller/plot_curve.py b/ross_ui/controller/plot_curve.py index 4e3d367..371443d 100644 --- a/ross_ui/controller/plot_curve.py +++ b/ross_ui/controller/plot_curve.py @@ -13,7 +13,6 @@ def viewRangeChanged(self): self.updatePlot() def updatePlot(self): - print('in update plot') vb = self.getViewBox() if vb is None: xstart = 0 @@ -23,33 +22,25 @@ def updatePlot(self): xend = int(vb.viewRange()[0][1]) + 1 if self.curve.end <= xend and self.curve.start >= xstart: - print('here 1') visible = self.curve.pos_y - print(len(visible)) self.setData(visible) self.setPos(self.curve.start, 0) self.setPen(self.curve.pen) elif self.curve.end <= xend and self.curve.start < xstart: - print('here 2') visible = self.curve.pos_y[xstart:] - print(len(visible)) self.setData(visible) self.setPos(xstart, 0) self.setPen(self.curve.pen) elif self.curve.end > xend and self.curve.start >= xstart: - print('here 3') visible = self.curve.pos_y[:xend] - print(len(visible)) self.setData(visible) self.setPos(self.curve.start, 0) self.setPen(self.curve.pen) elif self.curve.end > xend and self.curve.start < xstart: - print('here 4') visible = self.curve.pos_y[xstart: xend] - print(len(visible)) self.setData(visible) self.setPos(xstart, 0) self.setPen(self.curve.pen) diff --git a/ross_ui/controller/plot_time.py b/ross_ui/controller/plot_time.py deleted file mode 100644 index b6ed4ed..0000000 --- a/ross_ui/controller/plot_time.py +++ /dev/null @@ -1,6 +0,0 @@ -class PlotTime: - def __init__(self, start, end, spike, color): - self.start = start - self.end = end - self.spike = spike - self.color = color diff --git a/ross_ui/controller/segmented_time.py b/ross_ui/controller/segmented_time.py deleted file mode 100644 index 4e4a53e..0000000 --- a/ross_ui/controller/segmented_time.py +++ /dev/null @@ -1,17 +0,0 @@ -from controller.plot_time import PlotTime - - -class SegmentedTime: - def __init__(self, spike_time, clusters, sr=40000): - self.segmented = [] - self.segmented.append(PlotTime(0, spike_time[0] - sr * 0.001, spike=False, color=-1)) - for i in range(1, spike_time.shape[0]): - try: - self.segmented.append( - PlotTime(spike_time[i - 1] - sr * 0.001, spike_time[i - 1] + sr * 0.0015, spike=True, - color=clusters[i - 1] if clusters[i - 1] >= 0 else -1)) - - self.segmented.append(PlotTime(spike_time[i - 1] + sr * 0.0015, spike_time[i] - sr * 0.001, spike=False, - color=-1)) - except: - break diff --git a/ross_ui/controller/serverFileDialog.py b/ross_ui/controller/serverFileDialog.py new file mode 100644 index 0000000..cc80323 --- /dev/null +++ b/ross_ui/controller/serverFileDialog.py @@ -0,0 +1,62 @@ +import os + +from PyQt5 import QtWidgets + +from model.api import API +from view.serverFileDialog import ServerFileDialog + + +class ServerFileDialogApp(ServerFileDialog): + def __init__(self, api: API): + super(ServerFileDialogApp, self).__init__() + self.api = api + self.root = None + + self.list_folder.itemDoubleClicked.connect(self.itemDoubleClicked) + self.push_open.clicked.connect(self.pushOpenClicked) + self.push_cancel.clicked.connect(self.reject) + + self.request_dir() + + def request_dir(self): + self.list_folder.clear() + dir_dict = self.api.browse(self.root) + if dir_dict is not None: + self.line_address.setText(dir_dict['root']) + + item = QtWidgets.QListWidgetItem('..') + item.setIcon(QtWidgets.QApplication.style().standardIcon(QtWidgets.QStyle.SP_DirIcon)) + self.list_folder.addItem(item) + + for folder_name in dir_dict['folders']: + item = QtWidgets.QListWidgetItem(folder_name) + item.setIcon(QtWidgets.QApplication.style().standardIcon(QtWidgets.QStyle.SP_DirIcon)) + self.list_folder.addItem(item) + for filename in dir_dict['files']: + item = QtWidgets.QListWidgetItem(filename) + item.setIcon(QtWidgets.QApplication.style().standardIcon(QtWidgets.QStyle.SP_FileIcon)) + self.list_folder.addItem(item) + else: + QtWidgets.QMessageBox.critical(self, 'Error', 'Server Error') + + def itemDoubleClicked(self, item: QtWidgets.QListWidgetItem): + name = item.text() + isfolder = item.icon().name() == 'folder' + if isfolder: + self.root = os.path.join(self.line_address.text(), name) + self.request_dir() + else: + ret = self.api.post_raw_data(raw_data_path=os.path.join(self.line_address.text(), name), + mode=1, + varname=self.line_varname.text()) + if ret['stat']: + QtWidgets.QMessageBox.information(self, 'Info', 'File successfully added.') + self.accept() + else: + QtWidgets.QMessageBox.critical(self, 'Error', ret['message']) + + def pushOpenClicked(self): + try: + self.itemDoubleClicked(self.list_folder.selectedItems()[0]) + except IndexError: + pass diff --git a/ross_ui/controller/signin.py b/ross_ui/controller/signin.py index 293c80c..4a662e5 100644 --- a/ross_ui/controller/signin.py +++ b/ross_ui/controller/signin.py @@ -1,19 +1,20 @@ from PyQt5 import QtWidgets -from model.api import UserAccount +from model.api import API from view.signin import Signin_Dialog class SigninApp(Signin_Dialog): def __init__(self, server): super(SigninApp, self).__init__(server) + self.user = None self.pushButton_in.pressed.connect(self.accept_in) self.pushButton_up.pressed.connect(self.accept_up) def accept_in(self): username = self.textEdit_username.text() password = self.textEdit_password.text() - self.user = UserAccount(self.url) + self.user = API(self.url) res = self.user.sign_in(username, password) if res['stat']: super().accept() @@ -23,12 +24,9 @@ def accept_in(self): def accept_up(self): username = self.textEdit_username.text() password = self.textEdit_password.text() - self.user = UserAccount(self.url) + self.user = API(self.url) res = self.user.sign_up(username, password) if res['stat']: - self.label_res.setStyleSheet("color: green") QtWidgets.QMessageBox.information(self, "Account Created", res["message"]) else: - self.label_res.setStyleSheet("color: red") - # self.label_res.setText(res['message']) QtWidgets.QMessageBox.critical(self, "Error", res["message"]) diff --git a/ross_ui/model/api.py b/ross_ui/model/api.py index c102299..20ebc81 100644 --- a/ross_ui/model/api.py +++ b/ross_ui/model/api.py @@ -4,7 +4,7 @@ import requests -class UserAccount(): +class API(): def __init__(self, url): self.url = url self.refresh_token = None @@ -15,12 +15,11 @@ def refresh_jwt_token(self): response = requests.post(self.url + '/refresh', headers={'Authorization': 'Bearer ' + self.refresh_token}) - # response = requests.post(self.url + '/refresh', - # headers={'Authorization': 'Bearer ' + self.refresh_token}, - # data={'project_id': self.project_id}) - if response.ok: self.access_token = response.json()["access_token"] + return True + else: + return False def sign_up(self, username, password): data = {"username": username, "password": password} @@ -56,109 +55,65 @@ def sign_out(self): else: return {'stat': False, 'message': response.json()["message"]} - def post_raw_data(self, raw_data): - if not (self.access_token is None): + def post_raw_data(self, raw_data_path, mode=0, varname=''): + if self.access_token is not None: + # buffer = io.BytesIO() # np.savez_compressed(buffer, raw=raw_data) # buffer.seek(0) # raw_bytes = buffer.read() # buffer.close() - response = requests.put(self.url + '/raw', headers={'Authorization': 'Bearer ' + self.access_token}, - data={"raw_bytes": raw_data, "project_id": self.project_id}) + response = requests.post(self.url + '/raw', headers={'Authorization': 'Bearer ' + self.access_token}, + json={"raw_data": raw_data_path, + "project_id": self.project_id, + "mode": mode, + "varname": varname}) if response.ok: return {'stat': True, 'message': 'success'} - elif response.json()["message"] == 'The token has expired.': - self.refresh_jwt_token() - response = requests.put(self.url + '/raw', headers={'Authorization': 'Bearer ' + self.access_token}, - data={"raw_bytes": raw_data, "project_id": self.project_id}) - if response.ok: - return {'stat': True, 'message': 'success'} + elif response.status_code == 401: + ret = self.refresh_jwt_token() + if ret: + self.post_raw_data(raw_data_path, mode, varname) return {'stat': False, 'message': response.json()["message"]} return {'stat': False, 'message': 'Not Logged In!'} - # def post_detected_data(self, spike_mat, spike_time): - # if not (self.access_token is None): - # buffer = io.BytesIO() - # np.savez_compressed(buffer, spike_mat=spike_mat, spike_time=spike_time) - # buffer.seek(0) - # detected_bytes = buffer.read() - # buffer.close() - # response = requests.put(self.url + '/detect', headers={'Authorization': 'Bearer ' + self.access_token}, - # data={"raw_bytes": detected_bytes, "project_id": self.project_id}) - # - # if response.ok: - # return {'stat': True, 'message': 'success'} - # elif response.json()["message"] == 'The token has expired.': - # self.refresh_jwt_token() - # response = requests.put(self.url + '/detect', headers={'Authorization': 'Bearer ' + self.access_token}, - # data=detected_bytes) - # if response.ok: - # return {'stat': True, 'message': 'success'} - # return {'stat': False, 'message': response.json()["message"]} - # return {'stat': False, 'message': 'Not Logged In!'} - - # def post_sorting_data(self, clusters): - # pass - - # def load_project(self, project_name): - # if not (self.access_token is None): - # response = requests.get(self.url + '/project/' + project_name, - # headers={'Authorization': 'Bearer ' + self.access_token}) - # if response.ok: - # return {'stat': True} - # elif response.json()["message"] == 'The token has expired.': - # self.refresh_jwt_token() - # response = requests.get(self.url + '/project/' + project_name, - # headers={'Authorization': 'Bearer ' + self.access_token}) - # if response.ok: - # return {'stat': True} - # return {'stat': False, 'message': response.json()["message"]} - # return {'stat': False, 'message': 'Not Logged In!'} - - # def get_projects(self): - # if not (self.access_token == None): - # response = requests.get(self.url + '/projects', headers={'Authorization': 'Bearer ' + self.access_token}) - # if response.ok: - # return {'stat': True, 'projects': response.json()['projects']} - # elif response.json()["message"] == 'The token has expired.': - # self.refresh_jwt_token() - # response = requests.get(self.url + '/projects', - # headers={'Authorization': 'Bearer ' + self.access_token}) - # if response.ok: - # return {'stat': True, 'projects': response.json()["projects"]} - # return {'stat': False, 'message': response.json()["message"]} - # return {'stat': False, 'message': 'Not Logged In!'} - - def get_raw_data(self): - if not (self.access_token is None): + def get_raw_data(self, start=None, stop=None, limit=None): + if self.access_token is not None: response = requests.get(self.url + '/raw', headers={'Authorization': 'Bearer ' + self.access_token}, - data={'project_id': self.project_id}) + json={'project_id': self.project_id, + 'start': start, + 'stop': stop, + 'limit': limit}) + if response.ok: - # b = io.BytesIO() - # b.write(response.content) - # b.seek(0) - # d = np.load(b, allow_pickle=True) - return {'stat': True, 'raw': response.content} + if response.status_code == 210: + return {'stat': True, 'raw': response.content} + elif response.status_code == 211: + iob = io.BytesIO() + iob.write(response.content) + iob.seek(0) + raw_data = np.load(iob, allow_pickle=True) + return {'stat': True, + 'visible': raw_data['visible'].flatten(), + 'stop': raw_data['stop'].flatten(), + 'ds': raw_data['ds'].flatten()} + elif response.status_code == 212: + return {'stat': True, 'message': 'SERVER MODE'} + else: + return {'stat': False, 'message': 'Status code not supported!'} elif response.status_code == 401: - self.refresh_jwt_token() - response = requests.get(self.url + '/raw', - headers={'Authorization': 'Bearer ' + self.access_token}, - data={'project_id': self.project_id}) - if response.ok: - b = io.BytesIO() - b.write(response.content) - b.seek(0) - d = np.load(b, allow_pickle=True) - return {'stat': True, 'raw': d['raw'].flatten()} + ret = self.refresh_jwt_token() + if ret: + self.get_raw_data(start, stop, limit) return {'stat': False, 'message': response.json()['message']} return {'stat': False, 'message': 'Not Logged In!'} def get_detection_result(self): - if not (self.access_token is None): + if self.access_token is not None: response = requests.get(self.url + '/detection_result', headers={'Authorization': 'Bearer ' + self.access_token}, data={'project_id': self.project_id}) @@ -168,24 +123,41 @@ def get_detection_result(self): b.write(response.content) b.seek(0) d = np.load(b, allow_pickle=True) - return {'stat': True, 'spike_mat': d['spike_mat'], 'spike_time': d['spike_time']} + return {'stat': True, + 'spike_mat': d['spike_mat'], 'spike_time': d['spike_time'], + 'pca_spikes': d['pca_spikes'], 'inds': d['inds']} elif response.status_code == 401: - self.refresh_jwt_token() - response = requests.get(self.url + '/detection_result', - headers={'Authorization': 'Bearer ' + self.access_token}, - data={'project_id': self.project_id}) - if response.ok: - b = io.BytesIO() - b.write(response.content) - b.seek(0) - d = np.load(b, allow_pickle=True) - return {'stat': True, 'spike_mat': d['spike_mat'], 'spike_time': d['spike_time']} - return {'stat': False, 'message': response.json()['message']} + ret = self.refresh_jwt_token() + if ret: + self.get_detection_result() + + return {'stat': False, 'message': response.content} + return {'stat': False, 'message': 'Not Logged In!'} + + def get_spike_mat(self): + if self.access_token is not None: + response = requests.get(self.url + '/detection_result_waveform', + headers={'Authorization': 'Bearer ' + self.access_token}, + json={'project_id': self.project_id}) + + if response.ok: + b = io.BytesIO() + b.write(response.content) + b.seek(0) + d = np.load(b, allow_pickle=True) + return {'stat': True, 'spike_mat': d['spike_mat']} + + elif response.status_code == 401: + ret = self.refresh_jwt_token() + if ret: + self.get_spike_mat() + + return {'stat': False, 'message': response.content} return {'stat': False, 'message': 'Not Logged In!'} def get_sorting_result(self): - if not (self.access_token is None): + if self.access_token is not None: response = requests.get(self.url + '/sorting_result', headers={'Authorization': 'Bearer ' + self.access_token}, data={'project_id': self.project_id}) @@ -195,167 +167,103 @@ def get_sorting_result(self): b.seek(0) d = np.load(b, allow_pickle=True) - return {'stat': True, 'clusters': d['clusters']} + return {'stat': True, 'clusters': d['clusters'], "cluster_time_vec": d["cluster_time_vec"]} elif response.status_code == 401: - self.refresh_jwt_token() - response = requests.get(self.url + '/sorting_result', - headers={'Authorization': 'Bearer ' + self.access_token}, - data={'project_id': self.project_id}) - if response.ok: - b = io.BytesIO() - b.write(response.content) - b.seek(0) - d = np.load(b, allow_pickle=True) - print("d", d) - return {'stat': True, 'clusters': d['clusters']} - return {'stat': False, 'message': response.json()['message']} + ret = self.refresh_jwt_token() + if ret: + self.get_sorting_result() + + return {'stat': False, 'message': response.content} return {'stat': False, 'message': 'Not Logged In!'} def get_config_detect(self): - if not (self.access_token is None): + if self.access_token is not None: response = requests.get(self.url + '/detect', headers={'Authorization': 'Bearer ' + self.access_token}, data={'project_id': self.project_id}) if response.ok: return {'stat': True, 'config': response.json()} - elif response.json()["message"] == 'The token has expired.': - self.refresh_jwt_token() - response = requests.get(self.url + '/detect', - headers={'Authorization': 'Bearer ' + self.access_token}, - data={'project_id': self.project_id}) - if response.ok: - return {'stat': True, 'config': response.json()} + elif response.status_code == 401: + ret = self.refresh_jwt_token() + if ret: + self.get_config_detect() return {'stat': False, 'message': response.json()["message"]} return {'stat': False, 'message': 'Not Logged In!'} def get_config_sort(self): - if not (self.access_token is None): + if self.access_token is not None: response = requests.get(self.url + '/sort', headers={'Authorization': 'Bearer ' + self.access_token}, data={'project_id': self.project_id}) if response.ok: return {'stat': True, 'config': response.json()} - elif response.json()["message"] == 'The token has expired.': - self.refresh_jwt_token() - response = requests.get(self.url + '/sort', - headers={'Authorization': 'Bearer ' + self.access_token}, - data={'project_id': self.project_id}) - if response.ok: - return {'stat': True, 'config': response.json()} + elif response.status_code == 401: + ret = self.refresh_jwt_token() + if ret: + self.get_config_sort() return {'stat': False, 'message': response.json()["message"]} return {'stat': False, 'message': 'Not Logged In!'} - # def save_project_as(self, name): - # if not (self.access_token is None): - # response = requests.post(self.url + '/project/' + name, - # headers={'Authorization': 'Bearer ' + self.access_token}, - # data = {'project_id': self.project_id}) - # if response.ok: - # return {'stat': True, 'message': 'success'} - # - # elif response.json()["message"] == 'The token has expired.': - # self.refresh_jwt_token() - # response = requests.post(self.url + '/project/' + name, - # headers={'Authorization': 'Bearer ' + self.access_token}, - # data = {'project_id': self.project_id}) - # if response.ok: - # return {'stat': True, 'message': 'success'} - # return {'stat': False, 'message': response.json()["message"]} - # return {'stat': False, 'message': 'Not Logged In!'} - - # def save_project(self, name): - # if not (self.access_token is None): - # response = requests.put(self.url + '/project/' + name, - # headers={'Authorization': 'Bearer ' + self.access_token}) - # if response.ok: - # return {'stat': True, 'message': 'success'} - # - # elif response.json()["message"] == 'The token has expired.': - # self.refresh_jwt_token() - # response = requests.post(self.url + '/project/' + name, - # headers={'Authorization': 'Bearer ' + self.access_token}) - # if response.ok: - # return {'stat': True, 'message': 'success'} - # return {'stat': False, 'message': response.json()["message"]} - # return {'stat': False, 'message': 'Not Logged In!'} - # - # def delete_project(self, name): - # if not (self.access_token is None): - # response = requests.delete(self.url + '/project/' + name, - # headers={'Authorization': 'Bearer ' + self.access_token}) - # if response.ok: - # return {'stat': True, 'message': 'success'} - # - # elif response.json()["message"] == 'The token has expired.': - # self.refresh_jwt_token() - # response = requests.delete(self.url + '/project/' + name, - # headers={'Authorization': 'Bearer ' + self.access_token}) - # if response.ok: - # return {'stat': True, 'message': 'success'} - # return {'stat': False, 'message': response.json()["message"]} - # return {'stat': False, 'message': 'Not Logged In!'} - def start_detection(self, config): - data = config - data['run_detection'] = True - data['project_id'] = self.project_id - if not (self.access_token is None): - response = requests.put(self.url + '/detect', - headers={'Authorization': 'Bearer ' + self.access_token}, - json=data) + if self.access_token is not None: + + data = config + data['run_detection'] = True + data['project_id'] = self.project_id + + response = requests.post(self.url + '/detect', + headers={'Authorization': 'Bearer ' + self.access_token}, + json=data) if response.ok: return {'stat': True, 'message': 'success'} - elif response.json()["message"] == 'The token has expired.': - self.refresh_jwt_token() - response = requests.put(self.url + '/detect', - headers={'Authorization': 'Bearer ' + self.access_token}, - json=data) - if response.ok: - return {'stat': True, 'message': 'success'} + elif response.status_code == 401: + ret = self.refresh_jwt_token() + if ret: + self.start_detection(config) return {'stat': False, 'message': response.json()["message"]} return {'stat': False, 'message': 'Not Logged In!'} def start_sorting(self, config): - data = config - data['run_sorting'] = True - data['project_id'] = self.project_id - if not (self.access_token is None): + if self.access_token is not None: + + data = config + data['run_sorting'] = True + data['project_id'] = self.project_id + response = requests.put(self.url + '/sort', headers={'Authorization': 'Bearer ' + self.access_token}, json=data) if response.ok: return {'stat': True, 'message': 'success'} - elif response.json()["message"] == 'The token has expired.': - self.refresh_jwt_token() - response = requests.put(self.url + '/sort', headers={'Authorization': 'Bearer ' + self.access_token}, - json=data) - if response.ok: - return {'stat': True, 'message': 'success'} + elif response.status_code == 401: + ret = self.refresh_jwt_token() + if ret: + self.start_sorting(config) return {'stat': False, 'message': response.json()["message"]} return {'stat': False, 'message': 'Not Logged In!'} def start_Resorting(self, config, clusters, selected_clusters): - data = config - data['clusters'] = [clusters.tolist()] - data['selected_clusters'] = [selected_clusters] - data['run_sorting'] = True - data['project_id'] = self.project_id - if not (self.access_token is None): + if self.access_token is not None: + + data = config + data['clusters'] = [clusters.tolist()] + data['selected_clusters'] = [selected_clusters] + data['run_sorting'] = True + data['project_id'] = self.project_id + response = requests.put(self.url + '/sort', headers={'Authorization': 'Bearer ' + self.access_token}, json=data) if response.ok: return {'stat': True, 'message': 'success', 'clusters': response.json()["clusters"]} - elif response.json()["message"] == 'The token has expired.': - self.refresh_jwt_token() - response = requests.put(self.url + '/sort', headers={'Authorization': 'Bearer ' + self.access_token}, - json=data) - if response.ok: - return {'stat': True, 'message': 'success', 'clusters': response.json()["clusters"]} + elif response.status_code == 401: + ret = self.refresh_jwt_token() + if ret: + self.start_Resorting(config, clusters, selected_clusters) return {'stat': False, 'message': response.json()["message"]} return {'stat': False, 'message': 'Not Logged In!'} def save_sort_results(self, clusters): - if not (self.access_token is None): + if self.access_token is not None: buffer = io.BytesIO() np.savez_compressed(buffer, clusters=clusters, project_id=self.project_id) @@ -368,14 +276,25 @@ def save_sort_results(self, clusters): if response.ok: return {'stat': True, 'message': 'success'} - elif response.json()["message"] == 'The token has expired.': - self.refresh_jwt_token() - - response = requests.put(self.url + '/sorting_result', - headers={'Authorization': 'Bearer ' + self.access_token}, - data=clusters_bytes) - - if response.ok: - return {'stat': True, 'message': 'success'} - return {'stat': False, 'message': response.json()["message"]} + elif response.status_code == 401: + ret = self.refresh_jwt_token() + if ret: + self.save_sort_results(clusters) + return {'stat': False, 'message': response.content} return {'stat': False, 'message': 'Not Logged In!'} + + def browse(self, root: str): + if self.access_token is not None: + if root is None: + response = requests.get(self.url + '/browse', headers={'Authorization': 'Bearer ' + self.access_token}) + else: + response = requests.get(self.url + '/browse', headers={'Authorization': 'Bearer ' + self.access_token}, + json={'root': root}) + if response.ok: + return response.json() + elif response.status_code == 401: + ret = self.refresh_jwt_token() + if ret: + self.browse(root) + else: + return None diff --git a/ross_ui/view/exportResults.py b/ross_ui/view/exportResults.py new file mode 100644 index 0000000..dca3a65 --- /dev/null +++ b/ross_ui/view/exportResults.py @@ -0,0 +1,57 @@ +from PyQt5 import QtWidgets + + +class ExportResults(QtWidgets.QDialog): + def __init__(self): + super().__init__() + self.setFixedSize(600, 300) + + layout_out = QtWidgets.QVBoxLayout() + groupbox = QtWidgets.QGroupBox("Export variables") + + vbox = QtWidgets.QVBoxLayout() + groupbox.setLayout(vbox) + + self.checkSpikeMat = QtWidgets.QCheckBox("Spike Waveforms") + self.checkSpikeTime = QtWidgets.QCheckBox("Spike Times") + self.checkClusters = QtWidgets.QCheckBox("Sorting Results (Cluster Indexes)") + + self.checkSpikeTime.setChecked(True) + self.checkClusters.setChecked(True) + + vbox.addWidget(self.checkSpikeMat) + vbox.addWidget(self.checkSpikeTime) + vbox.addWidget(self.checkClusters) + + groupboxradio = QtWidgets.QGroupBox("Export type") + + hboxradio = QtWidgets.QHBoxLayout() + + groupboxradio.setLayout(hboxradio) + + self.radioPickle = QtWidgets.QRadioButton("pickle") + self.radioMat = QtWidgets.QRadioButton("mat") + + self.radioPickle.setChecked(True) + + hboxradio.addWidget(self.radioPickle) + hboxradio.addWidget(self.radioMat) + + hboxpush = QtWidgets.QHBoxLayout() + self.pushExport = QtWidgets.QPushButton("Export") + self.pushClose = QtWidgets.QPushButton("Close") + + hboxpush.addWidget(self.pushExport) + hboxpush.addWidget(self.pushClose) + self.labelDownload = QtWidgets.QLabel() + self.progbar = QtWidgets.QProgressBar() + + layout_out.addWidget(groupbox) + layout_out.addWidget(groupboxradio) + layout_out.addWidget(self.labelDownload) + # layout_out.addWidget(self.progbar) + layout_out.addLayout(hboxpush) + + self.setLayout(layout_out) + + self.setWindowTitle("Export Dialog") diff --git a/ross_ui/view/icons/ross.png b/ross_ui/view/icons/ross.png index 0d0afb7..457254d 100644 Binary files a/ross_ui/view/icons/ross.png and b/ross_ui/view/icons/ross.png differ diff --git a/ross_ui/view/mainWindow.py b/ross_ui/view/mainWindow.py index cf7a077..fc2e5d6 100644 --- a/ross_ui/view/mainWindow.py +++ b/ross_ui/view/mainWindow.py @@ -1,4 +1,3 @@ -from PyQt5 import QtCore import pyqtgraph as pg import pyqtgraph.opengl as gl from PyQt5 import QtCore @@ -12,7 +11,7 @@ icon_path = './view/icons/' -__version__ = "1.0.0" +__version__ = "2.0.0-alpha" # ----------------------------------------------------------------------------- @@ -72,11 +71,16 @@ def createActions(self): self.closeAct.setIcon(QtGui.QIcon(icon_path + "Close.png")) self.closeAct.setEnabled(False) + self.exportAct = QtWidgets.QAction(self.tr("&Export"), self) + self.exportAct.setStatusTip(self.tr("Export results")) + self.exportAct.setIcon(QtGui.QIcon(icon_path + "Export.png")) + self.exportAct.setEnabled(False) + self.saveAct = QtWidgets.QAction(self.tr("&Save"), self) + self.saveAct.setEnabled(False) self.saveAct.setShortcut(QtGui.QKeySequence.Save) self.saveAct.setStatusTip(self.tr("Save")) self.saveAct.setIcon(QtGui.QIcon(icon_path + "Save.png")) - self.saveAct.setEnabled(False) self.saveAct.triggered.connect(self.onSave) self.saveAsAct = QtWidgets.QAction(self.tr("&Save As..."), self) @@ -91,6 +95,11 @@ def createActions(self): self.importRawAct.setIcon(QtGui.QIcon(icon_path + "Import.png")) self.importRawAct.triggered.connect(self.onImportRaw) + self.importRawActServer = QtWidgets.QAction(self.tr("&Raw Data From Server"), self) + self.importRawActServer.setStatusTip(self.tr("Import Raw Data From Server")) + self.importRawActServer.setIcon(QtGui.QIcon(icon_path + "Import.png")) + self.importRawActServer.triggered.connect(self.open_file_dialog_server) + self.importDetectedAct = QtWidgets.QAction(self.tr("&Detection Result"), self) self.importDetectedAct.setStatusTip(self.tr("Import Detection Result")) self.importDetectedAct.setIcon(QtGui.QIcon(icon_path + "Import.png")) @@ -251,11 +260,11 @@ def createActions(self): self.plotAct.setStatusTip(self.tr("Plotting Clusters")) self.plotAct.triggered.connect(self.onVisPlot) - self.detect3dAct = QtWidgets.QAction(self.tr("&Detection Histogram 3D")) + # self.detect3dAct = QtWidgets.QAction(self.tr("&Detection Histogram 3D")) # self.detect3dAct.setCheckable(True) # self.detect3dAct.setChecked(False) - self.detect3dAct.setStatusTip(self.tr("Plotting 3D histogram of detection result")) - self.detect3dAct.triggered.connect(self.onDetect3D) + # self.detect3dAct.setStatusTip(self.tr("Plotting 3D histogram of detection result")) + # self.detect3dAct.triggered.connect(self.onDetect3D) self.clusterwaveAct = QtWidgets.QAction(self.tr("&Plotting Clusters Waveforms")) # self.clusterwaveAct.setCheckable(True) @@ -336,15 +345,23 @@ def createMenubar(self): # Menu entry for file actions. self.fileMenu = self.menuBar().addMenu(self.tr("&File")) # self.fileMenu.addAction(self.openAct) - self.fileMenu.addAction(self.closeAct) - self.fileMenu.addSeparator() - self.fileMenu.addAction(self.saveAct) - self.fileMenu.addAction(self.saveAsAct) - self.fileMenu.addSeparator() + + # self.fileMenu.addSeparator() + # self.fileMenu.addAction(self.saveAct) + # self.fileMenu.addAction(self.saveAsAct) + # self.fileMenu.addSeparator() self.importMenu = self.fileMenu.addMenu(self.tr("&Import")) - self.importMenu.addActions((self.importRawAct, self.importDetectedAct, self.importSortedAct)) - self.exportMenu = self.fileMenu.addMenu(self.tr("&Export")) - self.exportMenu.addActions((self.exportRawAct, self.exportDetectedAct, self.exportSortedAct)) + self.importMenu.setEnabled(False) + self.importMenu.addActions( + (self.importRawAct, self.importRawActServer, self.importDetectedAct, self.importSortedAct)) + + self.fileMenu.addSeparator() + self.fileMenu.addAction(self.exportAct) + self.fileMenu.addAction(self.closeAct) + + # self.exportMenu = self.fileMenu.addMenu(self.tr("&Export")) + # self.exportMenu.setEnabled(False) + # self.exportMenu.addActions((self.exportRawAct, self.exportDetectedAct, self.exportSortedAct)) # ------------------- tool from menu ------------------ # Menu entry for tool acions @@ -354,6 +371,7 @@ def createMenubar(self): # run menu self.runMenu = self.menuBar().addMenu(self.tr("&Run")) + self.runMenu.setEnabled(False) self.runMenu.addAction(self.detectAct) self.runMenu.addAction(self.sortAct) # self.runMenu.addAction(self.batchAct) @@ -373,7 +391,8 @@ def createMenubar(self): # Menu entry for visualization self.visMenu = self.menuBar().addMenu(self.tr("&Visualization")) - self.visMenu.addAction(self.detect3dAct) + self.visMenu.setEnabled(False) + # self.visMenu.addAction(self.detect3dAct) self.visMenu.addAction(self.clusterwaveAct) self.visMenu.addAction(self.livetimeAct) self.visMenu.addAction(self.isiAct) @@ -425,7 +444,7 @@ def align_subwindows(self): self.subwindow_detect3d.setGeometry(0, 0, int(self.w_mdi / 2), int(self.h_mdi * 2 / 3)) self.subwindow_detect3d.setVisible(False) - self.subwindow_3d.setGeometry(0, 0, int(self.w_mdi / 2), int(self.h_mdi * 2 / 3)) + # self.subwindow_3d.setGeometry(0, 0, int(self.w_mdi / 2), int(self.h_mdi * 2 / 3)) self.subwindow_3d.setVisible(False) self.subwindow_pca_manual.setVisible(False) @@ -474,20 +493,13 @@ def createSubWindows(self): layout_detect3d.addWidget(self.plot_detect3d) self.widget_detect3d.setLayout(layout_detect3d) - # Plot Clusters Waveform - # self.widget_clusterwaveform = pyqtgraph.widgets.MultiPlotWidget.MultiPlotWidget() - - # Plot Live Time - # self.widget_livetime = QtWidgets.QWidget() - # Plot ISI # self.widget_isi = QtWidgets.QWidget() # 3D Plot self.widget_3d = QtWidgets.QWidget() - self.plot_3d = gl.GLViewWidget() - self.plot_3d.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) - layout_plot_3d = QtWidgets.QGridLayout() + self.widget_3d.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum) + layout_widget_3d = QtWidgets.QGridLayout() layout_plot_pca_manual = QtWidgets.QGridLayout() self.axis1ComboBox = QtWidgets.QComboBox() @@ -521,23 +533,35 @@ def createSubWindows(self): self.plotButton = QtWidgets.QPushButton(text='Plot') - axisWidget = QtWidgets.QWidget() + layout_widget_3d.addWidget(QtWidgets.QLabel("Axis 1"), 0, 0) + layout_widget_3d.addWidget(QtWidgets.QLabel("Axis 2"), 0, 1) + layout_widget_3d.addWidget(QtWidgets.QLabel("Axis 3"), 0, 2) + + layout_widget_3d.addWidget(self.axis1ComboBox, 1, 0) + layout_widget_3d.addWidget(self.axis2ComboBox, 1, 1) + layout_widget_3d.addWidget(self.axis3ComboBox, 1, 2) + layout_widget_3d.addWidget(self.plotButton, 1, 3) + layout_widget_3d.addWidget(self.closeButton3d, 1, 4) + + self.widget_3d.setLayout(layout_widget_3d) + + # axisWidget = QtWidgets.QWidget() axisWidget_pca_manual = QtWidgets.QWidget() - axisLayout = QtWidgets.QHBoxLayout() - label = QtWidgets.QLabel('axis 1') - axisLayout.addWidget(label) - axisLayout.addWidget(self.axis1ComboBox) - label = QtWidgets.QLabel('axis 2') - axisLayout.addWidget(label) - axisLayout.addWidget(self.axis2ComboBox) - label = QtWidgets.QLabel('axis 3') - axisLayout.addWidget(label) - axisLayout.addWidget(self.axis3ComboBox) - axisLayout.addStretch(5) - axisLayout.addWidget(self.plotButton) - axisLayout.addWidget(self.closeButton3d) - axisWidget.setLayout(axisLayout) + # axisLayout = QtWidgets.QHBoxLayout() + # label = QtWidgets.QLabel('axis 1') + # axisLayout.addWidget(label) + # axisLayout.addWidget(self.axis1ComboBox) + # label = QtWidgets.QLabel('axis 2') + # axisLayout.addWidget(label) + # axisLayout.addWidget(self.axis2ComboBox) + # label = QtWidgets.QLabel('axis 3') + # axisLayout.addWidget(label) + # axisLayout.addWidget(self.axis3ComboBox) + # axisLayout.addStretch(5) + # axisLayout.addWidget(self.plotButton) + # axisLayout.addWidget(self.closeButton3d) + # axisWidget.setLayout(axisLayout) axisLayout_pca_button = QtWidgets.QHBoxLayout() axisLayout_pca_main = QtWidgets.QVBoxLayout() @@ -556,10 +580,7 @@ def createSubWindows(self): axisLayout_pca_main.addLayout(axisLayout_pca_button) axisWidget_pca_manual.setLayout(axisLayout_pca_main) - - layout_plot_3d.addWidget(axisWidget) - layout_plot_3d.addWidget(self.plot_3d) - self.widget_3d.setLayout(layout_plot_3d) + # layout_plot_3d.addWidget(self.plot_3d) self.painter = QPainter() layout_plot_pca_manual.addWidget(axisWidget_pca_manual) @@ -944,18 +965,11 @@ def createSubWindows(self): self.mdiArea.addSubWindow(self.widget_waveform).setWindowTitle("Waveforms") self.mdiArea.addSubWindow(self.plot_histogram_pca).setWindowTitle("2D PCA Histogram") - # self.mdiArea.addSubWindow(self.widget_clusters).setWindowTitle("Clusters") self.mdiArea.addSubWindow(self.widget_raw).setWindowTitle("Raw Data") self.mdiArea.addSubWindow(self.widget_settings).setWindowTitle("Settings") - # self.mdiArea.addSubWindow(self.widget_visualizations).setWindowTitle("Visualizations") self.mdiArea.addSubWindow(self.widget_detect3d).setWindowTitle("DetectionResult 3D Hist") - # self.mdiArea.addSubWindow(self.widget_clusters).setWindowTitle("Waveforms") - # self.mdiArea.addSubWindow(self.widget_clusterwaveform).setWindowTitle("Plot Clusters Waveforms") - # self.mdiArea.addSubWindow(self.widget_livetime).setWindowTitle("Live Time") - # self.mdiArea.addSubWindow(self.widget_isi).setWindowTitle("ISI") self.mdiArea.addSubWindow(self.widget_3d).setWindowTitle("3D Plot") self.mdiArea.addSubWindow(self.widget_assign_manual).setWindowTitle("Assign to Nearest") - self.mdiArea.addSubWindow(self.widget_pca_manual).setWindowTitle('PCA MANUAL') self.mdiArea.addSubWindow(self.plot_clusters_pca).setWindowTitle("2D PCA Clusters") @@ -964,15 +978,12 @@ def createSubWindows(self): self.subwindow_pca_histograms = subwindow_list[1] self.subwindow_raw = subwindow_list[2] self.subwindow_settings = subwindow_list[3] - # self.subwindow_visualization = subwindow_list[4] self.subwindow_detect3d = subwindow_list[4] self.subwindow_3d = subwindow_list[5] + self.subwindow_3d.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Minimum) self.subwindow_assign = subwindow_list[6] self.subwindow_pca_manual = subwindow_list[7] self.subwindow_pca_clusters = subwindow_list[8] - # self.subwindow_clusterwave = subwindow_list[5] - # self.subwindow_livetime = subwindow_list[6] - # self.subwindow_isi = subwindow_list[7] for subwindow in subwindow_list: subwindow.setWindowFlags( diff --git a/ross_ui/view/serverAddress.py b/ross_ui/view/serverAddress.py index 7b88739..042bcc6 100644 --- a/ross_ui/view/serverAddress.py +++ b/ross_ui/view/serverAddress.py @@ -4,18 +4,26 @@ class Server_Dialog(QtWidgets.QDialog): def __init__(self, server_text): super().__init__() - self.setFixedSize(400, 227) + self.setFixedSize(400, 150) self.buttonBox = QtWidgets.QDialogButtonBox(self) - self.buttonBox.setGeometry(QtCore.QRect(30, 160, 341, 32)) + layout_main = QtWidgets.QVBoxLayout() + # self.buttonBox.setGeometry(QtCore.QRect(30, 160, 341, 32)) self.buttonBox.setOrientation(QtCore.Qt.Horizontal) self.buttonBox.setStandardButtons(QtWidgets.QDialogButtonBox.Cancel | QtWidgets.QDialogButtonBox.Ok) self.label = QtWidgets.QLabel(self) - self.label.setGeometry(QtCore.QRect(80, 50, 141, 21)) + # self.label.setGeometry(QtCore.QRect(80, 50, 141, 21)) self.lineEdit = QtWidgets.QLineEdit(self) - self.lineEdit.setGeometry(QtCore.QRect(80, 100, 221, 31)) + # self.lineEdit.setGeometry(QtCore.QRect(80, 100, 221, 31)) self.setWindowTitle("Server Address") self.label.setText("Enter Server Address:") self.lineEdit.setText(server_text) + + layout_main.addWidget(self.label) + layout_main.addWidget(self.lineEdit) + layout_main.addWidget(QtWidgets.QLabel("")) + layout_main.addWidget(self.buttonBox) + + self.setLayout(layout_main) diff --git a/ross_ui/view/serverFileDialog.py b/ross_ui/view/serverFileDialog.py new file mode 100644 index 0000000..54e2199 --- /dev/null +++ b/ross_ui/view/serverFileDialog.py @@ -0,0 +1,33 @@ +from PyQt5 import QtWidgets + + +class ServerFileDialog(QtWidgets.QDialog): + def __init__(self): + super().__init__() + self.setFixedSize(600, 300) + + self.layout_out = QtWidgets.QVBoxLayout() + self.layout_file_but = QtWidgets.QHBoxLayout() + + self.line_address = QtWidgets.QLineEdit() + self.line_address.setEnabled(False) + self.layout_out.addWidget(self.line_address) + + self.list_folder = QtWidgets.QListWidget() + self.layout_out.addWidget(self.list_folder) + + self.layout_out.addWidget(QtWidgets.QLabel('For mat files enter the variable name (if more than one variable ' + 'is stored).')) + + self.line_varname = QtWidgets.QLineEdit() + self.layout_out.addWidget(self.line_varname) + + self.push_open = QtWidgets.QPushButton('Open') + self.push_cancel = QtWidgets.QPushButton('Cancel') + self.layout_file_but.addWidget(self.push_open) + self.layout_file_but.addWidget(self.push_cancel) + self.layout_out.addLayout(self.layout_file_but) + + self.setLayout(self.layout_out) + + self.setWindowTitle("Server File Dialog") diff --git a/ross_ui/view/signin.py b/ross_ui/view/signin.py index e355951..00d5bf0 100644 --- a/ross_ui/view/signin.py +++ b/ross_ui/view/signin.py @@ -1,35 +1,43 @@ -from PyQt5 import QtCore, QtWidgets +from PyQt5 import QtWidgets class Signin_Dialog(QtWidgets.QDialog): def __init__(self, server): super().__init__() self.url = server - self.setFixedSize(400, 300) - self.label = QtWidgets.QLabel(self) - self.label.setGeometry(QtCore.QRect(40, 60, 121, 21)) - self.label_2 = QtWidgets.QLabel(self) - self.label_2.setGeometry(QtCore.QRect(50, 130, 121, 20)) + self.setFixedSize(300, 150) - self.textEdit_username = QtWidgets.QLineEdit(self) - self.textEdit_username.setGeometry(QtCore.QRect(145, 55, 141, 31)) + l_label = QtWidgets.QVBoxLayout() + l_line = QtWidgets.QVBoxLayout() + l_push = QtWidgets.QHBoxLayout() + l1 = QtWidgets.QHBoxLayout() + l2 = QtWidgets.QVBoxLayout() - self.textEdit_password = QtWidgets.QLineEdit(self) - self.textEdit_password.setGeometry(QtCore.QRect(145, 125, 141, 31)) + self.label_u = QtWidgets.QLabel("Username") + self.label_p = QtWidgets.QLabel("Password") + + l_label.addWidget(self.label_u) + l_label.addWidget(self.label_p) + + self.textEdit_username = QtWidgets.QLineEdit() + self.textEdit_password = QtWidgets.QLineEdit() self.textEdit_password.setEchoMode(QtWidgets.QLineEdit.Password) - self.label_res = QtWidgets.QLabel(self) - self.label_res.setGeometry(QtCore.QRect(50, 180, 361, 16)) + l_line.addWidget(self.textEdit_username) + l_line.addWidget(self.textEdit_password) + + l1.addLayout(l_label) + l1.addLayout(l_line) + + self.pushButton_in = QtWidgets.QPushButton("Sign In") + self.pushButton_up = QtWidgets.QPushButton("Sign Up") + + l_push.addWidget(self.pushButton_in) + l_push.addWidget(self.pushButton_up) - self.pushButton_in = QtWidgets.QPushButton(self) - self.pushButton_in.setGeometry(QtCore.QRect(130, 220, 121, 31)) + l2.addLayout(l1) + l2.addLayout(l_push) - self.pushButton_up = QtWidgets.QPushButton(self) - self.pushButton_up.setGeometry(QtCore.QRect(260, 220, 111, 31)) + self.setLayout(l2) - self.setWindowTitle("Sign In/Up") - self.label.setText("Username") - self.label_2.setText("Password") - self.label_res.setText("") - self.pushButton_in.setText("Sign In") - self.pushButton_up.setText("Sign Up") + self.setWindowTitle('Authentication') diff --git a/test/backend/test_detection.py b/test/backend/test_detection.py new file mode 100644 index 0000000..af2d754 --- /dev/null +++ b/test/backend/test_detection.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass + +import numpy as np + +from ross_backend.resources.funcs.detection import startDetection + + +@dataclass +class Config: + filter_type = "butter" + filter_order = 4 + pass_freq = 300 + stop_freq = 3000 + sampling_rate = 40000 + thr_method = "median" + side_thr = "negative" + pre_thr = 40 + post_thr = 59 + dead_time = 20 + + +def get_data(num_spikes=3, config=None): + if config is None: + config = Config() + x = np.arange(0, 0.001, 1 / config.sampling_rate) + y = 2 * np.sin(2 * np.pi * 1000 * x) + z = np.zeros((config.pre_thr + config.post_thr + config.dead_time)) + return np.tile(np.concatenate((z, y, z)), num_spikes) + + +def test_start_detection(): + config = Config() + num_spikes = 4 + data = get_data(num_spikes=num_spikes) + SpikeMat, _, _, _ = startDetection(data, config) + + assert SpikeMat.shape[0] == num_spikes diff --git a/test/backend/test_sorting.py b/test/backend/test_sorting.py new file mode 100644 index 0000000..86a9bd3 --- /dev/null +++ b/test/backend/test_sorting.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass + +import numpy as np + +from ross_backend.resources.funcs.gmm import gmm_sorter + + +@dataclass +class Config: + g_max = 10 + g_min = 2 + max_iter = 1000 + error = 1e-5 + + +def generate_data(num_clusters=2, n_per_class=200): + data = np.zeros((num_clusters * n_per_class, 2)) + for i in range(num_clusters): + data[i * n_per_class: (i + 1) * n_per_class, :] = np.random.multivariate_normal( + [2 * i, 2 * i], [[0.01, 0], [0, 0.01]], size=(n_per_class,) + ) + return data + + +def test_gmm(): + num_clusters = 3 + config = Config() + data = generate_data(num_clusters) + cluster_index = gmm_sorter(data, config) + + assert len(np.unique(cluster_index)) == num_clusters