From 4bb422dba7b3ca68bd93ec6ddcf288ee26122673 Mon Sep 17 00:00:00 2001
From: Afonso <afonso.antunes@tecnico.ulisboa.pt>
Date: Wed, 2 Apr 2025 00:12:03 +0100
Subject: [PATCH] BUG: Preserve extension dtypes in MultiIndex during concat
 (#58421)

---
 doc/source/whatsnew/v3.0.0.rst                |  1 +
 pandas/core/reshape/concat.py                 | 43 ++++++++++++++--
 .../frame/methods/test_concat_arrow_index.py  | 51 +++++++++++++++++++
 3 files changed, 90 insertions(+), 5 deletions(-)
 create mode 100644 pandas/tests/frame/methods/test_concat_arrow_index.py

diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst
index 4a6cf117fd196..3a53de1704870 100644
--- a/doc/source/whatsnew/v3.0.0.rst
+++ b/doc/source/whatsnew/v3.0.0.rst
@@ -712,6 +712,7 @@ MultiIndex
 - :func:`MultiIndex.get_level_values` accessing a :class:`DatetimeIndex` does not carry the frequency attribute along (:issue:`58327`, :issue:`57949`)
 - Bug in :class:`DataFrame` arithmetic operations in case of unaligned MultiIndex columns (:issue:`60498`)
 - Bug in :class:`DataFrame` arithmetic operations with :class:`Series` in case of unaligned MultiIndex (:issue:`61009`)
+- Fixed a bug where extension dtypes like ``timestamp[pyarrow]`` were not preserved when building ``MultiIndex`` levels during ``pd.concat`` operations. (:issue:`58421`)
 -
 
 I/O
diff --git a/pandas/core/reshape/concat.py b/pandas/core/reshape/concat.py
index e7cb7069bbc26..8ee149f886a17 100644
--- a/pandas/core/reshape/concat.py
+++ b/pandas/core/reshape/concat.py
@@ -22,6 +22,7 @@
 
 from pandas.core.dtypes.common import (
     is_bool,
+    is_extension_array_dtype,
     is_scalar,
 )
 from pandas.core.dtypes.concat import concat_compat
@@ -36,6 +37,7 @@
     factorize_from_iterables,
 )
 import pandas.core.common as com
+from pandas.core.construction import array
 from pandas.core.indexes.api import (
     Index,
     MultiIndex,
@@ -819,7 +821,20 @@ def _get_sample_object(
 
 
 def _concat_indexes(indexes) -> Index:
-    return indexes[0].append(indexes[1:])
+    # try to preserve extension types such as timestamp[pyarrow]
+    values = []
+    for idx in indexes:
+        values.extend(idx._values if hasattr(idx, "_values") else idx)
+
+    # use the first index as a sample to infer the desired dtype
+    sample = indexes[0]
+    try:
+        # this helps preserve extension types like timestamp[pyarrow]
+        arr = array(values, dtype=sample.dtype)
+    except Exception:
+        arr = array(values)  # fallback
+
+    return Index(arr)
 
 
 def validate_unique_levels(levels: list[Index]) -> None:
@@ -876,14 +891,32 @@ def _make_concat_multiindex(indexes, keys, levels=None, names=None) -> MultiInde
 
         concat_index = _concat_indexes(indexes)
 
-        # these go at the end
         if isinstance(concat_index, MultiIndex):
             levels.extend(concat_index.levels)
             codes_list.extend(concat_index.codes)
         else:
-            codes, categories = factorize_from_iterable(concat_index)
-            levels.append(categories)
-            codes_list.append(codes)
+            # handle the case where the resulting index is a flat Index
+            # but contains tuples (i.e., a collapsed MultiIndex)
+            if isinstance(concat_index[0], tuple):
+                # retrieve the original dtypes
+                original_dtypes = [lvl.dtype for lvl in indexes[0].levels]
+
+                unzipped = list(zip(*concat_index))
+                for i, level_values in enumerate(unzipped):
+                    # reconstruct each level using original dtype
+                    arr = array(level_values, dtype=original_dtypes[i])
+                    level_codes, _ = factorize_from_iterable(arr)
+                    levels.append(ensure_index(arr))
+                    codes_list.append(level_codes)
+            else:
+                # simple indexes factorize directly
+                codes, categories = factorize_from_iterable(concat_index)
+                values = getattr(concat_index, "_values", concat_index)
+                if is_extension_array_dtype(values):
+                    levels.append(values)
+                else:
+                    levels.append(categories)
+                codes_list.append(codes)
 
         if len(names) == len(levels):
             names = list(names)
diff --git a/pandas/tests/frame/methods/test_concat_arrow_index.py b/pandas/tests/frame/methods/test_concat_arrow_index.py
new file mode 100644
index 0000000000000..6fcc5ee5119a6
--- /dev/null
+++ b/pandas/tests/frame/methods/test_concat_arrow_index.py
@@ -0,0 +1,51 @@
+import pytest
+
+import pandas as pd
+
+schema = {
+    "id": "int64[pyarrow]",
+    "time": "timestamp[s][pyarrow]",
+    "value": "float[pyarrow]",
+}
+
+
+@pytest.mark.parametrize("dtype", ["timestamp[s][pyarrow]"])
+def test_concat_preserves_pyarrow_timestamp(dtype):
+    dfA = (
+        pd.DataFrame(
+            [
+                (0, "2021-01-01 00:00:00", 5.3),
+                (1, "2021-01-01 00:01:00", 5.4),
+                (2, "2021-01-01 00:01:00", 5.4),
+                (3, "2021-01-01 00:02:00", 5.5),
+            ],
+            columns=schema,
+        )
+        .astype(schema)
+        .set_index(["id", "time"])
+    )
+
+    dfB = (
+        pd.DataFrame(
+            [
+                (1, "2022-01-01 08:00:00", 6.3),
+                (2, "2022-01-01 08:01:00", 6.4),
+                (3, "2022-01-01 08:02:00", 6.5),
+            ],
+            columns=schema,
+        )
+        .astype(schema)
+        .set_index(["id", "time"])
+    )
+
+    df = pd.concat([dfA, dfB], keys=[0, 1], names=["run"])
+
+    # check whether df.index is multiIndex
+    assert isinstance(df.index, pd.MultiIndex), (
+        f"Expected MultiIndex, but received {type(df.index)}"
+    )
+
+    # Verifying special dtype timestamp[s][pyarrow] stays intact after concat
+    assert df.index.levels[2].dtype == dtype, (
+        f"Expected {dtype}, but received {df.index.levels[2].dtype}"
+    )