@@ -46,30 +46,51 @@ def test_task_id(self, task_id, expected):
4646
4747 self .assertEqual (i , expected )
4848
49- def test_tpu_env_from_gce_metadata (self ):
50- tpu_env_yaml = textwrap .dedent ("""
51- ACCELERATOR_TYPE: 'v4-16'
52- CHIPS_PER_HOST_BOUNDS: '2,2,1'
53- HOST_BOUNDS: '1,1,2'
54- TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1'
55- TPU_PROCESS_BOUNDS: '1,1,2'
56- ZONE: 'us-central2-b'
57- WORKER_ID: '0'
58- """ )
59-
49+ @parameterized .named_parameters (
50+ ('v4' ,
51+ textwrap .dedent ("""
52+ ACCELERATOR_TYPE: 'v4-16'
53+ CHIPS_PER_HOST_BOUNDS: '2,2,1'
54+ HOST_BOUNDS: '1,1,2'
55+ TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1'
56+ TPU_PROCESS_BOUNDS: '1,1,2'
57+ ZONE: 'us-central2-b'
58+ WORKER_ID: '0'
59+ """ ), {
60+ 'ACCELERATOR_TYPE' : 'v4-16' ,
61+ 'CHIPS_PER_HOST_BOUNDS' : '2,2,1' ,
62+ 'HOST_BOUNDS' : '1,1,2' ,
63+ 'TPU_CHIPS_PER_PROCESS_BOUNDS' : '2,2,1' ,
64+ 'TPU_PROCESS_BOUNDS' : '1,1,2' ,
65+ 'ZONE' : 'us-central2-b' ,
66+ 'WORKER_ID' : '0'
67+ }, 4 ),
68+ ('v5' ,
69+ textwrap .dedent ("""
70+ ACCELERATOR_TYPE: 'v5abcdefg-16'
71+ CHIPS_PER_HOST_BOUNDS: '2,2,1'
72+ HOST_BOUNDS: '1,1,2'
73+ TPU_CHIPS_PER_PROCESS_BOUNDS: '2,2,1'
74+ TPU_PROCESS_BOUNDS: '1,1,2'
75+ ZONE: 'us-central2-b'
76+ WORKER_ID: '0'
77+ """ ), {
78+ 'ACCELERATOR_TYPE' : 'v5abcdefg-16' ,
79+ 'CHIPS_PER_HOST_BOUNDS' : '2,2,1' ,
80+ 'HOST_BOUNDS' : '1,1,2' ,
81+ 'TPU_CHIPS_PER_PROCESS_BOUNDS' : '2,2,1' ,
82+ 'TPU_PROCESS_BOUNDS' : '1,1,2' ,
83+ 'ZONE' : 'us-central2-b' ,
84+ 'WORKER_ID' : '0'
85+ }, 5 ),
86+ )
87+ def test_tpu_env_from_gce_metadata (self , tpu_env_yaml , expected_env ,
88+ expected_version ):
6089 with mock .patch .object (tpu , '_get_metadata' , return_value = tpu_env_yaml ):
6190 tpu_env = tpu .get_tpu_env ()
62-
63- self .assertDictEqual (
64- tpu_env , {
65- 'ACCELERATOR_TYPE' : 'v4-16' ,
66- 'CHIPS_PER_HOST_BOUNDS' : '2,2,1' ,
67- 'HOST_BOUNDS' : '1,1,2' ,
68- 'TPU_CHIPS_PER_PROCESS_BOUNDS' : '2,2,1' ,
69- 'TPU_PROCESS_BOUNDS' : '1,1,2' ,
70- 'ZONE' : 'us-central2-b' ,
71- 'WORKER_ID' : '0'
72- })
91+ version = tpu .version ()
92+ self .assertDictEqual (tpu_env , expected_env )
93+ self .assertEqual (version , expected_version )
7394
7495 @parameterized .named_parameters (
7596 ('all-vars-set' , {
0 commit comments