In [63]:
!pip install gdown



In [64]:
import gdown

In [65]:
url = 'https://drive.google.com/file/d/10hcRPFP3J4p47ouiT3iOY1sX6UtqD-Ff/view?usp=sharing'
output_path = '/content/model.pt'
gdown.download(url, output_path, quiet=False,fuzzy=True)

Downloading...
From (original): https://drive.google.com/uc?id=10hcRPFP3J4p47ouiT3iOY1sX6UtqD-Ff
From (redirected): https://drive.google.com/uc?id=10hcRPFP3J4p47ouiT3iOY1sX6UtqD-Ff&confirm=t&uuid=bbc2ec70-47bf-4ab1-90b4-9518d7461a27
To: /content/model.pt
100%|██████████| 294M/294M [00:07<00:00, 41.2MB/s]


'/content/model.pt'

In [66]:
!pip install -q -U segmentation-models-pytorch albumentations > /dev/null
import segmentation_models_pytorch as smp

In [67]:
!pip install flask-ngrok



In [68]:
!pip install ngrok



In [69]:
!pip install rasterio



In [70]:
from google.colab.output import eval_js
print(eval_js("google.colab.kernel.proxyPort(5000)"))

https://tptuc21yrq-496ff2e9c6d22116-5000-colab.googleusercontent.com/


In [71]:
!mkdir templates

mkdir: cannot create directory ‘templates’: File exists


In [72]:
!mkdir uploads

mkdir: cannot create directory ‘uploads’: File exists


In [73]:
%%writefile /content/templates/index.html
<!DOCTYPE html>
<html>
<head>
    <title>Upload Image</title>
</head>
<body>
    <h1>Upload an Image</h1>
    <p>Upload an image to be segmented: *only 128x128 .tif files are allowed*</p>
    <p>{{error}}</p>
    <form action="/upload" method="POST" enctype="multipart/form-data">
        <input type="file" name="image">
        <input type="submit" value="Upload">
    </form>
</body>
</html>

Overwriting /content/templates/index.html


In [97]:
%%writefile /content/templates/result.html
<!DOCTYPE html>
<html>
<head>
    <title>Model Results</title>
</head>
<body>
    <h1>Model Results</h1>
    <p>The original image:</p>
    <img src="{{ url_for('serve_image', display_filename=display_filename) }}" alt="Uploaded Image">
    <p>The segmented image:</p>
    <img src="{{ url_for('serve_model', model_filename=model_filename) }}" alt="Uploaded Image">
</body>
</html>

Overwriting /content/templates/result.html


In [102]:
%%writefile app.py
from flask import Flask, render_template, request, send_from_directory
import os
import rasterio
import numpy as np
import matplotlib.pyplot as plt
import io
from PIL import Image
import torch
from torch.utils.data import DataLoader, Dataset

class DatasetOptim(Dataset):
    def __init__(self, images):
        self.images = images

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = torch.from_numpy(self.images[idx]).float()
        return image

def input_processor(path):
    image_array_holder = []
    with rasterio.open(path) as image:
        image_array = image.read()
    normalized_image_array = []
    for j in range(12):
        band_range = image_array[j].max() - image_array[j].min()
        normalized_band = (image_array[j]-image_array[j].min())/[band_range if band_range!=0 else 1]
        normalized_image_array.append(normalized_band)
    image_array = np.array(normalized_image_array)
    image_array_holder.append(image_array)
    image_array_holder.append(image_array)
    image_array_holder = np.array(image_array_holder)
    model_input = DatasetOptim(image_array_holder)
    model_input = DataLoader(model_input, 2)
    return model_input

app = Flask(__name__)

app.template_folder = os.path.join(os.path.dirname(__file__), 'templates')

UPLOAD_FOLDER = '/content/uploads'

app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER

@app.route('/')
def index():
    return render_template('index.html', error=None)

@app.route('/upload', methods=['POST'])
def upload_image():
    if 'image' not in request.files:
        return render_template('index.html', error='No image file')

    file = request.files['image']
    if file.filename == '':
        return render_template('index.html', error='No selected file')

    if '.tif' not in file.filename:
        return render_template('index.html', error='Invalid file format. Only .tif files are allowed.')

    filename = file.filename
    file.save(os.path.join(app.config['UPLOAD_FOLDER'], filename))

    with rasterio.open(os.path.join(app.config['UPLOAD_FOLDER'], filename)) as image:
      image_array = image.read()

    normalized_image_array = []
    for j in range(12):
        band_range = image_array[j].max() - image_array[j].min()
        normalized_band = (image_array[j]-image_array[j].min())/[band_range if band_range!=0 else 1]
        normalized_image_array.append(normalized_band)
    image_array = np.array(normalized_image_array)
    image_array = image_array.transpose((1, 2, 0))
    display_image = image_array[:, :, [3, 2, 1]]
    display_filename = 'result.jpeg'

    plt.imsave(os.path.join(app.config['UPLOAD_FOLDER'], display_filename), display_image)

    model_input = input_processor(os.path.join(app.config['UPLOAD_FOLDER'], filename))

    model = torch.load('/content/model.pt', map_location=torch.device('cpu'))
    model = model.module.to('cpu')

    with torch.no_grad():
      for x in model_input:
          outputs = model(x)
          predicted = outputs.round()

          images = x[0].permute((1, 2, 0))
          predicted = predicted[0].permute((1, 2, 0)).squeeze(0)

          masked_pic =  np.array(images[:, :, [3, 2, 1]] + np.dstack([predicted*0.1, predicted*0.1, predicted*0.6]))

          model_filename = 'model_output.jpeg'

          plt.imsave(os.path.join(app.config['UPLOAD_FOLDER'], model_filename), masked_pic/masked_pic.max())

    return render_template('result.html', display_filename=display_filename, model_filename=model_filename)

@app.route('/uploads/<display_filename>')
def serve_image(display_filename):
    return send_from_directory('uploads', display_filename)

@app.route('/uploads/<model_filename>')
def serve_model(model_filename):
    return send_from_directory('uploads', model_filename)

if __name__ == '__main__':
    app.run(debug=True)

Overwriting app.py


In [None]:
!python app.py

 * Serving Flask app 'app'
 * Debug mode: on
 * Running on http://127.0.0.1:5000
[33mPress CTRL+C to quit[0m
 * Restarting with stat
 * Debugger is active!
 * Debugger PIN: 528-776-979
127.0.0.1 - - [19/Sep/2024 20:42:44] "GET /?authuser=0 HTTP/1.1" 200 -
127.0.0.1 - - [19/Sep/2024 20:42:44] "[33mGET /favicon.ico?authuser=0 HTTP/1.1[0m" 404 -
