5
5
from fractions import Fraction
6
6
from itertools import filterfalse , zip_longest
7
7
from types import GeneratorType
8
- from typing import Any , cast
8
+ from typing import Any , Generator , cast
9
9
10
10
import orjson
11
11
import sentry_sdk
@@ -599,7 +599,6 @@ def parse_chunks(chunks: str) -> tuple[list[str], ReportHeader]:
599
599
600
600
601
601
class Report (object ):
602
- file_class = ReportFile
603
602
_files : dict [str , ReportFileSummary ]
604
603
_header : ReportHeader
605
604
@@ -618,22 +617,57 @@ def __init__(
618
617
self .sessions = get_sessions (sessions ) if sessions else {}
619
618
620
619
# ["<json>", ...]
620
+ self ._chunks : list [str | ReportFile ]
621
621
self ._chunks , self ._header = (
622
622
parse_chunks (chunks )
623
623
if chunks and isinstance (chunks , str )
624
624
else (chunks or [], ReportHeader ())
625
625
)
626
626
627
627
# <ReportTotals>
628
+ self ._totals : ReportTotals | None = None
628
629
if isinstance (totals , ReportTotals ):
629
630
self ._totals = totals
630
631
elif totals :
631
632
self ._totals = ReportTotals (* migrate_totals (totals ))
632
- else :
633
- self ._totals = None
634
633
635
634
self .diff_totals = diff_totals
636
635
636
+ def _invalidate_caches (self ):
637
+ self ._totals = None
638
+
639
+ @property
640
+ def totals (self ):
641
+ if not self ._totals :
642
+ self ._totals = self ._process_totals ()
643
+ return self ._totals
644
+
645
+ def _process_totals (self ):
646
+ """Runs through the file network to aggregate totals
647
+ returns <ReportTotals>
648
+ """
649
+
650
+ def _iter_totals ():
651
+ for filename , data in self ._files .items ():
652
+ if data .file_totals is None :
653
+ yield self .get (filename ).totals
654
+ else :
655
+ yield data .file_totals
656
+
657
+ totals = agg_totals (_iter_totals ())
658
+ totals .sessions = len (self .sessions )
659
+ return totals
660
+
661
+ def _iter_parsed_files (self ) -> Generator [ReportFile , None , None ]:
662
+ for name , summary in self ._files .items ():
663
+ idx = summary .file_index
664
+ file = self ._chunks [idx ]
665
+ if not isinstance (file , ReportFile ):
666
+ file = self ._chunks [idx ] = ReportFile (
667
+ name = name , totals = summary .file_totals , lines = file
668
+ )
669
+ yield file
670
+
637
671
@property
638
672
def header (self ) -> ReportHeader :
639
673
return self ._header
@@ -796,7 +830,7 @@ def get(self, filename, _else=None, bind=False):
796
830
lines = None
797
831
if isinstance (lines , ReportFile ):
798
832
return lines
799
- report_file = self . file_class (
833
+ report_file = ReportFile (
800
834
name = filename ,
801
835
totals = _file .file_totals ,
802
836
lines = lines ,
@@ -866,29 +900,6 @@ def get_file_totals(self, path: str) -> ReportTotals | None:
866
900
else :
867
901
return ReportTotals (* totals )
868
902
869
- @property
870
- def totals (self ):
871
- if not self ._totals :
872
- # reprocess totals
873
- self ._totals = self ._process_totals ()
874
- return self ._totals
875
-
876
- def _process_totals (self ):
877
- """Runs through the file network to aggregate totals
878
- returns <ReportTotals>
879
- """
880
-
881
- def _iter_totals ():
882
- for filename , data in self ._files .items ():
883
- if data .file_totals is None :
884
- yield self .get (filename ).totals
885
- else :
886
- yield data .file_totals
887
-
888
- totals = agg_totals (_iter_totals ())
889
- totals .sessions = len (self .sessions )
890
- return totals
891
-
892
903
def next_session_number (self ):
893
904
start_number = len (self .sessions )
894
905
while start_number in self .sessions or str (start_number ) in self .sessions :
@@ -921,7 +932,7 @@ def __iter__(self):
921
932
if isinstance (report , ReportFile ):
922
933
yield report
923
934
else :
924
- yield self . file_class (
935
+ yield ReportFile (
925
936
name = filename ,
926
937
totals = _file .file_totals ,
927
938
lines = report ,
@@ -1239,6 +1250,76 @@ def _passes_integrity_analysis(self):
1239
1250
return False
1240
1251
return True
1241
1252
1253
+ def delete_labels (
1254
+ self , sessionids : list [int ] | set [int ], labels_to_delete : list [int ] | set [int ]
1255
+ ):
1256
+ for file in self ._iter_parsed_files ():
1257
+ file .delete_labels (sessionids , labels_to_delete )
1258
+ if file :
1259
+ self ._files [file .name ] = dataclasses .replace (
1260
+ self ._files [file .name ],
1261
+ file_totals = file .totals ,
1262
+ )
1263
+ else :
1264
+ del self [file .name ]
1265
+
1266
+ self ._invalidate_caches ()
1267
+ return sessionids
1268
+
1269
+ def delete_multiple_sessions (self , session_ids_to_delete : list [int ] | set [int ]):
1270
+ session_ids_to_delete = set (session_ids_to_delete )
1271
+ for sessionid in session_ids_to_delete :
1272
+ self .sessions .pop (sessionid )
1273
+
1274
+ for file in self ._iter_parsed_files ():
1275
+ file .delete_multiple_sessions (session_ids_to_delete )
1276
+ if file :
1277
+ self ._files [file .name ] = dataclasses .replace (
1278
+ self ._files [file .name ],
1279
+ file_totals = file .totals ,
1280
+ )
1281
+ else :
1282
+ del self [file .name ]
1283
+
1284
+ self ._invalidate_caches ()
1285
+
1286
+ @sentry_sdk .trace
1287
+ def change_sessionid (self , old_id : int , new_id : int ):
1288
+ """
1289
+ This changes the session with `old_id` to have `new_id` instead.
1290
+ It patches up all the references to that session across all files and line records.
1291
+
1292
+ In particular, it changes the id in all the `LineSession`s and `CoverageDatapoint`s,
1293
+ and does the equivalent of `calculate_present_sessions`.
1294
+ """
1295
+ session = self .sessions [new_id ] = self .sessions .pop (old_id )
1296
+ session .id = new_id
1297
+
1298
+ for file in self ._iter_parsed_files ():
1299
+ all_sessions = set ()
1300
+
1301
+ for idx , _line in enumerate (file ._lines ):
1302
+ if not _line :
1303
+ continue
1304
+
1305
+ # this turns the line into an actual `ReportLine`
1306
+ line = file ._lines [idx ] = file ._line (_line )
1307
+
1308
+ for session in line .sessions :
1309
+ if session .id == old_id :
1310
+ session .id = new_id
1311
+ all_sessions .add (session .id )
1312
+
1313
+ if line .datapoints :
1314
+ for point in line .datapoints :
1315
+ if point .sessionid == old_id :
1316
+ point .sessionid = new_id
1317
+
1318
+ file ._invalidate_caches ()
1319
+ file .__present_sessions = all_sessions
1320
+
1321
+ self ._invalidate_caches ()
1322
+
1242
1323
1243
1324
def _ignore_to_func (ignore ):
1244
1325
"""Returns a function to determine whether a a line should be saved to the ReportFile
0 commit comments