# TTC Implementation

## Steps

* adapt to expect xx input band feature stack. currently expects 17 channels
* change the final output layer to support multi-class segmentation with a softmax activation function. currently is sigmoid
* adjust input patch size. currently expects 512x512
* adjust loss function to be categorical cross-entropy (or Dice loss for imbalanced classes). currently is binary cross-entropy. confirm whether to consider weighted loss function.
* model deployment - how large of an area?
* model validation - necessary?

### Updates reflected in `modified_unet.py`
* Removed ConvGRU: classification will be done on a single composite images (not monthly time series). No need for RNN/time layers?
* Removed SSE and DropBlock: Starting with a simple model. Attention and regularization can be added later if needed. 
* Switched Binary to Multiclass Output: Instead of "tree vs no-tree" model predicts 4 classes.
* Simplified U-Net Encoder-Decoder: Clean, readable U-Net blocks: Conv → BatchNorm → ReLU with skip connections.
* Modified Input Channels: Original model expected 17 bands (Sentinel only); your model expects 94 bands (Sentinel + Texture + Tree Features).
* Adjust the loss function: currently binary cross-entropy, needs to be categorical cross-entropy using the weight argument in torch.nnn.functional.cross_entropy

In [18]:
import sys
import os
import yaml
from utils.logs import get_logger
sys.path.append(os.path.abspath('../../src/'))

from features import create_xy as create
%load_ext autoreload
%autoreload 2

## Prep training data

In [27]:
params_path = '../../params.yaml'
with open(params_path) as file:
    params = yaml.safe_load(file)

ceo_batch = params["data_load"]["ceo_survey"]
logger = get_logger("FEATURIZE", log_level=params["base"]["log_level"])

X, y = create.build_training_sample_CNN(
    ceo_batch,
    classes=params["data_condition"]["classes"],
    n_feats=29,
    params_path=params_path,
    logger=logger,
)

2025-05-02 12:45:56,180 — FEATURIZE — INFO — Writing plot IDs to file...
2025-05-02 12:45:56,182 — FEATURIZE — INFO — SUMMARY
2025-05-02 12:45:56,183 — FEATURIZE — INFO — 242 plots labeled "unknown" were dropped.
2025-05-02 12:45:56,185 — FEATURIZE — INFO — 118 plots did not have ARD.
2025-05-02 12:45:56,189 — FEATURIZE — INFO — Training data batch includes: 976 plots.
2025-05-02 12:45:56,192 — FEATURIZE — INFO — 976 plots will be used in training.


100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 976/976 [01:15<00:00, 12.88it/s]


In [22]:
params_path

'../../params.yaml'

In [29]:
import hickle as hkl
sample = hkl.load('../../data/train-pytorch/08003.hkl')
print(sample.shape)

(14, 14, 29)
