In [8]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [9]:
from glob import glob
from tqdm import tqdm
from os.path import expanduser, join, basename, dirname
import xarray as xr
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from shutil import copy
from sklearn.model_selection import StratifiedKFold
import torch
from tempfile import TemporaryDirectory

from albk.data.utils import idx_to_locate
use_disjoint_files = False


import torch
import torch.nn as nn

from glob import glob
from os.path import expanduser, join, basename, dirname
import xarray as xr
import numpy as np
from tqdm import tqdm
import pandas as pd
from joblib import Parallel, delayed
from itertools import product

In [10]:
def get_common_label_files(path1, path2):
    files1 = glob(join(path1, "*.nc"))
    files2 = glob(join(path2, "*.nc"))
    
    f1_base_files = [basename(f) for f in files1]
    f2_base_files = [basename(f) for f in files2]
    
    common_files = set(f1_base_files).intersection(f2_base_files)
    common_label_files = []
    for file in common_files:
        ds1 = xr.open_dataset(join(path1, file))
        ds2 = xr.open_dataset(join(path2, file))
        if np.all(ds1.label.values == ds2.label.values):
            common_label_files.append(file)
    
    return list(map(lambda f: join(path1, f), common_label_files))

def get_disjoint_files(path1, path2):
    files1 = glob(join(path1, "*.nc"))
    files2 = glob(join(path2, "*.nc"))
    
    f1_base_files = [basename(f) for f in files1]
    f2_base_files = [basename(f) for f in files2]
    
    disjoint_files = set(f1_base_files).symmetric_difference(f2_base_files)
    
    f1_disjoint = [f for f in disjoint_files if f in f1_base_files]
    f1_disjoint = list(map(lambda f: join(path1, f), f1_disjoint))

    f2_disjoint = [f for f in disjoint_files if f in f2_base_files]
    f2_disjoint = list(map(lambda f: join(path2, f), f2_disjoint))
    
    return f1_disjoint + f2_disjoint

In [11]:
base_path = expanduser("~/bangladesh_labels/bkdb/bangladesh_labels")
paths = {"zeel": ("vannsh", "rishabh"), "rishabh": ("suraj", "dhruv"), "suraj": ("aditi", "madhav")}

all_labeled_files = []
for moderator, annotators in paths.items():
    # Get moderator files
    moderator_path = join(base_path, "moderated", moderator)
    moderator_files = glob(join(moderator_path, "*.nc"))
    
    # Get annotator common label files
    annotator1_path = join(base_path, annotators[0])
    annotator2_path = join(base_path, annotators[1])
    
    common_base_files = get_common_label_files(annotator1_path, annotator2_path)
    
    # Get disjoint files
    disjoint_files = get_disjoint_files(annotator1_path, annotator2_path)
    
    all_files = moderator_files + common_base_files
    if use_disjoint_files:
        all_files.extend(disjoint_files)
    assert len(all_files) == len(set(all_files))
    all_labeled_files.extend(all_files)
    
    print("Moderator", moderator)
    print(" "*5, "Moderator files", len(moderator_files))
    print(" "*5, "Common label files", len(common_base_files))
    print(" "*5, "Disjoint files", len(disjoint_files))
    print(" "*5, f"Total files from {moderator} and {annotators}", len(all_files))
    print(" "*5, "Total annotatated files", len(all_labeled_files))
    
print("Total dataset size", len(all_labeled_files) * 25)

Moderator zeel
      Moderator files 88
      Common label files 359
      Disjoint files 662
      Total files from zeel and ('vannsh', 'rishabh') 447
      Total annotatated files 447
Moderator rishabh
      Moderator files 98
      Common label files 115
      Disjoint files 736
      Total files from rishabh and ('suraj', 'dhruv') 213
      Total annotatated files 660
Moderator suraj
      Moderator files 195
      Common label files 165
      Disjoint files 746
      Total files from suraj and ('aditi', 'madhav') 360
      Total annotatated files 1020
Total dataset size 25500


In [21]:
print(all_labeled_files[:5])

['/home/rishabh.mondal/bangladesh_labels/bkdb/bangladesh_labels/moderated/zeel/24.90,90.77.nc', '/home/rishabh.mondal/bangladesh_labels/bkdb/bangladesh_labels/moderated/zeel/24.58,91.69.nc', '/home/rishabh.mondal/bangladesh_labels/bkdb/bangladesh_labels/moderated/zeel/24.44,90.83.nc', '/home/rishabh.mondal/bangladesh_labels/bkdb/bangladesh_labels/moderated/zeel/24.63,88.25.nc', '/home/rishabh.mondal/bangladesh_labels/bkdb/bangladesh_labels/moderated/zeel/24.99,89.79.nc']


In [12]:
def get_bk_stats(path):
    ds = xr.open_dataset(path)
    z = (ds.label.values == "Z").sum()
    f = (ds.label.values == "F").sum()
    o = (ds.label.values == "O").sum()
    return {"Z": z, "F": f, "O": o}

df = pd.DataFrame([get_bk_stats(path) for path in all_labeled_files])

df_sum = df.sum(axis=0)

