@@ -459,6 +459,26 @@ def test_corrwith_mixed_dtypes(self):
459459 expected = pd .Series (data = corrs , index = ['a' , 'b' ])
460460 tm .assert_series_equal (result , expected )
461461
462+ def test_corrwith_index_intersection (self ):
463+ df1 = pd .DataFrame (np .random .random (size = (10 , 2 )),
464+ columns = ["a" , "b" ])
465+ df2 = pd .DataFrame (np .random .random (size = (10 , 3 )),
466+ columns = ["a" , "b" , "c" ])
467+
468+ result = df1 .corrwith (df2 , drop = True ).index .sort_values ()
469+ expected = df1 .columns .intersection (df2 .columns ).sort_values ()
470+ tm .assert_index_equal (result , expected )
471+
472+ def test_corrwith_index_union (self ):
473+ df1 = pd .DataFrame (np .random .random (size = (10 , 2 )),
474+ columns = ["a" , "b" ])
475+ df2 = pd .DataFrame (np .random .random (size = (10 , 3 )),
476+ columns = ["a" , "b" , "c" ])
477+
478+ result = df1 .corrwith (df2 , drop = False ).index .sort_values ()
479+ expected = df1 .columns .union (df2 .columns ).sort_values ()
480+ tm .assert_index_equal (result , expected )
481+
462482 def test_corrwith_dup_cols (self ):
463483 # GH 21925
464484 df1 = pd .DataFrame (np .vstack ([np .arange (10 )] * 3 ).T )
0 commit comments