## Dataset test

In [30]:
import json
import os
from torch.utils.data import Dataset
from utils import load_wb_path
from dataclasses import dataclass
import torchvision
class LRS2_dataset(Dataset):
    def __init__(self,dataset_path,wb_path):
        self.dataset_path = dataset_path
        self.json_path,self.dict_path = load_wb_path(wb_path)
        self.wb_dict = json.load(open(self.dict_path))
    
    def __len__(self):
        return len(self.json_path)
    
    def __getitem__(self, idx):
        path = self.json_path[idx]
        with open(path,"r") as f:
            wb_data = json.load(f)
        vid_path = os.path.join(self.dataset_path,wb_data["id"])
        vid = torchvision.io.read_video(vid_path,output_format="THWC",pts_unit="sec")
        return vid,wb_data

@dataclass
class Point:
    token_index: int
    time_index: int
    score: float

    def to_list(self):
        return [self.token_index, self.time_index, self.score]

# Merge the labels
@dataclass
class Segment:
    label: str
    start: int
    end: int
    score: float

    def __repr__(self):
        return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"

    @property
    def length(self):
        return self.end - self.start


def merge_repeats(path,transcript):
    i1, i2 = 0, 0
    segments = []
    while i1 < len(path):
        while i2 < len(path) and path[i1].token_index == path[i2].token_index:
            i2 += 1
        score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
        segments.append(
            Segment(
                transcript[path[i1].token_index],
                path[i1].time_index,
                path[i2 - 1].time_index + 1,
                score,
            )
        )
        i1 = i2
    return segments

# Merge words
def merge_words(segments, separator="|"):
    words = []
    i1, i2 = 0, 0
    while i1 < len(segments):
        if i2 >= len(segments) or segments[i2].label == separator:
            if i1 != i2:
                segs = segments[i1:i2]
                word = "".join([seg.label for seg in segs])
                score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
                words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
            i1 = i2 + 1
            i2 = i1
        else:
            i2 += 1
    return words

In [31]:
lrs2_dataset = LRS2_dataset('/home/zzzzz/poison/data/source_data/mvlrs_v1','/home/zzzzz/poison/data/wb')

In [32]:
test_json_path = '/home/zzzzz/auto_avsr/word_boundary/wb/main/5535415699068794046/00006.json'
import json
data = json.load(open(test_json_path))
type(data["path"])
print(data["path"])
transcript = data["transcript"]
print(transcript)

