-
Notifications
You must be signed in to change notification settings - Fork 50
/
center_client.rs
388 lines (365 loc) · 15.6 KB
/
center_client.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
// Copyright 2022 - 2023 Wenmeng See the COPYRIGHT
// file at the top-level directory of this distribution.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
//
// Author: tickbh
// -----
// Created Date: 2023/09/25 10:08:56
use std::sync::Arc;
use std::time::Duration;
use std::{collections::HashMap, io};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::mpsc::Receiver;
use tokio::{io::split, net::TcpStream, sync::mpsc::channel};
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::mpsc::Sender,
};
use tokio_rustls::{client::TlsStream, TlsConnector};
use webparse::{BinaryMut, Buf};
use crate::proxy::ProxyServer;
use crate::{
HealthCheck, Helper, MappingConfig, ProtClose, ProtCreate, ProtFrame, ProxyConfig, ProxyResult,
TransStream, VirtualStream,
};
/// 中心客户端
/// 负责与服务端建立连接,断开后自动再重连
pub struct CenterClient {
option: ProxyConfig,
/// tls的客户端连接信息
tls_client: Option<Arc<rustls::ClientConfig>>,
/// tls的客户端连接域名
domain: Option<String>,
/// 连接中心服务器的地址
server_addr: String,
/// 内网映射的相关消息
mappings: Vec<MappingConfig>,
/// 存在普通连接和加密连接,此处不为None则表示普通连接
stream: Option<TcpStream>,
/// 存在普通连接和加密连接,此处不为None则表示加密连接
tls_stream: Option<TlsStream<TcpStream>>,
/// 绑定的下一个sock_map映射,为单数
next_id: u32,
/// 发送Create,并将绑定的Sender发到做绑定
sender_work: Sender<(ProtCreate, Sender<ProtFrame>)>,
/// 接收的Sender绑定,开始服务时这值move到工作协程中,所以不能二次调用服务
receiver_work: Option<Receiver<(ProtCreate, Sender<ProtFrame>)>>,
/// 发送协议数据,接收到服务端的流数据,转发给相应的Stream
sender: Sender<ProtFrame>,
/// 接收协议数据,并转发到服务端。
receiver: Option<Receiver<ProtFrame>>,
}
impl CenterClient {
pub fn new(
option: ProxyConfig,
server_addr: String,
tls_client: Option<Arc<rustls::ClientConfig>>,
domain: Option<String>,
mappings: Vec<MappingConfig>,
) -> Self {
let (sender, receiver) = channel::<ProtFrame>(100);
let (sender_work, receiver_work) = channel::<(ProtCreate, Sender<ProtFrame>)>(10);
Self {
option,
tls_client,
domain,
server_addr,
mappings,
stream: None,
tls_stream: None,
next_id: 1,
sender_work,
receiver_work: Some(receiver_work),
sender,
receiver: Some(receiver),
}
}
async fn inner_connect(
tls_client: Option<Arc<rustls::ClientConfig>>,
server_addr: String,
domain: Option<String>,
) -> ProxyResult<(Option<TcpStream>, Option<TlsStream<TcpStream>>)> {
if tls_client.is_some() {
let connector = TlsConnector::from(tls_client.unwrap());
let stream = HealthCheck::connect(&server_addr).await?;
// 这里的域名只为认证设置
let domain =
rustls::ServerName::try_from(&*domain.unwrap_or("soft.wm-proxy.com".to_string()))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?;
let outbound = connector.connect(domain, stream).await?;
Ok((None, Some(outbound)))
} else {
let outbound = HealthCheck::connect(&server_addr).await?;
Ok((Some(outbound), None))
}
}
pub async fn connect(&mut self) -> ProxyResult<bool> {
let (stream, tls_stream) = Self::inner_connect(
self.tls_client.clone(),
self.server_addr.clone(),
self.domain.clone(),
)
.await?;
self.stream = stream;
self.tls_stream = tls_stream;
Ok(self.stream.is_some() || self.tls_stream.is_some())
}
async fn inner_serve<T>(
option: &ProxyConfig,
stream: T,
sender: &mut Sender<ProtFrame>,
receiver_work: &mut Receiver<(ProtCreate, Sender<ProtFrame>)>,
receiver: &mut Receiver<ProtFrame>,
mappings: &mut Vec<MappingConfig>,
) -> ProxyResult<()>
where
T: AsyncRead + AsyncWrite + Unpin,
{
let mut map = HashMap::<u64, Sender<ProtFrame>>::new();
let mut read_buf = BinaryMut::new();
let mut write_buf = BinaryMut::new();
let (mut reader, mut writer) = split(stream);
let mut vec = Vec::with_capacity(4096);
vec.resize(4096, 0);
let is_closed;
if option.username.is_some() && option.password.is_some() {
ProtFrame::new_token(
option.username.clone().unwrap(),
option.password.clone().unwrap(),
)
.encode(&mut write_buf)?;
}
if mappings.len() > 0 {
ProtFrame::new_mapping(0, mappings.clone()).encode(&mut write_buf)?;
}
loop {
let _ = tokio::select! {
// 严格的顺序流
biased;
// 新的流建立,这里接收Create并进行绑定
r = receiver_work.recv() => {
if let Some((create, sender)) = r {
map.insert(create.sock_map(), sender);
let _ = create.encode(&mut write_buf);
}
}
// 数据的接收,并将数据写入给远程端
r = receiver.recv() => {
if let Some(p) = r {
let _ = p.encode(&mut write_buf);
}
}
// 数据的等待读取,一旦流可读则触发,读到0则关闭主动关闭所有连接
r = reader.read(&mut vec) => {
match r {
Ok(0)=>{
is_closed=true;
break;
}
Ok(n) => {
read_buf.put_slice(&vec[..n]);
}
Err(_err) => {
is_closed = true;
break;
},
}
}
// 一旦有写数据,则尝试写入数据,写入成功后扣除相应的数据
r = writer.write(write_buf.chunk()), if write_buf.has_remaining() => {
match r {
Ok(n) => {
write_buf.advance(n);
if !write_buf.has_remaining() {
write_buf.clear();
}
}
Err(e) => {
log::info!("写入的时候发生错误,错误内容为:{:?}", e);
},
}
}
};
loop {
// 将读出来的数据全部解析成ProtFrame并进行相应的处理,如果是0则是自身消息,其它进行转发
match Helper::decode_frame(&mut read_buf)? {
Some(p) => {
match p {
ProtFrame::Create(p) => {
let domain = p.domain().clone().unwrap_or(String::new());
let mut mapping = None;
for m in &*mappings {
if m.domain == domain || m.name == domain {
mapping = Some(m);
}
}
if mapping.is_none() {
log::info!("本地地址为空,无法做内网映射");
log::warn!("local addr is none, can't mapping");
let _ = sender.send(ProtFrame::new_close(p.sock_map())).await;
continue;
}
let (virtual_sender, virtual_receiver) = channel::<ProtFrame>(10);
map.insert(p.sock_map(), virtual_sender);
if mapping.as_ref().unwrap().is_proxy() {
let stream = VirtualStream::new(
p.sock_map(),
sender.clone(),
virtual_receiver,
);
let proxy_server = ProxyServer::new(
option.flag,
option.username.clone(),
option.password.clone(),
option.udp_bind.clone(),
Some(mapping.as_ref().unwrap().headers.clone()),
);
tokio::spawn(async move {
// 处理代理的能力
let _ = proxy_server.deal_proxy(stream).await;
});
} else {
if mapping.as_ref().unwrap().local_addr.is_none() {
log::info!("本地地址为空,无法做内网映射");
log::warn!("local addr is none, can't mapping");
continue;
}
let domain = mapping.as_ref().unwrap().local_addr.unwrap();
let sock_map = p.sock_map();
let sender = sender.clone();
tokio::spawn(async move {
match HealthCheck::connect(&domain).await {
Ok(tcp) => {
let trans = TransStream::new(
tcp,
sock_map,
sender,
virtual_receiver,
);
let _ = trans.copy_wait().await;
}
Err(e) => {
log::trace!(
"连接地址:{},发生错误:{:?}",
domain,
e
);
let _ = sender
.send(ProtFrame::new_close(sock_map))
.await;
}
}
});
}
}
ProtFrame::Data(_) => {
if let Some(sender) = map.get(&p.sock_map()) {
let _ = sender.try_send(p);
}
}
ProtFrame::Close(p) => {
if p.sock_map() == 0 {
log::warn!("客户端被服务端关闭:{}", p.reason());
} else if let Some(sender) = map.get(&p.sock_map()) {
let _ = sender.try_send(ProtFrame::Close(p));
}
}
ProtFrame::Mapping(_) => {}
ProtFrame::Token(_) => todo!(),
}
}
None => {
break;
}
}
}
if !read_buf.has_remaining() {
read_buf.clear();
}
}
if is_closed {
for v in map {
let _ = v.1.try_send(ProtFrame::Close(ProtClose::new(v.0)));
}
}
Ok(())
}
pub async fn serve(&mut self) -> ProxyResult<()> {
let tls_client = self.tls_client.clone();
let server = self.server_addr.clone();
let domain = self.domain.clone();
let option = self.option.clone();
let stream = self.stream.take();
let tls_stream = self.tls_stream.take();
let mut client_sender = self.sender.clone();
let mut client_receiver = self.receiver.take().unwrap();
let mut receiver_work = self.receiver_work.take().unwrap();
let mut mappings = self.mappings.clone();
tokio::spawn(async move {
let mut stream = stream;
let mut tls_stream = tls_stream;
loop {
if stream.is_some() {
let _ = Self::inner_serve(
&option,
stream.take().unwrap(),
&mut client_sender,
&mut receiver_work,
&mut client_receiver,
&mut mappings,
)
.await;
tokio::time::sleep(Duration::from_millis(1000)).await;
} else if tls_stream.is_some() {
let _ = Self::inner_serve(
&option,
tls_stream.take().unwrap(),
&mut client_sender,
&mut receiver_work,
&mut client_receiver,
&mut mappings,
)
.await;
tokio::time::sleep(Duration::from_millis(1000)).await;
};
match Self::inner_connect(tls_client.clone(), server.clone(), domain.clone()).await
{
Ok((s, tls)) => {
stream = s;
tls_stream = tls;
}
Err(_err) => {
tokio::time::sleep(Duration::from_millis(1000)).await;
}
}
}
});
Ok(())
}
fn calc_next_id(&mut self) -> u64 {
let id = self.next_id;
self.next_id += 2;
Helper::calc_sock_map(self.option.server_id, id)
}
pub async fn deal_new_stream<T>(&mut self, inbound: T) -> ProxyResult<()>
where
T: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
{
let id = self.calc_next_id();
let sender = self.sender.clone();
let (stream_sender, stream_receiver) = channel::<ProtFrame>(10);
let _ = self
.sender_work
.send((ProtCreate::new(id, None), stream_sender))
.await;
tokio::spawn(async move {
let trans = TransStream::new(inbound, id, sender, stream_receiver);
let _ = trans.copy_wait().await;
});
Ok(())
}
}