In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [14]:
from glob import glob
from tqdm import tqdm
from os.path import expanduser, join, basename, dirname
import xarray as xr
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from shutil import copy
from sklearn.model_selection import StratifiedKFold
import torch
from tempfile import TemporaryDirectory

from albk.data.utils import idx_to_locate
use_disjoint_files = False
from torch.utils.data import TensorDataset, DataLoader


import torch
import torch.nn as nn

from PIL import Image
from glob import glob
from os.path import expanduser, join, basename, dirname
import xarray as xr
import numpy as np
from tqdm import tqdm
import pandas as pd
from joblib import Parallel, delayed
from itertools import product

In [3]:
def get_common_label_files(path1, path2):
    files1 = glob(join(path1, "*.nc"))
    files2 = glob(join(path2, "*.nc"))
    
    f1_base_files = [basename(f) for f in files1]
    f2_base_files = [basename(f) for f in files2]
    
    common_files = set(f1_base_files).intersection(f2_base_files)
    common_label_files = []
    for file in common_files:
        ds1 = xr.open_dataset(join(path1, file))
        ds2 = xr.open_dataset(join(path2, file))
        if np.all(ds1.label.values == ds2.label.values):
            common_label_files.append(file)
    
    return list(map(lambda f: join(path1, f), common_label_files))

def get_disjoint_files(path1, path2):
    files1 = glob(join(path1, "*.nc"))
    files2 = glob(join(path2, "*.nc"))
    
    f1_base_files = [basename(f) for f in files1]
    f2_base_files = [basename(f) for f in files2]
    
    disjoint_files = set(f1_base_files).symmetric_difference(f2_base_files)
    
    f1_disjoint = [f for f in disjoint_files if f in f1_base_files]
    f1_disjoint = list(map(lambda f: join(path1, f), f1_disjoint))

    f2_disjoint = [f for f in disjoint_files if f in f2_base_files]
    f2_disjoint = list(map(lambda f: join(path2, f), f2_disjoint))
    
    return f1_disjoint + f2_disjoint

In [4]:
base_path = expanduser("/home/rishabh.mondal/bangladesh_labels/bkdb/bangladesh_labels")
paths = {"zeel": ("vannsh", "rishabh"), "rishabh": ("suraj", "dhruv"), "suraj": ("aditi", "madhav")}

all_labeled_files = []
for moderator, annotators in paths.items():
    # Get moderator files
    moderator_path = join(base_path, "moderated", moderator)
    moderator_files = glob(join(moderator_path, "*.nc"))
    
    # Get annotator common label files
    annotator1_path = join(base_path, annotators[0])
    annotator2_path = join(base_path, annotators[1])
    
    common_base_files = get_common_label_files(annotator1_path, annotator2_path)
    
    # Get disjoint files
    disjoint_files = get_disjoint_files(annotator1_path, annotator2_path)
    
    all_files = moderator_files + common_base_files
    if use_disjoint_files:
        all_files.extend(disjoint_files)
    assert len(all_files) == len(set(all_files))
    all_labeled_files.extend(all_files)
    
    print("Moderator", moderator)
    print(" "*5, "Moderator files", len(moderator_files))
    print(" "*5, "Common label files", len(common_base_files))
    print(" "*5, "Disjoint files", len(disjoint_files))
    print(" "*5, f"Total files from {moderator} and {annotators}", len(all_files))
    print(" "*5, "Total annotatated files", len(all_labeled_files))
    
print("Total dataset size", len(all_labeled_files) * 25)

Moderator zeel
      Moderator files 88
      Common label files 359
      Disjoint files 662
      Total files from zeel and ('vannsh', 'rishabh') 447
      Total annotatated files 447
Moderator rishabh
      Moderator files 98
      Common label files 115
      Disjoint files 736
      Total files from rishabh and ('suraj', 'dhruv') 213
      Total annotatated files 660
Moderator suraj
      Moderator files 195
      Common label files 165
      Disjoint files 746
      Total files from suraj and ('aditi', 'madhav') 360
      Total annotatated files 1020
Total dataset size 25500


In [5]:
print(all_labeled_files[:5])

['/home/rishabh.mondal/bangladesh_labels/bkdb/bangladesh_labels/moderated/zeel/24.90,90.77.nc', '/home/rishabh.mondal/bangladesh_labels/bkdb/bangladesh_labels/moderated/zeel/24.58,91.69.nc', '/home/rishabh.mondal/bangladesh_labels/bkdb/bangladesh_labels/moderated/zeel/24.44,90.83.nc', '/home/rishabh.mondal/bangladesh_labels/bkdb/bangladesh_labels/moderated/zeel/24.63,88.25.nc', '/home/rishabh.mondal/bangladesh_labels/bkdb/bangladesh_labels/moderated/zeel/24.99,89.79.nc']


In [27]:
# create a torch dataset from the files

def process_file(file):
    raw_file_path = f"/home/patel_zeel/bkdb/bangladesh/{file.split('/')[-1].rsplit('.', 1)[0]}.zarr"
    ds = xr.open_zarr(raw_file_path, consolidated=False)
    label_ds = xr.open_dataset(file)
    image_label_pairs = []
    for lat_lag in [-2, -1, 0, 1, 2]:
        for lon_lag in [-2, -1, 0, 1, 2]:
            img = Image.fromarray(ds.sel(lat_lag=lat_lag, lon_lag=lon_lag)['data'].values)
            label = label_ds.sel(lat_lag=lat_lag, lon_lag=lon_lag)['label'].values.item()
            img.save(f"/tmp/bk_new/{file.split('/')[-1].rsplit('.', 1)[0]}_{lat_lag}_{lon_lag}_{label}.png")

_ = Parallel(n_jobs=32)(delayed(process_file)(file) for file in tqdm(all_labeled_files))

100%|██████████| 1020/1020 [00:28<00:00, 35.76it/s]


In [22]:
all_pairs = []
for pairs in image_label_pairs:
    all_pairs.extend(pairs)
len(all_pairs)

25500

In [23]:
all_pairs[0][0], all_pairs[0][1]

(<PIL.Image.Image image mode=RGB size=224x224>, 'O')

In [24]:
class BKDataset(torch.utils.data.Dataset):
    def __init__(self, pairs):
        self.pairs = pairs
    def __len__(self):
        return len(self.pairs)
    def __getitem__(self, idx):
        img, label = self.pairs[idx]
        return img, label
    
dataset = BKDataset(all_pairs)

In [25]:
ds = torch.load("/tmp/bd_dataset.pt")
ds

<__main__.BKDataset at 0x7fa3fb3795d0>