[[0, 0, 0.0009422105504199862], [1, 1, 0.0013278962578624487], [2, 2, 0.0007371663232333958], [3, 3, 0.9897135496139526], [3, 4, 0.0008294470026157796], [4, 5, 0.001762603409588337], [5, 6, 0.0041459696367383], [6, 7, 0.9922027587890625], [6, 8, 0.9916253685951233], [6, 9, 0.990632176399231], [6, 10, 0.9913316965103149], [6, 11, 0.9886521697044373], [6, 12, 0.9888148903846741], [6, 13, 0.9909303784370422], [6, 14, 0.9916812777519226], [6, 15, 0.991908848285675], [6, 16, 0.991607666015625], [6, 17, 0.9913801550865173], [6, 18, 0.9908061623573303], [6, 19, 0.9883666634559631], [6, 20, 0.9865463972091675], [6, 21, 0.9889179468154907], [6, 22, 0.9893532991409302], [6, 23, 0.9867656230926514], [6, 24, 0.000673270202241838], [7, 25, 0.0007840362377464771], [8, 26, 0.00031947853858582675], [9, 27, 0.9885320663452148], [9, 28, 0.9889367818832397], [9, 29, 0.9834524989128113], [9, 30, 0.0007027140236459672], [10, 31, 0.003918760921806097], [11, 32, 0.0014136952813714743], [12, 33, 0.00069528399

In [33]:
for c in map(lambda x: Point(x[0], x[1], x[2]), data["path"]):
    print(c)
data = map(lambda x: Point(x[0], x[1], x[2]), data["path"])
data = list(data)
print(data)

Point(token_index=0, time_index=0, score=0.0009422105504199862)
Point(token_index=1, time_index=1, score=0.0013278962578624487)
Point(token_index=2, time_index=2, score=0.0007371663232333958)
Point(token_index=3, time_index=3, score=0.9897135496139526)
Point(token_index=3, time_index=4, score=0.0008294470026157796)
Point(token_index=4, time_index=5, score=0.001762603409588337)
Point(token_index=5, time_index=6, score=0.0041459696367383)
Point(token_index=6, time_index=7, score=0.9922027587890625)
Point(token_index=6, time_index=8, score=0.9916253685951233)
Point(token_index=6, time_index=9, score=0.990632176399231)
Point(token_index=6, time_index=10, score=0.9913316965103149)
Point(token_index=6, time_index=11, score=0.9886521697044373)
Point(token_index=6, time_index=12, score=0.9888148903846741)
Point(token_index=6, time_index=13, score=0.9909303784370422)
Point(token_index=6, time_index=14, score=0.9916812777519226)
Point(token_index=6, time_index=15, score=0.991908848285675)
Point(

### Test segments

In [34]:
segments = merge_repeats(data,transcript)
for seg in segments:
    print(seg)

|	(0.00): [    0,     1)
A	(0.00): [    1,     2)
P	(0.00): [    2,     3)
A	(0.50): [    3,     5)
R	(0.00): [    5,     6)
T	(0.00): [    6,     7)
|	(0.94): [    7,    25)
F	(0.00): [   25,    26)
R	(0.00): [   26,    27)
O	(0.74): [   27,    31)
M	(0.00): [   31,    32)
|	(0.00): [   32,    33)
T	(0.00): [   33,    34)
H	(0.74): [   34,    38)
E	(0.00): [   38,    39)
|	(0.00): [   39,    40)
G	(0.00): [   40,    41)
O	(0.00): [   41,    42)
L	(0.00): [   42,    43)
D	(0.00): [   43,    44)
E	(0.00): [   44,    45)
N	(0.00): [   45,    46)
|	(0.66): [   46,    49)
C	(0.00): [   49,    50)
O	(0.79): [   50,    55)
L	(0.00): [   55,    56)
O	(0.00): [   56,    57)
U	(0.74): [   57,    61)
R	(0.00): [   61,    62)
|	(0.94): [   62,    81)
A	(0.00): [   81,    82)
N	(0.82): [   82,    88)
D	(0.00): [   88,    89)
|	(0.95): [   89,   113)
T	(0.49): [  113,   115)
H	(0.82): [  115,   121)
E	(0.00): [  121,   122)
|	(0.00): [  122,   123)
D	(0.00): [  123,   124)
E	(0.00): [  124,   125)


### Test merge_words

In [35]:
word_segments = merge_words(segments)
for word in word_segments:
    print(word)
    print(type(word))

APART	(0.17): [    1,     7)
<class '__main__.Segment'>
FROM	(0.42): [   25,    32)
<class '__main__.Segment'>
THE	(0.50): [   33,    39)
<class '__main__.Segment'>
GOLDEN	(0.00): [   40,    46)
<class '__main__.Segment'>
COLOUR	(0.53): [   49,    62)
<class '__main__.Segment'>
AND	(0.62): [   81,    89)
<class '__main__.Segment'>
THE	(0.66): [  113,   122)
<class '__main__.Segment'>
DELICIOUS	(0.25): [  123,   135)
<class '__main__.Segment'>
FLAVOUR	(0.46): [  189,   202)
<class '__main__.Segment'>


### Test collate_fn

In [36]:
from torch.utils.data import DataLoader
def collate_fn(samples):
    """
    Collate function for the dataloader
    """
    video_data = [sample[0] for sample in samples]
    wb_data = [sample[1] for sample in samples]
    return video_data, wb_data

dataloader = DataLoader(lrs2_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

for batch in dataloader:
    video_data, wb_data = batch
    # print(video_data)
    print(len(video_data))
    print(wb_data)
    break

1
[{'path': [[0, 0, 0.9999905824661255], [0, 1, 0.9999599456787109], [0, 2, 0.9998718500137329], [0, 3, 0.9995676875114441], [0, 4, 0.9994589686393738], [0, 5, 0.9890396595001221], [1, 6, 0.9965063333511353], [1, 7, 0.9925447702407837], [2, 8, 0.6816789507865906], [2, 9, 0.9999152421951294], [3, 10, 0.5931492447853088], [3, 11, 0.9702309370040894], [3, 12, 0.9999557733535767], [4, 13, 0.9994369149208069], [5, 14, 0.9999990463256836], [5, 15, 0.9998999834060669], [6, 16, 0.9621546864509583], [6, 17, 0.9997120499610901], [7, 18, 0.016336830332875252], [7, 19, 0.997967541217804], [8, 20, 0.9955659508705139], [9, 21, 0.008307437412440777], [9, 22, 0.569094717502594], [10, 23, 0.29925480484962463], [11, 24, 0.9936229586601257], [12, 25, 0.9993336796760559], [13, 26, 0.5080338716506958], [13, 27, 0.9953086972236633], [14, 28, 0.9871134161949158], [14, 29, 0.9988658428192139], [15, 30, 0.9996452331542969], [15, 31, 0.9980605244636536], [16, 32, 0.8881946206092834], [16, 33, 0.9972147345542908