In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import gzip
import re
from scipy.spatial import cKDTree


class StressCoarseGrainer:
    """
    Goldhirsch-Weinhart stress coarse-graining for granular materials
    Includes kinetic and virial (contact) contributions
    """
    
    def __init__(self, folder_path, pattern='Dump.shear_fixed_Load_1wall.*', 
                 max_frames=np.inf):
        self.folder_path = Path(folder_path)
        self.pattern = pattern
        self.max_frames = max_frames
        
        # Particle type densities (kg/m³)
        self.type_density = {1: 2600.0, 2: 2600.0, 3: 2600.0}
        
        # Coarse-graining parameters
        self.w = None  # Auto-estimated
        self.support_fac = 3.0  # For Gaussian, this is ~3 standard deviations
        self.kernel_type = 'gaussian'  # Default to Gaussian
        
        print(f"Initializing Stress Coarse-Grainer (Goldhirsch-Weinhart)")
        print(f"  Kernel type: Gaussian")
        print(f"  Support factor: {self.support_fac}")
        
    def gaussian_kernel(self, r, w):
        """
        Gaussian kernel function (normalized for 3D)
        φ(r) = (1/(√(2π)w)³) exp(-r²/(2w²))
        """
        norm = 1.0 / ((np.sqrt(2 * np.pi) * w)**3)
        return norm * np.exp(-r**2 / (2 * w**2))
    
    def gaussian_kernel_gradient(self, r_vec, w):
        """
        Gradient of Gaussian kernel: ∇φ(r)
        ∇φ = -(r_vec/w²) * φ(r)
        Returns: gradient vector for each point
        """
        r = np.linalg.norm(r_vec, axis=-1, keepdims=True)
        
        # Compute kernel value
        kernel_val = self.gaussian_kernel(r.squeeze(), w)
        
        # ∇φ = -(r_vec/w²) * φ(r)
        grad = -(r_vec / (w**2)) * kernel_val[..., None]
        
        return grad
    
    def find_and_sort_files(self):
        """Find and sort LAMMPS dump files"""
        print(f"Looking for files in: {self.folder_path}")
        
        if not self.folder_path.exists():
            raise FileNotFoundError(f"Folder does not exist: {self.folder_path}")
            
        files = list(self.folder_path.glob(self.pattern))
        if not files:
            raise FileNotFoundError(f"No files found matching pattern: {self.pattern}")
            
        print(f"Found {len(files)} files matching pattern")
        
        # Extract and sort by timestep
        timesteps = []
        valid_files = []
        patterns = [r'\.(\d+)$', r'\.(\d+)\.gz$', r'\.(\d+)\..*$', 
                   r'wall\.(\d+)', r'_(\d+)$', r'(\d+)$']
        
        for file in files:
            timestep = None
            for pattern in patterns:
                match = re.search(pattern, file.name)
                if match:
                    timestep = int(match.group(1))
                    break
            if timestep is not None:
                timesteps.append(timestep)
                valid_files.append(file)
                
        if not valid_files:
            raise ValueError("No files with extractable timesteps found")
            
        sorted_indices = np.argsort(timesteps)
        self.files = [valid_files[i] for i in sorted_indices]
        self.timesteps = [timesteps[i] for i in sorted_indices]
        
        print(f"Processing {len(self.files)} files")
        return self.files[:int(self.max_frames)]
    
    def read_lammps_dump_with_contacts(self, filename):
        """
        Read LAMMPS dump file with contact/pair interaction data
        Expected format includes particle data and pair force data
        """
        filepath = Path(filename)
        opener = gzip.open if filepath.suffix == '.gz' else open
        
        with opener(filepath, 'rt') as f:
            # Read timestep
            line = f.readline()
            if not line.startswith('ITEM: TIMESTEP'):
                raise ValueError(f"Expected TIMESTEP header")
            timestep = int(f.readline().strip())
            
            # Read number of atoms
            line = f.readline()
            if not line.startswith('ITEM: NUMBER OF ATOMS'):
                raise ValueError(f"Expected NUMBER OF ATOMS header")
            N = int(f.readline().strip())
            
            # Read box bounds
            line = f.readline()
            if not line.startswith('ITEM: BOX BOUNDS'):
                raise ValueError(f"Expected BOX BOUNDS header")
            
            boundary_types = line.split()[3:6] if len(line.split()) > 3 else ['pp', 'pp', 'pp']
            
            bounds = []
            for _ in range(3):
                line = f.readline().strip()
                if line.startswith('ITEM:'):
                    break
                vals = list(map(float, line.split()))
                if len(vals) >= 2:
                    bounds.append(vals[:2])
                    
            bounds = np.array(bounds)
            dim = len(bounds)
            
            # Read atoms section
            if not line.startswith('ITEM: ATOMS'):
                line = f.readline()
            
            cols = line.split()[2:]
            
            # Read particle data
            data = []
            for _ in range(N):
                line = f.readline().strip()
                if line:
                    data.append(list(map(float, line.split())))
                    
            data = np.array(data)
        
        # Parse particle data
        frame = {
            'timestep': timestep,
            'N': N,
            'box_bounds': bounds,
            'boundary_types': boundary_types,
            'dim': dim,
            'x': np.zeros((N, max(2, dim))),
            'v': np.zeros((N, max(2, dim))),
            'type': np.ones(N, dtype=int),
            'radius': np.ones(N) * 0.005,
            'mass': None,
            'id': None,
            'contacts': []  # List of contact interactions
        }
        
        # Map columns
        id_col = None
        for i, col in enumerate(cols):
            col = col.lower()
            if col == 'id':
                frame['id'] = data[:, i].astype(int)
                id_col = i
            elif col == 'type':
                frame['type'] = data[:, i].astype(int)
            elif col in ['x', 'xu']:
                frame['x'][:, 0] = data[:, i]
            elif col in ['y', 'yu']:
                frame['x'][:, 1] = data[:, i]
            elif col in ['z', 'zu'] and dim >= 3:
                frame['x'][:, 2] = data[:, i]
            elif col == 'vx':
                frame['v'][:, 0] = data[:, i]
            elif col == 'vy':
                frame['v'][:, 1] = data[:, i]
            elif col == 'vz' and dim >= 3:
                frame['v'][:, 2] = data[:, i]
            elif col in ['radius', 'r']:
                frame['radius'] = data[:, i]
            elif col in ['mass', 'm']:
                frame['mass'] = data[:, i]
        
        # Check if we need to read a separate contact file
        # For now, we'll generate synthetic contacts from neighbor data
        # In practice, you would read from a separate dump.contact file
        
        print(f"    Timestep {timestep}: {N} particles")
        print(f"    Box: [{bounds[0,0]:.3f}, {bounds[0,1]:.3f}] x "
              f"[{bounds[1,0]:.3f}, {bounds[1,1]:.3f}] x [{bounds[2,0]:.3f}, {bounds[2,1]:.3f}]")
        
        return frame
    
    def detect_contacts(self, frame, contact_cutoff=None):
        """
        Detect particle contacts and estimate contact forces
        In practice, these should come from LAMMPS compute pair/local
        
        Parameters:
        -----------
        frame : dict
            Frame data
        contact_cutoff : float
            Distance threshold for contact (default: sum of radii)
        
        Returns:
        --------
        contacts : list of dict
            Each contact: {'i': id_i, 'j': id_j, 'r_ij': vector, 'f_ij': force_vector}
        """
        positions = frame['x'][:, :frame['dim']]
        radii = frame['radius']
        
        if contact_cutoff is None:
            contact_cutoff = 2.0 * np.mean(radii) * 1.01  # Small gap tolerance
        
        tree = cKDTree(positions)
        contacts = []
        
        # Find all pairs within contact distance
        pairs = tree.query_pairs(contact_cutoff, output_type='ndarray')
        
        print(f"    Detecting contacts (cutoff={contact_cutoff:.6f})...")
        print(f"    Found {len(pairs)} potential contacts")
        
        for pair in pairs:
            i, j = pair
            r_ij = positions[j] - positions[i]
            dist = np.linalg.norm(r_ij)
            overlap = (radii[i] + radii[j]) - dist
            
            if overlap > 0:  # Actual contact
                # Estimate normal force (simplified Hertzian contact)
                # In practice, read from LAMMPS output
                k_n = 1e5  # Normal stiffness (adjust based on simulation)
                f_mag = k_n * overlap**1.5
                f_ij = f_mag * r_ij / (dist + 1e-12)  # Normal force
                
                contacts.append({
                    'i': i,
                    'j': j,
                    'r_ij': r_ij,
                    'f_ij': f_ij
                })
        
        print(f"    Contacts with overlap: {len(contacts)}")
        frame['contacts'] = contacts
        return contacts
    
    def read_contact_file(self, contact_filename, frame):
        """
        Read contact data from LAMMPS pair/local output
        Expected format: atom_i atom_j fx fy fz x y z (contact point)
        
        This is a placeholder - adapt to your actual contact file format
        """
        # TODO: Implement based on your specific contact file format
        # For now, use detect_contacts as fallback
        pass
    
    def compute_particle_mass(self, frame):
        """Compute particle masses"""
        if frame['mass'] is not None:
            return frame['mass']
        
        volume = (4.0/3.0) * np.pi * frame['radius']**3
        mass = np.array([volume[i] * self.type_density.get(frame['type'][i], 2600.0) 
                        for i in range(frame['N'])])
        return mass
    
    def estimate_coarse_graining_width(self, frame):
        """Estimate coarse-graining width from particle spacing"""
        positions = frame['x'][:, :frame['dim']]
        n_sample = min(1000, len(positions))
        sample_idx = np.random.choice(len(positions), n_sample, replace=False)
        
        tree = cKDTree(positions)
        distances, _ = tree.query(positions[sample_idx], k=2)
        avg_spacing = np.mean(distances[:, 1])
        
        w = 2.5 * avg_spacing
        
        print(f"Estimated coarse-graining parameters:")
        print(f"  Average spacing: {avg_spacing:.6f}")
        print(f"  Width (w): {w:.6f}")
        print(f"  Support: {self.support_fac * w:.6f}")
        
        return w
    
    def coarse_grain_stress(self, frame, grid_points, w=None):
        """
        Coarse-grain stress tensor using Goldhirsch-Weinhart method with Gaussian kernel
        
        σ(r) = σ^kin(r) + σ^vir(r)
        
        Kinetic part: σ^kin = Σ_i m_i (v_i - v(r)) ⊗ (v_i - v(r)) φ(|r - r_i|)
        Virial part: σ^vir = -1/2 Σ_contacts f_ij ⊗ r_ij ∫_0^1 φ(|r - r_i - s*r_ij|) ds
        
        Parameters:
        -----------
        frame : dict
            Frame data with particles and contacts
        grid_points : ndarray
            Grid points for evaluation
        w : float
            Coarse-graining width
        
        Returns:
        --------
        stress_field : ndarray (N_grid x 3 x 3)
            Stress tensor at each grid point
        stress_kinetic : ndarray (N_grid x 3 x 3)
            Kinetic contribution
        stress_virial : ndarray (N_grid x 3 x 3)
            Virial (contact) contribution
        """
        if w is None:
            w = self.estimate_coarse_graining_width(frame)
        
        support = self.support_fac * w
        
        positions = frame['x'][:, :frame['dim']]
        velocities = frame['v'][:, :frame['dim']]
        mass = self.compute_particle_mass(frame)
        contacts = frame.get('contacts', [])
        
        # First compute velocity field for fluctuation calculation
        print("Computing velocity field for stress calculation...")
        velocity_field = self._compute_velocity_field(frame, grid_points, w)
        
        n_grid = len(grid_points)
        dim = frame['dim']
        
        stress_kinetic = np.zeros((n_grid, 3, 3))
        stress_virial = np.zeros((n_grid, 3, 3))
        
        tree = cKDTree(positions)
        
        print(f"Computing stress at {n_grid} grid points...")
        
        # Kinetic stress contribution
        for i, grid_pt in enumerate(grid_points):
            if i % 1000 == 0:
                print(f"  Progress: {i}/{n_grid} points")
            
            # Find particles within support
            indices = tree.query_ball_point(grid_pt[:dim], support)
            
            if not indices:
                continue
            
            # Compute kinetic stress
            for idx in indices:
                r_vec = positions[idx] - grid_pt[:dim]
                r = np.linalg.norm(r_vec)
                
                if r < support:
                    kernel_val = self.gaussian_kernel(np.array([r]), w)[0]
                    
                    # Velocity fluctuation
                    v_fluct = velocities[idx] - velocity_field[i, :dim]
                    
                    # σ^kin = m * v_fluct ⊗ v_fluct * φ
                    stress_kinetic[i, :dim, :dim] += mass[idx] * np.outer(v_fluct, v_fluct) * kernel_val
        
        # Virial stress contribution from contacts
        print(f"Computing virial stress from {len(contacts)} contacts...")
        
        n_quad = 10  # Quadrature points for line integral
        quad_points = np.linspace(0, 1, n_quad)
        
        for contact in contacts:
            i_p = contact['i']
            j_p = contact['j']
            r_ij = contact['r_ij']
            f_ij = contact['f_ij']
            
            # Line integral along contact
            for s in quad_points:
                r_contact = positions[i_p] + s * r_ij
                
                # Find grid points within support of this contact point
                dists = np.linalg.norm(grid_points[:, :dim] - r_contact, axis=1)
                mask = dists < support
                
                if np.any(mask):
                    kernel_vals = self.gaussian_kernel(dists[mask], w)
                    
                    # σ^vir = -1/2 * f_ij ⊗ r_ij * φ * ds
                    virial_contrib = -0.5 * np.outer(f_ij[:dim], r_ij[:dim]) * (1.0 / n_quad)
                    
                    for grid_idx in np.where(mask)[0]:
                        stress_virial[grid_idx, :dim, :dim] += virial_contrib * kernel_vals[grid_idx - np.where(mask)[0][0]]
        
        # Total stress
        stress_field = stress_kinetic + stress_virial
        
        print("  ✓ Stress computation complete")
        
        return stress_field, stress_kinetic, stress_virial
    
    def _compute_velocity_field(self, frame, grid_points, w):
        """Helper: compute velocity field for kinetic stress"""
        positions = frame['x'][:, :frame['dim']]
        velocities = frame['v'][:, :frame['dim']]
        mass = self.compute_particle_mass(frame)
        support = self.support_fac * w
        
        tree = cKDTree(positions)
        n_grid = len(grid_points)
        dim = frame['dim']
        velocity_field = np.zeros((n_grid, dim))
        
        for i, grid_pt in enumerate(grid_points):
            indices = tree.query_ball_point(grid_pt[:dim], support)
            
            if not indices:
                continue
            
            r_vec = positions[indices] - grid_pt[:dim]
            r = np.linalg.norm(r_vec, axis=1)
            weights = self.gaussian_kernel(r, w) * mass[indices]
            
            total_weight = np.sum(weights)
            if total_weight > 1e-12:
                velocity_field[i] = np.sum(weights[:, None] * velocities[indices], axis=0) / total_weight
        
        return velocity_field
    
    def create_grid(self, frame, n_points=None, dx=None):
        """Create grid for coarse-graining"""
        bounds = frame['box_bounds']
        dim = frame['dim']
        
        if n_points is None and dx is None:
            w = self.estimate_coarse_graining_width(frame)
            dx = w / 2.0
            print(f"Using grid spacing: {dx:.6f}")
        
        if n_points is None:
            if np.isscalar(dx):
                dx = np.array([dx] * dim)
            n_points = tuple(int((bounds[i, 1] - bounds[i, 0]) / dx[i]) for i in range(dim))
        
        if isinstance(n_points, int):
            n_points = tuple([n_points] * dim)
        
        if dim == 2:
            x = np.linspace(bounds[0, 0], bounds[0, 1], n_points[0])
            y = np.linspace(bounds[1, 0], bounds[1, 1], n_points[1])
            xx, yy = np.meshgrid(x, y)
            grid_points = np.column_stack([xx.ravel(), yy.ravel()])
            grid_shape = (n_points[1], n_points[0])
        else:
            x = np.linspace(bounds[0, 0], bounds[0, 1], n_points[0])
            y = np.linspace(bounds[1, 0], bounds[1, 1], n_points[1])
            z = np.linspace(bounds[2, 0], bounds[2, 1], n_points[2])
            xx, yy, zz = np.meshgrid(x, y, z, indexing='ij')
            grid_points = np.column_stack([xx.ravel(), yy.ravel(), zz.ravel()])
            grid_shape = n_points
        
        print(f"Created grid: {n_points} points")
        return grid_points, grid_shape
    
    def plot_stress_field(self, frame, stress_field, grid_points, grid_shape, 
                         component='xx', slice_dim=None, slice_val=None):
        """
        Plot stress field component
        
        Components: 'xx', 'yy', 'zz', 'xy', 'xz', 'yz'
        """
        comp_map = {
            'xx': (0, 0), 'yy': (1, 1), 'zz': (2, 2),
            'xy': (0, 1), 'yx': (1, 0),
            'xz': (0, 2), 'zx': (2, 0),
            'yz': (1, 2), 'zy': (2, 1)
        }
        
        i, j = comp_map[component.lower()]
        stress_comp = stress_field[:, i, j]
        
        dim = frame['dim']
        
        fig, ax = plt.subplots(figsize=(10, 8))
        
        if dim == 2:
            stress_grid = stress_comp.reshape(grid_shape)
            vmax = np.percentile(np.abs(stress_comp[stress_comp != 0]), 95)
            im = ax.imshow(stress_grid, 
                          extent=[frame['box_bounds'][0, 0], frame['box_bounds'][0, 1],
                                 frame['box_bounds'][1, 0], frame['box_bounds'][1, 1]],
                          origin='lower', cmap='RdBu_r', aspect='auto',
                          vmin=-vmax, vmax=vmax)
            ax.set_xlabel('x')
            ax.set_ylabel('y')
        else:
            # 3D - show slice or scatter
            if slice_dim is not None and slice_val is not None:
                w = self.estimate_coarse_graining_width(frame)
                mask = np.abs(grid_points[:, slice_dim] - slice_val) < w
                scatter = ax.scatter(grid_points[mask, 0], grid_points[mask, 1],
                                   c=stress_comp[mask], s=30, cmap='RdBu_r')
                plt.colorbar(scatter, ax=ax, label=f'σ_{component}')
            else:
                scatter = ax.scatter(grid_points[:, 0], grid_points[:, 1],
                                   c=stress_comp, s=20, cmap='RdBu_r', alpha=0.6)
                plt.colorbar(scatter, ax=ax, label=f'σ_{component}')
            ax.set_xlabel('x')
            ax.set_ylabel('y')
        
        if dim == 2:
            plt.colorbar(im, ax=ax, label=f'σ_{component}')
        
        ax.set_title(f'Stress Field Component σ_{component} (Timestep {frame["timestep"]})')
        plt.tight_layout()
        
        return fig
    
    def plot_stress_comparison(self, frame, stress_total, stress_kinetic, stress_virial,
                              grid_points, grid_shape, component='xx'):
        """Compare kinetic and virial stress contributions"""
        comp_map = {
            'xx': (0, 0), 'yy': (1, 1), 'zz': (2, 2),
            'xy': (0, 1), 'xz': (0, 2), 'yz': (1, 2)
        }
        
        i, j = comp_map[component.lower()]
        
        fig, axes = plt.subplots(1, 3, figsize=(18, 5))
        
        components = [
            (stress_kinetic[:, i, j], 'Kinetic'),
            (stress_virial[:, i, j], 'Virial'),
            (stress_total[:, i, j], 'Total')
        ]
        
        for ax, (stress_comp, title) in zip(axes, components):
            if frame['dim'] == 2:
                stress_grid = stress_comp.reshape(grid_shape)
                vmax = np.percentile(np.abs(stress_comp[stress_comp != 0]), 95)
                im = ax.imshow(stress_grid,
                             extent=[frame['box_bounds'][0, 0], frame['box_bounds'][0, 1],
                                    frame['box_bounds'][1, 0], frame['box_bounds'][1, 1]],
                             origin='lower', cmap='RdBu_r', aspect='auto',
                             vmin=-vmax, vmax=vmax)
                plt.colorbar(im, ax=ax, label=f'σ_{component}')
            
            ax.set_xlabel('x')
            ax.set_ylabel('y')
            ax.set_title(f'{title} Stress σ_{component}')
        
        plt.tight_layout()
        return fig
    
    def read_all_frames(self):
        """Read all dump files"""
        files = self.find_and_sort_files()
        frames = []
        
        for i, file in enumerate(files):
            print(f"\nReading file {i+1}/{len(files)}: {file.name}")
            try:
                frame = self.read_lammps_dump_with_contacts(file)
                
                # Detect contacts (or read from separate file)
                self.detect_contacts(frame)
                
                frames.append(frame)
                print(f"  ✓ Success")
            except Exception as e:
                print(f"  ✗ Error: {str(e)}")
                import traceback
                traceback.print_exc()
        
        if not frames:
            raise RuntimeError("No frames successfully read")
        
        print(f"\n✓ Successfully read {len(frames)} frames")
        self.frames = frames
        return frames


