Skip to content

Commit

Permalink
Fix potential segfault when reaching maximum clients. Simplify connec…
Browse files Browse the repository at this point in the history
…tion buffer pool. Formatting.
  • Loading branch information
seemk committed Jul 19, 2017
1 parent 6f5dbac commit b17d2a6
Show file tree
Hide file tree
Showing 14 changed files with 78 additions and 79 deletions.
2 changes: 1 addition & 1 deletion CRC32.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include <stdint.h>
#include <stddef.h>
#include <stdint.h>

uint32_t StunCRC32(const void* data, size_t len);
uint32_t SctpCRC32(const void* data, size_t len);
63 changes: 27 additions & 36 deletions Wu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include <algorithm>
#include <vector>
#include "WuArena.h"
#include "WuCert.h"
#include "WuCert.h"
#include "WuClock.h"
#include "WuDataChannel.h"
#include "WuHttp.h"
#include "WuMath.h"
#include "WuNetwork.h"
#include "WuPool.h"
#include "WuQueue.h"
Expand Down Expand Up @@ -62,30 +61,21 @@ struct WuConnectionBuffer {
};

struct WuConnectionBufferPool {
WuConnectionBufferPool(size_t n) : buffers(n) {
for (size_t i = 0; i < n; i++) {
freeBuffers.push_back(&buffers[i]);
}
}
WuConnectionBufferPool(size_t n)
: pool(WuPoolCreate(sizeof(WuConnectionBuffer), n)) {}

WuConnectionBuffer* GetBuffer() {
if (freeBuffers.size() > 0) {
WuConnectionBuffer* buf = freeBuffers.back();
freeBuffers.pop_back();
return buf;
}

return nullptr;
WuConnectionBuffer* buffer = (WuConnectionBuffer*)WuPoolAcquire(pool);
return buffer;
}

void Reclaim(WuConnectionBuffer* buf) {
buf->fd = -1;
buf->size = 0;
freeBuffers.push_back(buf);
WuPoolRelease(pool, buf);
}

std::vector<WuConnectionBuffer> buffers;
std::vector<WuConnectionBuffer*> freeBuffers;
WuPool* pool;
};

