In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [2]:
from glob import glob
from tqdm import tqdm
from os.path import expanduser, join, basename, dirname
import xarray as xr
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from shutil import copy
from sklearn.model_selection import StratifiedKFold
import torch
from tempfile import TemporaryDirectory

from albk.data.utils import idx_to_locate
use_disjoint_files = False


import torch
import torch.nn as nn

from glob import glob
from os.path import expanduser, join, basename, dirname
import xarray as xr
import numpy as np
from tqdm import tqdm
import pandas as pd
from joblib import Parallel, delayed
from itertools import product

In [3]:
def get_common_label_files(path1, path2):
    files1 = glob(join(path1, "*.nc"))
    files2 = glob(join(path2, "*.nc"))
    
    f1_base_files = [basename(f) for f in files1]
    f2_base_files = [basename(f) for f in files2]
    
    common_files = set(f1_base_files).intersection(f2_base_files)
    common_label_files = []
    for file in common_files:
        ds1 = xr.open_dataset(join(path1, file))
        ds2 = xr.open_dataset(join(path2, file))
        if np.all(ds1.label.values == ds2.label.values):
            common_label_files.append(file)
    
    return list(map(lambda f: join(path1, f), common_label_files))

def get_disjoint_files(path1, path2):
    files1 = glob(join(path1, "*.nc"))
    files2 = glob(join(path2, "*.nc"))
    
    f1_base_files = [basename(f) for f in files1]
    f2_base_files = [basename(f) for f in files2]
    
    disjoint_files = set(f1_base_files).symmetric_difference(f2_base_files)
    
    f1_disjoint = [f for f in disjoint_files if f in f1_base_files]
    f1_disjoint = list(map(lambda f: join(path1, f), f1_disjoint))

    f2_disjoint = [f for f in disjoint_files if f in f2_base_files]
    f2_disjoint = list(map(lambda f: join(path2, f), f2_disjoint))
    
    return f1_disjoint + f2_disjoint

In [4]:
base_path = expanduser("~/bangladesh_labels/bkdb/india_labels/region/delhi/sarath_data")
paths = {"rishabh": ("shataxi", "suraj"), "suraj": ("rishabh", "vannsh")}

all_labeled_files = []
for moderator, annotators in paths.items():
    # Get moderator files
    moderator_path = join(base_path, "moderated", moderator)
    moderator_files = glob(join(moderator_path, "*.nc"))
    
    # Get annotator common label files
    annotator1_path = join(base_path, annotators[0])
    annotator2_path = join(base_path, annotators[1])
    
    common_base_files = get_common_label_files(annotator1_path, annotator2_path)
    
    # Get disjoint files
    disjoint_files = get_disjoint_files(annotator1_path, annotator2_path)
    
    all_files = moderator_files + common_base_files
    if use_disjoint_files:
        all_files.extend(disjoint_files)
    assert len(all_files) == len(set(all_files))
    all_labeled_files.extend(all_files)
    
    print("Moderator", moderator)
    print(" "*5, "Moderator files", len(moderator_files))
    print(" "*5, "Common label files", len(common_base_files))
    print(" "*5, "Disjoint files", len(disjoint_files))
    print(" "*5, f"Total files from {moderator} and {annotators}", len(all_files))
    print(" "*5, "Total annotatated files", len(all_labeled_files))
    
print("Total dataset size", len(all_labeled_files) * 25)
def get_bk_stats(path):
    ds = xr.open_dataset(path)
    z = (ds.label.values == "Z").sum()
    f = (ds.label.values == "F").sum()
    o = (ds.label.values == "O").sum()
    return {"Z": z, "F": f, "O": o}

df = pd.DataFrame([get_bk_stats(path) for path in all_labeled_files])

df_sum = df.sum(axis=0)

print("All Brick Kilns", df_sum["Z"] + df_sum["F"])
print("All Non-brick Kilns", df_sum["O"])

Moderator rishabh
      Moderator files 151
      Common label files 98
      Disjoint files 0
      Total files from rishabh and ('shataxi', 'suraj') 249
      Total annotatated files 249
