forked from tom-doerr/chatgpt_commit_message_hook
/
prepare-commit-msg
executable file
·188 lines (152 loc) · 6.07 KB
/
prepare-commit-msg
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
#!/usr/bin/env python3
import os
import sys
import configparser
import requests
import git
import re
if os.path.exists(".git/rebase-merge") or os.path.exists(".git/rebase-apply") or os.path.exists(".git/MERGE_HEAD"):
sys.exit(0)
COMMIT_MSG_FILE = sys.argv[1]
#COMMIT_SOURCE = sys.argv[2]
#SHA1 = sys.argv[3]
current_commit_file_content = open(COMMIT_MSG_FILE, 'r').read()
# commit_message = 'test 1234'
# with open(COMMIT_MSG_FILE, 'w') as f:
# f.write(f'{commit_message}: ')
# f.write(current_commit_file_content)
CONFIG_DIR = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config"))
API_KEYS_LOCATION = os.path.join(CONFIG_DIR, "openaiapirc")
def count_staged_changes(repo_path="."):
# Open the repository
repo = git.Repo(repo_path)
# Get the Git index
index = repo.index
# Get the staged changes as a text diff
# staged_diff = repo.diff("HEAD", "index", create_patch=True)
staged_diff = index.diff("HEAD", create_patch=True, R=True)
changes_dict = {}
for diff in staged_diff:
diff_text = diff.diff.decode("utf-8")
added_lines = sum(1 for line in diff_text.splitlines() if line.startswith("+"))
deleted_lines = sum(1 for line in diff_text.splitlines() if line.startswith("-"))
# Save the change to the dictionary
changes_dict[diff.b_path] = (added_lines, deleted_lines, diff_text)
return changes_dict
def get_staged_changes_summary(changes_dict, n):
# Check if the total number of changed lines is less than n
total_changed_lines = sum(added_lines + deleted_lines for _, (added_lines, deleted_lines, _) in changes_dict.items())
if total_changed_lines < n:
return get_full_diff(changes_dict)
# Generate the summary string
summary = ""
for file_path, (added_lines, deleted_lines, _) in changes_dict.items():
summary += f"{file_path} | {added_lines + deleted_lines} "
if added_lines > 0:
summary += "+" * added_lines
if deleted_lines > 0:
summary += "-" * deleted_lines
summary += "\n"
return summary
def get_full_diff(changes_dict):
# Generate the full diff string
full_diff = ""
for diff_path, (_, _, diff_text) in changes_dict.items():
full_diff += f"diff --git a/{diff_path} b/{diff_path}:\n"
full_diff += '\n'.join(filter(lambda line: re.match(r'^[+\-]', line), diff_text.split('\n')))
full_diff += "\n"
return full_diff
def get_openai_chat_response(prompt, model, api_key, proxy_server=None):
# Set the URL of the OpenAI API endpoint
url = "https://api.openai.com/v1/chat/completions"
# Set the HTTP headers for the API request
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
# Set the HTTP proxy for the requests library, if provided
proxies = {
"http": proxy_server,
"https": proxy_server
} if proxy_server else None
# Set the data for the API request
data = {
"model": model,
"messages": prompt
}
# Send the API request with the specified proxy server, if provided
response = requests.post(url, headers=headers, json=data, proxies=proxies)
# Process the API response
if response.status_code == 200:
messages = response.json()["choices"][0]["message"]["content"]
return messages
else:
print("Error:", response.text)
def create_template_ini_file():
# """
# If the ini file does not exist create it and add the organization_id and
# secret_key
# """
"""
If the ini file does not exist create it and add the secret_key
"""
if not os.path.isfile(API_KEYS_LOCATION):
with open(API_KEYS_LOCATION, "w") as f:
f.write("[openai]\n")
f.write("secret_key=\n")
f.write("proxy=\n")
f.write("max_changed_lines=\n")
print("OpenAI API config file created at {}".format(API_KEYS_LOCATION))
print("Please edit it and add your organization ID and secret key")
print(
"If you do not yet have an organization ID and secret key, you\n"
"need to register for OpenAI Codex: \n"
"https://openai.com/blog/openai-codex/"
)
sys.exit(1)
def initialize_openai_api():
"""
Initialize the OpenAI API
"""
# Check if file at API_KEYS_LOCATION exists
create_template_ini_file()
config = configparser.ConfigParser()
config.read(API_KEYS_LOCATION)
# openai.organization_id = config["openai"]["organization_id"].strip('"').strip("'")
# Read the API key from the configuration
api_key = config["openai"]["secret_key"].strip('"').strip("'")
# Check if the API key is None
if api_key is None:
raise ValueError("OpenAI API key is not set in the configuration")
# Read the proxy from the configuration
proxy = config["openai"]["proxy"].strip('"').strip("'")
# Read the maximum number of changed lines from the configuration
max_changed_lines = config["openai"]["max_changed_lines"]
if max_changed_lines:
max_changed_lines = int(max_changed_lines)
# Check if the maximum number of changed lines is valid
if max_changed_lines is None or max_changed_lines <= 0:
max_changed_lines = 80
return api_key, proxy, max_changed_lines
api_key, proxy, max_changed_lines = initialize_openai_api()
changes_dict = count_staged_changes()
summary = get_staged_changes_summary(changes_dict, max_changed_lines)
messages = [
{'role': 'system', 'content': 'You are a helpful assistant writes short git commit messages based on the git diff.'},
{'role': 'user', 'content': f'{summary}\n\nWrite the commit message.'},
]
# response = openai.ChatCompletion.create(
# model="gpt-3.5-turbo",
# # model="text-davinci-003",
# messages=messages,
# )
# replace it with get_openai_chat_response
response_text = get_openai_chat_response(
prompt=messages,
model="gpt-3.5-turbo",
api_key=api_key,
proxy_server=proxy,
)
content_whole_file = response_text + current_commit_file_content
with open(COMMIT_MSG_FILE, 'w') as f:
f.write(content_whole_file)