Skip to content

Commit

Permalink
mp4: Improved stsz handling
Browse files Browse the repository at this point in the history
Track size/count instead of just sizes which should decrease memory usage
making count sanity check unnecessary.

Fixes issue with huge mp4 files (80gb+) with lots of samples.
  • Loading branch information
wader committed Jan 15, 2022
1 parent 8092151 commit 4a1e859
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 38 deletions.
17 changes: 9 additions & 8 deletions format/mp4/boxes.go
Original file line number Diff line number Diff line change
Expand Up @@ -536,18 +536,19 @@ func init() {
d.FieldArrayLoop("entries", func() bool { return i < entryCount }, func(d *decode.D) {
size := uint32(d.FieldU32("size"))
if ctx.currentTrack != nil {
ctx.currentTrack.stsz = append(ctx.currentTrack.stsz, size)
ctx.currentTrack.stsz = append(ctx.currentTrack.stsz, stsz{
size: size,
count: 1,
})
}
i++
})
} else {
if ctx.currentTrack != nil {
if entryCount > maxSampleEntryCount {
d.Errorf("too many constant stsz entries %d > %d", entryCount, maxSampleEntryCount)
}
for i := uint64(0); i < entryCount; i++ {
ctx.currentTrack.stsz = append(ctx.currentTrack.stsz, uint32(sampleSize))
}
ctx.currentTrack.stsz = append(ctx.currentTrack.stsz, stsz{
size: uint32(sampleSize),
count: uint32(entryCount),
})
}
}
},
Expand Down Expand Up @@ -767,7 +768,7 @@ func init() {
}

if sampleCount > maxSampleEntryCount {
d.Errorf("too many constant trun entries %d > %d", sampleCount, maxSampleEntryCount)
d.Errorf("too many sample trun entries %d > %d", sampleCount, maxSampleEntryCount)
}

d.FieldArray("samples", func(d *decode.D) {
Expand Down
70 changes: 40 additions & 30 deletions format/mp4/mp4.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,18 @@ type sampleDescription struct {
originalFormat string
}

type stsz struct {
size uint32
count uint32
}

type track struct {
id uint32
sampleDescriptions []sampleDescription
subType string
stco []uint64 //
stsc []stsc
stsz []uint32
stsz []stsz
formatInArg interface{}
objectType int // if data format is "mp4a"

Expand Down Expand Up @@ -205,48 +210,53 @@ func mp4Decode(d *decode.D, in interface{}) interface{} {
}
}

// TODO: warning if unused stsc/stco entries?
d.FieldArray("samples", func(d *decode.D) {
stscIndex := 0
chunkNr := uint32(0)
sampleNr := uint64(0)
stszIndex := 0
stcoIndex := 0
sizeNr := 0

for stszIndex < len(t.stsz) {
stszEntry := t.stsz[stszIndex]

for sampleNr < uint64(len(t.stsz)) {
if stscIndex >= len(t.stsc) {
// TODO: add warning
break
// TODO: outside sample-to-chunk table, add warning
return
}
stscEntry := t.stsc[stscIndex]
if int(chunkNr) >= len(t.stco) {
// TODO: add warning
break
if stcoIndex >= len(t.stco) {
// TODO: outside sample-chunk-offset table, add warning
return
}
sampleOffset := t.stco[chunkNr]
sampleOffset := t.stco[stcoIndex]

for i := uint32(0); i < stscEntry.samplesPerChunk; i++ {
if int(sampleNr) >= len(t.stsz) {
// TODO: add warning
break
if sizeNr >= int(stszEntry.count) {
stszIndex++
if stszIndex >= len(t.stsz) {
// TODO: outside sample-size table, add warning
return
}
stszEntry = t.stsz[stszIndex]
sizeNr = 0
}

sampleSize := t.stsz[sampleNr]
decodeSampleRange(d, t, trackSdDataFormat, "sample", int64(sampleOffset)*8, int64(sampleSize)*8, t.formatInArg)

// log.Printf("%s %d/%d %d/%d sample=%d/%d chunk=%d size=%d %d-%d\n",
// trackSdDataFormat, stscIndex, len(t.stsc),
// i, stscEntry.samplesPerChunk,
// sampleNr, len(t.stsz),
// chunkNr,
// sampleSize,
// sampleOffset,
// sampleOffset+uint64(sampleSize))

sampleOffset += uint64(sampleSize)
sampleNr++

// log.Printf("%s stsc[%d/%d]=%#v stco[%d/%d]=%d stsz[%d/%d]=%#v i=%d\n",
// trackSdDataFormat,
// stscIndex, len(t.stsc), stscEntry,
// stcoIndex, len(t.stco), sampleOffset,
// stszIndex, len(t.stsz), stszEntry,
// i,
// )

decodeSampleRange(d, t, trackSdDataFormat, "sample", int64(sampleOffset)*8, int64(stszEntry.size)*8, t.formatInArg)
sampleOffset += uint64(stszEntry.size)
sizeNr++
}

chunkNr++
if stscIndex < len(t.stsc)-1 && chunkNr >= t.stsc[stscIndex+1].firstChunk-1 {
stcoIndex++
if stscIndex < len(t.stsc)-1 && stcoIndex >= int(t.stsc[stscIndex+1].firstChunk-1) {
stscIndex++
}
}
Expand Down

0 comments on commit 4a1e859

Please sign in to comment.