@@ -46,7 +46,7 @@ def process(record):
4646
4747        # Convert 
4848        images  =  tf .map_fn (tf .image .decode_jpeg , tf .reshape (images , [- 1 ]), ** kwargs )
49-         images  =  tf .reshape (images , (- 1 , SEQ_DIM , 3 , IMG_DIM , IMG_DIM ))
49+         images  =  tf .reshape (images , (- 1 , SEQ_DIM , IMG_DIM , IMG_DIM , 3 ))
5050        poses   =  tf .reshape (poses ,  (- 1 , SEQ_DIM , POSE_DIM ))
5151
5252        # Numpy conversion 
@@ -64,8 +64,8 @@ def convert(record, batch_size):
6464    batch_process  =  lambda  r : chunk (process (r ), batch_size )
6565
6666    for  i , batch  in  enumerate (batch_process (record )):
67-         path  =  os .path .join (path , "{0:}-{1:02}.pt.gz" .format (basename , i ))
68-         with  gzip .open (path , 'wb' ) as  f :
67+         p  =  os .path .join (path , "{0:}-{1:02}.pt.gz" .format (basename , i ))
68+         with  gzip .open (p , 'wb' ) as  f :
6969            torch .save (list (batch ), f )
7070
7171if  __name__  ==  '__main__' :
@@ -91,4 +91,4 @@ def convert(record, batch_size):
9191
9292    with  mp .Pool (processes = mp .cpu_count ()) as  pool :
9393        f  =  partial (convert , batch_size = args .batch_size )
94-         pool .map (f , records )
94+         pool .map (f , records )
0 commit comments