In [None]:
import os
from os.path import join

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

from shapely.geometry import Polygon, Point
import numpy as np
import xarray as xr
from joblib import Parallel, delayed

from tqdm.notebook import tqdm

from torch.utils.data import Dataset, DataLoader

import torch
import torch.nn as nn

from torchvision.models import efficientnet_b0
import geopandas as gpd

import matplotlib.pyplot as plt

device = "cuda"

In [None]:
f_path = "/home/vannsh.jani/brick_kilns/githubrepo/ML/model_50_no_ssl_features_imagenet.pth"
c_path = "/home/vannsh.jani/brick_kilns/githubrepo/ML/model_50_no_ssl_classifier_imagenet.pth"

model = efficientnet_b0(pretrained=False)
model.classifier = nn.Linear(1280,2)
model.features.load_state_dict(torch.load(f_path))
model.classifier.load_state_dict(torch.load(c_path))
model.to(device)
model.eval();

In [None]:
gdf = gpd.read_file("/home/rishabh.mondal/Brick-Kilns-project/albk/experiments/data_preperation/shapefiles/statewise/DISTRICT_BOUNDARY.shp")
gdf.columns
gdf['District'] = gdf['District'].str.replace('>', 'A')
gdf['District'] = gdf['District'].str.replace('<', 'A')
gdf['STATE'] = gdf['STATE'].str.replace('>', 'A') 
gdf['STATE'] = gdf['STATE'].str.replace('<', 'A')

In [None]:
# gdf.head(50)

In [None]:
# state_gdf = gdf[gdf['STATE'] == 'BIHAR']
state_gdf = gdf[gdf['STATE'] == 'UTTAR PRADESH']
# delhi_gdf = gdf[gdf['STATE'] == 'PUNJAB']
state_gdf.head(50)

In [None]:
district = "AZAMGARH".upper()
custom_gdf = gdf[gdf["District"] == district]
state_gdf = state_gdf.to_crs(epsg=4326)
custom_gdf = custom_gdf.to_crs(epsg=4326)

custom_gdf.plot()
# state_gdf.plot()

In [None]:
# custom_gdf=state_gdf
## All pairs within the district
lon_min, lat_min, lon_max, lat_max = custom_gdf.geometry.unary_union.bounds

union = custom_gdf.geometry.unary_union
# print(union)
pairs = []
for lat in tqdm(np.arange(lat_min-0.02, lat_max+0.02, 0.01),desc="Latitude progress"):
    for lon in (np.arange(lon_min-0.02, lon_max+0.02, 0.01)):
        # check if the point is within the district
        point = Point(lon, lat)
        if union.contains(point):
            pairs.append((lon, lat))
            
len(pairs)

In [None]:
# process pairs
proessed_pairs = []
for pair in pairs:
    # format to .2f
    lon, lat = pair
    
    lon = f"{round(lon, 2):.2f}"
    lat = f"{round(lat, 2):.2f}"
    proessed_pairs.append(f"{lat},{lon}.zarr")

In [None]:
len(proessed_pairs)

In [None]:
# os.listdir("/home/jaiswalsuraj/bkdb/india/bihar/")

In [None]:
os.listdir("/home/jaiswalsuraj/bkdb/india/")
# os.listdir("/home/rishabh.mondal/bkdb/statewise/up")

In [None]:
# data_path = "/home/jaiswalsuraj/bkdb/india/haryana/"
# data_path = "/home/jaiswalsuraj/bkdb/india/bihar/"
data_path = '/home/rishabh.mondal/bkdb/statewise/up'
available_files = []
non_available_files = []
for pair in tqdm(proessed_pairs):
    if os.path.exists(join(data_path, pair)):
        available_files.append(join(data_path, pair))
    else:
        non_available_files.append(join(data_path, pair))
print(f"Available: {len(available_files)}")
print(f"Non-Available: {len(non_available_files)}")

In [None]:
print(available_files[1391])
xr.open_zarr(available_files[1391], consolidated=False)

In [None]:
img_list = []
idx_list = []

