Skip to content
Permalink
master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
#define CHECK_RESULTS
#define _CRT_SECURE_NO_WARNINGS
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>
#include <math.h>
#include <time.h>
#include <vector>
// Probabilities are expressed in fixed point, with kProbBits bits of
// resolution. No need to go overboard with this.
static int const kProbBits = 15;
static uint32_t const kProbMax = 1u << kProbBits;
// Type used for buffers.
typedef std::vector<uint8_t> ByteVec;
// Binary arithmetic encoder (Ilya Muravyov's variant)
// Encodes/decodes a string of binary (0/1) events with
// probabilities that are not 1/2.
//
// This code is written for clarity, not performance.
class BinArithEncoder
{
uint32_t lo, hi;
ByteVec &bytes;
// noncopyable
BinArithEncoder(BinArithEncoder const &);
BinArithEncoder &operator =(BinArithEncoder const &);
public:
// Initialize
explicit BinArithEncoder(ByteVec &target) : lo(0), hi(~0u), bytes(target) { }
// Finish encoding - flushes remaining codeword
~BinArithEncoder()
{
// Find shortest encoding that still decodes to the right symbols.
// The decoder implicitly zero-pads w
uint32_t round_up = 0xffffffu;
while (round_up)
{
if ((lo | round_up) != ~0u)
{
uint32_t rounded = (lo + round_up) & ~round_up;
if (rounded <= hi) // inside interval, we're good!
{
lo = rounded;
break;
}
}
round_up >>= 8;
}
while (lo)
{
bytes.push_back(lo >> 24);
lo <<= 8;
}
}
// Encode a binary symbol "bit" with the probability of a 1 being "prob".
// Note that prob=0 (or prob=1<<kProbBits) really mean that a 1 (or 0,
// respectively) cannot occur!
void encode(int bit, uint32_t prob)
{
// Midpoint of active probability interval subdivided via prob
uint32_t x = lo + ((uint64_t(hi - lo) * prob) >> kProbBits);
if (bit)
hi = x;
else
lo = x + 1;
// Renormalize: when top byte of lo/hi is same, shift it out.
while ((lo ^ hi) < (1u << 24))
{
bytes.push_back(lo >> 24);
lo <<= 8;
hi = (hi << 8) | 0xff;
}
}
};
// Corresponding decoder.
class BinArithDecoder
{
uint32_t code, lo, hi;
ByteVec const &bytes;
size_t read_pos, size;
// noncopyable
BinArithDecoder(BinArithDecoder const &);
BinArithDecoder &operator =(BinArithDecoder const &);
uint8_t getb()
{
if (read_pos < size)
return bytes[read_pos++];
else
return 0;
}
public:
// Start decoding
explicit BinArithDecoder(ByteVec const &source)
: lo(0), hi(~0u), bytes(source), read_pos(0)
{
code = 0;
size = source.size();
for (int i = 0; i < 4; ++i)
code = (code << 8) | getb();
}
// Decode a binary symbol with the probability of a 1 being "prob".
int decode(uint32_t prob)
{
int bit;
// Midpoint of active probability interval subdivided via prob
uint32_t x = lo + ((uint64_t(hi - lo) * prob) >> kProbBits);
if (code <= x)
{
hi = x;
bit = 1;
}
else
{
lo = x + 1;
bit = 0;
}
// Renormalize
while ((lo ^ hi) < (1u << 24))
{
code = (code << 8) | getb();
lo <<= 8;
hi = (hi << 8) | 0xff;
}
return bit;
}
};
// ---- A few basic models
// NOTE: Again, this is written for clarity and ease of tinkering.
// In practice, you will write more direct code for these once you've
// figured out your coding structure.
// Adaptive binary model. These are pretty good!
// Lower Inertia = faster.
//
// You typically build more sophisticated models out of these
// by having lots of them and choosing the active model based on
// context.
template<int Inertia>
struct BinShiftModel
{
uint16_t prob;
BinShiftModel() : prob(kProbMax / 2) {}
void encode(BinArithEncoder &enc, int bit)
{
enc.encode(bit, prob);
adapt(bit);
}
int decode(BinArithDecoder &dec)
{
int bit = dec.decode(prob);
adapt(bit);
return bit;
}
void adapt(int bit)
{
// Note prob never his 0 or kProbMax with this update rule!
if (bit)
prob += (kProbMax - prob) >> Inertia;
else
prob -= prob >> Inertia;
}
};
template<int Inertia0, int Inertia1>
struct TwoBinShiftModel
{
uint16_t p0, p1;
TwoBinShiftModel() : p0(kProbMax / 4), p1(kProbMax / 4) {}
void encode(BinArithEncoder &enc, int bit)
{
enc.encode(bit, p0 + p1);
adapt(bit);
}
int decode(BinArithDecoder &dec)
{
int bit = dec.decode(p0 + p1);
adapt(bit);
return bit;
}
void adapt(int bit)
{
// Note prob never his 0 or kProbMax with this update rule!
if (bit)
{
p0 += (kProbMax/2 - p0) >> Inertia0;
p1 += (kProbMax/2 - p1) >> Inertia1;
}
else
{
p0 -= p0 >> Inertia0;
p1 -= p1 >> Inertia1;
}
}
};
// BitTree model. A tree-shaped cascade of BinShiftModels.
// This is the de-facto standard way to build a multi-symbol coder
// (values with NumBits bits) out of binary models.
//
// LZMA (as in 7zip/xz) uses this type of model (backed by a BinShiftModel
// as above) for its literals.
template<typename BitModel, int NumBits>
struct BitTreeModel
{
static size_t const kNumSyms = 1 << NumBits;
static size_t const kMSB = kNumSyms / 2;
BitModel model[kNumSyms - 1];
void encode(BinArithEncoder &enc, size_t value)
{
assert(value < kNumSyms);
// The first bit sent is the MSB of the value and coded without context
// Second bit is the bit below the MSB, using the value of the MSB as context
// and so forth.
//
// 1 + 2 + 4 + ... = 2^NumBits - 1 contexts.
// Numbering the MSB context 1 and then shifting in the coded bits from the
// bottom is a convenient way to index them. (So ctx is 1-based)
size_t ctx = 1;
while (ctx < kNumSyms)
{
int bit = (value & kMSB) != 0;
value += value; // shift value by 1 for next iter
model[ctx - 1].encode(enc, bit);
ctx += ctx + bit; // shift in "bit" into context
}
}
size_t decode(BinArithDecoder &dec)
{
// Corresponding decoder is nice and easy:
size_t ctx = 1;
while (ctx < kNumSyms)
ctx += ctx + model[ctx - 1].decode(dec);
return ctx - kNumSyms;
}
};
// Unsigned exponential Golomb-style model.
template<typename MagModel>
struct UExpGolombModel
{
static size_t const kMaxTop = 7;
BitTreeModel<MagModel, 5> mag;
MagModel top[kMaxTop + 1];
void encode(BinArithEncoder &enc, uint32_t value)
{
++value; // we code non-negative values
// determine magnitude (position of highest 1 bit)
// and send it in unary.
// bitscan is the better way to do this.
uint32_t m = 0;
while (value >= (2u << m))
++m;
mag.encode(enc, m);
// send remaining bits flat, MSB->LSB
uint32_t mask = m ? 1u << (m - 1) : 0;
if (mask)
{
uint32_t mtop = (m < kMaxTop) ? m : kMaxTop;
top[mtop].encode(enc, (value & mask) != 0);
mask >>= 1;
while (mask)
{
uint32_t bit = (value & mask) != 0;
enc.encode(bit, kProbMax / 2);
mask >>= 1;
}
}
}
uint32_t decode(BinArithDecoder &dec)
{
// decode magnitude code
uint32_t m = (uint32_t) mag.decode(dec);
// decode value bits
uint32_t v = 1;
if (m)
{
uint32_t mtop = (m < kMaxTop) ? m : kMaxTop;
v += v + top[mtop].decode(dec);
for (uint32_t i = 1; i < m; ++i)
v += v + dec.decode(kProbMax / 2);
}
return v - 1;
}
};
// Signed exponential Golomb-style model.
template<typename MagModel, typename SignModel>
struct SExpGolombModel
{
UExpGolombModel<MagModel> abs_coder;
SignModel sign[2];
void encode(BinArithEncoder &enc, int32_t value, bool sign_pred)
{
uint32_t absv = (value < 0) ? -value : value;
abs_coder.encode(enc, absv);
if (absv)
sign[sign_pred].encode(enc, value < 0);
}
int32_t decode(BinArithDecoder &dec, bool sign_pred)
{
int32_t v = abs_coder.decode(dec);
if (v)
{
if (sign[sign_pred].decode(dec))
v = -v;
}
return v;
}
};
// ---- Data format
static const int kNumCubes = 901;
static const int kRefDist = 6; // distance to reference frame
static const int kFrameRate = 60; // just used to calc kbps
struct CubeState
{
int orientation_largest;
int orientation[3];
int position[3];
int interacting;
};
// Prediction state. Not sent in the stream; inferred from the data
// sent to aid coding.
struct PredState
{
int changing;
int orient_delta[3];
int vel[3];
};
struct ModelSet
{
static const int kNumMagCtx = 10;
typedef TwoBinShiftModel<3, 7> DefaultBit;
typedef SExpGolombModel<DefaultBit, DefaultBit> SExpGolomb;
DefaultBit orientation_different[4]; // [refp.changing + 2*close_to_cube0]
BitTreeModel<DefaultBit, 2> orientation_largest[4*4]; // [orient_context]
SExpGolomb orientation_delta[kNumMagCtx + 1]; // [mag_ctx]
DefaultBit orientation_signflip[2]; // [second_largest_sign]
DefaultBit pos_different[8]; // [orientation_differs + 2*orient_diff_ctx]
SExpGolomb pos_delta[kNumMagCtx]; // [mag_ctx]
DefaultBit interacting[4]; // [ref.interacting + 2*any_diff]
};
struct Frame
{
Frame();
CubeState cubes[kNumCubes];
ModelSet models; // coding state
PredState pred[kNumCubes]; // prediction state
};
Frame::Frame()
{
memset(cubes, 0, sizeof(cubes));
memset(pred, 0, sizeof(pred));
}
static int xyzw_from_abc(int abc_ind, int largest)
{
return abc_ind + (abc_ind >= largest);
}
static int abc_from_xyzw(int xyzw_ind, int largest)
{
assert(xyzw_ind != largest);
return xyzw_ind - (xyzw_ind >= largest);
}
static int orient_context(CubeState const *cube)
{
// Largest axis is elided. Find index and magnitude of second-largest.
int v[3];
for (int i = 0; i < 3; ++i)
v[i] = abs(cube->orientation[i] - 256);
int abc_ind;
if (v[0] >= v[1])
abc_ind = (v[0] >= v[2]) ? 0 : 2;
else
abc_ind = (v[1] >= v[2]) ? 1 : 2;
int ctx = cube->orientation_largest;
if (v[abc_ind] >= 128) // second-largest axis is getting closer to cross-over
ctx += 4 * (abc_ind + 1);
return ctx;
}
static int mag_context(int dv)
{
int v = abs(dv);
int ctx = 0;
while (v > 1 && ctx < ModelSet::kNumMagCtx - 1)
{
++ctx;
v /= 2;
}
return ctx;
}
static int orient_newmag_context(int axis, int old_largest, PredState const *refp, bool *sign_bit, bool flipped)
{
if (axis != old_largest)
{
int old_ind = abc_from_xyzw(axis, old_largest);
*sign_bit = (refp->orient_delta[old_ind] < 0) ^ flipped;
return mag_context(refp->orient_delta[old_ind]);
}
else
{
*sign_bit = flipped;
return ModelSet::kNumMagCtx;
}
}
static void unpack_quat_prediction(int dest[4], int const src[3], int largest)
{
for (int i = 0; i < 3; ++i)
dest[xyzw_from_abc(i, largest)] = src[i];
dest[largest] = 450;
}
static bool are_close(CubeState const *a, CubeState const *b)
{
for (int i = 0; i < 3; i++)
if (abs(a->position[i] - b->position[i]) >= 2048)
return false;
return true;
}
static void encode_frame(ByteVec &dest, Frame *cur, Frame const *ref)
{
BinArithEncoder coder(dest);
ModelSet &m = cur->models;
// Start with ref frame models
m = ref->models;
for (int cube_id = 0; cube_id < kNumCubes; ++cube_id)
{
CubeState *cube = &cur->cubes[cube_id];
PredState *pred = &cur->pred[cube_id];
CubeState const *refc = &ref->cubes[cube_id];
PredState const *refp = &ref->pred[cube_id];
int diff_orient = (cube->orientation_largest != refc->orientation_largest), diff_pos = 0;
for (int i = 0; i < 3; ++i)
{
pred->orient_delta[i] = cube->orientation[i] - refc->orientation[i];
pred->vel[i] = cube->position[i] - refc->position[i];
diff_orient |= pred->orient_delta[i];
diff_pos |= pred->vel[i];
}
int orient_diff_ctx = refp->changing;
if (are_close(refc, &ref->cubes[0]))
orient_diff_ctx |= 2;
m.orientation_different[orient_diff_ctx].encode(coder, diff_orient);
if (diff_orient)
{
int orient_ctx = orient_context(refc);
m.orientation_largest[orient_ctx].encode(coder, cube->orientation_largest);
if (cube->orientation_largest == refc->orientation_largest)
{
for (int i = 0; i < 3; ++i)
{
int ctx = mag_context(refp->orient_delta[i]);
m.orientation_delta[ctx].encode(coder, pred->orient_delta[i], refp->orient_delta[i] < 0);
}
}
else
{
int old_largest = refc->orientation_largest;
int new_largest = cube->orientation_largest;
int old[4];
bool flipped = false;
unpack_quat_prediction(old, refc->orientation, old_largest);
int sign_context = old[new_largest] < 256;
if (cube->orientation[abc_from_xyzw(old_largest, new_largest)] < 256)
{
flipped = true;
m.orientation_signflip[sign_context].encode(coder, 1);
for (int i = 0; i < 4; ++i)
old[i] = 512 - old[i];
}
else
m.orientation_signflip[sign_context].encode(coder, 0);
for (int i = 0; i < 3; ++i)
{
int axis = xyzw_from_abc(i, new_largest);
bool sign_bit;
int ctx = orient_newmag_context(axis, old_largest, refp, &sign_bit, flipped);
m.orientation_delta[ctx].encode(coder, cube->orientation[i] - old[axis], sign_bit);
}
}
}
m.pos_different[(diff_orient != 0) + orient_diff_ctx*2].encode(coder, diff_pos != 0);
if (diff_pos)
{
for (int i = 0; i < 3; ++i)
{
int ctx = mag_context(refp->vel[i]);
m.pos_delta[ctx].encode(coder, pred->vel[i] - refp->vel[i], refp->vel[i] < 0);
}
}
m.interacting[refc->interacting + ((diff_orient | diff_pos) ? 2 : 0)].encode(coder, cube->interacting);
// NOTE: in general, we would need to account for variable frame
// spacing here. But in this testbed we always predict from 6 frames
// ago, so no problem.
pred->changing = (diff_orient | diff_pos) != 0;
}
}
static void decode_frame(ByteVec const &src, Frame *cur, Frame const *ref)
{
BinArithDecoder coder(src);
ModelSet &m = cur->models;
// Start with ref frame models
m = ref->models;
for (int cube_id = 0; cube_id < kNumCubes; ++cube_id)
{
CubeState *cube = &cur->cubes[cube_id];
PredState *pred = &cur->pred[cube_id];
CubeState const *refc = &ref->cubes[cube_id];
PredState const *refp = &ref->pred[cube_id];
bool diff_orient = false;
int orient_diff_ctx = refp->changing;
if (are_close(refc, &ref->cubes[0]))
orient_diff_ctx |= 2;
if (m.orientation_different[orient_diff_ctx].decode(coder))
{
diff_orient = true;
int orient_ctx = orient_context(refc);
cube->orientation_largest = (int) m.orientation_largest[orient_ctx].decode(coder);
if (cube->orientation_largest == refc->orientation_largest)
{
for (int i = 0; i < 3; ++i)
{
int ctx = mag_context(refp->orient_delta[i]);
cube->orientation[i] = refc->orientation[i] + m.orientation_delta[ctx].decode(coder, refp->orient_delta[i] < 0);
}
}
else
{
int old_largest = refc->orientation_largest;
int new_largest = cube->orientation_largest;
int old[4];
unpack_quat_prediction(old, refc->orientation, old_largest);
bool flipped = false;
if (m.orientation_signflip[old[new_largest] < 256].decode(coder))
{
flipped = true;
for (int i = 0; i < 4; ++i)
old[i] = 512 - old[i];
}
for (int i = 0; i < 3; ++i)
{
int axis = xyzw_from_abc(i, new_largest);
bool sign_bit;
int ctx = orient_newmag_context(axis, old_largest, refp, &sign_bit, flipped);
cube->orientation[i] = m.orientation_delta[ctx].decode(coder, sign_bit) + old[axis];
}
}
}
else
{
cube->orientation_largest = refc->orientation_largest;
for (int i = 0; i < 3; ++i)
cube->orientation[i] = refc->orientation[i];
}
for (int i = 0; i < 3; ++i)
{
pred->orient_delta[i] = cube->orientation[i] - refc->orientation[i];
pred->vel[i] = 0;
}
bool diff_pos = false;
if (m.pos_different[diff_orient + orient_diff_ctx*2].decode(coder))
{
diff_pos = true;
for (int i = 0; i < 3; ++i)
{
int ctx = mag_context(refp->vel[i]);
pred->vel[i] = refp->vel[i] + m.pos_delta[ctx].decode(coder, refp->vel[i] < 0);
}
}
for (int i = 0; i < 3; ++i)
cube->position[i] = refc->position[i] + pred->vel[i];
cube->interacting = m.interacting[refc->interacting + ((diff_orient | diff_pos) ? 2 : 0)].decode(coder);
pred->changing = (int(diff_orient) | int(diff_pos)) != 0;
}
}
// ---- I/O and main
static Frame *read_data(char const *filename, int &num_frames, Frame *initial)
{
FILE *f = fopen(filename, "rb");
if (!f)
{
printf("data missing!\n");
exit(1);
}
fseek(f, 0, SEEK_END);
num_frames = ftell(f) / (kNumCubes * sizeof(CubeState)) - 1;
fseek(f, 0, SEEK_SET);
Frame *frames = new Frame[num_frames];
// read initial frame cubes
if (fread(initial->cubes, sizeof(CubeState), kNumCubes, f) != kNumCubes)
{
printf("error reading initial frame!\n");
exit(1);
}
for (int frame = 0; frame < num_frames; ++frame)
{
if (fread(frames[frame].cubes, sizeof(CubeState), kNumCubes, f) != kNumCubes)
{
printf("error reading frame %d!\n", frame);
exit(1);
}
}
fclose(f);
return frames;
}
static void write_data(char const *filename, Frame *frames, int num_frames, Frame const *initial)
{
FILE *f = fopen(filename, "wb");
if (!f)
{
printf("error writing output!\n");
exit(1);
}
fwrite(initial->cubes, sizeof(CubeState), kNumCubes, f);
for (int frame = 0; frame < num_frames; ++frame)
fwrite(frames[frame].cubes, sizeof(CubeState), kNumCubes, f);
fclose(f);
}
int main()
{
Frame initial_frame;
// Read the data
printf("reading...\n");
int num_frames;
Frame *frames = read_data("delta_data_realnew.bin", num_frames, &initial_frame);
printf("done.\n");
// Coding loop
ByteVec packet_buf;
size_t packet_size_sum = 0;
size_t packet_count = 0;
Frame out;
clock_t enc_start = clock();
// Gaffer says skip the first 6 frames. Okay.
for (int frame = 6; frame < num_frames; ++frame)
{
Frame *cur = &frames[frame];
Frame *ref = (frame >= kRefDist) ? &frames[frame - kRefDist] : &initial_frame;
packet_buf.clear();
encode_frame(packet_buf, cur, ref);
decode_frame(packet_buf, &out, ref);
#ifdef CHECK_RESULTS
if (memcmp(out.cubes, cur->cubes, sizeof(out.cubes)) != 0)
{
printf("decode mismatch on frame %d\n", frame);
return 1;
}
#endif
packet_size_sum += packet_buf.size();
++packet_count;
}
double enc_time = double(clock() - enc_start) / CLOCKS_PER_SEC;
printf("processing took %.2fs (%.2fus/frame)\n", enc_time, 1e6*enc_time / (double)packet_count);
printf("total packed size %d\n", (int)packet_size_sum);
double bytes_per_frame = (double)packet_size_sum / (double)packet_count;
double kbps = bytes_per_frame * kFrameRate * 8.0 / 1000.0;
printf("%.2f bytes/frame\n", bytes_per_frame);
printf("%.2f kbps\n", kbps);
// Write output
write_data("output.bin", frames, num_frames, &initial_frame);
// Clean up
delete[] frames;
return 0;
}
// vim:et:sts=4:sw=4