Moderator suraj
      Moderator files 88
      Common label files 64
      Disjoint files 0
      Total files from suraj and ('rishabh', 'vannsh') 152
      Total annotatated files 401
Total dataset size 10025
All Brick Kilns 1042
All Non-brick Kilns 8983


In [5]:
print(all_labeled_files[:5])


['/home/rishabh.mondal/bangladesh_labels/bkdb/india_labels/region/delhi/sarath_data/moderated/rishabh/28.90,77.25.nc', '/home/rishabh.mondal/bangladesh_labels/bkdb/india_labels/region/delhi/sarath_data/moderated/rishabh/28.80,77.45.nc', '/home/rishabh.mondal/bangladesh_labels/bkdb/india_labels/region/delhi/sarath_data/moderated/rishabh/28.77,77.60.nc', '/home/rishabh.mondal/bangladesh_labels/bkdb/india_labels/region/delhi/sarath_data/moderated/rishabh/28.86,77.17.nc', '/home/rishabh.mondal/bangladesh_labels/bkdb/india_labels/region/delhi/sarath_data/moderated/rishabh/28.84,77.53.nc']


In [6]:
images_path = expanduser("~/bkdb/statewise/sarath_data1/")
# load_path = "/home/rishabh.mondal/Brick-Kilns-project/albk_rishabh/temporary"
files = all_labeled_files
# print(files)
print(len(files))

401


In [7]:
def get_index_and_image(file):
    index = []
    images = []
    labels = []
    base_name = basename(file)
    # print(base_name)
    image_path = join(images_path, base_name).replace(".nc", ".zarr")
    # print(image_path)
    label_ds = xr.open_dataset(file)
    # print (label_ds)
    image_ds = xr.open_zarr(image_path, consolidated=False)
    # image = image_ds.data.reshape(-1, 224, 224, 3)
    for lat_lag, lon_lag in product(range(-2, 3), repeat=2):
        index.append(base_name.replace(".nc", "")+f"_{lat_lag}_{lon_lag}")
        images.append(torch.tensor(image_ds.sel(lat_lag=lat_lag, lon_lag=lon_lag)['data'].values)[np.newaxis, ...])
        labels.append(torch.tensor((label_ds.sel(lat_lag=lat_lag, lon_lag=lon_lag)['label'].values != "O").astype(np.uint8)))
        
    return index, images, labels



def get_data():
    out = Parallel(n_jobs=32)(delayed(get_index_and_image)(file) for file in tqdm(files, total=len(files)))
    index = np.concatenate([np.array(idx) for idx, _, _ in out])
    images = torch.concat([torch.einsum("nhwc->nchw", torch.concat(imgs)) for _, imgs, _ in out])
    # scale
    # images = images / 255
    # mean normalize
    # images = (images - images.mean(dim=(0, 2, 3), keepdim=True)) / images.std(dim=(0, 2, 3), keepdim=True)
    
    labels = np.concatenate([np.array(lbl) for _, _, lbl in out])
    labels = torch.tensor(labels, dtype=torch.uint8)
    #check the all dytpes
    print(index.dtype, images.dtype, labels.dtype)
    return index, images, labels

index, images, labels = get_data()
print(images.dtype)
print(index.shape, images.shape, labels.shape)    

100%|██████████| 401/401 [00:07<00:00, 56.30it/s]


<U17 torch.uint8 torch.uint8
torch.uint8
(10025,) torch.Size([10025, 3, 224, 224]) torch.Size([10025])


In [8]:
print(images.dtype)
print(labels.dtype)

torch.uint8
torch.uint8


### test data path# save_path="/home/rishabh.mondal/Brick-Kilns-project/albk_rishabh/tensor_data/test_data.pt"


In [10]:
# save the tensors data 
# print(images.dtype)
# save_path="/home/rishabh.mondal/Brick-Kilns-project/albk_rishabh/tensor_data/test_data.pt"
# torch.save({
#     'index': index,
#     'images': images,
#     'labels': labels
# }, save_path)

torch.uint8
