-
Notifications
You must be signed in to change notification settings - Fork 16
/
main.rs
100 lines (81 loc) · 2.56 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
use hyper::{body::Buf, header, Body, Client, Request};
use hyper_tls::HttpsConnector;
use serde_derive::{Deserialize, Serialize};
use std::{env, env::args};
#[derive(Deserialize, Debug)]
struct OpenAIChoices {
text: String,
}
#[derive(Deserialize, Debug)]
struct OpenAIResponse {
choices: Vec<OpenAIChoices>,
}
#[derive(Serialize, Debug)]
struct OpenAIRequest {
model: String,
prompt: String,
max_tokens: u32,
stop: String,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// Check for environment variable OPENAI_KEY
let api_key = match env::var("OPENAI_KEY") {
Ok(key) => key,
Err(_) => {
println!("Error: please create an environment variable OPENAI_KEY");
std::process::exit(1);
}
};
let https = HttpsConnector::new();
let client = Client::builder().build(https);
let uri = "https://api.openai.com/v1/completions";
let model = String::from("text-davinci-003");
let stop = String::from("Text");
let default_prompt =
"Given text, return 1 bash command. Text:list contents of a directory. Command:ls";
let mut user_input = String::new();
let mut arguments: Vec<String> = args().collect();
arguments.remove(0);
if arguments.is_empty() {
println!("Welcome to Rusty! Enter an argument to get started.");
std::process::exit(1);
}
for x in arguments {
user_input.push(' ');
user_input.push_str(&x);
}
let auth_header_val = format!("Bearer {}", api_key);
let openai_request = OpenAIRequest {
model,
prompt: format!("{} Text:{}. Command:", default_prompt, user_input),
max_tokens: 64,
stop,
};
let body = Body::from(serde_json::to_vec(&openai_request)?);
let req = Request::post(uri)
.header(header::CONTENT_TYPE, "application/json")
.header("Authorization", &auth_header_val)
.body(body)
.unwrap();
let res = client.request(req).await?;
let body = hyper::body::aggregate(res).await?;
let json: OpenAIResponse = match serde_json::from_reader(body.reader()) {
Ok(response) => response,
Err(_) => {
println!("Error calling OpenAI. Check environment variable OPENAI_KEY");
std::process::exit(1);
}
};
println!(
"{}",
json.choices[0]
.text
.split('\n')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect::<Vec<_>>()
.join("\n")
);
Ok(())
}