In [1]:
import glob
import numpy as np
import matplotlib.pyplot as plt

import wandb
import torch

import lightning as L
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

torch.set_float32_matmul_precision('medium')

%load_ext autoreload
%autoreload 2
from src import utils, cnn

### Define some parameters
Importantly, set classification to true for cultivar classification or false for regression towards yield, stomatal conductance, chlorophyll flourescence and fertilizer amount.

In [2]:
resume = False

epochs = 5000
lr = 0.00001

batch_size = 16
num_accumulated_batches = 1 # Number of gradient accumulation steps
# Effective batch size = batch_size * num_accumulated_batches

# wandb project name
project_name = f'challenge'

# Model checkpointing settings
model_path = f'models/'
model_name = f'cnn'

### Create dataloaders

In [3]:
# Get dataloaders
train_loader, val_loader, X_test = utils.get_data_loaders(batch_size)

### Train model

In [None]:
# Set up WandB logger
if resume:
    logger = WandbLogger(name=model_name, project=project_name, id=f'{project_name}_{model_name}', log_model=False, resume='must')
else:
    logger = WandbLogger(name=model_name, project=project_name, id=f'{project_name}_{model_name}', log_model=False)
    

# Set up callbacks
best_checkpoint_callback = ModelCheckpoint(dirpath=model_path,
                                           filename=f'{model_name}_best',
                                           monitor='val/loss',
                                           enable_version_counter=False)

last_checkpoint_callback = ModelCheckpoint(dirpath=model_path,
                                           filename=f'{model_name}',
                                           monitor=None,
                                           enable_version_counter=False)

early_stopping_callback = EarlyStopping(monitor='val/loss',
                                        min_delta=0.00,
                                        patience=50,
                                        verbose=True,
                                        mode='min')

callbacks = [best_checkpoint_callback, last_checkpoint_callback, early_stopping_callback]

# Create model
model = cnn.CNN(lr=lr)

# Set up trainer
trainer = L.Trainer(max_epochs=epochs,
                    precision='16-mixed',
                    log_every_n_steps=1,
                    logger=logger,
                    callbacks=callbacks,
                    accelerator="gpu",
                    accumulate_grad_batches=num_accumulated_batches)

# Train model
if resume:
    trainer.fit(model, train_loader, val_loader, ckpt_path=f'{model_path}/{model_name}.ckpt')
else:
    trainer.fit(model, train_loader, val_loader)

wandb.finish()

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: Currently logged in as: [33mwillap[0m. Use [1m`wandb login --relogin`[0m to force relogin


C:\Users\Billy\miniconda3\envs\matrix\Lib\site-packages\lightning\pytorch\callbacks\model_checkpoint.py:639: Checkpoint directory C:\Users\Billy\Downloads\challenge_data\challenge_data\models exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type       | Params
------------------------------------
0 | net  | Sequential | 25.6 M
------------------------------------
25.6 M    Trainable params
0         Non-trainable params
25.6 M    Total params
102.232   Total estimated model params size (MB)


Sanity Checking: |                                       | 0/? [00:00<?, ?it/s]

Training: |                                              | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Metric val/loss improved. New best score: 0.696


Validation: |                                            | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.018 >= min_delta = 0.0. New best score: 0.678


Validation: |                                            | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.025 >= min_delta = 0.0. New best score: 0.652


Validation: |                                            | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.030 >= min_delta = 0.0. New best score: 0.622


Validation: |                                            | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.003 >= min_delta = 0.0. New best score: 0.619


Validation: |                                            | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.084 >= min_delta = 0.0. New best score: 0.535


Validation: |                                            | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.026 >= min_delta = 0.0. New best score: 0.509


Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.030 >= min_delta = 0.0. New best score: 0.479


Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.061 >= min_delta = 0.0. New best score: 0.418


Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.010 >= min_delta = 0.0. New best score: 0.408


Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.018 >= min_delta = 0.0. New best score: 0.391


Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.050 >= min_delta = 0.0. New best score: 0.341


Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.036 >= min_delta = 0.0. New best score: 0.305


Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.005 >= min_delta = 0.0. New best score: 0.300


Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.007 >= min_delta = 0.0. New best score: 0.293


Validation: |                                            | 0/? [00:00<?, ?it/s]

Validation: |                                            | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.004 >= min_delta = 0.0. New best score: 0.289


Validation: |                                            | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.061 >= min_delta = 0.0. New best score: 0.228


Validation: |                                            | 0/? [00:00<?, ?it/s]

In [None]:

from torchvision.transforms import v2
from torchvision import tv_tensors
import torch

def process_image(image):
    
    channel_mean = np.array([1.32242872, 2.42520988, 2.5470593])
    channel_std = np.array([3.06781983, 3.93189598, 3.8842867])
    transforms = v2.Compose([utils.Standardize(channel_mean=channel_mean, channel_std=channel_std),
                                  v2.Resize(size=(224,224))])

    image = np.moveaxis(image, -1, 0)

    # Convert to tv_tensors
    image = tv_tensors.Image(torch.tensor(image))

    # Apply transforms
    image = transforms(image)
    
    return image.to(torch.float32).cuda()

prob = torch.nn.functional.sigmoid(model(process_image(X_test[0])[None,...]))

In [None]:
import vtk
import numpy as np
import os
import argparse
import matplotlib.pyplot as plt
import pickle
import json

test_files = np.sort(glob.glob('test/crops/*label*'))

def create_detection(model, X_test, test_files, only_200=True):
    
    test_files_200 = np.genfromtxt('test_files_200.txt', str)

    model.cuda()

    num_outlier = 0

    # Create results
    test_results = []
    for i in range(len(test_files)):
        
        scan_id = test_files[i].split('\\')[-1].split('_')[0] + '_' + test_files[i].split('\\')[-1].split('_')[1]

        if only_200:
            if scan_id in test_files_200:
            
                prob = torch.nn.functional.sigmoid(model(process_image(X_test[i])[None,...]))
                prob = prob.detach().cpu().numpy().ravel()[0]
                pred = int(prob > 0.5)

                if pred == 1:
                    num_outlier += 1
        
                # Remember to cast bools to int for json serialization
                test_results.append({"scan_id": scan_id, "outlier": pred})
        else:
            
            prob = torch.nn.functional.sigmoid(model(process_image(X_test[i])[None,...]))
            prob = prob.detach().cpu().numpy().ravel()[0]
    
            # Remember to cast bools to int for json serialization
            test_results.append({"scan_id": scan_id, "outlier": int(prob > 0.5)})

    print(num_outlier / 200)
    # Write results to JSON file
    if only_200:
        with open("test_results_200.json", 'w') as json_file:
            json.dump(test_results, json_file, indent=4)
    else:
        with open("test_results.json", 'w') as json_file:
            json.dump(test_results, json_file, indent=4)

create_detection(model, X_test, test_files, only_200=True)