Skip to content

Commit

Permalink
Implements handle_request for Handler.
Browse files Browse the repository at this point in the history
This adds the `do_handle_request` function, which peeks at the query
name and runs the right request handler. The implemented request
handlers are:

- `do_handle_request_myip` returns the IP address of the requester,
- `do_handle_request_counter` returns the counter as a TXT record,
- `do_handle_request_hello` returns a greeting as a TXT record,
- `do_handle_request_default` returns a NXDomain response.
  • Loading branch information
xfbs committed Oct 2, 2022
1 parent a65a1d8 commit 5e8a720
Showing 1 changed file with 169 additions and 7 deletions.
176 changes: 169 additions & 7 deletions src/handler.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,186 @@
use crate::Options;
use trust_dns_server::server::{Request, RequestHandler, ResponseHandler, ResponseInfo};
use std::{
borrow::Borrow,
net::IpAddr,
str::FromStr,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
};
use tracing::*;
use trust_dns_server::{
authority::MessageResponseBuilder,
client::rr::{rdata::TXT, LowerName, Name, RData, Record},
proto::op::{Header, MessageType, OpCode, ResponseCode},
server::{Request, RequestHandler, ResponseHandler, ResponseInfo},
};

#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("Invalid OpCode {0:}")]
InvalidOpCode(OpCode),
#[error("Invalid MessageType {0:}")]
InvalidMessageType(MessageType),
#[error("Invalid Zone {0:}")]
InvalidZone(LowerName),
#[error("IO error: {0:}")]
Io(#[from] std::io::Error),
}

/// DNS Request Handler
#[derive(Clone, Debug)]
pub struct Handler {}
pub struct Handler {
/// Request counter, incremented on every successful request.
pub counter: Arc<AtomicU64>,
/// Domain to serve DNS responses for (requests for other domains are silently ignored).
pub root_zone: LowerName,
/// Zone name for counter (counter.dnsfun.dev)
pub counter_zone: LowerName,
/// Zone name for myip (myip.dnsfun.dev)
pub myip_zone: LowerName,
/// Zone name for hello (hello.dnsfun.dev)
pub hello_zone: LowerName,
}

impl Handler {
/// Create new handler from command-line options.
pub fn from_options(_options: &Options) -> Self {
Handler {}
pub fn from_options(options: &Options) -> Self {
let domain = &options.domain;
Handler {
root_zone: LowerName::from(Name::from_str(domain).unwrap()),
counter: Arc::new(AtomicU64::new(0)),
counter_zone: LowerName::from(Name::from_str(&format!("counter.{domain}")).unwrap()),
myip_zone: LowerName::from(Name::from_str(&format!("myip.{domain}")).unwrap()),
hello_zone: LowerName::from(Name::from_str(&format!("hello.{domain}")).unwrap()),
}
}

/// Handle requests for myip.{domain}.
async fn do_handle_request_myip<R: ResponseHandler>(
&self,
request: &Request,
mut responder: R,
) -> Result<ResponseInfo, Error> {
self.counter.fetch_add(1, Ordering::SeqCst);
let builder = MessageResponseBuilder::from_message_request(request);
let mut header = Header::response_from_request(request.header());
header.set_authoritative(true);
let rdata = match request.src().ip() {
IpAddr::V4(ipv4) => RData::A(ipv4),
IpAddr::V6(ipv6) => RData::AAAA(ipv6),
};
let records = vec![Record::from_rdata(request.query().name().into(), 60, rdata)];
let response = builder.build(header, records.iter(), &[], &[], &[]);
Ok(responder.send_response(response).await?)
}

/// Handle requests for counter.{domain}.
async fn do_handle_request_counter<R: ResponseHandler>(
&self,
request: &Request,
mut responder: R,
) -> Result<ResponseInfo, Error> {
let counter = self.counter.fetch_add(1, Ordering::SeqCst);
let builder = MessageResponseBuilder::from_message_request(request);
let mut header = Header::response_from_request(request.header());
header.set_authoritative(true);
let rdata = RData::TXT(TXT::new(vec![counter.to_string()]));
let records = vec![Record::from_rdata(request.query().name().into(), 60, rdata)];
let response = builder.build(header, records.iter(), &[], &[], &[]);
Ok(responder.send_response(response).await?)
}

/// Handle requests for *.hello.{domain}.
async fn do_handle_request_hello<R: ResponseHandler>(
&self,
request: &Request,
mut responder: R,
) -> Result<ResponseInfo, Error> {
self.counter.fetch_add(1, Ordering::SeqCst);
let builder = MessageResponseBuilder::from_message_request(request);
let mut header = Header::response_from_request(request.header());
header.set_authoritative(true);
let name: &Name = request.query().name().borrow();
let zone_parts = (name.num_labels() - self.hello_zone.num_labels() - 1) as usize;
let name = name
.iter()
.enumerate()
.filter(|(i, _)| i <= &zone_parts)
.fold(String::from("hello,"), |a, (_, b)| {
a + " " + &String::from_utf8_lossy(b)
});
let rdata = RData::TXT(TXT::new(vec![name]));
let records = vec![Record::from_rdata(request.query().name().into(), 60, rdata)];
let response = builder.build(header, records.iter(), &[], &[], &[]);
Ok(responder.send_response(response).await?)
}

/// Handle requests for anything else (NXDOMAIN)
async fn do_handle_request_default<R: ResponseHandler>(
&self,
request: &Request,
mut responder: R,
) -> Result<ResponseInfo, Error> {
self.counter.fetch_add(1, Ordering::SeqCst);
let builder = MessageResponseBuilder::from_message_request(request);
let mut header = Header::response_from_request(request.header());
header.set_authoritative(true);
header.set_response_code(ResponseCode::NXDomain);
let response = builder.build_no_records(header);
Ok(responder.send_response(response).await?)
}

/// Handle request, returning ResponseInfo if response was successfully sent, or an error.
async fn do_handle_request<R: ResponseHandler>(
&self,
request: &Request,
response: R,
) -> Result<ResponseInfo, Error> {
// make sure the request is a query
if request.op_code() != OpCode::Query {
return Err(Error::InvalidOpCode(request.op_code()));
}

// make sure the message type is a query
if request.message_type() != MessageType::Query {
return Err(Error::InvalidMessageType(request.message_type()));
}

match request.query().name() {
name if self.myip_zone.zone_of(name) => {
self.do_handle_request_myip(request, response).await
}
name if self.counter_zone.zone_of(name) => {
self.do_handle_request_counter(request, response).await
}
name if self.hello_zone.zone_of(name) => {
self.do_handle_request_hello(request, response).await
}
name if self.root_zone.zone_of(name) => {
self.do_handle_request_default(request, response).await
}
name => Err(Error::InvalidZone(name.clone())),
}
}
}

#[async_trait::async_trait]
impl RequestHandler for Handler {
async fn handle_request<R: ResponseHandler>(
&self,
_request: &Request,
_response: R,
request: &Request,
response: R,
) -> ResponseInfo {
todo!()
// try to handle request
match self.do_handle_request(request, response).await {
Ok(info) => info,
Err(error) => {
error!("Error in RequestHandler: {error}");
let mut header = Header::new();
header.set_response_code(ResponseCode::ServFail);
header.into()
}
}
}
}

0 comments on commit 5e8a720

Please sign in to comment.