def process_file(file):
    try:
        data = xr.open_zarr(file, consolidated=False)
        img_list = []
        idx_list = []
        for lat_lag in range(-2, 3):
            for lon_lag in range(-2, 3):
                img = data["data"].sel(lat_lag=lat_lag, lon_lag=lon_lag).values
                lat = data["lat"].values.item()
                lon = data["lon"].values.item()
                idx = f"{lat:.2f},{lon:.2f}_{lat_lag}_{lon_lag}"
                img = torch.tensor(img) / 255.0
                img = torch.einsum("hwc -> chw", img)
                img_list.append(img)
                idx_list.append(idx)
        return torch.stack(img_list), idx_list
    except KeyError as e:
        print(f"Skipping file {file} due to KeyError: {e}")
        return None

# Your list of files
# available_files = [...]

# Parallel processing with error handling
results = Parallel(n_jobs=48)(delayed(process_file)(file) for file in tqdm(available_files) if process_file(file) is not None)

In [None]:
all_images = torch.cat([result[0] for result in results], dim=0)
mean = all_images.mean(dim=(0, 2, 3), keepdims=True)
std = all_images.std(dim=(0, 2, 3), keepdims=True)
all_images = (all_images - mean) / std
all_idx = [idx for result in results for idx in result[1]]
print(all_images.shape, len(all_idx))

In [None]:
batch_size = 512
all_preds = []

for i in tqdm(range(0, len(all_images), batch_size)):
    batch = all_images[i:i+batch_size].to(device)
    with torch.no_grad():
        preds = model(batch).argmax(dim=1).cpu()
    all_preds.append(preds)

all_preds = torch.cat(all_preds, dim=0)

In [None]:
pred_positive_idx = (all_preds == 1)
print(pred_positive_idx.sum())

locs = np.array(all_idx)[pred_positive_idx]
# print(locs);

In [None]:
print(pred_positive_idx)

In [None]:
file_name, lags = locs[3].split("_", 1)
lat_lag, lon_lag = lags.split("_")
plt.imshow(xr.open_zarr(join(data_path, file_name+".zarr"), consolidated=False).sel(lat_lag=int(lat_lag), lon_lag=int(lon_lag))['data'].values)

In [None]:
from skimage import img_as_ubyte


In [None]:
len(all_images)
len(all_idx)

In [None]:
# rescaled_images = (all_images * std) + mean
all_images_numpy = rescaled_images.cpu().numpy()

save_path = "/home/rishabh.mondal/Brick-Kilns-project/albk_rishabh/predicted_positive/UTTAR_Pradesh/AZAMGARH/"

os.makedirs(save_path, exist_ok=True)
for i, img in tqdm(enumerate(all_images_numpy[pred_positive_idx])):
    img_normalized = img / img.max()
    img = img_as_ubyte(img_normalized)
    plt.imsave(join(save_path, f"{locs[i]}.png"), np.moveaxis(img, 0, -1))

In [None]:
import os

path = "/home/rishabh.mondal/Brick-Kilns-project/albk_rishabh/predicted_positive/UTTAR_Pradesh/lucknow"

files = os.listdir(path)
file_names = []

for file in files:
    file_path = os.path.join(path, file)
    if os.path.isfile(file_path):
        file_names.append(file)

print(file_names)
file_list=[file.split('_')[:-2] for file in file_names]
print(file_list)


In [None]:
import geopandas as gpd
import matplotlib.pyplot as plt

# Your list of coordinates
coordinates = file_list
# Extracting latitude and longitude from the coordinates list
points = [(float(coord[0].split(',')[0]), float(coord[0].split(',')[1])) for coord in coordinates]
print(points)
# Create a GeoDataFrame for the points
points_gdf = gpd.GeoDataFrame(geometry=gpd.points_from_xy([point[1] for point in points], [point[0] for point in points]))
print(points_gdf)
# Plotting the custom_gdf
ax = custom_gdf.plot(color='lightblue', edgecolor='black', figsize=(8, 8))

# Plotting the points on the same plot
points_gdf.plot(ax=ax, color='red', marker='o',label='Brick Kilns')
ax.text(80.95, 26.82, 'Lucknow', color='black', fontsize=13, ha='center')
plt.legend()
plt.title('Brick Kilns in Lucknow')

# Display the plot
plt.show()

