In [3]:
import os
import tempfile
from torchgeo.datamodules import NAIPChesapeakeDataModule
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
import pytorch_lightning as pl
from torchgeo.trainers import SemanticSegmentationTask
from multiprocessing import freeze_support
from pytorch_lightning.callbacks import TQDMProgressBar

In [4]:
imagery_data=os.path.join(tempfile.gettempdir(), "imagery")
LC_data=os.path.join(tempfile.gettempdir(), "LC")
data_dir=os.path.join(tempfile.gettempdir(), "training")

In [7]:
datamodule=NAIPChesapeakeDataModule(naip_root_dir=imagery_data,
                                    chesapeake_root_dir=LC_data,
                                    batch_size=64,
                                    num_workers=6,
                                    patch_size=1024,
                                   )

In [None]:
!pip show torchgeo

Name: torchgeo
Version: 0.3.1
Summary: TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data
Home-page: https://github.com/microsoft/torchgeo
Author: Adam J. Stewart
Author-email: ajstewart426@gmail.com
License: 
Location: /usr/local/lib/python3.8/dist-packages
Requires: einops, fiona, kornia, matplotlib, numpy, omegaconf, packaging, pillow, pyproj, pytorch-lightning, rasterio, rtree, scikit-learn, segmentation-models-pytorch, shapely, timm, torch, torchmetrics, torchvision
Required-by: 


In [8]:
task=SemanticSegmentationTask(
    segmentation_model='unet',
    encoder_name='resnet34',
    encoder_weights='imagenet',
    pretrained=True,
    in_channels=4,
    num_classes=13,
    ignore_index=-1000,
    loss='ce',
    learning_rate=0.1,
    learning_rate_schedule_patience=5,
    ignore_zeros=True,
    )

In [9]:
experiment_dir=data_dir
checkpoint_callback=ModelCheckpoint(monitor="val_loss", dirpath=experiment_dir, save_top_k=1, save_last=True)
early_stopping_callback=EarlyStopping(monitor="val_loss", min_delta=0.00, patience=10)
csv_logger=CSVLogger(save_dir=experiment_dir, name="segmentation_unet")

In [10]:
in_tests="PYTEST_CURRENT_TEST" in os.environ

trainer=pl.Trainer(
    callbacks=[checkpoint_callback,early_stopping_callback],
    # callbacks=[TQDMProgressBar(refresh_rate=10)],
    logger=[csv_logger],
    default_root_dir=experiment_dir,
    min_epochs=1,
    max_epochs=10,
    fast_dev_run=in_tests,
    accelerator="gpu",    
)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [12]:
trainer.fit(model=task,datamodule=datamodule)

FileNotFoundError: ignored

In [None]:
import torchgeo.datamodules

In [None]:
import importlib
importlib.reload(torchgeo.datamodules)

In [11]:
import os
arr = os.listdir(LC_data)
arr

['DE_STATEWIDE.tif.ovr',
 'DE_STATEWIDE.tif.vat.cpg',
 'DE_STATEWIDE.tif.vat.dbf',
 'DE_STATEWIDE.tfw',
 'DE_STATEWIDE.tif.aux.xml',
 '_DE_STATEWIDE.zip',
 'DE_STATEWIDE.tif.xml',
 'DE_STATEWIDE.tif']

In [16]:
filename_glob = "*"
pathname = os.path.join(LC_data, "**", filename_glob)
pathname

'/tmp/LC/**/*'

In [17]:
 import re
 filename_regex = ".*"
 filename_regex = re.compile(filename_regex, re.VERBOSE)
 filename_regex

re.compile(r'.*', re.UNICODE|re.VERBOSE)

In [34]:
import glob
import rasterio
i=0
for filepath in glob.iglob(pathname, recursive=True):
    print(filepath)
    match = re.match(filename_regex, os.path.basename(filepath))
    if match is not None:
        print('-'*100)
        try:
            with rasterio.open(filepath) as src:
                print(src.crs)     
        except:
          print('^^^')
        else:
          i+=1
print('-----------',i)



/tmp/LC/DE_STATEWIDE.tif.ovr
----------------------------------------------------------------------------------------------------
^^^
/tmp/LC/DE_STATEWIDE.tif.vat.cpg
----------------------------------------------------------------------------------------------------
^^^
/tmp/LC/DE_STATEWIDE.tif.vat.dbf
----------------------------------------------------------------------------------------------------
^^^
/tmp/LC/DE_STATEWIDE.tfw
----------------------------------------------------------------------------------------------------
^^^
/tmp/LC/DE_STATEWIDE.tif.aux.xml
----------------------------------------------------------------------------------------------------
^^^
/tmp/LC/_DE_STATEWIDE.zip
----------------------------------------------------------------------------------------------------
^^^
/tmp/LC/DE_STATEWIDE.tif.xml
----------------------------------------------------------------------------------------------------
^^^
/tmp/LC/DE_STATEWIDE.tif
--------------------------------

In [38]:
with rasterio.open('/tmp/LC/DE_STATEWIDE.tif','r') as src:
  print(src.crs)

RasterioIOError: ignored