# Example usage
if __name__ == "__main__":
    # Initialize
    cg = StressCoarseGrainer(
        folder_path=r"F:\DATA_constant_volume_DEM\DATA_constant_volume\vf578_vt0.23",
        pattern="Dump.shear_fixed_Load_1wall.*",
        max_frames=3  # Start with a few frames for testing
    )
    
    # Read frames
    frames = cg.read_all_frames()
    
    # Select frame with flow
    frame_idx = -1
    for i, f in enumerate(frames):
        v_max = np.max(np.abs(f['v']))
        if v_max > 0.01:
            frame_idx = i
            break
    if frame_idx == -1:
        frame_idx = len(frames) - 1
    
    frame = frames[frame_idx]
    print(f"\n=== Analyzing frame {frame_idx} (timestep {frame['timestep']}) ===")
    
    # Create grid
    print("\n=== Creating grid ===")
    grid_points, grid_shape = cg.create_grid(frame, n_points=30)  # Coarser for speed
    
    # Compute stress field
    print("\n=== Computing stress field ===")
    stress_total, stress_kin, stress_vir = cg.coarse_grain_stress(
        frame, grid_points
    )
    
    print(f"\nStress statistics:")
    print(f"  Kinetic σ_xx: [{np.min(stress_kin[:,0,0]):.3e}, {np.max(stress_kin[:,0,0]):.3e}]")
    print(f"  Virial σ_xx: [{np.min(stress_vir[:,0,0]):.3e}, {np.max(stress_vir[:,0,0]):.3e}]")
    print(f"  Total σ_xx: [{np.min(stress_total[:,0,0]):.3e}, {np.max(stress_total[:,0,0]):.3e}]")
    
    # Plot stress fields
    print("\n=== Plotting stress fields ===")
    
    # Plot individual components
    fig1 = cg.plot_stress_field(frame, stress_total, grid_points, grid_shape, 
                                component='xx')
    
    fig2 = cg.plot_stress_field(frame, stress_total, grid_points, grid_shape,
                                component='xy')
    
    # Compare contributions
    fig3 = cg.plot_stress_comparison(frame, stress_total, stress_kin, stress_vir,
                                     grid_points, grid_shape, component='xx')
    
    plt.show()