print("All Brick Kilns", df_sum["Z"] + df_sum["F"])
print("All Non-brick Kilns", df_sum["O"])

All Brick Kilns 1697
All Non-brick Kilns 23803


In [22]:
save_path = expanduser("/home/rishabh.mondal/Brick-Kilns-project/albk_rishabh/bangladesh_labels/")
os.system(f"rm -rf {save_path}")
os.makedirs(save_path)

def copy_file(path):
    copy(path, save_path)
    
_ = Parallel(n_jobs=20)(delayed(copy_file)(path) for path in tqdm(all_labeled_files))



100%|██████████| 1020/1020 [00:00<00:00, 2585.91it/s]


In [39]:
images_path = expanduser("/home/patel_zeel/bkdb/bangladesh/")
load_path = "/home/rishabh.mondal/Brick-Kilns-project/albk_rishabh/temporary"
files = all_labeled_files
print(files)
print(len(files))



['/home/rishabh.mondal/bangladesh_labels/bkdb/bangladesh_labels/moderated/zeel/24.90,90.77.nc', '/home/rishabh.mondal/bangladesh_labels/bkdb/bangladesh_labels/moderated/zeel/24.58,91.69.nc', '/home/rishabh.mondal/bangladesh_labels/bkdb/bangladesh_labels/moderated/zeel/24.44,90.83.nc', '/home/rishabh.mondal/bangladesh_labels/bkdb/bangladesh_labels/moderated/zeel/24.63,88.25.nc', '/home/rishabh.mondal/bangladesh_labels/bkdb/bangladesh_labels/moderated/zeel/24.99,89.79.nc', '/home/rishabh.mondal/bangladesh_labels/bkdb/bangladesh_labels/moderated/zeel/25.28,89.39.nc', '/home/rishabh.mondal/bangladesh_labels/bkdb/bangladesh_labels/moderated/zeel/24.93,90.37.nc', '/home/rishabh.mondal/bangladesh_labels/bkdb/bangladesh_labels/moderated/zeel/24.77,89.90.nc', '/home/rishabh.mondal/bangladesh_labels/bkdb/bangladesh_labels/moderated/zeel/25.03,90.01.nc', '/home/rishabh.mondal/bangladesh_labels/bkdb/bangladesh_labels/moderated/zeel/24.85,91.73.nc', '/home/rishabh.mondal/bangladesh_labels/bkdb/bang

In [40]:
def get_index_and_image(file):
    index = []
    images = []
    labels = []
    base_name = basename(file)
    # print(base_name)
    image_path = join(images_path, base_name).replace(".nc", ".zarr")
    # print(image_path)
    label_ds = xr.open_dataset(file)
    # print (label_ds)
    image_ds = xr.open_zarr(image_path, consolidated=False)
    # image = image_ds.data.reshape(-1, 224, 224, 3)
    for lat_lag, lon_lag in product(range(-2, 3), repeat=2):
        index.append(base_name.replace(".nc", "")+f"_{lat_lag}_{lon_lag}")
        images.append(torch.tensor(image_ds.sel(lat_lag=lat_lag, lon_lag=lon_lag)['data'].values, dtype=torch.float32)[np.newaxis, ...])
        labels.append(torch.tensor((label_ds.sel(lat_lag=lat_lag, lon_lag=lon_lag)['label'].values != "O").astype(np.uint8)))
        
    return index, images, labels



def get_data():
    out = Parallel(n_jobs=32)(delayed(get_index_and_image)(file) for file in tqdm(files, total=len(files)))
    index = np.concatenate([np.array(idx) for idx, _, _ in out])
    images = torch.concat([torch.einsum("nhwc->nchw", torch.concat(imgs)) for _, imgs, _ in out])
    # scale
    images = images / 255
    # mean normalize
    images = (images - images.mean(dim=(0, 2, 3), keepdim=True)) / images.std(dim=(0, 2, 3), keepdim=True)
    
    labels = np.concatenate([np.array(lbl) for _, _, lbl in out])
    labels = torch.tensor(labels, dtype=torch.uint8)
    return index, images, labels

index, images, labels = get_data()
print(index.shape, images.shape, labels.shape)    

  6%|▌         | 63/1020 [24:42<6:15:23, 23.54s/it]
100%|██████████| 1020/1020 [00:31<00:00, 32.25it/s]


(25500,) torch.Size([25500, 3, 224, 224]) torch.Size([25500])


In [42]:
# # save the tensors data 
# save_path="/home/rishabh.mondal/Brick-Kilns-project/albk_rishabh/tensor_data/data.pt"
# torch.save({
#     'index': index,
#     'images': images,
#     'labels': labels
# }, save_path)

In [43]:
# Load the saved tensors
loaded_data = torch.load("/home/rishabh.mondal/Brick-Kilns-project/albk_rishabh/tensor_data/data.pt")

# Access the tensors
index = loaded_data['index']
images = loaded_data['images']
labels = loaded_data['labels']


In [44]:
print(index.shape, images.shape, labels.shape)  

(25500,) torch.Size([25500, 3, 224, 224]) torch.Size([25500])