const double kMaxClientTtl = 8.0;
Expand Down Expand Up @@ -160,9 +150,9 @@ void WuSendSctp(const WuHost* wu, WuClient* client, const SctpPacket* packet,

WuClient* WuHostNewClient(WuHost* wu) {
WuClient* client = (WuClient*)WuPoolAcquire(wu->clientPool);
memset(client, 0, sizeof(WuClient));

if (client) {
memset(client, 0, sizeof(WuClient));
WuClientStart(wu, client);
wu->clients[wu->numClients++] = client;
return client;
Expand Down Expand Up @@ -262,14 +252,12 @@ void WuHandleHttpRequest(WuHost* wu, WuConnectionBuffer* conn) {
client->serverPassword.length = 24;
WuRandomString((char*)client->serverPassword.identifier,
client->serverPassword.length);
memcpy(
client->remoteUser.identifier, iceFields.ufrag.value,
std::min(iceFields.ufrag.length, kMaxStunIdentifierLength));
memcpy(client->remoteUser.identifier, iceFields.ufrag.value,
Min(iceFields.ufrag.length, kMaxStunIdentifierLength));
client->remoteUser.length = iceFields.ufrag.length;
memcpy(client->remoteUserPassword.identifier,
iceFields.password.value,
std::min(iceFields.password.length,
kMaxStunIdentifierLength));
Min(iceFields.password.length, kMaxStunIdentifierLength));

int bodyLength = 0;
const char* body = GenerateSDP(
Expand Down Expand Up @@ -398,7 +386,7 @@ void WuHostHandleSctp(WuHost* wu, WuClient* client, const uint8_t* buf,
const uint8_t* userDataBegin = dataChunk->userData;
const int32_t userDataLength = dataChunk->userDataLength;

client->remoteTsn = std::max(chunk->as.data.tsn, client->remoteTsn);
client->remoteTsn = Max(chunk->as.data.tsn, client->remoteTsn);
client->ttl = kMaxClientTtl;

if (dataChunk->protoId == DCProto_Control) {
Expand Down Expand Up @@ -672,7 +660,7 @@ int32_t WuCryptoInit(WuHost* wu, const WuConf* conf) {
}

int32_t WuInit(WuHost* wu, const WuConf* conf) {
wu->arena = (WuArena*)calloc(1, sizeof(WuArena));
wu->arena = (WuArena*)calloc(1, sizeof(WuArena));
WuArenaInit(wu->arena, 1 << 20);
wu->time = MsNow() * 0.001;
wu->dt = 0.0;
Expand Down Expand Up @@ -825,22 +813,25 @@ int32_t WuServe(WuHost* wu, WuEvent* evt) {
}

if (MakeNonBlocking(infd) == -1) {
abort();
close(infd);
continue;
}

WuConnectionBuffer* conn = pool->GetBuffer();
assert(conn);
conn->fd = infd;

struct epoll_event event;
event.events = EPOLLIN | EPOLLET;
event.data.ptr = conn;
if (epoll_ctl(wu->epfd, EPOLL_CTL_ADD, infd, &event) == -1) {
perror("epoll_ctl");
abort();

if (conn) {
conn->fd = infd;
struct epoll_event event;
event.events = EPOLLIN | EPOLLET;
event.data.ptr = conn;
if (epoll_ctl(wu->epfd, EPOLL_CTL_ADD, infd, &event) == -1) {
close(infd);
perror("epoll_ctl");
}
} else {
close(infd);
}
}
continue;
} else if (wu->udpfd == c->fd) {
struct sockaddr_in remote;
socklen_t remoteLen = sizeof(remote);
Expand Down
2 changes: 1 addition & 1 deletion WuCrypto.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include <stdint.h>
#include <stddef.h>
#include <stdint.h>

const size_t kSHA1Length = 20;

Expand Down
6 changes: 0 additions & 6 deletions WuDataChannel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@
int32_t ParseDataChannelControlPacket(const uint8_t* buf, size_t len,
DataChannelPacket* packet) {
ReadScalarSwapped(buf, &packet->messageType);

/*
printf("data channel message %s (%u)\n",
DataChannelMessageTypeName(packet->messageType), packet->messageType);
*/

return 0;
}

Expand Down
10 changes: 4 additions & 6 deletions WuDataChannel.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
#pragma once

#include <stdint.h>
#include <stddef.h>
#include <stdint.h>

enum DataChannelMessageType {
DCMessage_Ack = 0x02,
DCMessage_Open = 0x03
};
enum DataChannelMessageType { DCMessage_Ack = 0x02, DCMessage_Open = 0x03 };

enum DataChanProtoIdentifier {
DCProto_Control = 50,
Expand All @@ -28,5 +25,6 @@ struct DataChannelPacket {
} as;
};

int32_t ParseDataChannelControlPacket(const uint8_t* buf, size_t len, DataChannelPacket* packet);
int32_t ParseDataChannelControlPacket(const uint8_t* buf, size_t len,
DataChannelPacket* packet);
const char* DataChannelMessageTypeName(uint8_t type);
15 changes: 15 additions & 0 deletions WuMath.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

template <typename T>
const T& Min(const T& a, const T& b) {
if (a < b) return a;

return b;
}

template <typename T>
const T& Max(const T& a, const T& b) {
if (a > b) return a;

return b;
}
2 changes: 1 addition & 1 deletion WuNetwork.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#pragma once
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include <sys/types.h>
#include <assert.h>

enum SocketType { ST_TCP, ST_UDP };

Expand Down
2 changes: 1 addition & 1 deletion WuPool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void* WuPoolAcquire(WuPool* pool) {
const int32_t offset = index * pool->slotSize;

uint8_t* block = pool->memory + offset;
BlockHeader* header = (BlockHeader*)block;
BlockHeader* header = (BlockHeader*)block;
header->index = index;

uint8_t* userMem = block + sizeof(BlockHeader);
Expand Down
5 changes: 2 additions & 3 deletions WuSctp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
#include <arpa/inet.h>
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include "CRC32.h"
#include "WuBufferOp.h"
#include "WuMath.h"
#include "WuNetwork.h"

const char* SctpTypeName(SctpChunkType type) {
Expand Down Expand Up @@ -39,7 +39,6 @@ const char* SctpTypeName(SctpChunkType type) {
int32_t ParseSctpPacket(const uint8_t* buf, size_t len, SctpPacket* packet,
SctpChunk* chunks, size_t maxChunks, size_t* nChunk) {
if (len < 16) {
printf("SCTP packet: invalid\n");
return 0;
}

Expand Down Expand Up @@ -68,7 +67,7 @@ int32_t ParseSctpPacket(const uint8_t* buf, size_t len, SctpPacket* packet,
chunkOffset +=
ReadScalarSwapped(buf + offset + chunkOffset, &p->streamSeq);
chunkOffset += ReadScalarSwapped(buf + offset + chunkOffset, &p->protoId);
p->userDataLength = std::max(int32_t(chunk->length) - 16, 0);
p->userDataLength = Max(int32_t(chunk->length) - 16, 0);
p->userData = buf + offset + chunkOffset;
} else if (chunk->type == Sctp_Sack) {
auto* sack = &chunk->as.sack;
Expand Down
1 change: 1 addition & 0 deletions WuSctp.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,4 @@ size_t SerializeSctpPacket(const SctpPacket* packet, const SctpChunk* chunks,

int32_t SctpDataChunkLength(int32_t userDataLength);
int32_t SctpChunkLength(int32_t contentLength);
const char* SctpTypeName(SctpChunkType type);
38 changes: 18 additions & 20 deletions WuSdp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,26 +135,24 @@ const char* GenerateSDP(WuArena* arena, const char* certFingerprint,
* a=sctpmap:{port} webrtc-datachannel {max-size}
*/
char buf[2048];
int written =
snprintf(buf, 2048,
"v=0\r\n"
"o=- %u 1 IN IP4 %s\r\n"
"s=-\r\n"
"t=0 0\r\n"
"m=application %s DTLS/SCTP %s\r\n"
"c=IN IP4 %s\r\n"
"a=ice-lite\r\n"
"a=ice-ufrag:%.*s\r\n"
"a=ice-pwd:%.*s\r\n"
"a=fingerprint:sha-256 %s\r\n"
"a=ice-options:trickle\r\n"
"a=setup:passive\r\n"
"a=mid:%.*s\r\n"
"a=sctpmap:%s webrtc-datachannel 1024\r\n",
WuRandomU32(), serverPort, serverIp, serverPort, serverIp,
ufragLen, ufrag, passLen, pass, certFingerprint,
remote->mid.length, remote->mid.value, serverPort);
(void)written;
snprintf(buf, 2048,
"v=0\r\n"
"o=- %u 1 IN IP4 %s\r\n"
"s=-\r\n"
"t=0 0\r\n"
"m=application %s DTLS/SCTP %s\r\n"
"c=IN IP4 %s\r\n"
"a=ice-lite\r\n"
"a=ice-ufrag:%.*s\r\n"
"a=ice-pwd:%.*s\r\n"
"a=fingerprint:sha-256 %s\r\n"
"a=ice-options:trickle\r\n"
"a=setup:passive\r\n"
"a=mid:%.*s\r\n"
"a=sctpmap:%s webrtc-datachannel 1024\r\n",
WuRandomU32(), serverPort, serverIp, serverPort, serverIp, ufragLen,
ufrag, passLen, pass, certFingerprint, remote->mid.length,
remote->mid.value, serverPort);

rjs::Document doc;
doc.SetObject();
Expand Down
3 changes: 2 additions & 1 deletion WuString.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ int32_t FindTokenIndex(const char* s, size_t len, char token) {
return -1;
}

bool MemEqual(const void* first, size_t firstLen, const void* second, size_t secondLen) {
bool MemEqual(const void* first, size_t firstLen, const void* second,
size_t secondLen) {
if (firstLen != secondLen) return false;

return memcmp(first, second, firstLen) == 0;
Expand Down
5 changes: 3 additions & 2 deletions WuString.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

#define STRLIT(s) (s), sizeof(s) - 1

uint32_t StringToUint(const char*s, size_t len);
uint32_t StringToUint(const char* s, size_t len);
bool CompareCaseInsensitive(const char* first, size_t lenFirst,
const char* second, size_t lenSecond);
int32_t FindTokenIndex(const char* s, size_t len, char token);
bool MemEqual(const void* first, size_t firstLen, const void* second, size_t secondLen);
bool MemEqual(const void* first, size_t firstLen, const void* second,
size_t secondLen);
3 changes: 2 additions & 1 deletion WuStun.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ struct StunAddress {
} address;
};

inline bool StunUserIdentifierEqual(const StunUserIdentifier* a, const StunUserIdentifier* b) {
inline bool StunUserIdentifierEqual(const StunUserIdentifier* a,
const StunUserIdentifier* b) {
return MemEqual(a->identifier, a->length, b->identifier, b->length);
}

Expand Down

0 comments on commit b17d2a6

Please sign in to comment.