from mpi4py import MPI
import netCDF4
import numpy as np

def create(path, form, dtype="f8", parallel=False):    
    
    root = netCDF4.Dataset(path, "w", format="NETCDF4", parallel=parallel)  # type: ignore

    root.createGroup("/")
    used = 0
    
    for variable, element in form.items():
        shape = element[0]
        chunks = element[1]
        dimensions = []
        
        for size in shape:
            root.createDimension(f"{used}", size)
            dimensions.append(f"{used}")
            used += 1
        
        if len(chunks) != 0: 
            x = root.createVariable(variable, dtype, dimensions, chunksizes=chunks)
        else: 
            x = root.createVariable(variable, dtype, dimensions)
        
        if parallel == False:
            print(len(np.random.random_sample(shape)))
            x[:] = np.random.random_sample(shape)
        else:
            rank = MPI.COMM_WORLD.rank  # type: ignore
            rsize = MPI.COMM_WORLD.size  # type: ignore
            total_size = shape[0]
            size = int(total_size / rsize)
            
            rstart = rank * size
            rend = rstart + size
            
            print(f"shape: {shape}, chunks: {chunks}, dimensions: {dimensions}, total chunksize: {total_size}, size per rank:{size} rank: {rank}, rsize: {rsize}, rstart: {rstart}, rend: {rend}")
            
            print(len(np.random.random_sample(size)))
            x[rstart:rend] = np.random.random_sample(size)
            MPI.COMM_WORLD.Barrier()  # type: ignore
            print(f"var: {x}, ncattrs after fill: {x.ncattrs()}, as dict: {x.__dict__}")
            

def main():
    
    create(form={"X": [[10 * 134217728], []]}, path="test.nc", parallel=True)

if __name__=="__main__":
    main()